Skip to content
Snippets Groups Projects
Commit 84e129bc authored by kulcsar's avatar kulcsar
Browse files

add medalpaca diagnosis generation

parent 03e7d5a0
No related branches found
No related tags found
No related merge requests found
import re
from tqdm import tqdm
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, AutoModelForCausalLM, BioGptForCausalLM #, GenerationConfig
import evaluate
import torch
from statistics import mean
import pickle as pkl
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import logging
from transformers import get_scheduler, AdamW
from t5_model import preprocess
#from accelerate import Accelerator
import argparse
logging.basicConfig(filename='log_t5.log', level=logging.DEBUG)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#accelerator=Accelerator()
#device=accelerator.device
def run():
print(args.saved_model)
print(args.tokenizer)
logging.basicConfig(filename=args.log, level=logging.DEBUG)
logging.info("performing evaluation")
logging.info("loading saved model")
#checkpoint=torch.load(args.saved_model)
#model.load_state_dict(checkpoint)
#model=BioGptForCausalLM.from_pretrained(args.saved_model)
model=AutoModelForCausalLM.from_pretrained(args.saved_model)
#model=T5ForConditionalGeneration.from_pretrained("t5-small")
#model.load_state_dict(torch.load(args.saved_model))
#model=accelerator.load("./t5_small_deepspeed_train_test_2.pt")
#model=accelerator.load_state("./t5_small_deepspee_train_test.pt")
tokenizer=AutoTokenizer.from_pretrained(args.tokenizer)
tokenizer.padding_side="left"
tokenizer.pad_token=tokenizer.eos_token
model.config.pad_token_id=model.config.eos_token_id
test_dataset=preprocess(tokenizer, args.test_dataset)
#with open(args.test_dataset, "rb") as f:
# data=pkl.load(f)
#test_dataset=DiagnosesDataset(data, tokenizer)
logging.info("running evaluation")
res=evaluate_model_loop(model, args.config_name, test_dataset, args.batch_size,tokenizer, args.topk, args.temp, args.num_beams, args.early_stopping, args.no_rep_ngram, args.num_return_sequences, args.metrics, args.do_sample, args.generative, args.icd_codes)
logging.info(res)
class DiagnosesDataset(torch.utils.data.Dataset):
def __init__(self, instances, tokenizer):
self.instances=instances
#self.labels=labels
self.tokenizer=tokenizer
def __getitem__(self, idx):
item={}
#print(self.instances[0])
instance=self.instances[idx]
#print(instance)
prompt= instance["prompt"]
labels =instance["label"]
item=self.tokenize(prompt+labels)
tokenized_instruction=self.tokenize(prompt)
label_instruction=self.tokenizer(labels)
i=len(tokenized_instruction["input_ids"])-1
#while 1<len(item["input_ids"])
#print("Len of item labels before ", len(item["labels"]))
item["labels"][i:]=label_instruction["input_ids"]
#item.pop("token_type_ids")
#print(item["labels"])
#we now need to pad to 2048
#print("Len labels: ", len(item["labels"]))
#print("Len input ids: ", len(item["input_ids"]))
#print("\n\n")
#try:
# assert len(item["labels"]) == len(item["input_ids"])
#except AssertionError:
# print(len(item["labels"]))
# print(len(item["input_ids"]))
# print(len(tokenized_instruction["input_ids"]))
# print("\n\n")
# break
return item
def tokenize(self, prompt):
#print(prompt + labels)
result_prompt=self.tokenizer(prompt,
truncation=True,
max_length=1024,
padding=False,
return_tensors=None)
#print(type(result_prompt))
#print(len(result_prompt["input_ids"]))
#result_labels=self.tokenizer(labels,
# truncation=True,
# max_length=1024,
# padding=False,
# return_tensors=None)
#old_labels=result_labels["input_ids"].copy()
#result_prompt["labels"]=[-100 for i in result_prompt["input_ids"]] + result_labels["input_ids"]
#result_prompt["input_ids"]=result_prompt["input_ids"] + old_labels
#print(result_prompt["input_ids"]
#result_prompt["labels"] = [-100 for i in result_prompt["input_ids"]] + result_labels["input_ids"]
#print(len(result_prompt["labels"]))
#assert len(result_prompt["input_ids"]) == len(result_prompt["labels"])
result_prompt["labels"]=[-100]*len(result_prompt["input_ids"])
#print(result_prompt["labels"])
return result_prompt
def __len__(self):
return len(self.instances)
def evaluate_model_loop(model,config_name, test_dataset, batch_size, tokenizer, top_k, temp, num_beams, early_stopping, no_rep, num_return_sequences, metrics, do_sample, generative=False, icd_codes=False):
torch.cuda.empty_cache()
print("Testing model")
print("Metric: ", metrics)
if len(metrics) > 1:
metric=evaluate.combine(metrics)
else:
print("here")
metric=evaluate.load(metrics[0])
model.eval().to(device)
print("Batch size: ", batch_size)
print("tokenizer: ", tokenizer)
print("topk: ", top_k)
print("temp: ", temp)
print("num_beams: ", num_beams)
print("early stopping: ", early_stopping)
print("no rep: ", no_rep)
print("num_return_sequences: ", num_return_sequences)
print("metrics: ", metrics)
print("generative? ", generative)
#config=GenerationConfig.from_pretrained(config_name, top_k=top_k, temperature=temp, num_beams=num_beams, early_stopping=early_stopping, no_repeat_ngram_size=no_rep, num_return_sequences=num_return_sequences, max_length=512)
#print(config.num_return_sequences)
#print(tokenizer.max_length)
eval_sampler=SequentialSampler(test_dataset)
eval_dataloader=DataLoader(test_dataset, sampler=eval_sampler, batch_size=batch_size)
accuracies=[]
f1s=[]
recalls=[]
precs=[]
for index, batch in tqdm(enumerate(eval_dataloader)):
with torch.no_grad():
#print(batch.to(device))
#if index == 20:
# break
print(len(batch[0]))
#input_ids=torch.tensor(batch[0]).unsqueeze(0).to(device)
input_ids=batch[0]
#print(input_ids.size())
attention_mask=batch[1]
labels=batch[2]
labels_len=len(batch[2][batch[2] != tokenizer.pad_token_id])
input_ids_len=len(batch[0][0])
print("Len input ids: ", input_ids_len)
print("Len labels: ", labels_len)
#attention_mask=torch.tensor(batch[1]).unsqueeze(0).to(device)
#labels=torch.tensor(batch[2]).unsqueeze(0).to(device)
#last_inp_token_index=batch[2][::-1].index(-100)
#last_occurrence=len(batch[2]) - last_inp_token_index
#print(last_occurrence)
#label_length=len(batch["input_ids"][last_occurrence])
#print(label_length)
#outputs=model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=512, top_k=top_k, temperature=temp, num_beams=num_beams, early_stopping=early_stopping, no_repeat_ngram_size=no_rep, num_return_sequences=num_return_sequences)
outputs=model.generate(input_ids=input_ids, attention_mask=attention_mask, top_k=top_k, temperature=temp, num_beams=num_beams, early_stopping=early_stopping, no_repeat_ngram_size=no_rep, num_return_sequences=num_return_sequences, max_new_tokens=labels_len+8)# , length_penalty=-0.8)
#print("Outputs: ",len(outputs[0]))
#print("Len input ids: ", labels_len)
#outputs=model(input_ids=input_ids, attention_mask=attention_mask)
#print(outputs.sequences())
#test_decoded=tokenizer.decode(outputs[0], skip_special_tokens=True)
#print(test_decoded)
preds=tokenizer.decode(outputs[0][input_ids_len:], skip_special_tokens=True)
if icd_codes == False:
#preds=re.findall('[A-Z]{1}[^A-Z]{1,}',preds)
preds=preds.split(", ")
elif icd_codes == True:
preds=preds.split(",")
#print(input_ids[0])
#print("Prompt: ", tokenizer.decode(input_ids[0], skip_special_tokens=True))
#print("Predicitons", preds)
labels_decoded=tokenizer.decode(labels[0], skip_special_tokens=True)
if icd_codes == False:
#labels_decoded=re.findall('[A-Z]{1}[^A-Z]{1,}', labels_decoded)
labels_decoded=labels_decoded.split(", ")
elif icd_codes == True:
labels_decoded=labels_decoded.split(", ")
preds_clean=[p.strip(" ").lstrip(" ").lower() for p in preds]
labels_clean=[l.strip(" ").lstrip(" ").lower() for l in labels_decoded]
print("Predicitons ", preds_clean)
print("Labels: ", labels_clean)
#print(metrics)
if "accuracy" not in metrics:
metric.add_batch(predictions=[preds], references=[labels_decoded])
else:
#print("here")
TP=len([t for t in preds_clean if t in labels_clean])
print(TP)
FP=len([f for f in preds_clean if f not in labels_clean])
print(FP)
FN=len([fn for fn in labels_clean if fn not in preds_clean])
print(FN)
#accuracies.append(TP/ (TP+FP+FN))
print("Precision: ", TP/(TP+FP))
print("Recall: ", TP/(TP+FN))
print("F1: ", TP/(TP+ 0.5*(FP+FN)))
p=TP/(TP+FP) if TP+FP !=0 else 0
r=TP/(TP+FN) if TP+FN != 0 else 0
#f1_v= (2*(p*r))/(p+r) if p+r !=0 else 0
#precs.append(TP/(TP+FP) if (TP+FP) != 0 else 0)
#recalls.append(TP/(TP+FN) if (TP+FN) != 0 else 0)
#f1s.append(TP/(TP + (0.5*(FP+FN))) if (TP + 0.5*(FP+FN)) != 0 else 0) #formel
precs.append(p)
recalls.append(r)
#f1s.append(f1_v)
#res=metric.compute()
prec=mean(precs)
recs=mean(recalls)
f1_v=(2*prec*recs)/(prec+recs)
print([prec, recs, f1_v])
res= "Precision: "+ str(prec)+", Recall: "+ str(recs) + ", F1: "+ str(f1_v)
print("RESULTS: ", "".join(res))
return res
if __name__=="__main__":
#tokenizer=T5Tokenizer.from_pretrained("t5-small")
#dataset=preprocess(tokenizer, "./test/test_dataset_diagnoses_icd.pkl")
#model=torch.load("./t5_small_10000_icd.pt")
#res=evaluate_model(model, dataset, 2, tokenizer)
parser=argparse.ArgumentParser()
parser.add_argument(
"--saved_model",
"-sm",
help="Path to saved model"
)
parser.add_argument(
"--tokenizer",
"-t",
help="Tokenizer to use (Name on huggingface)"
)
parser.add_argument(
"--test_dataset",
"-td",
help="Path to test dataset"
)
parser.add_argument(
"--batch_size",
"-b",
type=int,
help="Batch size"
)
parser.add_argument(
"--log",
"-l",
help="Path to save log"
)
parser.add_argument(
"-topk",
type=int,
help="Top k value for generation"
)
parser.add_argument(
"-temp",
type=float,
help="Temperature for generation"
)
parser.add_argument(
"-num_beams",
type=int,
help="Number of beams for generation"
)
parser.add_argument(
"-early_stopping",
action="store_true",
help="Use early stopping in generation?"
)
parser.add_argument(
"-no_rep_ngram",
type=int,
help="How many ngram not repeated"
)
parser.add_argument(
"-num_return_sequences",
type=int,
default=1,
help="Number of return sequences"
)
parser.add_argument(
"-do_sample",
action="store_true",
help="Whether to sample"
)
parser.add_argument(
"--metrics",
"-m",
nargs="+",
help="List all the metrics you want to use for evaluation"
)
parser.add_argument(
"--icd_codes",
action="store_true",
help="Whether outputs are icd codes"
)
#parser.add_argument(
# "--output_type",
# "-ot",
# choices=["classification", "generative_diag", "generative_ds"]
# help="Output type of the model that is evaluated"
# )
parser.add_argument(
"--generative",
"-g",
action="store_true",
help="Is the model a generative diagnosis model?",
)
parser.add_argument(
"--config_name",
help="Name of the config to use for generation")
args=parser.parse_args()
run()
This diff is collapsed.
#!/bin/bash
#
#SBATCH --job-name=medalpaca
#SBATCH --output=output_train_medalpaca_lora_min_400.txt
#SBATCH --mail-user=kulcsar@cl.uni-heidelberg.de
#SBATCH --mail-type=ALL
#SBATCH --partition=single
#SBATCH --nodes=1
#SBATCH --time=120:00:00
#SBATCH --mem 80G
#SBATCH --gres=gpu:1
#SBATCH --ntasks=1
#JOB STEPS
#srun hostname
#cd /home/students/kulcsar/
#source /home/students/kulcsar/anaconda3/etc/profile.d/conda.sh
#conda activate software_bubble_updated_accelerate
#cd /home/students/kulcsar/Bachelor/for_dataset/10000_diagnoses
#accelerate config
#python -m torch.distributed.launch --nproc_per_node=2 --use_env t5_model.py --model luqh/ClinicalT5-large --tokenizer luqh/ClinicalT5-large --dataset ./dev/dev_dataset_diagnoses_icd_48_pref.pkl --ff --seed 42 --batch_size 8 --epochs 3 --learning_rate 0.000005 --model_save_path t5_small_10000_icd_48_rs42_prefix.pt --do_eval --test_dataset ./test/test_dataset_diagnoses_icd_48_pref.pkl --topk 3 --temp 0.9 --num_beams 4 --no_rep_ngram 2 --do_sample --metrics accuracy --log log_clinicalt5.log
CUDA_LAUNCH_BLOCKING=1 python falcon_model_peft.py --model medalpaca/medalpaca-7b --tokenizer medalpaca/medalpaca-7b --dataset ./train_dataset_diagnoses_shortened_min_400_split.pkl --test_dataset ./test_dataset_diagnoses_shortened_min_400_split.pkl --seed 42 --batch_size 1 --gradient_accumulation_steps 4 --epochs 100 --learning_rate 2e-5 --model_save_path medalpaca_promp_tuned_5_shot.pt
#--do_eval --test_dataset ./test/test_dataset_diagnoses_48.pkl --topk 30 --temp 0.9 --num_beams 4 --no_rep_ngram 2 --do_sample --metrics accuracy --num_return_sequences 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment