From 91d68179402158cf9df33bc4fa737ab165d7b0fb Mon Sep 17 00:00:00 2001
From: finn <finn@hillengass.de>
Date: Thu, 29 Feb 2024 14:27:51 +0100
Subject: [PATCH] Fix handling of nan values

---
 metrics/slot_accuracy/slot_accuracy.py | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/metrics/slot_accuracy/slot_accuracy.py b/metrics/slot_accuracy/slot_accuracy.py
index ca53d83..0e794da 100644
--- a/metrics/slot_accuracy/slot_accuracy.py
+++ b/metrics/slot_accuracy/slot_accuracy.py
@@ -55,9 +55,17 @@ def slot_accuracy(domain_knowledge, annotation):
     Returns:
     float: Slot accuracy score.
     """
+
+    if pd.isna(domain_knowledge) or pd.isna(annotation):
+        # Handle the nan values 
+        return 0
+        
+    domain_knowledge = domain_knowledge.replace('null', 'None')
+    annotation = annotation.replace('null', 'None')
+
     domain_knowledge = ast.literal_eval(domain_knowledge)
     annotation = ast.literal_eval(annotation)
-
+        
     count = {"True": 0, "False": 0, "Total": 0}
     
     for detail in domain_knowledge["Details"]:
@@ -88,6 +96,7 @@ def main(input_file, output_file):
     input_file (str): Path to the input CSV file.
     output_file (str): Path to the output CSV file.
     """
+    
     df = pd.read_csv(input_file, sep=',', quoting=csv.QUOTE_NONE, escapechar='/', index_col=False)
     df['slot_accuracy_original'] = df.apply(lambda row: slot_accuracy(row['Domain Knowledge'], row['original_annotation']), axis=1)
     df['slot_accuracy_generated'] = df.apply(lambda row: slot_accuracy(row['Domain Knowledge'], row['generated_annotation']), axis=1)
-- 
GitLab