Skip to content
Snippets Groups Projects
Commit 61b39a78 authored by kulcsar's avatar kulcsar
Browse files

demo saved model

parent 152b69bc
No related branches found
No related tags found
No related merge requests found
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
import json
import torch
import preprocess
import models
from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaModel, RobertaPreTrainedModel, RobertaConfig, BertConfig, BertPreTrainedModel, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
import re
import models
import train
from torch.utils.data import DataLoader, RandomSampler
......@@ -14,11 +14,16 @@ sentence = sentence.split()
print("Specify the target word position using square brackets (e.g. [0,2])")
target_pos = input()
target_json = json.loads(target_pos)
print(type(target_json))
print("Enter the label: 0 for literal, 1 for non-literal")
label = int(input())
#label_json=json.loads(label)
print("Is this your target word: ", sentence[target_json[0]: target_json[1]])
data_sample = {"sentence": sentence, "pos": target_pos, "label": label}
data_sample = [{"sentence": sentence, "pos": target_json, "label": label}]
print(data_sample)
tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")
......@@ -26,10 +31,10 @@ input_as_dataset=preprocess.tokenizer_new(tokenizer, data_sample, max_length=512
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=models.BertForWordClassification.from_pretrained("bert-base-uncased")
model_path = "saved_models/bert_baseline.pth"
model=models.WordClassificationModel.from_pretrained("bert-base-uncased")
model_path = "saved_models/bert_.pth"
model = torch.load(model_path, map_location=device)
model.eval()
train_sampler = RandomSampler(data_sample)
......@@ -46,5 +51,10 @@ for batch in train_dataloader:
start_positions=batch[3]
end_positions=batch[4]
outputs=model(**inputs)
prediction=torch.argmax(outputs[0])
if prediciton == 1:
print("metonymy")
elif prediciton == 0:
print("literal")
print("Outputs: ", outputs)
#print("Outputs: ",
......@@ -22,7 +22,7 @@ torch.cuda.empty_cache()
#with torch.autocast("cuda"):
def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer, train_dataset, test_dataset, num_epochs, learning_rate, batch_size, test_batch_size):
def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer, train_dataset, test_dataset, num_epochs, learning_rate, batch_size, test_batch_size, model_save_path=None):
"""Train loop for models. Iterates over epochs and batches and gives inputs to model. After training, call evaluation.py for evaluation of finetuned model.
Params:
......@@ -107,7 +107,7 @@ def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer,
model.zero_grad()
if epoch==mixepoch:
print("mixepoch")
#print("mixepoch")
if mixup == True:
#calculate new last hidden states and predictions(logits)
new_matrix_batch, new_labels_batch = mixup_function(outputs[2], labels, lambda_value, threshold)
......@@ -127,12 +127,12 @@ def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer,
lr_scheduler.step()
optimizer.zero_grad()
model.zero_grad()
torch.save(model, "saved_models/bert_baseline.pth")
if model_save_path != None:
torch.save(model, model_save_path)
#evaluate trained model
evaluation_test = evaluation.evaluate_model(model, name, test_dataset, test_batch_size, imdb)
evaluation_train = evaluation.evaluate_model(model, name, train_dataset, test_batch_size, imdb)
evaluation_test = Code.evaluation.evaluate_model(model, name, test_dataset, test_batch_size, imdb)
evaluation_train = Code.evaluation.evaluate_model(model, name, train_dataset, test_batch_size, imdb)
print("TEST: ", evaluation_test)
print("TRAIN: ", evaluation_train)
......
......@@ -26,16 +26,16 @@ def run(raw_args):
print("Loading tokenizers and initializing model...")
tokenizer=AutoTokenizer.from_pretrained(args.architecture)
models.set_seed(args.random_seed)
Code.models.set_seed(args.random_seed)
if args.model_type == "separate":
if args.architecture=="bert-base-uncased":
model=models.BertForWordClassification.from_pretrained(args.architecture).to("cuda")
model=Code.models.BertForWordClassification.from_pretrained(args.architecture).to("cuda")
elif args.architecture=="roberta-base":
model=models.RobertaForWordClassification.from_pretrained(args.architecture).to("cuda")
model=Code.models.RobertaForWordClassification.from_pretrained(args.architecture).to("cuda")
else:
print("non eligible model type selected")
elif args.model_type == "one":
model=models.WordClassificationModel(args.architecture, args.tmix, args.imdb).to("cuda")
model=Code.models.WordClassificationModel(args.architecture, args.tmix, args.imdb).to("cuda")
else:
print("non eligible model type selected")
......@@ -43,22 +43,22 @@ def run(raw_args):
#preprocess...
print("preprocessing datasets...")
if args.imdb==True:
train_dataset=preprocess.tokenizer_imdb(tokenizer, data_train, args.max_length)
test_dataset=preprocess.tokenizer_imdb(tokenizer, data_test, args.max_length)
train_dataset=Code.preprocess.tokenizer_imdb(tokenizer, data_train, args.max_length)
test_dataset=Code.preprocess.tokenizer_imdb(tokenizer, data_test, args.max_length)
elif args.tokenizer=="salami":
train_dataset=preprocess.salami_tokenizer(tokenizer, data_train, args.max_length, masked=args.masking) #no context implemented
test_dataset=preprocess.salami_tokenizer(tokenizer, data_test, args.max_length, masked=args.masking)
train_dataset=Code.preprocess.salami_tokenizer(tokenizer, data_train, args.max_length, masked=args.masking) #no context implemented
test_dataset=Code.preprocess.salami_tokenizer(tokenizer, data_test, args.max_length, masked=args.masking)
elif args.tokenizer=="swp":
print("train dataset preprocessing ")
print(args.tcontext)
train_dataset=preprocess.tokenizer_new(tokenizer, data_train, args.max_length, masked=args.masking, old_dataset=args.tcontext)
test_dataset=preprocess.tokenizer_new(tokenizer, data_test, args.max_length, masked=args.masking, old_dataset=False)
train_dataset=Code.preprocess.tokenizer_new(tokenizer, data_train, args.max_length, masked=args.masking, old_dataset=args.tcontext)
test_dataset=Code.preprocess.tokenizer_new(tokenizer, data_test, args.max_length, masked=args.masking, old_dataset=False)
elif args.tokenizer=="li":
train_dataset=preprocess.tokenizer_new(tokenizer, data_train, args.max_length, masked=args.masking) #no context implemented
test_dataset=preprocess.tokenizer_new(tokenizer, data_test, args.max_length, masked=args.masking)
train_dataset=Code.preprocess.tokenizer_new(tokenizer, data_train, args.max_length, masked=args.masking) #no context implemented
test_dataset=Code.preprocess.tokenizer_new(tokenizer, data_test, args.max_length, masked=args.masking)
else:
print("non eligible tokenizer selected")
......@@ -66,9 +66,9 @@ def run(raw_args):
#train&evaluate...
print("training..")
if args.train_loop=="swp":
evaluation_test, evaluation_train = train.train(model, args.architecture, args.imdb, args.random_seed, args.mix_up, args.lambda_value, args.mixepoch, args.tmix, args.mixlayer, train_dataset, test_dataset, args.epochs, args.learning_rate, args.batch_size, args.test_batch_size)
evaluation_test, evaluation_train = Code.train.train(model, args.architecture, args.imdb, args.random_seed, args.mix_up, args.lambda_value, args.mixepoch, args.tmix, args.mixlayer, train_dataset, test_dataset, args.epochs, args.learning_rate, args.batch_size, args.test_batch_size, args.model_save_path)
elif args.train_loop=="salami":
evaluation_test = train.train_salami(model,args.random_seed, train_dataset, test_dataset, args.batch_size, args.test_batch_size, args.learning_rate, args.epochs)
evaluation_test = Code.train.train_salami(model,args.random_seed, train_dataset, test_dataset, args.batch_size, args.test_batch_size, args.learning_rate, args.epochs)
else:
print("no eligible train loop selected")
......@@ -203,7 +203,7 @@ if __name__ == "__main__":
"--mixepoch",
help="specify the epoch(s) in which to apply mixup",
type=int,
default=1)
default=None)
#Test arguments
......@@ -219,6 +219,11 @@ if __name__ == "__main__":
"-sd",
"--save_directory",
help="Directory to save run")
parser.add_argument(
"-msp",
"--model_save_path",
help="path to save model")
args = parser.parse_args()
run(args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment