Skip to content
Snippets Groups Projects
Commit b4f3637e authored by kupper's avatar kupper
Browse files

Classification with Word2Vec

parent 4e71dced
No related branches found
No related tags found
No related merge requests found
......@@ -46,3 +46,20 @@ def get_annotated_sentences(dataset, test_instances=10):
return random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine())))
else:
raise Exception
def get_label_dict(dataset, test_instances=10):
"""
Returns a dictionary of example entities for each label.
"""
label_dict = {}
labels = get_labels(dataset)
for label in labels:
label_dict[label] = set()
annotated_sentences = get_annotated_sentences(dataset, test_instances)
for annotated_sentence in annotated_sentences:
for labeled_entity in annotated_sentence[1]:
if labeled_entity[1] not in label_dict:
label_dict[labeled_entity[1]] = set()
label_dict[labeled_entity[1]].add(labeled_entity[0])
print(f"Label dict: {label_dict}")
return label_dict
......@@ -7,6 +7,8 @@ from src.models.GLiNER import find_entities as find_entities_gliner
from src.models.GLiNER import classify_entity as classify_entity_gliner
from src.models.T5_NLI import classify_entity as classify_entity_t5_nli
from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label
from src.models.Word2Vec import classify_entity as classify_entity_word2vec
from src.models.Word2Vec import set_label_dict as set_label_dict_word2vec
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
from src.experiments.NER_with_LLMs.NER_with_LLMs import classify_entity as classify_entities_llm
......@@ -21,8 +23,12 @@ def classify_entity(model_name, sentence, entity, labels):
return classify_entity_t5_mlm_label(sentence, entity, labels)
elif model_name == "GLiNER":
return classify_entity_gliner(sentence, entity, labels)
elif model_name == "Word2Vec":
return classify_entity_word2vec(entity, labels)
elif model_name in llms:
return classify_entities_llm(model_name, sentence, entity, labels)
else:
print(f"classify_entity not implemented for {model_name}")
def find_entities(model_name, sentence, labels):
......@@ -34,4 +40,13 @@ def find_entities(model_name, sentence, labels):
elif model_name == "GLiNER":
return find_entities_gliner(sentence, labels)
else:
print("Not implemented")
print(f"find_entities not implemented for {model_name}")
def set_label_dict(model_name, label_dict):
"""
NER. Sets the label dictionary required for the Word2Vec model
"""
if model_name == "Word2Vec":
return set_label_dict_word2vec(label_dict)
else:
print(f"set_label_dict not implemented for {model_name}")
......@@ -7,7 +7,7 @@ import datetime
import pandas as pd
from tqdm import tqdm
import data.data_manager as data_manager
from src.common_interface import classify_entity
from src.common_interface import classify_entity, set_label_dict
def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
......@@ -27,6 +27,11 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
labels = data_manager.get_labels(dataset)
data = data_manager.get_annotated_sentences(dataset, test_instances)
if (model_name == "Word2Vec"):
label_dict = data_manager.get_label_dict(dataset, test_instances)
print(label_dict)
set_label_dict(model_name, label_dict)
with open(csv_filename, mode="w", newline='', encoding="utf-8") as csv_file, \
open(txt_filename, mode="w", encoding="utf-8") as txt_file:
csv_writer = csv.writer(csv_file)
......@@ -53,7 +58,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
def run_NEC_tests_all():
models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "DeepSeek-R1-Distill-Qwen-32B"]
models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "Word2Vec", "DeepSeek-R1-Distill-Qwen-32B"]
datasets = ["CoNLL", "FIGER-coarse", "FIGER-fine"] # "Pile-NER-type"]
for model in models:
for dataset in datasets:
......
......@@ -15,7 +15,9 @@ def train(sentences):
def load_pretrained():
global model
print("Loading model: Word2Vec Pretrained Google News")
model = api.load('word2vec-google-news-300')
model = Word2Vec(vector_size=300, window=5, min_count=1)
model.wv = api.load('word2vec-google-news-300')
model.trainables = None
print("Finished loading model: Word2Vec Pretrained Google News")
def similarity(word1, word2):
......@@ -30,4 +32,47 @@ def print_nearest(word):
print(model.wv.most_similar("him"))
else:
print("word is not in vocabulary")
# dictionary containing example values for each label
def set_label_dict(label_dict):
global label_representatives
label_representatives = label_dict
def classify_entity(entity, labels):
global label_representatives
best_label = None
best_similarity = float("inf")
# Keep track how many of the entities have word vectors
has_vecs = 0
has_no_vecs = 0
# Find label with highest probability of entailment
for label in labels:
representatives = label_representatives[label]
# do not cheat if entity is one of the representatives already
if entity in representatives:
representatives.remove(entity)
best_for_label = float("inf")
for representative in representatives:
sim = similarity(entity, representative)
if sim == float("inf"):
has_no_vecs += 1
else:
has_vecs += 1
best_for_label = min(best_for_label, sim)
if (best_for_label < best_similarity):
best_similarity = best_for_label
best_label = label
# print(f"{has_vecs} entities had word vectors, {has_no_vecs} entities did not have word vectors")
return best_label if best_label else labels[0]
load_pretrained()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment