diff --git a/src/models/T5.py b/src/models/T5.py index 08ee89c5930a31615fe83326f0595b623a762701..8e4cfd9958e6d4c95b853f866ed4401314865d27 100644 --- a/src/models/T5.py +++ b/src/models/T5.py @@ -1,3 +1,5 @@ +import torch +from torch.nn.functional import softmax from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq from datasets import Dataset, DatasetDict @@ -14,33 +16,55 @@ print("Finished loading model: T5 NLI") def infer_nli(premise, hypothesis): input_text = f"nli hypothesis: {hypothesis} premise: {premise}" - - # print("tokenize") inputs = tokenizer(input_text, return_tensors="pt") - - # print("generate") output_ids = model.generate(**inputs) - - # print("decode") result = tokenizer.decode(output_ids[0], skip_special_tokens=True) return result +def infer_nli_probabilities(premise, hypothesis): + input_text = f"nli hypothesis: {hypothesis} premise: {premise}" + + inputs = tokenizer(input_text, return_tensors="pt") + + decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]]) + + with torch.no_grad(): + outputs = model(**inputs, decoder_input_ids=decoder_input_ids) + + logits = outputs.logits + + # Use this code if we also want to get the classification: + # predicted_ids = torch.argmax(logits, dim=-1) + # predicted_label = tokenizer.decode(predicted_ids[0], skip_special_tokens=True) + + probabilities = softmax(logits, dim=-1) + + label_tokens = ["entailment", "contradiction"] + label_ids = [tokenizer.encode(label, add_special_tokens=False)[0] for label in label_tokens] + + label_probs = {label: probabilities[0, 0, token_id].item() for label, token_id in zip(label_tokens, label_ids)} + + return label_probs + def classify_entity(sentence, entity, labels): best_label = None + best_prob = 0 + # Find label with highest probability of entailment for label in labels: - # print(f"Label: {label}") hypothesis = f"{entity} is {label}" - result = infer_nli(sentence, hypothesis) - # print(f"Hypothesis: {hypothesis}, Result: {result}") - if result == "entailment": - return label # Return immediately if entailment is found - elif result == "neutral" and best_label is None: - best_label = label # Store the first neutral label as a fallback + label_probs = infer_nli_probabilities(sentence, hypothesis) - return best_label if best_label else labels[0] + # print(f"label: {label}, entailment prob: {label_probs["entailment"]}, contradiction prob: {label_probs["contradiction"]}") + + entailment_prob = label_probs["entailment"] + if entailment_prob > best_prob: + best_prob = entailment_prob + best_label = label + + return best_label if best_label else labels[0] def preprocess_data(sample): input_text = f"nli hypothesis: {sample['hypothesis']} premise: {sample['premise']}"