Skip to content
Snippets Groups Projects
Commit 75de4530 authored by Thomas Wolf's avatar Thomas Wolf
Browse files

NEC, fixed misunderstanding, removed unnecessary code

parent 2814e9a7
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@ from src.models.GLiNER import classify_entity as classify_entity_gliner
from src.models.T5_NLI import classify_entity as classify_entity_t5_nli
from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities_nli as find_entities_nli_llm
from src.experiments.NER_with_LLMs.NER_with_LLMs import classify_entity as classify_entities_llm
def classify_entity(model_name, sentence, entity, labels):
......@@ -21,20 +21,8 @@ def classify_entity(model_name, sentence, entity, labels):
return classify_entity_t5_mlm_label(sentence, entity, labels)
elif model_name == "GLiNER":
return classify_entity_gliner(sentence, entity, labels)
def predict_mask_mlm(model_name, masked_sentence, labels):
"""
[Mask] was the president of the US -> Return list of possible answers (with probabilities?)
"""
def predict_mask_nli(model_name, sentence, entity, labels):
"""
Barack Obama was the president of the US. -> Barack Obama is/was a [Mask] Return prediction(s) of model for [Mask]
"""
if model_name in llms:
return find_entities_nli_llm(model_name, sentence, entity, labels)
elif model_name in llms:
return classify_entities_llm(model_name, sentence, entity, labels)
def find_entities(model_name, sentence, labels):
......
"""
Evaluates GLiNER as SotA.
Evaluates GLiNER as SotA (NER). This is mostly useful for demonstrating dataset shortcomings.
NEC evaluation for GLiNER is part of /NEC_evaluation.
"""
from src.metrics import NER_metrics, read_NER_metrics
......
......@@ -21,8 +21,8 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# Define file paths with model name and timestamp in the file names
csv_filename = os.path.join(results_dir, f"{timestamp}_{model_name}_metrics.csv")
txt_filename = os.path.join(results_dir, f"{timestamp}_{model_name}_results.txt")
csv_filename = os.path.join(results_dir, f"{timestamp}_{model_name}_{dataset}_metrics.csv")
txt_filename = os.path.join(results_dir, f"{timestamp}_{model_name}_{dataset}_results.txt")
labels = data_manager.get_labels(dataset)
data = data_manager.get_annotated_sentences(dataset, test_instances)
......@@ -53,16 +53,17 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
def run_NEC_tests_all():
models = ["GLiNER", "T5"]
datasets = ["CoNLL"] # , "FIGER-coarse", "FIGER-fine"]
models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "DeepSeek-R1-Distill-Qwen-32B"]
datasets = ["CoNLL", "FIGER-coarse", "FIGER-fine"] # "Pile-NER-type"]
for model in models:
for dataset in datasets:
print(f"Testing {model} on {dataset} dataset...")
run_NEC_tests(model, dataset, "results", 100)
def read_NEC_metrics(directory):
"""
Reads all CSV files in the given directory and prints the accuracye.
Reads all CSV files in the given directory and prints the accuracy.
"""
metrics = {}
......@@ -91,4 +92,5 @@ def read_NEC_metrics(directory):
print(f"Model: {model}, Dataset: {dataset}, Accuracy: {avg_accuracy:.2f}%")
run_NEC_tests_all()
read_NEC_metrics("results")
......@@ -56,7 +56,7 @@ def find_entities(model_name, sentence, labels):
return "No answer was marked, here is the model's response: " + answer
def find_entities_nli(model_name, sentence, entity, labels):
def classify_entity(model_name, sentence, entity, labels):
"""
Uses NLI to classify a specified entity in a sentence. Returns string entity class.
"""
......
File moved
from src.common_interface import predict_mask_nli
from src.common_interface import classify_entity
from src.metrics import precision, recall, f1_score
......@@ -20,7 +20,7 @@ for model in tested_models:
predicted_entities = []
for pair in true_labels:
entity = pair[0]
predicted_label = predict_mask_nli(model, test_sentence, entity, test_labels)
predicted_label = classify_entity(model, test_sentence, entity, test_labels)
predicted_entities.append((entity, predicted_label))
print(f"{model} found entities: \n{predicted_entities}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment