From b4f3637e7886cf329aa0e0fbfc249584fbdb162e Mon Sep 17 00:00:00 2001 From: kupper <kupper@login.cl.uni-heidelberg.de> Date: Sun, 23 Mar 2025 15:13:05 +0100 Subject: [PATCH] Classification with Word2Vec --- data/data_manager.py | 17 +++++++ src/common_interface.py | 17 ++++++- src/experiments/NEC_evaluation/evaluation.py | 9 +++- src/models/Word2Vec.py | 47 +++++++++++++++++++- 4 files changed, 86 insertions(+), 4 deletions(-) diff --git a/data/data_manager.py b/data/data_manager.py index 787b457..e218dd9 100644 --- a/data/data_manager.py +++ b/data/data_manager.py @@ -46,3 +46,20 @@ def get_annotated_sentences(dataset, test_instances=10): return random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine()))) else: raise Exception + +def get_label_dict(dataset, test_instances=10): + """ + Returns a dictionary of example entities for each label. + """ + label_dict = {} + labels = get_labels(dataset) + for label in labels: + label_dict[label] = set() + annotated_sentences = get_annotated_sentences(dataset, test_instances) + for annotated_sentence in annotated_sentences: + for labeled_entity in annotated_sentence[1]: + if labeled_entity[1] not in label_dict: + label_dict[labeled_entity[1]] = set() + label_dict[labeled_entity[1]].add(labeled_entity[0]) + print(f"Label dict: {label_dict}") + return label_dict diff --git a/src/common_interface.py b/src/common_interface.py index 56d4cc6..578b666 100644 --- a/src/common_interface.py +++ b/src/common_interface.py @@ -7,6 +7,8 @@ from src.models.GLiNER import find_entities as find_entities_gliner from src.models.GLiNER import classify_entity as classify_entity_gliner from src.models.T5_NLI import classify_entity as classify_entity_t5_nli from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label +from src.models.Word2Vec import classify_entity as classify_entity_word2vec +from src.models.Word2Vec import set_label_dict as set_label_dict_word2vec from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm from src.experiments.NER_with_LLMs.NER_with_LLMs import classify_entity as classify_entities_llm @@ -21,8 +23,12 @@ def classify_entity(model_name, sentence, entity, labels): return classify_entity_t5_mlm_label(sentence, entity, labels) elif model_name == "GLiNER": return classify_entity_gliner(sentence, entity, labels) + elif model_name == "Word2Vec": + return classify_entity_word2vec(entity, labels) elif model_name in llms: return classify_entities_llm(model_name, sentence, entity, labels) + else: + print(f"classify_entity not implemented for {model_name}") def find_entities(model_name, sentence, labels): @@ -34,4 +40,13 @@ def find_entities(model_name, sentence, labels): elif model_name == "GLiNER": return find_entities_gliner(sentence, labels) else: - print("Not implemented") + print(f"find_entities not implemented for {model_name}") + +def set_label_dict(model_name, label_dict): + """ + NER. Sets the label dictionary required for the Word2Vec model + """ + if model_name == "Word2Vec": + return set_label_dict_word2vec(label_dict) + else: + print(f"set_label_dict not implemented for {model_name}") diff --git a/src/experiments/NEC_evaluation/evaluation.py b/src/experiments/NEC_evaluation/evaluation.py index 1442e57..748e56f 100644 --- a/src/experiments/NEC_evaluation/evaluation.py +++ b/src/experiments/NEC_evaluation/evaluation.py @@ -7,7 +7,7 @@ import datetime import pandas as pd from tqdm import tqdm import data.data_manager as data_manager -from src.common_interface import classify_entity +from src.common_interface import classify_entity, set_label_dict def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): @@ -27,6 +27,11 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): labels = data_manager.get_labels(dataset) data = data_manager.get_annotated_sentences(dataset, test_instances) + if (model_name == "Word2Vec"): + label_dict = data_manager.get_label_dict(dataset, test_instances) + print(label_dict) + set_label_dict(model_name, label_dict) + with open(csv_filename, mode="w", newline='', encoding="utf-8") as csv_file, \ open(txt_filename, mode="w", encoding="utf-8") as txt_file: csv_writer = csv.writer(csv_file) @@ -53,7 +58,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): def run_NEC_tests_all(): - models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "DeepSeek-R1-Distill-Qwen-32B"] + models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "Word2Vec", "DeepSeek-R1-Distill-Qwen-32B"] datasets = ["CoNLL", "FIGER-coarse", "FIGER-fine"] # "Pile-NER-type"] for model in models: for dataset in datasets: diff --git a/src/models/Word2Vec.py b/src/models/Word2Vec.py index 43daee9..d749a5c 100644 --- a/src/models/Word2Vec.py +++ b/src/models/Word2Vec.py @@ -15,7 +15,9 @@ def train(sentences): def load_pretrained(): global model print("Loading model: Word2Vec Pretrained Google News") - model = api.load('word2vec-google-news-300') + model = Word2Vec(vector_size=300, window=5, min_count=1) + model.wv = api.load('word2vec-google-news-300') + model.trainables = None print("Finished loading model: Word2Vec Pretrained Google News") def similarity(word1, word2): @@ -30,4 +32,47 @@ def print_nearest(word): print(model.wv.most_similar("him")) else: print("word is not in vocabulary") + +# dictionary containing example values for each label +def set_label_dict(label_dict): + global label_representatives + label_representatives = label_dict + +def classify_entity(entity, labels): + global label_representatives + best_label = None + best_similarity = float("inf") + + # Keep track how many of the entities have word vectors + has_vecs = 0 + has_no_vecs = 0 + + # Find label with highest probability of entailment + for label in labels: + representatives = label_representatives[label] + + # do not cheat if entity is one of the representatives already + if entity in representatives: + representatives.remove(entity) + + best_for_label = float("inf") + + for representative in representatives: + sim = similarity(entity, representative) + if sim == float("inf"): + has_no_vecs += 1 + else: + has_vecs += 1 + + best_for_label = min(best_for_label, sim) + + if (best_for_label < best_similarity): + best_similarity = best_for_label + best_label = label + + # print(f"{has_vecs} entities had word vectors, {has_no_vecs} entities did not have word vectors") + + return best_label if best_label else labels[0] + +load_pretrained() -- GitLab