Skip to content
Snippets Groups Projects
Commit c691bc2a authored by friebolin's avatar friebolin
Browse files

Add print prediction

parent 6ea968d4
No related branches found
No related tags found
No related merge requests found
......@@ -25,35 +25,35 @@ sentence = input()
def extract_target_words(input_string):
target_words = []
pattern = r'\*(.*?)\*'
matches = re.findall(pattern, input_string)
for match in matches:
target_words.append(match.strip())
return target_words
target_words = []
pattern = r'\*(.*?)\*'
matches = re.findall(pattern, input_string)
for match in matches:
target_words.append(match.strip())
return target_words
target_word = extract_target_words(sentence)
split_target = target_word[0].split()
def remove_asterisks_and_split(input_string):
pattern = r"\*"
# Remove asterisks and split the input string into a list of words
words = re.sub(pattern, "", input_string).split()
return words
pattern = r"\*"
# Remove asterisks and split the input string into a list of words
words = re.sub(pattern, "", input_string).split()
return words
split_sentence = remove_asterisks_and_split(sentence)
def find_target_position(split_sentence, split_target):
start = -1
end = -1
for i in range(len(split_sentence)):
if split_sentence[i:i+len(split_target)] == split_target:
start = i
end = i+len(split_target)-1
break
return [start, end+1]
start = -1
end = -1
for i in range(len(split_sentence)):
if split_sentence[i:i+len(split_target)] == split_target:
start = i
end = i+len(split_target)-1
break
return [start, end+1]
pos = find_target_position(split_sentence, split_target)
......@@ -97,6 +97,7 @@ for batch in train_dataloader:
end_positions = batch[4]
outputs = model(**inputs)
prediction = torch.argmax(outputs[0])
print(prediction)
if prediction == 1:
print("metonymy")
elif prediction == 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment