Newer
Older
import preprocess
import models
from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaModel, RobertaPreTrainedModel, RobertaConfig, BertConfig, BertPreTrainedModel, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
import train
from torch.utils.data import DataLoader, RandomSampler
print("Specify the target word position using square brackets (e.g. [0,2])")
target_pos = input()
target_json = json.loads(target_pos)
print(type(target_json))
print("Enter the label: 0 for literal, 1 for non-literal")
label = int(input())
#label_json=json.loads(label)
print("Is this your target word: ", sentence[target_json[0]: target_json[1]])
data_sample = [{"sentence": sentence, "pos": target_json, "label": label}]
input_as_dataset=preprocess.tokenizer_new(tokenizer, data_sample, max_length=512)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=models.WordClassificationModel.from_pretrained("bert-base-uncased")
model_path = "saved_models/bert_.pth"
train_sampler = RandomSampler(data_sample)
train_dataloader=DataLoader(data_sample, sampler=train_sampler, batch_size=1)
for batch in train_dataloader:
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2],
'start_position': batch[3],
'end_position': batch[4],
'labels': batch[5]}
labels=batch[5]
start_positions=batch[3]
end_positions=batch[4]
outputs=model(**inputs)
prediction=torch.argmax(outputs[0])
if prediciton == 1:
print("metonymy")
elif prediciton == 0:
print("literal")