Skip to content
Snippets Groups Projects
Commit 5352425a authored by ih322@uni-heidelberg.de's avatar ih322@uni-heidelberg.de
Browse files

T5 NEC use highest probability label

parent 887ffc93
No related branches found
No related tags found
No related merge requests found
import torch
from torch.nn.functional import softmax
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset, DatasetDict
......@@ -14,33 +16,55 @@ print("Finished loading model: T5 NLI")
def infer_nli(premise, hypothesis):
input_text = f"nli hypothesis: {hypothesis} premise: {premise}"
# print("tokenize")
inputs = tokenizer(input_text, return_tensors="pt")
# print("generate")
output_ids = model.generate(**inputs)
# print("decode")
result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return result
def infer_nli_probabilities(premise, hypothesis):
input_text = f"nli hypothesis: {hypothesis} premise: {premise}"
inputs = tokenizer(input_text, return_tensors="pt")
decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]])
with torch.no_grad():
outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
logits = outputs.logits
# Use this code if we also want to get the classification:
# predicted_ids = torch.argmax(logits, dim=-1)
# predicted_label = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
probabilities = softmax(logits, dim=-1)
label_tokens = ["entailment", "contradiction"]
label_ids = [tokenizer.encode(label, add_special_tokens=False)[0] for label in label_tokens]
label_probs = {label: probabilities[0, 0, token_id].item() for label, token_id in zip(label_tokens, label_ids)}
return label_probs
def classify_entity(sentence, entity, labels):
best_label = None
best_prob = 0
# Find label with highest probability of entailment
for label in labels:
# print(f"Label: {label}")
hypothesis = f"{entity} is {label}"
result = infer_nli(sentence, hypothesis)
# print(f"Hypothesis: {hypothesis}, Result: {result}")
if result == "entailment":
return label # Return immediately if entailment is found
elif result == "neutral" and best_label is None:
best_label = label # Store the first neutral label as a fallback
label_probs = infer_nli_probabilities(sentence, hypothesis)
return best_label if best_label else labels[0]
# print(f"label: {label}, entailment prob: {label_probs["entailment"]}, contradiction prob: {label_probs["contradiction"]}")
entailment_prob = label_probs["entailment"]
if entailment_prob > best_prob:
best_prob = entailment_prob
best_label = label
return best_label if best_label else labels[0]
def preprocess_data(sample):
input_text = f"nli hypothesis: {sample['hypothesis']} premise: {sample['premise']}"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment