From c691bc2a4593add2fe5df4ce7135a4be9f045b99 Mon Sep 17 00:00:00 2001
From: friebolin <friebolin@cl.uni-heidelberg.de>
Date: Fri, 24 Feb 2023 16:00:12 +0100
Subject: [PATCH] Add print prediction

---
 inference.py | 37 +++++++++++++++++++------------------
 1 file changed, 19 insertions(+), 18 deletions(-)

diff --git a/inference.py b/inference.py
index 9785f35..7117c11 100644
--- a/inference.py
+++ b/inference.py
@@ -25,35 +25,35 @@ sentence = input()
 
 
 def extract_target_words(input_string):
-    target_words = []
-    pattern = r'\*(.*?)\*'
-    matches = re.findall(pattern, input_string)
-    for match in matches:
-        target_words.append(match.strip())
-    return target_words
+	target_words = []
+	pattern = r'\*(.*?)\*'
+	matches = re.findall(pattern, input_string)
+	for match in matches:
+		target_words.append(match.strip())
+	return target_words
 
 target_word = extract_target_words(sentence)
 split_target = target_word[0].split()
 
 
 def remove_asterisks_and_split(input_string):
-    pattern = r"\*"
-    # Remove asterisks and split the input string into a list of words
-    words = re.sub(pattern, "", input_string).split()
-    return words
+	pattern = r"\*"
+	# Remove asterisks and split the input string into a list of words
+	words = re.sub(pattern, "", input_string).split()
+	return words
 
 split_sentence = remove_asterisks_and_split(sentence)
 
 
 def find_target_position(split_sentence, split_target):
-    start = -1
-    end = -1
-    for i in range(len(split_sentence)):
-        if split_sentence[i:i+len(split_target)] == split_target:
-            start = i
-            end = i+len(split_target)-1
-            break
-    return [start, end+1]
+	start = -1
+	end = -1
+	for i in range(len(split_sentence)):
+		if split_sentence[i:i+len(split_target)] == split_target:
+			start = i
+			end = i+len(split_target)-1
+			break
+	return [start, end+1]
 
 
 pos = find_target_position(split_sentence, split_target)
@@ -97,6 +97,7 @@ for batch in train_dataloader:
 	end_positions = batch[4]
 	outputs = model(**inputs)
 	prediction = torch.argmax(outputs[0])
+	print(prediction)
 	if prediction == 1:
 		print("metonymy")
 	elif prediction == 0:
-- 
GitLab