Skip to content
Snippets Groups Projects
evaluation.py 3.05 KiB
Newer Older
kulcsar's avatar
kulcsar committed
import torch
import tqdm
import numpy as np
import evaluate
import json
import random
import math
from tqdm.auto import tqdm
from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaModel, RobertaPreTrainedModel, RobertaConfig,  BertConfig, BertPreTrainedModel, PreTrainedModel, AutoModel, AutoTokenizer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AdamW, get_scheduler
from torch import nn
from torch.nn import CrossEntropyLoss
import matplotlib.pyplot as plt
import os
import pandas as pd
import sklearn

metric=evaluate.load("accuracy")
torch.cuda.empty_cache()


kulcsar's avatar
kulcsar committed
def evaluate_model(model, name,test_dataset, batch_size, imdb=False):
	"""Evaluation for model. Iterates over test set and computes accuracy, f1, precision
	and recall based on outputs of models and true labels.
	
	Params:
	model:mode						   -> model trained in train loop
	name:str						   -> name of the model (for input format)
	test_dataset: list of dictionaries -> test dataset
	batch_size: int					   -> batch size for test dataset
	imdb: bool 						   -> whether or not imdb dataset is used
	
	Returns: Accuracy, F1, Precision, Recall 
	"""
kulcsar's avatar
kulcsar committed
	torch.cuda.empty_cache()
	print("eval swp")
kulcsar's avatar
kulcsar committed
	metric=evaluate.combine(["accuracy", "f1", "precision", "recall"])
	model.eval()

	eval_sampler = SequentialSampler(test_dataset)
	eval_dataloader=DataLoader(test_dataset, sampler=eval_sampler, batch_size=batch_size)

	for batch in eval_dataloader:
		with torch.no_grad():
			if name[0] == "b":
				if imdb==False:
					print("Evaluating Bert model")
					inputs = {'input_ids': batch[0],
							  'attention_mask': batch[1],
							  'token_type_ids': batch[2],
							  'start_position': batch[3],
							  'end_position': batch[4],
							  'labels': batch[5]}
				elif imdb==True:
					print("Evaluating Bert model on imdb")
					inputs={'input_ids':batch[0],
							'attention_mask':batch[1],
							'token_type_ids':batch[2],
							'labels':batch[3]}

			if name[0] == "r":
kulcsar's avatar
kulcsar committed
				print("Evaluating roberta model")
				inputs = {'input_ids': batch[0],
						  'attention_mask': batch[1],
						  'start_position': batch[2],
						  'end_position': batch[3],
						  'labels': batch[4]}

kulcsar's avatar
kulcsar committed
			outputs=model(**inputs) #get logits
			prediction=torch.argmax(outputs[1], dim=-1) #get predictions
			if name[0] =="b":
kulcsar's avatar
kulcsar committed
				if imdb==False:
					metric.add_batch(predictions=prediction, references=batch[5])
				else:
kulcsar's avatar
kulcsar committed
					metric.add_batch(predictions=prediction, references=batch[3])
			if name[0] =="r":
kulcsar's avatar
kulcsar committed
				metric.add_batch(predictions=prediction, references=batch[4])

kulcsar's avatar
kulcsar committed
	res=metric.compute() 
kulcsar's avatar
kulcsar committed
	return res


def compute_metrics(eval_pred):
kulcsar's avatar
kulcsar committed
	""""Compute metrics function to apply predictions and references to
	accuracy, f1, precision and recall, also used in salami train loop.
	
	Params:
	eval_pred: tuple ->(logits, labels)
	
	Returns: Accuracy, F1, Precision, Recall"""
kulcsar's avatar
kulcsar committed
	metric=evaluate.combine(["accuracy", "f1", "precision", "recall"])
	logits, labels=eval_pred
	predictions=np.argmax(logits, axis=-1)
	return metric.compute(predictions=predictions, references=labels)