Skip to content
Snippets Groups Projects
Commit 887ffc93 authored by Thomas Wolf's avatar Thomas Wolf
Browse files

Finished NEC evaluation (accuracy only)

parent 93d60fed
No related branches found
No related tags found
No related merge requests found
......@@ -89,7 +89,7 @@ def get_annotated_sentences():
# Append the annotated sentence.
# Each entry is a list: [ [sentence], [list of (entity, label) tuples] ]
annotated_sentences.append([[sentence], entities])
annotated_sentences.append([sentence, entities])
print("Finished processing dataset: CoNLL2003")
......
......@@ -103,6 +103,6 @@ def get_annotated_sentences():
i += 1
else:
i += 1
annotated_sentences.append([[full_text], annotations])
annotated_sentences.append([full_text, annotations])
return annotated_sentences
"""
This file evaluates all NEC approaches.
"""
import os
import csv
import datetime
import pandas as pd
from tqdm import tqdm
import data.data_manager as data_manager
from src.common_interface import classify_entity
# todo: perform tests on datasets and store results
# todo read results and compare the models / plot the results
def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
"""
Evaluates a model on a dataset on the NEC task for a chosen number of test instances and stores the
results in the given directory
"""
os.makedirs(results_dir, exist_ok=True)
# Get current timestamp for unique file naming
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# Define file paths with model name and timestamp in the file names
csv_filename = os.path.join(results_dir, f"{timestamp}_{model_name}_metrics.csv")
txt_filename = os.path.join(results_dir, f"{timestamp}_{model_name}_results.txt")
labels = data_manager.get_labels(dataset)
data = data_manager.get_annotated_sentences(dataset, test_instances)
with open(csv_filename, mode="w", newline='', encoding="utf-8") as csv_file, \
open(txt_filename, mode="w", encoding="utf-8") as txt_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(
["Instance", "Model", "Dataset", "Sentence", "Entity", "True Annotation", "Predicted Annotation", "Correct"])
for idx, entry in enumerate(tqdm(data, desc="Processing Instances", unit="instance")):
sentence = entry[0]
for entity in entry[1]:
predicted_label = classify_entity(model_name, sentence, entity[0], labels)
correct = predicted_label == entity[1]
csv_writer.writerow([idx, model_name, dataset, sentence, entity[0], entity[1], predicted_label, correct])
# Log stuff
txt_file.write(f"Instance {idx}:\n")
txt_file.write(f"Sentence: {sentence}\n")
txt_file.write(f"Entity: {entity[0]}\n")
txt_file.write(f"True Label: {entity[1]}\n")
txt_file.write(f"Predicted Label: {predicted_label}\n")
txt_file.write("-" * 50 + "\n\n")
print(f"Results saved to:\n CSV: {csv_filename}\n TXT: {txt_filename}")
def run_NEC_tests_all():
models = ["GLiNER", "T5"]
datasets = ["CoNLL"] # , "FIGER-coarse", "FIGER-fine"]
for model in models:
for dataset in datasets:
run_NEC_tests(model, dataset, "results", 100)
def read_NEC_metrics(directory):
"""
Reads all CSV files in the given directory and prints the accuracye.
"""
metrics = {}
for filename in os.listdir(directory):
if filename.endswith(".csv"):
csv_filename = os.path.join(directory, filename)
try:
df = pd.read_csv(csv_filename)
required_columns = {"Model", "Dataset", "Correct"}
if not required_columns.issubset(df.columns):
print(f"Skipping {filename} due to missing columns.")
continue
# Group by model and dataset, then compute accuracy
for (model, dataset), group in df.groupby(["Model", "Dataset"]):
accuracy = group["Correct"].mean() * 100
metrics.setdefault((model, dataset), []).append(accuracy)
except Exception as e:
print(f"Error reading {filename}: {e}")
print("\nModel Performance Summary:")
for (model, dataset), accuracies in metrics.items():
avg_accuracy = sum(accuracies) / len(accuracies)
print(f"Model: {model}, Dataset: {dataset}, Accuracy: {avg_accuracy:.2f}%")
read_NEC_metrics("results")
......@@ -71,7 +71,7 @@ def NER_metrics(model_name, dataset, results_dir, test_instances=10):
["Instance", "Dataset", "Sentence", "True Annotation", "Predicted Annotation", "Precision", "Recall", "F1 Score"])
for idx, entry in enumerate(tqdm(data, desc="Processing Instances", unit="instance")):
sentence = entry[0][0]
sentence = entry[0]
true_annotation = entry[1]
predicted_annotation = find_entities(model_name, sentence, labels)
......
......@@ -24,5 +24,5 @@ def classify_entity(sentence, entity, labels):
if e[0] == entity:
return e[1] # Return label
return ""
return "Not an entity. (acc. to GLiNER)"
......@@ -30,7 +30,7 @@ def infer_nli(premise, hypothesis):
def classify_entity(sentence, entity, labels):
best_label = None
for label in labels:
# rint(f"Label: {label}")
# print(f"Label: {label}")
hypothesis = f"{entity} is {label}"
result = infer_nli(sentence, hypothesis)
# print(f"Hypothesis: {hypothesis}, Result: {result}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment