diff --git a/inference.py b/inference.py index 9785f35c5a56b19eb3afcfcb0d7c7364494af537..7117c1106731499f0d9480be4fe8d87a0925299e 100644 --- a/inference.py +++ b/inference.py @@ -25,35 +25,35 @@ sentence = input() def extract_target_words(input_string): - target_words = [] - pattern = r'\*(.*?)\*' - matches = re.findall(pattern, input_string) - for match in matches: - target_words.append(match.strip()) - return target_words + target_words = [] + pattern = r'\*(.*?)\*' + matches = re.findall(pattern, input_string) + for match in matches: + target_words.append(match.strip()) + return target_words target_word = extract_target_words(sentence) split_target = target_word[0].split() def remove_asterisks_and_split(input_string): - pattern = r"\*" - # Remove asterisks and split the input string into a list of words - words = re.sub(pattern, "", input_string).split() - return words + pattern = r"\*" + # Remove asterisks and split the input string into a list of words + words = re.sub(pattern, "", input_string).split() + return words split_sentence = remove_asterisks_and_split(sentence) def find_target_position(split_sentence, split_target): - start = -1 - end = -1 - for i in range(len(split_sentence)): - if split_sentence[i:i+len(split_target)] == split_target: - start = i - end = i+len(split_target)-1 - break - return [start, end+1] + start = -1 + end = -1 + for i in range(len(split_sentence)): + if split_sentence[i:i+len(split_target)] == split_target: + start = i + end = i+len(split_target)-1 + break + return [start, end+1] pos = find_target_position(split_sentence, split_target) @@ -97,6 +97,7 @@ for batch in train_dataloader: end_positions = batch[4] outputs = model(**inputs) prediction = torch.argmax(outputs[0]) + print(prediction) if prediction == 1: print("metonymy") elif prediction == 0: