Skip to content
Snippets Groups Projects
Commit 586ba564 authored by friebolin's avatar friebolin
Browse files

Update inference

parent ef32f86f
No related branches found
No related tags found
No related merge requests found
import argparse
import torch
import preprocess
import train
import models
from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaModel, RobertaPreTrainedModel, RobertaConfig, BertConfig, BertPreTrainedModel, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
import re
import models
import train
from torch.utils.data import DataLoader, RandomSampler
# Get user input
print("Enter a sentence: ")
sentence = input()
sentence = sentence.split()
......@@ -16,22 +18,21 @@ target_pos = input()
print("Enter the label: 0 for literal, 1 for non-literal")
label = int(input())
data_sample = {"sentence": sentence, "pos": target_pos, "label": label}
print(data_sample)
filepath = "./saved_models/bert_baseline.pt"
model=models.BertForWordClassification.from_pretrained("bert-base-uncased")
#tokenizer=AutoTokenizer.from_pretrained(args.architecture)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(filepath))
model.eval() #loads saved model
model=models.BertForWordClassification.from_pretrained("bert-base-uncased")
model_path = "saved_models/bert_baseline.pth"
model = torch.load(model_path, map_location=device)
train_dataset = [{"sentence": ["Yet", "how", "many", "times", "has", "America", "sided", "with", "Israeli", "aggression", "against", "the", "people", "of", "Palestine?"], "pos": [5, 6], "label": 1}]
train_sampler = RandomSampler(train_dataset)
train_dataloader=DataLoader(train_dataset, sampler=train_sampler, batch_size=1)
tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")
train_sampler = RandomSampler(data_sample)
train_dataloader=DataLoader(data_sample, sampler=train_sampler, batch_size=1)
for batch in train_dataloader:
inputs = {'input_ids': batch[0],
......@@ -44,4 +45,5 @@ for batch in train_dataloader:
start_positions=batch[3]
end_positions=batch[4]
outputs=model(**inputs)
print("Outputs: ", outputs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment