From c691bc2a4593add2fe5df4ce7135a4be9f045b99 Mon Sep 17 00:00:00 2001 From: friebolin <friebolin@cl.uni-heidelberg.de> Date: Fri, 24 Feb 2023 16:00:12 +0100 Subject: [PATCH] Add print prediction --- inference.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/inference.py b/inference.py index 9785f35..7117c11 100644 --- a/inference.py +++ b/inference.py @@ -25,35 +25,35 @@ sentence = input() def extract_target_words(input_string): - target_words = [] - pattern = r'\*(.*?)\*' - matches = re.findall(pattern, input_string) - for match in matches: - target_words.append(match.strip()) - return target_words + target_words = [] + pattern = r'\*(.*?)\*' + matches = re.findall(pattern, input_string) + for match in matches: + target_words.append(match.strip()) + return target_words target_word = extract_target_words(sentence) split_target = target_word[0].split() def remove_asterisks_and_split(input_string): - pattern = r"\*" - # Remove asterisks and split the input string into a list of words - words = re.sub(pattern, "", input_string).split() - return words + pattern = r"\*" + # Remove asterisks and split the input string into a list of words + words = re.sub(pattern, "", input_string).split() + return words split_sentence = remove_asterisks_and_split(sentence) def find_target_position(split_sentence, split_target): - start = -1 - end = -1 - for i in range(len(split_sentence)): - if split_sentence[i:i+len(split_target)] == split_target: - start = i - end = i+len(split_target)-1 - break - return [start, end+1] + start = -1 + end = -1 + for i in range(len(split_sentence)): + if split_sentence[i:i+len(split_target)] == split_target: + start = i + end = i+len(split_target)-1 + break + return [start, end+1] pos = find_target_position(split_sentence, split_target) @@ -97,6 +97,7 @@ for batch in train_dataloader: end_positions = batch[4] outputs = model(**inputs) prediction = torch.argmax(outputs[0]) + print(prediction) if prediction == 1: print("metonymy") elif prediction == 0: -- GitLab