From 655d7e114bd699a2af5c5d459562cf591abaab23 Mon Sep 17 00:00:00 2001 From: friebolin <friebolin@cl.uni-heidelberg.de> Date: Fri, 24 Feb 2023 12:18:46 +0100 Subject: [PATCH] Update inference --- Code/inference.py | 65 ++++++++++++++++++++++++++++++---------- documentation/.DS_Store | Bin 8196 -> 6148 bytes 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/Code/inference.py b/Code/inference.py index 441d9e2..212d4c5 100644 --- a/Code/inference.py +++ b/Code/inference.py @@ -1,3 +1,5 @@ +"""Demo for inference: User enters a sentence and our trained BERT model predicts if the target word is literal or non-literal""" + import json import torch import preprocess @@ -8,34 +10,67 @@ import train from torch.utils.data import DataLoader, RandomSampler # Get user input -print("Enter a sentence: ") +print("Enter a sentence and enclose the target word(s) between asteriks (e.g. \"I love *New York*\"): ") sentence = input() -sentence = sentence.split() -print("Specify the target word position using square brackets (e.g. [0,2])") -target_pos = input() -target_json = json.loads(target_pos) -print(type(target_json)) -print("Enter the label: 0 for literal, 1 for non-literal") +def extract_target_words(input_string): + target_words = [] + pattern = r'\*(.*?)\*' + matches = re.findall(pattern, input_string) + for match in matches: + target_words.append(match.strip()) + return target_words + +target_word = extract_target_words(sentence) +split_target = target_word[0].split() + + +def remove_asterisks_and_split(input_string): + pattern = r"\*" + # Remove asterisks and split the input string into a list of words + words = re.sub(pattern, "", input_string).split() + return words + +split_sentence = remove_asterisks_and_split(sentence) + + +def find_target_position(split_sentence, split_target): + start = -1 + end = -1 + for i in range(len(split_sentence)): + if split_sentence[i:i+len(split_target)] == split_target: + start = i + end = i+len(split_target)-1 + break + return [start, end+1] + + +pos = find_target_position(split_sentence, split_target) +target_json = json.loads(pos) +print(f"The target word is {target_word} and at the position {pos}.") + +print("Now enter the label: 0 for literal, 1 for non-literal") label = int(input()) -#label_json=json.loads(label) -print("Is this your target word: ", sentence[target_json[0]: target_json[1]]) +# Convert to data sample for BERT data_sample = [{"sentence": sentence, "pos": target_json, "label": label}] print(data_sample) -tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased") +tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased") input_as_dataset=preprocess.tokenizer_new(tokenizer, data_sample, max_length=512) + + # Load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model=models.WordClassificationModel.from_pretrained("bert-base-uncased") -model_path = "saved_models/bert_.pth" + +model_path = "saved_models/bert.pth" model = torch.load(model_path, map_location=device) -model.eval() +model.eval() train_sampler = RandomSampler(data_sample) train_dataloader=DataLoader(data_sample, sampler=train_sampler, batch_size=1) @@ -52,9 +87,9 @@ for batch in train_dataloader: end_positions=batch[4] outputs=model(**inputs) prediction=torch.argmax(outputs[0]) - if prediciton == 1: + if prediction == 1: print("metonymy") - elif prediciton == 0: + elif prediction == 0: print("literal") - #print("Outputs: ", + diff --git a/documentation/.DS_Store b/documentation/.DS_Store index 2895fc477ea9375256239866164b5716e6cc9944..9f8fbe41b9791d8babd3fac3c4f272fb61652b0a 100644 GIT binary patch delta 232 zcmZp1XfcprU|?W$DortDU=RQ@Ie-{Mvv5r;6q~50$jG@dU^g=(=Vl&(WsIsk4EYR2 z4CxGs40#Nh3{^n71W4yI<Rz6C7bNB6CjoWtnCvZFs4h`mZD?evqhMrcT&tr{ZE0kn zqhMlYQd`T(A*!rz9TcCPlbe^{HTkuGyk<907fzGQf{XHU^7GPxY8W>b1~4vW=im@z m2J(PFfE!4-f*iZC@H_Klei=`Y(;1i`-T?WGVRJms9A*Hp2rx1L delta 675 zcmZoMXmOBWU|?W$DortDU;r^WfEYvza8E20o2aMAD7-OXH}hr%jz7$c**Q2SHn1=X zZ{}fH##qnI;LDH-ge44_40#Oc48;t347m&`o;mr+NjdpRATxm25s3Bvg8`5QGQWT! zhanNDLl4LT>O)n}D2t|EfT0LzM==m50$Irn84S7*6Y9Z6U|P5y%>a1}1LEOMh-WAV zI=%#GXgtvP0w7EUnwkO>s{lJMg&_^qc}yTL=!1Pz4<vycA)wpJff_*uB!j~uA7})! zA2AGJ>_&4UCs1EHP**%d8qn$-kbhCF1&87qu(cpN85n>%c#ySBc40YHkHwgrbi?4} z{M-VtCJ2E=Rc^kE3pC0&_A>}7GtW2ziZ-k&Q*f%}Lik{EJ*&7h7VTxhMR_^-dFh)E o@c1!vNpJ%#aRnu{&4L`?nJ4p$cuw}`;ox9|q(X+x@jP>w0aUZ0s{jB1 -- GitLab