From 655d7e114bd699a2af5c5d459562cf591abaab23 Mon Sep 17 00:00:00 2001
From: friebolin <friebolin@cl.uni-heidelberg.de>
Date: Fri, 24 Feb 2023 12:18:46 +0100
Subject: [PATCH] Update inference

---
 Code/inference.py       |  65 ++++++++++++++++++++++++++++++----------
 documentation/.DS_Store | Bin 8196 -> 6148 bytes
 2 files changed, 50 insertions(+), 15 deletions(-)

diff --git a/Code/inference.py b/Code/inference.py
index 441d9e2..212d4c5 100644
--- a/Code/inference.py
+++ b/Code/inference.py
@@ -1,3 +1,5 @@
+"""Demo for inference: User enters a sentence and our trained BERT model predicts if the target word is literal or non-literal"""
+
 import json
 import torch
 import preprocess
@@ -8,34 +10,67 @@ import train
 from torch.utils.data import DataLoader, RandomSampler
 
 # Get user input
-print("Enter a sentence: ")
+print("Enter a sentence and enclose the target word(s) between asteriks (e.g. \"I love *New York*\"): ")
 sentence = input()
-sentence = sentence.split()
 
-print("Specify the target word position using square brackets (e.g. [0,2])")
-target_pos = input()
-target_json = json.loads(target_pos)
-print(type(target_json))
 
-print("Enter the label: 0 for literal, 1 for non-literal")
+def extract_target_words(input_string):
+    target_words = []
+    pattern = r'\*(.*?)\*'
+    matches = re.findall(pattern, input_string)
+    for match in matches:
+        target_words.append(match.strip())
+    return target_words
+
+target_word = extract_target_words(sentence)
+split_target = target_word[0].split()
+
+
+def remove_asterisks_and_split(input_string):
+    pattern = r"\*"
+    # Remove asterisks and split the input string into a list of words
+    words = re.sub(pattern, "", input_string).split()
+    return words
+
+split_sentence = remove_asterisks_and_split(sentence)
+
+
+def find_target_position(split_sentence, split_target):
+    start = -1
+    end = -1
+    for i in range(len(split_sentence)):
+        if split_sentence[i:i+len(split_target)] == split_target:
+            start = i
+            end = i+len(split_target)-1
+            break
+    return [start, end+1]
+
+
+pos = find_target_position(split_sentence, split_target)
+target_json = json.loads(pos)
+print(f"The target word is {target_word} and at the position {pos}.")
+
+print("Now enter the label: 0 for literal, 1 for non-literal")
 label = int(input())
-#label_json=json.loads(label)
 
-print("Is this your target word: ", sentence[target_json[0]: target_json[1]])
 
+# Convert to data sample for BERT
 data_sample = [{"sentence": sentence, "pos": target_json, "label": label}]
 print(data_sample)
-tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")
 
+tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")
 input_as_dataset=preprocess.tokenizer_new(tokenizer, data_sample, max_length=512)
+
+
 # Load model
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 model=models.WordClassificationModel.from_pretrained("bert-base-uncased")
-model_path = "saved_models/bert_.pth"
+
+model_path = "saved_models/bert.pth"
 model = torch.load(model_path, map_location=device)
-model.eval()
 
+model.eval()
 
 train_sampler = RandomSampler(data_sample)
 train_dataloader=DataLoader(data_sample, sampler=train_sampler, batch_size=1)
@@ -52,9 +87,9 @@ for batch in train_dataloader:
 	end_positions=batch[4]
 	outputs=model(**inputs)
 	prediction=torch.argmax(outputs[0])
-	if prediciton == 1:
+	if prediction == 1:
 		print("metonymy")
-	elif prediciton == 0:
+	elif prediction == 0:
 		print("literal")
 
-	#print("Outputs: ", 
+
diff --git a/documentation/.DS_Store b/documentation/.DS_Store
index 2895fc477ea9375256239866164b5716e6cc9944..9f8fbe41b9791d8babd3fac3c4f272fb61652b0a 100644
GIT binary patch
delta 232
zcmZp1XfcprU|?W$DortDU=RQ@Ie-{Mvv5r;6q~50$jG@dU^g=(=Vl&(WsIsk4EYR2
z4CxGs40#Nh3{^n71W4yI<Rz6C7bNB6CjoWtnCvZFs4h`mZD?evqhMrcT&tr{ZE0kn
zqhMlYQd`T(A*!rz9TcCPlbe^{HTkuGyk<907fzGQf{XHU^7GPxY8W>b1~4vW=im@z
m2J(PFfE!4-f*iZC@H_Klei=`Y(;1i`-T?WGVRJms9A*Hp2rx1L

delta 675
zcmZoMXmOBWU|?W$DortDU;r^WfEYvza8E20o2aMAD7-OXH}hr%jz7$c**Q2SHn1=X
zZ{}fH##qnI;LDH-ge44_40#Oc48;t347m&`o;mr+NjdpRATxm25s3Bvg8`5QGQWT!
zhanNDLl4LT>O)n}D2t|EfT0LzM==m50$Irn84S7*6Y9Z6U|P5y%>a1}1LEOMh-WAV
zI=%#GXgtvP0w7EUnwkO>s{lJMg&_^qc}yTL=!1Pz4<vycA)wpJff_*uB!j~uA7})!
zA2AGJ>_&4UCs1EHP**%d8qn$-kbhCF1&87qu(cpN85n>%c#ySBc40YHkHwgrbi?4}
z{M-VtCJ2E=Rc^kE3pC0&_A>}7GtW2ziZ-k&Q*f%}Lik{EJ*&7h7VTxhMR_^-dFh)E
o@c1!vNpJ%#aRnu{&4L`?nJ4p$cuw}`;ox9|q(X+x@jP>w0aUZ0s{jB1

-- 
GitLab