Skip to content
Snippets Groups Projects
Commit 2814e9a7 authored by Thomas Wolf's avatar Thomas Wolf
Browse files

NLI with LLMs

parent a191f4d7
No related branches found
No related tags found
No related merge requests found
...@@ -8,6 +8,7 @@ from src.models.GLiNER import classify_entity as classify_entity_gliner ...@@ -8,6 +8,7 @@ from src.models.GLiNER import classify_entity as classify_entity_gliner
from src.models.T5_NLI import classify_entity as classify_entity_t5_nli from src.models.T5_NLI import classify_entity as classify_entity_t5_nli
from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label from src.models.T5_MLM_label import classify_entity as classify_entity_t5_mlm_label
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities as find_entities_llm
from src.experiments.NER_with_LLMs.NER_with_LLMs import find_entities_nli as find_entities_nli_llm
def classify_entity(model_name, sentence, entity, labels): def classify_entity(model_name, sentence, entity, labels):
...@@ -28,10 +29,12 @@ def predict_mask_mlm(model_name, masked_sentence, labels): ...@@ -28,10 +29,12 @@ def predict_mask_mlm(model_name, masked_sentence, labels):
""" """
def predict_mask_nli(model_name, masked_sentence, labels): def predict_mask_nli(model_name, sentence, entity, labels):
""" """
Barack Obama was the president of the US. -> Barack Obama is/was a [Mask] Return prediction(s) of model for [Mask] Barack Obama was the president of the US. -> Barack Obama is/was a [Mask] Return prediction(s) of model for [Mask]
""" """
if model_name in llms:
return find_entities_nli_llm(model_name, sentence, entity, labels)
def find_entities(model_name, sentence, labels): def find_entities(model_name, sentence, labels):
......
""" """
Evaluates GLiNER as SotA and plots results using reusable functions in plotter.py. Evaluates GLiNER as SotA.
""" """
from src.metrics import NER_metrics, read_NER_metrics from src.metrics import NER_metrics, read_NER_metrics
......
...@@ -20,16 +20,29 @@ If there is no valid entity in the target sentence, answer "<answer>[]</answer>" ...@@ -20,16 +20,29 @@ If there is no valid entity in the target sentence, answer "<answer>[]</answer>"
""" """
nli_system_prompt = """You are part of a named entity classification pipeline.
Given an entity, find the best fitting label of the provided labels (no not invent your own labels!).
Choose the label so that the sentence "Target entity is a chosen_label" makes the most sense.
Mark your result like this for easy extraction: <answer>predicted_class</answer>.
Example:
Labels == ['person', 'organization', 'location', 'miscellaneous']
Sentence: 'Europe rejects German call to boycott British lamb.'
Target Entity: Europe
Desired Result: <answer>organization</answer>
"""
def find_entities(model_name, sentence, labels): def find_entities(model_name, sentence, labels):
"""Gets answer from model as list.""" """Gets answer from model as list using the naive approach of asking the model to find all entities."""
model = LLM.create(model_name) model = LLM.create(model_name)
# We add the system prompt directly to the prompt, because DeepSeek documentation says that this is better. # We add the system prompt directly to the prompt, because DeepSeek documentation says that this is better.
prompt = system_prompt + "\nLabels: " + str(labels) + "\nTarget sentence: " + sentence prompt = system_prompt + "\nLabels: " + str(labels) + "\nTarget sentence: " + sentence
answer = model(prompt) answer = model(prompt)
answer = answer.encode("utf-8").decode("utf-8-sig") answer = answer.encode("utf-8").decode("utf-8-sig")
# Extract content between <answer> tags # Extract content between <answer> tags
match = re.search(r"<answer>(.*?)</answer>", answer, re.DOTALL) # todo: sometimes mysteriously does not work. Why?? match = re.search(r"<answer>(.*?)</answer>", answer, re.DOTALL)
if match: if match:
extracted_text = match.group(1).strip() extracted_text = match.group(1).strip()
extracted_text = extracted_text.replace("\n", "").replace(" ", " ") extracted_text = extracted_text.replace("\n", "").replace(" ", " ")
...@@ -42,4 +55,19 @@ def find_entities(model_name, sentence, labels): ...@@ -42,4 +55,19 @@ def find_entities(model_name, sentence, labels):
print(e) print(e)
return "No answer was marked, here is the model's response: " + answer return "No answer was marked, here is the model's response: " + answer
# todo find entities with NLI / MLM
def find_entities_nli(model_name, sentence, entity, labels):
"""
Uses NLI to classify a specified entity in a sentence. Returns string entity class.
"""
model = LLM.create(model_name)
# We add the system prompt directly to the prompt, because DeepSeek documentation says that this is better.
prompt = nli_system_prompt + "\nLabels: " + str(labels) + "\nTarget sentence: " + sentence + "\nTarget entity: " + entity
answer = model(prompt)
answer = answer.encode("utf-8").decode("utf-8-sig")
# Extract content between <answer> tags
match = re.search(r"<answer>(.*?)</answer>", answer, re.DOTALL)
if match:
result = match.group(1).strip()
return result
return "No answer was marked, here is the model's response: " + answer
from src.common_interface import predict_mask_nli
from src.metrics import precision, recall, f1_score
tested_models = ["Llama-3.1-8B", "DeepSeek-R1-Distill-Qwen-32B"]
test_labels = ["person", "organization", "time", "location", "miscellaneous"]
test_sentence = "Apollo 11 was a spaceflight conducted in July 1969 by the United States and launched " \
"by NASA, sending the astronauts Neil Armstrong and Buzz Aldrin to become the first humans to walk on the moon."
true_labels = [('Apollo 11', 'miscellaneous'), ('July 1969', 'time'),
('United States', 'organization'), ('NASA', 'organization'), ('Neil Armstrong', 'person'),
('Buzz Aldrin', 'person'), ('moon', 'location')]
print("Test sentence:\n" + test_sentence)
for model in tested_models:
print(f"Testing model {model}...")
predicted_entities = []
for pair in true_labels:
entity = pair[0]
predicted_label = predict_mask_nli(model, test_sentence, entity, test_labels)
predicted_entities.append((entity, predicted_label))
print(f"{model} found entities: \n{predicted_entities}")
print(f"Precision: {precision(true_labels, predicted_entities)}")
print(f"Recall: {recall(true_labels, predicted_entities)}")
print(f"F1-score: {f1_score(true_labels, predicted_entities)}\n")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment