diff --git a/Code/inference.py b/Code/inference.py index 441d9e2fcfeb7d00f4f35fa6a848ae49f16f9b28..212d4c5944edbd54daf8f1f12924e42fedd6d56e 100644 --- a/Code/inference.py +++ b/Code/inference.py @@ -1,3 +1,5 @@ +"""Demo for inference: User enters a sentence and our trained BERT model predicts if the target word is literal or non-literal""" + import json import torch import preprocess @@ -8,34 +10,67 @@ import train from torch.utils.data import DataLoader, RandomSampler # Get user input -print("Enter a sentence: ") +print("Enter a sentence and enclose the target word(s) between asteriks (e.g. \"I love *New York*\"): ") sentence = input() -sentence = sentence.split() -print("Specify the target word position using square brackets (e.g. [0,2])") -target_pos = input() -target_json = json.loads(target_pos) -print(type(target_json)) -print("Enter the label: 0 for literal, 1 for non-literal") +def extract_target_words(input_string): + target_words = [] + pattern = r'\*(.*?)\*' + matches = re.findall(pattern, input_string) + for match in matches: + target_words.append(match.strip()) + return target_words + +target_word = extract_target_words(sentence) +split_target = target_word[0].split() + + +def remove_asterisks_and_split(input_string): + pattern = r"\*" + # Remove asterisks and split the input string into a list of words + words = re.sub(pattern, "", input_string).split() + return words + +split_sentence = remove_asterisks_and_split(sentence) + + +def find_target_position(split_sentence, split_target): + start = -1 + end = -1 + for i in range(len(split_sentence)): + if split_sentence[i:i+len(split_target)] == split_target: + start = i + end = i+len(split_target)-1 + break + return [start, end+1] + + +pos = find_target_position(split_sentence, split_target) +target_json = json.loads(pos) +print(f"The target word is {target_word} and at the position {pos}.") + +print("Now enter the label: 0 for literal, 1 for non-literal") label = int(input()) -#label_json=json.loads(label) -print("Is this your target word: ", sentence[target_json[0]: target_json[1]]) +# Convert to data sample for BERT data_sample = [{"sentence": sentence, "pos": target_json, "label": label}] print(data_sample) -tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased") +tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased") input_as_dataset=preprocess.tokenizer_new(tokenizer, data_sample, max_length=512) + + # Load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model=models.WordClassificationModel.from_pretrained("bert-base-uncased") -model_path = "saved_models/bert_.pth" + +model_path = "saved_models/bert.pth" model = torch.load(model_path, map_location=device) -model.eval() +model.eval() train_sampler = RandomSampler(data_sample) train_dataloader=DataLoader(data_sample, sampler=train_sampler, batch_size=1) @@ -52,9 +87,9 @@ for batch in train_dataloader: end_positions=batch[4] outputs=model(**inputs) prediction=torch.argmax(outputs[0]) - if prediciton == 1: + if prediction == 1: print("metonymy") - elif prediciton == 0: + elif prediction == 0: print("literal") - #print("Outputs: ", + diff --git a/documentation/.DS_Store b/documentation/.DS_Store index 2895fc477ea9375256239866164b5716e6cc9944..9f8fbe41b9791d8babd3fac3c4f272fb61652b0a 100644 Binary files a/documentation/.DS_Store and b/documentation/.DS_Store differ