Skip to content
Snippets Groups Projects
Commit a9cdbec8 authored by Thomas Wolf's avatar Thomas Wolf
Browse files

Cleanup, PEP 8 conventions, comments

parent c877e95e
No related branches found
No related tags found
No related merge requests found
Showing
with 76 additions and 51 deletions
...@@ -23,10 +23,11 @@ Scripts for executing code on the Computerlinguistik cluster or BwUniCluster are ...@@ -23,10 +23,11 @@ Scripts for executing code on the Computerlinguistik cluster or BwUniCluster are
The experiments conducted as part of this project are located in [`/src/experiments`](src/experiments). The experiments conducted as part of this project are located in [`/src/experiments`](src/experiments).
The results for the named entity classification experiment are located in [`/src/experiments/NEC_evaluation/results/`](src/experiments/NEC_evaluation/results/). The evaluation results for the named entity classification experiment are located in [`/src/experiments/NEC_evaluation/results/`](src/experiments/NEC_evaluation/results/),
the achieved model accuracies can be displayed by running `read_NEC_metrics("results")` in [`/src/experiments/NEC_evaluation/evaluation.py`](src/experiments/NEC_evaluation/evaluation.py).
## Setup and Requirements ## Setup and Requirements
Note: A CUDA-enabled GPU is required to run the finetuning and LLM-based experiments Note: A CUDA-enabled GPU is required to run the finetuning and LLM-based experiments.
1. Ensure you have installed Python 3.8 or newer 1. Ensure you have installed Python 3.8 or newer
2. Run `pip install requirements.txt` to install necessary dependencies 2. Run `pip install requirements.txt` to install necessary dependencies
......
...@@ -28,6 +28,7 @@ fine_labels = ["actor", "architect", "artist", "athlete", "author", "coach", "di ...@@ -28,6 +28,7 @@ fine_labels = ["actor", "architect", "artist", "athlete", "author", "coach", "di
entries_coarse = [] entries_coarse = []
entries_fine = [] entries_fine = []
def download_and_unpack(): def download_and_unpack():
gzip_path = os.path.join(base_dir, "train.data.gz") gzip_path = os.path.join(base_dir, "train.data.gz")
data_path = os.path.join(base_dir, "train.data") data_path = os.path.join(base_dir, "train.data")
...@@ -54,11 +55,12 @@ def download_and_unpack(): ...@@ -54,11 +55,12 @@ def download_and_unpack():
return data_path return data_path
def load_or_preprocess(entries_coarse, entries_fine): def load_or_preprocess(entries_coarse, entries_fine):
coarse_serialized_path = os.path.join(base_dir, "figer-coarse.pickle") coarse_serialized_path = os.path.join(base_dir, "figer-coarse.pickle")
fine_serialized_path = os.path.join(base_dir, "figer-fine.pickle") fine_serialized_path = os.path.join(base_dir, "figer-fine.pickle")
#serialized already preprocessed data not found. Pre-process it not # serialized already preprocessed data not found. Pre-process it not
if not os.path.exists(coarse_serialized_path) or not os.path.exists(fine_serialized_path): if not os.path.exists(coarse_serialized_path) or not os.path.exists(fine_serialized_path):
if os.path.exists(coarse_serialized_path): if os.path.exists(coarse_serialized_path):
os.remove(coarse_serialized_path) os.remove(coarse_serialized_path)
...@@ -122,13 +124,13 @@ def load_or_preprocess(entries_coarse, entries_fine): ...@@ -122,13 +124,13 @@ def load_or_preprocess(entries_coarse, entries_fine):
coarse_label_set = set() coarse_label_set = set()
fine_label_set = set() fine_label_set = set()
for label in mention.labels: for label in mention.labels:
if (new_label := label_mapping.get(label)): if new_label := label_mapping.get(label):
if new_label[0] != "": if new_label[0] != "":
coarse_label_set.add(new_label[0]) coarse_label_set.add(new_label[0])
if new_label[1] != "": if new_label[1] != "":
fine_label_set.add(new_label[1]) fine_label_set.add(new_label[1])
#only add entries that have both a coarse and a fine label # only add entries that have both a coarse and a fine label
#this way both coarse and fine datasets are identical except for the label # this way both coarse and fine datasets are identical except for the label
if len(coarse_label_set) > 0 and len(fine_label_set) > 0: if len(coarse_label_set) > 0 and len(fine_label_set) > 0:
coarse_label_list = list(map(lambda x: (entity_name, x), coarse_label_set)) coarse_label_list = list(map(lambda x: (entity_name, x), coarse_label_set))
entries_coarse.append((sentence, coarse_label_list)) entries_coarse.append((sentence, coarse_label_list))
...@@ -139,7 +141,8 @@ def load_or_preprocess(entries_coarse, entries_fine): ...@@ -139,7 +141,8 @@ def load_or_preprocess(entries_coarse, entries_fine):
i += 1 i += 1
print("") print("")
print(f"Figer dataset: Preprocessing completed. Writing result into {coarse_serialized_path} and {fine_serialized_path} now...") print(
f"Figer dataset: Preprocessing completed. Writing result into {coarse_serialized_path} and {fine_serialized_path} now...")
with open(coarse_serialized_path, "wb") as coarse_file: with open(coarse_serialized_path, "wb") as coarse_file:
pickle.dump(entries_coarse, coarse_file) pickle.dump(entries_coarse, coarse_file)
with open(fine_serialized_path, "wb") as fine_file: with open(fine_serialized_path, "wb") as fine_file:
...@@ -148,9 +151,10 @@ def load_or_preprocess(entries_coarse, entries_fine): ...@@ -148,9 +151,10 @@ def load_or_preprocess(entries_coarse, entries_fine):
os.remove(data_path) os.remove(data_path)
print("Figer dataset: Writing complete.") print("Figer dataset: Writing complete.")
#just load already preprocessed data from pickle file # just load already preprocessed data from pickle file
else: else:
print(f"Figer dataset: Importing already preprocessed data from {coarse_serialized_path} and {fine_serialized_path} now...") print(
f"Figer dataset: Importing already preprocessed data from {coarse_serialized_path} and {fine_serialized_path} now...")
with open(coarse_serialized_path, "rb") as coarse_file: with open(coarse_serialized_path, "rb") as coarse_file:
entries_coarse += pickle.load(coarse_file) entries_coarse += pickle.load(coarse_file)
with open(fine_serialized_path, "rb") as fine_file: with open(fine_serialized_path, "rb") as fine_file:
......
...@@ -37,23 +37,28 @@ def get_annotated_sentences(dataset, test_instances=10, train_test_split=1.0, tr ...@@ -37,23 +37,28 @@ def get_annotated_sentences(dataset, test_instances=10, train_test_split=1.0, tr
random.seed(42) random.seed(42)
if dataset == "CoNLL": if dataset == "CoNLL":
sentences = random.sample(conll_interface.get_annotated_sentences(), min(test_instances, len(conll_interface.get_annotated_sentences()))) sentences = random.sample(conll_interface.get_annotated_sentences(),
min(test_instances, len(conll_interface.get_annotated_sentences())))
elif dataset == "Pile-NER-type": elif dataset == "Pile-NER-type":
sentences = random.sample(pile_interface.get_annotated_sentences(), min(test_instances, len(pile_interface.get_annotated_sentences()))) sentences = random.sample(pile_interface.get_annotated_sentences(),
min(test_instances, len(pile_interface.get_annotated_sentences())))
elif dataset == "FIGER-coarse": elif dataset == "FIGER-coarse":
sentences = random.sample(figer_interface.get_annotated_sentences_coarse(), min(test_instances, len(figer_interface.get_annotated_sentences_coarse()))) sentences = random.sample(figer_interface.get_annotated_sentences_coarse(),
min(test_instances, len(figer_interface.get_annotated_sentences_coarse())))
elif dataset == "FIGER-fine": elif dataset == "FIGER-fine":
sentences = random.sample(figer_interface.get_annotated_sentences_fine(), min(test_instances, len(figer_interface.get_annotated_sentences_fine()))) sentences = random.sample(figer_interface.get_annotated_sentences_fine(),
min(test_instances, len(figer_interface.get_annotated_sentences_fine())))
else: else:
raise Exception raise Exception
split_pos = max(0, min(len(sentences), int(train_test_split * len(sentences)))) split_pos = max(0, min(len(sentences), int(train_test_split * len(sentences))))
if train: if train:
return sentences[split_pos:] return sentences[split_pos:]
else: else:
return sentences[:split_pos] return sentences[:split_pos]
def get_label_dict(dataset, test_instances=10): def get_label_dict(dataset, test_instances=10):
""" """
Returns a dictionary of example entities for each label. Returns a dictionary of example entities for each label.
...@@ -70,4 +75,3 @@ def get_label_dict(dataset, test_instances=10): ...@@ -70,4 +75,3 @@ def get_label_dict(dataset, test_instances=10):
label_dict[labeled_entity[1]].add(labeled_entity[0]) label_dict[labeled_entity[1]].add(labeled_entity[0])
print(f"Label dict: {label_dict}") print(f"Label dict: {label_dict}")
return label_dict return label_dict
https://paperswithcode.com/dataset/conll-2003 # CoNLL https://paperswithcode.com/dataset/conll-2003 # CoNLL
https://paperswithcode.com/sota/entity-typing-on-figer # Figer https://paperswithcode.com/sota/entity-typing-on-figer # Figer
https://www.cs.utexas.edu/~eunsol/html_pages/open_entity.html # Ultra-Fine Entity Typing data https://www.cs.utexas.edu/~eunsol/html_pages/open_entity.html # Ultra-Fine Entity Typing data
(Sources may not be final, did not try to download the datasets yet)
...@@ -10,3 +10,4 @@ tqdm ...@@ -10,3 +10,4 @@ tqdm
gensim gensim
accelerate accelerate
matplotlib matplotlib
torch
""" """
Provides functions for probing / NER that can be called with model name and the required data. Provides functions for probing / NER that can be called with model name and the required data.
Makes evaluating models easier. Makes evaluating multiple models easier.
""" """
from src.models.LLM_interface import available_models as llms from src.models.LLM_interface import available_models as llms
from src.models.GLiNER import find_entities as find_entities_gliner from src.models.GLiNER import find_entities as find_entities_gliner
...@@ -49,7 +49,7 @@ def find_entities(model_name, sentence, labels): ...@@ -49,7 +49,7 @@ def find_entities(model_name, sentence, labels):
def set_label_dict(model_name, label_dict): def set_label_dict(model_name, label_dict):
""" """
NER. Sets the label dictionary required for the Word2Vec model Sets the label dictionary required for the Word2Vec model
""" """
if model_name == "Word2Vec": if model_name == "Word2Vec":
return set_label_dict_word2vec(label_dict) return set_label_dict_word2vec(label_dict)
......
...@@ -6,6 +6,6 @@ NEC evaluation for GLiNER is part of /NEC_evaluation. ...@@ -6,6 +6,6 @@ NEC evaluation for GLiNER is part of /NEC_evaluation.
from src.metrics import NER_metrics, read_NER_metrics from src.metrics import NER_metrics, read_NER_metrics
for dataset in ["CoNLL", "FIGER-coarse", "FIGER-fine"]: for dataset in ["CoNLL", "FIGER-coarse", "FIGER-fine"]:
#NER_metrics("GLiNER", dataset, f"results_{dataset}", test_instances=1000) # NER_metrics("GLiNER", dataset, f"results_{dataset}", test_instances=1000)
print(f"\nResults {dataset}:") print(f"\nResults {dataset}:")
read_NER_metrics(f"results_{dataset}") read_NER_metrics(f"results_{dataset}")
...@@ -41,7 +41,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): ...@@ -41,7 +41,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
for idx, entry in enumerate(tqdm(data, desc="Processing Instances", unit="instance")): for idx, entry in enumerate(tqdm(data, desc="Processing Instances", unit="instance")):
sentence = entry[0] sentence = entry[0]
entity_dict = {} # If there are multiple labels assigned to one entity, entity_dict = {} # If there are multiple labels assigned to one entity,
# it should count as correct if the model predicts one of them. # it should count as correct if the model prediction is one of them.
for entity in entry[1]: for entity in entry[1]:
entity_name = entity[0] entity_name = entity[0]
true_label = entity[1] true_label = entity[1]
...@@ -69,7 +69,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10): ...@@ -69,7 +69,7 @@ def run_NEC_tests(model_name, dataset, results_dir, test_instances=10):
def run_NEC_tests_all(): def run_NEC_tests_all():
models = ["GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "T5-MLM-entity", "DeepSeek-R1-Distill-Qwen-32B"] # "Word2Vec" models = ["DeepSeek-R1-Distill-Qwen-32B", "GLiNER", "Llama-3.1-8B", "T5-NLI", "T5-MLM-label", "T5-MLM-entity", "Word2Vec"]
datasets = ["CoNLL", "FIGER-coarse", "FIGER-fine"] # "Pile-NER-type"] datasets = ["CoNLL", "FIGER-coarse", "FIGER-fine"] # "Pile-NER-type"]
for model in models: for model in models:
for dataset in datasets: for dataset in datasets:
......
...@@ -22,7 +22,7 @@ Your Task: ...@@ -22,7 +22,7 @@ Your Task:
""" """
nli_system_prompt = """You are part of a named entity classification pipeline. nli_system_prompt = """You are part of a named entity classification pipeline.
Given an entity, find the best fitting label of the provided labels (no not invent your own labels!). Given an entity, find the best fitting label of the provided labels (do not invent your own labels!).
Choose the label so that the sentence "Target entity is a chosen_label" makes the most sense. Choose the label so that the sentence "Target entity is a chosen_label" makes the most sense.
Mark your result like this for easy extraction: <answer>predicted_class</answer>. Mark your result like this for easy extraction: <answer>predicted_class</answer>.
......
""" """
Uses the csv file generated in NER_with_LLMs.py for evaluation. Evaluates LLMs on the NER task.
""" """
from src.metrics import NER_metrics, read_NER_metrics from src.metrics import NER_metrics, read_NER_metrics
...@@ -13,5 +13,5 @@ def run_test(): ...@@ -13,5 +13,5 @@ def run_test():
NER_metrics(model, dataset, results_dir="results", test_instances=100) NER_metrics(model, dataset, results_dir="results", test_instances=100)
#run_test() # run_test()
read_NER_metrics("results") read_NER_metrics("results")
import data.data_manager as data_manager import data.data_manager as data_manager
from src.models.T5_MLM_entity import finetune_model, set_label_dict from src.models.T5_MLM_entity import finetune_model, set_label_dict
def finetune_t5(dataset): def finetune_t5(dataset):
print("start") print("start")
annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True) annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
labels = data_manager.get_labels(dataset)
sentences = [] sentences = []
entities = [] entities = []
...@@ -25,9 +25,11 @@ def finetune_t5(dataset): ...@@ -25,9 +25,11 @@ def finetune_t5(dataset):
epochs = 150 epochs = 150
label_dict = data_manager.get_label_dict(dataset, 1000) label_dict = data_manager.get_label_dict(dataset, 1000)
set_label_dict(label_dict) set_label_dict(label_dict)
finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_entity_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs) finetune_model(sentences, entities, labels,
output_dir=f"./src/models/t5_mlm_entity_finetuned_model/pretrained_{dataset}_epoch{epochs}",
epochs=epochs)
finetune_t5("FIGER-coarse") finetune_t5("FIGER-coarse")
import data.data_manager as data_manager import data.data_manager as data_manager
from src.models.T5_MLM_label import finetune_model from src.models.T5_MLM_label import finetune_model
def finetune_t5(dataset): def finetune_t5(dataset):
print("start") print("start")
annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True) annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
labels = data_manager.get_labels(dataset)
sentences = [] sentences = []
entities = [] entities = []
...@@ -25,6 +25,9 @@ def finetune_t5(dataset): ...@@ -25,6 +25,9 @@ def finetune_t5(dataset):
epochs = 150 epochs = 150
finetune_model(sentences, entities, labels, output_dir=f"./src/models/t5_mlm_label_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs) finetune_model(sentences, entities, labels,
output_dir=f"./src/models/t5_mlm_label_finetuned_model/pretrained_{dataset}_epoch{epochs}",
epochs=epochs)
finetune_t5("FIGER-coarse") finetune_t5("FIGER-coarse")
import data.data_manager as data_manager import data.data_manager as data_manager
from src.models.T5_NLI import finetune_model from src.models.T5_NLI import finetune_model
def finetune_t5(dataset): def finetune_t5(dataset):
print("start") print("start")
annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True) annotated_sentences = data_manager.get_annotated_sentences(dataset, 2000, train_test_split=0.5, train=True)
...@@ -27,6 +28,8 @@ def finetune_t5(dataset): ...@@ -27,6 +28,8 @@ def finetune_t5(dataset):
epochs = 150 epochs = 150
finetune_model(premises, hypotheses, entailment, output_dir=f"./src/models/t5_nli_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs) finetune_model(premises, hypotheses, entailment,
output_dir=f"./src/models/t5_nli_finetuned_model/pretrained_{dataset}_epoch{epochs}", epochs=epochs)
finetune_t5("FIGER-coarse") finetune_t5("FIGER-coarse")
""" """
This file contains functions calculating precision, recall and f1 score for NER. (todo NLI) This file contains functions calculating precision, recall and F1 score for NER.
Input format: two lists like this: [('entity', 'label'), ('entity', 'label'), ('entity', 'label')]] Input format: two lists like this: [('entity', 'label'), ('entity', 'label'), ('entity', 'label')]]
""" """
import os import os
......
from transformers import pipeline
unmasker = pipeline('fill-mask', model='bert-base-uncased')
print(unmasker("Hello I'm a [MASK] model."))
...@@ -10,6 +10,7 @@ print("Finished loading model: GLiNER") ...@@ -10,6 +10,7 @@ print("Finished loading model: GLiNER")
def find_entities(sentence, labels): def find_entities(sentence, labels):
"""Performs NER."""
entities = model.predict_entities(sentence, labels) entities = model.predict_entities(sentence, labels)
entity_list = [] entity_list = []
for entity in entities: for entity in entities:
...@@ -19,10 +20,10 @@ def find_entities(sentence, labels): ...@@ -19,10 +20,10 @@ def find_entities(sentence, labels):
def classify_entity(sentence, entity, labels): def classify_entity(sentence, entity, labels):
"""Performs NEC using NER."""
entity_list = find_entities(sentence, labels) entity_list = find_entities(sentence, labels)
for e in entity_list: for e in entity_list:
if e[0] == entity: if e[0] == entity:
return e[1] # Return label return e[1] # Return label
return "Target entity not found during NER." return "Target entity not found during NER."
import torch
import random import random
import numpy as np import numpy as np
from torch.nn.functional import softmax
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-base" model_name = "google-t5/t5-base"
def load_base(): def load_base():
global model global model
global tokenizer global tokenizer
...@@ -14,7 +13,7 @@ def load_base(): ...@@ -14,7 +13,7 @@ def load_base():
tokenizer = T5Tokenizer.from_pretrained(model_name) tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM entity") print("Finished loading model: T5 MLM entity")
def load_finetuned(input_dir): def load_finetuned(input_dir):
global model global model
...@@ -29,6 +28,7 @@ def set_label_dict(label_dict): ...@@ -29,6 +28,7 @@ def set_label_dict(label_dict):
global label_representatives global label_representatives
label_representatives = label_dict label_representatives = label_dict
def classify_entity(sentence, entity, labels): def classify_entity(sentence, entity, labels):
global label_representatives global label_representatives
...@@ -127,5 +127,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10): ...@@ -127,5 +127,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
load_base() load_base()
# load_finetuned("./src/models/t5_mlm_entity_finetuned_model/checkpoints/checkpoint-12200") # load_finetuned("./src/models/t5_mlm_entity_finetuned_model/checkpoints/checkpoint-12200")
import torch
import numpy as np import numpy as np
from torch.nn.functional import softmax
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
model_name = "google-t5/t5-base" model_name = "google-t5/t5-base"
def load_base(): def load_base():
global model global model
global tokenizer global tokenizer
...@@ -13,7 +12,7 @@ def load_base(): ...@@ -13,7 +12,7 @@ def load_base():
tokenizer = T5Tokenizer.from_pretrained(model_name) tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 MLM label") print("Finished loading model: T5 MLM label")
def load_finetuned(input_dir): def load_finetuned(input_dir):
global model global model
...@@ -103,5 +102,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10): ...@@ -103,5 +102,6 @@ def finetune_model(sentences, entities, labels, output_dir, epochs=10):
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
load_base() load_base()
# load_finetuned("./src/models/t5_mlm_label_finetuned_model/checkpoints/checkpoint-9638") # load_finetuned("./src/models/t5_mlm_label_finetuned_model/checkpoints/checkpoint-9638")
...@@ -8,6 +8,7 @@ label_map = {True: "entailment", False: "contradiction"} ...@@ -8,6 +8,7 @@ label_map = {True: "entailment", False: "contradiction"}
# Use t5-base for testing because it is smaller # Use t5-base for testing because it is smaller
model_name = "google-t5/t5-base" model_name = "google-t5/t5-base"
# model_name = "google/t5_xxl_true_nli_mixture" # model_name = "google/t5_xxl_true_nli_mixture"
def load_base(): def load_base():
...@@ -17,7 +18,7 @@ def load_base(): ...@@ -17,7 +18,7 @@ def load_base():
tokenizer = T5Tokenizer.from_pretrained(model_name) tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Finished loading model: T5 NLI") print("Finished loading model: T5 NLI")
def load_finetuned(input_dir): def load_finetuned(input_dir):
global model global model
...@@ -27,6 +28,7 @@ def load_finetuned(input_dir): ...@@ -27,6 +28,7 @@ def load_finetuned(input_dir):
model = T5ForConditionalGeneration.from_pretrained(input_dir) model = T5ForConditionalGeneration.from_pretrained(input_dir)
print(f"Finished loading model: T5 NLI finetuned") print(f"Finished loading model: T5 NLI finetuned")
def infer_nli(premise, hypothesis): def infer_nli(premise, hypothesis):
input_text = f"nli hypothesis: {hypothesis} premise: {premise}" input_text = f"nli hypothesis: {hypothesis} premise: {premise}"
inputs = tokenizer(input_text, return_tensors="pt") inputs = tokenizer(input_text, return_tensors="pt")
...@@ -35,6 +37,7 @@ def infer_nli(premise, hypothesis): ...@@ -35,6 +37,7 @@ def infer_nli(premise, hypothesis):
return result return result
def infer_nli_probabilities(premise, hypothesis): def infer_nli_probabilities(premise, hypothesis):
input_text = f"nli hypothesis: {hypothesis} premise: {premise}" input_text = f"nli hypothesis: {hypothesis} premise: {premise}"
...@@ -79,6 +82,7 @@ def classify_entity(sentence, entity, labels): ...@@ -79,6 +82,7 @@ def classify_entity(sentence, entity, labels):
return best_label if best_label else labels[0] return best_label if best_label else labels[0]
def preprocess_data(sample): def preprocess_data(sample):
input_text = f"nli hypothesis: {sample['hypothesis']} premise: {sample['premise']}" input_text = f"nli hypothesis: {sample['hypothesis']} premise: {sample['premise']}"
target_text = label_map[bool(sample['entailment'])] target_text = label_map[bool(sample['entailment'])]
...@@ -138,5 +142,6 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10): ...@@ -138,5 +142,6 @@ def finetune_model(premises, hypotheses, entailment, output_dir, epochs=10):
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
load_base() load_base()
# load_finetuned("./src/models/t5_nli_finetuned_model/checkpoints/checkpoint-85500") # load_finetuned("./src/models/t5_nli_finetuned_model/checkpoints/checkpoint-85500")
...@@ -2,16 +2,19 @@ from gensim.models import Word2Vec ...@@ -2,16 +2,19 @@ from gensim.models import Word2Vec
import gensim.downloader as api import gensim.downloader as api
import string import string
def train(sentences): def train(sentences):
global model global model
# make lowercase, remove punctuation, split at whitespace # make lowercase, remove punctuation, split at whitespace
sentences = [sentence.lower().translate(str.maketrans('', '', string.punctuation)).split() for sentence in sentences] sentences = [sentence.lower().translate(str.maketrans('', '', string.punctuation)).split() for sentence in
sentences]
model = Word2Vec(vector_size=100, window=5, min_count=1, workers=4) model = Word2Vec(vector_size=100, window=5, min_count=1, workers=4)
model.build_vocab(sentences) model.build_vocab(sentences)
model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs) model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs)
print(model.wv.most_similar("him")) print(model.wv.most_similar("him"))
def load_pretrained(): def load_pretrained():
global model global model
print("Loading model: Word2Vec Pretrained Google News") print("Loading model: Word2Vec Pretrained Google News")
...@@ -20,12 +23,14 @@ def load_pretrained(): ...@@ -20,12 +23,14 @@ def load_pretrained():
model.trainables = None model.trainables = None
print("Finished loading model: Word2Vec Pretrained Google News") print("Finished loading model: Word2Vec Pretrained Google News")
def similarity(word1, word2): def similarity(word1, word2):
if word1 in model.wv.key_to_index and word2 in model.wv.key_to_index: if word1 in model.wv.key_to_index and word2 in model.wv.key_to_index:
return model.wv.similarity(word1, word2) return model.wv.similarity(word1, word2)
else: else:
return float("inf") return float("inf")
def print_nearest(word): def print_nearest(word):
global model global model
if word in model.wv.key_to_index: if word in model.wv.key_to_index:
...@@ -33,12 +38,13 @@ def print_nearest(word): ...@@ -33,12 +38,13 @@ def print_nearest(word):
else: else:
print("word is not in vocabulary") print("word is not in vocabulary")
# dictionary containing example values for each label # dictionary containing example values for each label
def set_label_dict(label_dict): def set_label_dict(label_dict):
global label_representatives global label_representatives
label_representatives = label_dict label_representatives = label_dict
def classify_entity(entity, labels): def classify_entity(entity, labels):
global label_representatives global label_representatives
best_label = None best_label = None
...@@ -64,15 +70,15 @@ def classify_entity(entity, labels): ...@@ -64,15 +70,15 @@ def classify_entity(entity, labels):
has_no_vecs += 1 has_no_vecs += 1
else: else:
has_vecs += 1 has_vecs += 1
best_for_label = min(best_for_label, sim) best_for_label = min(best_for_label, sim)
if (best_for_label < best_similarity): if best_for_label < best_similarity:
best_similarity = best_for_label best_similarity = best_for_label
best_label = label best_label = label
# print(f"{has_vecs} entities had word vectors, {has_no_vecs} entities did not have word vectors") # print(f"{has_vecs} entities had word vectors, {has_no_vecs} entities did not have word vectors")
return best_label if best_label else labels[0] return best_label if best_label else labels[0]
#load_pretrained() # load_pretrained()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment