From b4f3637e7886cf329aa0e0fbfc249584fbdb162e Mon Sep 17 00:00:00 2001
From: kupper <kupper@login.cl.uni-heidelberg.de>
Date: Sun, 23 Mar 2025 15:13:05 +0100
Subject: [PATCH] Classification with Word2Vec

---
 data/data_manager.py                         | 17 +++++++
 src/common_interface.py                      | 17 ++++++-
 src/experiments/NEC_evaluation/evaluation.py |  9 +++-
 src/models/Word2Vec.py                       | 47 +++++++++++++++++++-
 4 files changed, 86 insertions(+), 4 deletions(-)

diff --git a/data/data_manager.py b/data/data_manager.py
index 787b457..e218dd9 100644
--- a/data/data_manager.py
+++ b/data/data_manager.py
@@ -46,3 +46,20 @@ def get_annotated_sentences(dataset, test_instances=10):
         return random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine())))
     else:
         raise Exception
+    
+def get_label_dict(dataset, test_instances=10):
+    """
+    Returns a dictionary of example entities for each label.
+    """
+    label_dict = {}
+    labels = get_labels(dataset)
+    for label in labels:
+        label_dict[label] = set()
+    annotated_sentences = get_annotated_sentences(dataset, test_instances)
+    for annotated_sentence in annotated_sentences:
+        for labeled_entity in annotated_sentence[1]:
+            if labeled_entity[1] not in label_dict:
+                label_dict[labeled_entity[1]] = set()
+            label_dict[labeled_entity[1]].add(labeled_entity[0])
+    print(f"Label dict: {label_dict}")
+    return label_dict
diff --git a/src/common_interface.py b/src/common_interface.py
index 56d4cc6..578b666 100644
--- a/src/common_interface.py
+++ b/src/common_interface.py
@@ -7,6 +7,8 @@ from src.models.GLiNER import find_entities as find_entities_gliner
 from src.models.GLiNER import classify_entity as classify_entity_gliner
 from src.models.T5_NLI import classify_entity as classify_entity_t5_nli
 from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label
+from src.models.Word2Vec import classify_entity as classify_entity_word2vec
+from src.models.Word2Vec import set_label_dict as set_label_dict_word2vec
 from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
 from src.experiments.NER_with_LLMs.NER_with_LLMs import classify_entity as classify_entities_llm
 
@@ -21,8 +23,12 @@ def classify_entity(model_name, sentence, entity, labels):
         return classify_entity_t5_mlm_label(sentence, entity, labels)
     elif model_name == "GLiNER":
         return classify_entity_gliner(sentence, entity, labels)
+    elif model_name == "Word2Vec":
+        return classify_entity_word2vec(entity, labels)
     elif model_name in llms:
         return classify_entities_llm(model_name, sentence, entity, labels)
+    else:
+        print(f"classify_entity not implemented for {model_name}")
 
 
 def find_entities(model_name, sentence, labels):
@@ -34,4 +40,13 @@ def find_entities(model_name, sentence, labels):
     elif model_name == "GLiNER":
         return find_entities_gliner(sentence, labels)
     else:
-        print("Not implemented")
+        print(f"find_entities not implemented for {model_name}")
+
+def set_label_dict(model_name, label_dict):
+    """
+    NER. Sets the label dictionary required for the Word2Vec model
+    """
+    if model_name == "Word2Vec":
+        return set_label_dict_word2vec(label_dict)
+    else:
+        print(f"set_label_dict not implemented for {model_name}")
diff --git a/src/experiments/NEC_evaluation/evaluation.py b/src/experiments/NEC_evaluation/evaluation.py
index 1442e57..748e56f 100644
--- a/src/experiments/NEC_evaluation/evaluation.py
+++ b/src/experiments/NEC_evaluation/evaluation.py
@@ -7,7 +7,7 @@ import datetime
 import pandas as pd
 from tqdm import tqdm
 import data.data_manager as data_manager
-from src.common_interface import classify_entity
+from src.common_interface import classify_entity, set_label_dict
 
 
 def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
@@ -27,6 +27,11 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
     labels = data_manager.get_labels(dataset)
     data = data_manager.get_annotated_sentences(dataset, test_instances)
 
+    if (model_name == "Word2Vec"):
+        label_dict = data_manager.get_label_dict(dataset, test_instances)
+        print(label_dict)
+        set_label_dict(model_name, label_dict)
+
     with open(csv_filename, mode="w", newline='', encoding="utf-8") as csv_file, \
             open(txt_filename, mode="w", encoding="utf-8") as txt_file:
         csv_writer = csv.writer(csv_file)
@@ -53,7 +58,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
 
 
 def run_NEC_tests_all():
-    models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "DeepSeek-R1-Distill-Qwen-32B"]
+    models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "Word2Vec", "DeepSeek-R1-Distill-Qwen-32B"]
     datasets = ["CoNLL", "FIGER-coarse", "FIGER-fine"]  # "Pile-NER-type"]
     for model in models:
         for dataset in datasets:
diff --git a/src/models/Word2Vec.py b/src/models/Word2Vec.py
index 43daee9..d749a5c 100644
--- a/src/models/Word2Vec.py
+++ b/src/models/Word2Vec.py
@@ -15,7 +15,9 @@ def train(sentences):
 def load_pretrained():
     global model
     print("Loading model: Word2Vec Pretrained Google News")
-    model = api.load('word2vec-google-news-300')
+    model = Word2Vec(vector_size=300, window=5, min_count=1)
+    model.wv = api.load('word2vec-google-news-300')
+    model.trainables = None
     print("Finished loading model: Word2Vec Pretrained Google News")
 
 def similarity(word1, word2):
@@ -30,4 +32,47 @@ def print_nearest(word):
         print(model.wv.most_similar("him"))
     else:
         print("word is not in vocabulary")
+
+# dictionary containing example values for each label
+def set_label_dict(label_dict):
+    global label_representatives
+    label_representatives = label_dict
+
     
+def classify_entity(entity, labels):
+    global label_representatives
+    best_label = None
+    best_similarity = float("inf")
+
+    # Keep track how many of the entities have word vectors
+    has_vecs = 0
+    has_no_vecs = 0
+
+    # Find label with highest probability of entailment
+    for label in labels:
+        representatives = label_representatives[label]
+
+        # do not cheat if entity is one of the representatives already
+        if entity in representatives:
+            representatives.remove(entity)
+
+        best_for_label = float("inf")
+
+        for representative in representatives:
+            sim = similarity(entity, representative)
+            if sim == float("inf"):
+                has_no_vecs += 1
+            else:
+                has_vecs += 1
+            
+            best_for_label = min(best_for_label, sim)
+
+        if (best_for_label < best_similarity):
+            best_similarity = best_for_label
+            best_label = label
+        
+    # print(f"{has_vecs} entities had word vectors, {has_no_vecs} entities did not have word vectors")
+
+    return best_label if best_label else labels[0]
+
+load_pretrained()
-- 
GitLab