Skip to content
Snippets Groups Projects
Commit b445b9a7 authored by kulcsar's avatar kulcsar
Browse files

add stuff

parent 7d943072
No related branches found
No related tags found
No related merge requests found
......@@ -16,28 +16,15 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#accelerator=Accelerator()
#device=accelerator.device
def run():
print(args.saved_model)
print(args.tokenizer)
logging.basicConfig(filename=args.log, level=logging.DEBUG)
logging.info("performing evaluation")
logging.info("loading saved model")
#checkpoint=torch.load(args.saved_model)
#model.load_state_dict(checkpoint)
#model=BioGptForCausalLM.from_pretrained(args.saved_model)
model=AutoModelForCausalLM.from_pretrained(args.saved_model)
#model=T5ForConditionalGeneration.from_pretrained("t5-small")
#model.load_state_dict(torch.load(args.saved_model))
#model=accelerator.load("./t5_small_deepspeed_train_test_2.pt")
#model=accelerator.load_state("./t5_small_deepspee_train_test.pt")
tokenizer=AutoTokenizer.from_pretrained(args.tokenizer)
tokenizer.padding_side="left"
tokenizer.pad_token=tokenizer.eos_token
model.config.pad_token_id=model.config.eos_token_id
test_dataset=preprocess(tokenizer, args.test_dataset)
#with open(args.test_dataset, "rb") as f:
# data=pkl.load(f)
#test_dataset=DiagnosesDataset(data, tokenizer)
logging.info("running evaluation")
res=evaluate_model_loop(model, args.config_name, test_dataset, args.batch_size,tokenizer, args.topk, args.temp, args.num_beams, args.early_stopping, args.no_rep_ngram, args.num_return_sequences, args.metrics, args.do_sample, args.generative, args.icd_codes)
logging.info(res)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment