Skip to content
Snippets Groups Projects
Commit 602ebabb authored by friebolin's avatar friebolin
Browse files

Update

parent bb0f97ee
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ import torch
import tqdm
import numpy as np
import evaluation
import evaluate
#import evaluate
import json
import random
import math
......@@ -21,7 +21,7 @@ import sklearn
import torch
import tqdm
import numpy as np
import evaluate
#import evaluate
import json
import random
import math
......@@ -48,16 +48,17 @@ from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaMode
filepath = "bert_baseline.pt"
model=models.BertForWordClassification.from_pretrained("bert-base-uncased").to("cuda")
model=models.BertForWordClassification.from_pretrained("bert-base-uncased")
#tokenizer=AutoTokenizer.from_pretrained(args.architecture)
model.load_state_dict(torch.load(filepath))
model.eval()
model.eval() #loads saved model
train_dataset = [{"sentence": ["Yet", "how", "many", "times", "has", "America", "sided", "with", "Israeli", "aggression", "against", "the", "people", "of", "Palestine?"], "pos": [5, 6], "label": 1}]
train_sampler = RandomSampler(train_dataset)
train_dataloader=DataLoader(train_dataset, sampler=train_sampler, batch_size=1)
for batch in train_dataloader:
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
......@@ -68,7 +69,5 @@ for batch in train_dataloader:
labels=batch[5]
start_positions=batch[3]
end_positions=batch[4]
outputs=model(**inputs)
print("Outputs: ", outputs)
......@@ -190,7 +190,7 @@ def train(model, name, imdb, seed,gradient_accumulation_steps,mixup, threshold,
#progress_bar.update(1)
#print("one epoch done")
torch.save(model.state_dict(), "./Code/saved_models/bert_baseline.pt")
torch.save(model.state_dict(), "bert_baseline.pt")
#print(model_name)
evaluation_test = evaluation.evaluate_model(model, name, test_dataset, learning_rate, test_batch_size, imdb)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment