Skip to content
Snippets Groups Projects
Commit 655d7e11 authored by friebolin's avatar friebolin
Browse files

Update inference

parent 4466681b
No related branches found
No related tags found
No related merge requests found
"""Demo for inference: User enters a sentence and our trained BERT model predicts if the target word is literal or non-literal"""
import json
import torch
import preprocess
......@@ -8,34 +10,67 @@ import train
from torch.utils.data import DataLoader, RandomSampler
# Get user input
print("Enter a sentence: ")
print("Enter a sentence and enclose the target word(s) between asteriks (e.g. \"I love *New York*\"): ")
sentence = input()
sentence = sentence.split()
print("Specify the target word position using square brackets (e.g. [0,2])")
target_pos = input()
target_json = json.loads(target_pos)
print(type(target_json))
print("Enter the label: 0 for literal, 1 for non-literal")
def extract_target_words(input_string):
target_words = []
pattern = r'\*(.*?)\*'
matches = re.findall(pattern, input_string)
for match in matches:
target_words.append(match.strip())
return target_words
target_word = extract_target_words(sentence)
split_target = target_word[0].split()
def remove_asterisks_and_split(input_string):
pattern = r"\*"
# Remove asterisks and split the input string into a list of words
words = re.sub(pattern, "", input_string).split()
return words
split_sentence = remove_asterisks_and_split(sentence)
def find_target_position(split_sentence, split_target):
start = -1
end = -1
for i in range(len(split_sentence)):
if split_sentence[i:i+len(split_target)] == split_target:
start = i
end = i+len(split_target)-1
break
return [start, end+1]
pos = find_target_position(split_sentence, split_target)
target_json = json.loads(pos)
print(f"The target word is {target_word} and at the position {pos}.")
print("Now enter the label: 0 for literal, 1 for non-literal")
label = int(input())
#label_json=json.loads(label)
print("Is this your target word: ", sentence[target_json[0]: target_json[1]])
# Convert to data sample for BERT
data_sample = [{"sentence": sentence, "pos": target_json, "label": label}]
print(data_sample)
tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")
input_as_dataset=preprocess.tokenizer_new(tokenizer, data_sample, max_length=512)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=models.WordClassificationModel.from_pretrained("bert-base-uncased")
model_path = "saved_models/bert_.pth"
model_path = "saved_models/bert.pth"
model = torch.load(model_path, map_location=device)
model.eval()
model.eval()
train_sampler = RandomSampler(data_sample)
train_dataloader=DataLoader(data_sample, sampler=train_sampler, batch_size=1)
......@@ -52,9 +87,9 @@ for batch in train_dataloader:
end_positions=batch[4]
outputs=model(**inputs)
prediction=torch.argmax(outputs[0])
if prediciton == 1:
if prediction == 1:
print("metonymy")
elif prediciton == 0:
elif prediction == 0:
print("literal")
#print("Outputs: ",
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment