From 91d68179402158cf9df33bc4fa737ab165d7b0fb Mon Sep 17 00:00:00 2001 From: finn <finn@hillengass.de> Date: Thu, 29 Feb 2024 14:27:51 +0100 Subject: [PATCH] Fix handling of nan values --- metrics/slot_accuracy/slot_accuracy.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/metrics/slot_accuracy/slot_accuracy.py b/metrics/slot_accuracy/slot_accuracy.py index ca53d83..0e794da 100644 --- a/metrics/slot_accuracy/slot_accuracy.py +++ b/metrics/slot_accuracy/slot_accuracy.py @@ -55,9 +55,17 @@ def slot_accuracy(domain_knowledge, annotation): Returns: float: Slot accuracy score. """ + + if pd.isna(domain_knowledge) or pd.isna(annotation): + # Handle the nan values + return 0 + + domain_knowledge = domain_knowledge.replace('null', 'None') + annotation = annotation.replace('null', 'None') + domain_knowledge = ast.literal_eval(domain_knowledge) annotation = ast.literal_eval(annotation) - + count = {"True": 0, "False": 0, "Total": 0} for detail in domain_knowledge["Details"]: @@ -88,6 +96,7 @@ def main(input_file, output_file): input_file (str): Path to the input CSV file. output_file (str): Path to the output CSV file. """ + df = pd.read_csv(input_file, sep=',', quoting=csv.QUOTE_NONE, escapechar='/', index_col=False) df['slot_accuracy_original'] = df.apply(lambda row: slot_accuracy(row['Domain Knowledge'], row['original_annotation']), axis=1) df['slot_accuracy_generated'] = df.apply(lambda row: slot_accuracy(row['Domain Knowledge'], row['generated_annotation']), axis=1) -- GitLab