Skip to content
Snippets Groups Projects
inference.py 1.83 KiB
Newer Older
kulcsar's avatar
kulcsar committed
import json
friebolin's avatar
friebolin committed
import torch
friebolin's avatar
friebolin committed
import preprocess
import models
from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaModel, RobertaPreTrainedModel, RobertaConfig,  BertConfig, BertPreTrainedModel, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
friebolin's avatar
friebolin committed
import re 
friebolin's avatar
friebolin committed
import train
from torch.utils.data import DataLoader, RandomSampler
friebolin's avatar
friebolin committed

friebolin's avatar
friebolin committed
# Get user input
friebolin's avatar
friebolin committed
print("Enter a sentence: ")
friebolin's avatar
friebolin committed
sentence = input()
sentence = sentence.split()
friebolin's avatar
friebolin committed

friebolin's avatar
friebolin committed
print("Specify the target word position using square brackets (e.g. [0,2])")
target_pos = input()
kulcsar's avatar
kulcsar committed
target_json = json.loads(target_pos)
print(type(target_json))
friebolin's avatar
friebolin committed

friebolin's avatar
friebolin committed
print("Enter the label: 0 for literal, 1 for non-literal")
label = int(input())
kulcsar's avatar
kulcsar committed
#label_json=json.loads(label)

print("Is this your target word: ", sentence[target_json[0]: target_json[1]])
friebolin's avatar
friebolin committed

kulcsar's avatar
kulcsar committed
data_sample = [{"sentence": sentence, "pos": target_json, "label": label}]
friebolin's avatar
friebolin committed
print(data_sample)
kulcsar's avatar
kulcsar committed
tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")
friebolin's avatar
friebolin committed

kulcsar's avatar
kulcsar committed
input_as_dataset=preprocess.tokenizer_new(tokenizer, data_sample, max_length=512)
friebolin's avatar
friebolin committed
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
friebolin's avatar
friebolin committed

kulcsar's avatar
kulcsar committed
model=models.WordClassificationModel.from_pretrained("bert-base-uncased")
model_path = "saved_models/bert_.pth"
friebolin's avatar
friebolin committed
model = torch.load(model_path, map_location=device)
kulcsar's avatar
kulcsar committed
model.eval()
kulcsar's avatar
kulcsar committed

friebolin's avatar
friebolin committed

friebolin's avatar
friebolin committed
train_sampler = RandomSampler(data_sample)
train_dataloader=DataLoader(data_sample, sampler=train_sampler, batch_size=1)
friebolin's avatar
friebolin committed

friebolin's avatar
friebolin committed
for batch in train_dataloader:
	inputs = {'input_ids': batch[0],
					'attention_mask': batch[1],
					'token_type_ids': batch[2],
					'start_position': batch[3],
					'end_position': batch[4],
					'labels': batch[5]}
	labels=batch[5]
	start_positions=batch[3]
	end_positions=batch[4]
	outputs=model(**inputs)
kulcsar's avatar
kulcsar committed
	prediction=torch.argmax(outputs[0])
	if prediciton == 1:
		print("metonymy")
	elif prediciton == 0:
		print("literal")
friebolin's avatar
friebolin committed

kulcsar's avatar
kulcsar committed
	#print("Outputs: ",