Skip to content
Snippets Groups Projects
evaluation.py 2.09 KiB
Newer Older
kulcsar's avatar
kulcsar committed
import torch
import tqdm
import numpy as np
import evaluate
import json
import random
import math
from tqdm.auto import tqdm
from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaModel, RobertaPreTrainedModel, RobertaConfig,  BertConfig, BertPreTrainedModel, PreTrainedModel, AutoModel, AutoTokenizer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AdamW, get_scheduler
from torch import nn
from torch.nn import CrossEntropyLoss
import matplotlib.pyplot as plt
import os
import pandas as pd
import sklearn

metric=evaluate.load("accuracy")
torch.cuda.empty_cache()


def evaluate_model(model, test_dataset, learning_rate, batch_size):
	torch.cuda.empty_cache()
	metric=evaluate.combine(["accuracy", "f1", "precision", "recall"])
	model.eval()

	eval_sampler = SequentialSampler(test_dataset)
	eval_dataloader=DataLoader(test_dataset, sampler=eval_sampler, batch_size=batch_size)

	for batch in eval_dataloader:
		with torch.no_grad():
			if model.name_or_path[0] == "b":
				print("Evaluating Bert model")
				inputs = {'input_ids': batch[0],
						  'attention_mask': batch[1],
						  'token_type_ids': batch[2],
						  'start_position': batch[3],
						  'end_position': batch[4],
						  'labels': batch[5]}
			if model.name_or_path[0] == "r":
				print("Evaluating roberta model")
				inputs = {'input_ids': batch[0],
						  'attention_mask': batch[1],
						  'start_position': batch[2],
						  'end_position': batch[3],
						  'labels': batch[4]}

			outputs=model(**inputs)
			prediction=torch.argmax(outputs[1], dim=-1)
			if model.name_or_path[0] =="b":
				metric.add_batch(predictions=prediction, references=batch[5])
			if model.name_or_path[0] =="r":
				metric.add_batch(predictions=prediction, references=batch[4])

	res=metric.compute()
	#print(f"learning rate {learning_rate}: ", res)
 
	return res


def compute_metrics(eval_pred):
	metric=evaluate.combine(["accuracy", "f1", "precision", "recall"])
	logits, labels=eval_pred
	predictions=np.argmax(logits, axis=-1)
	return metric.compute(predictions=predictions, references=labels)