Skip to content
Snippets Groups Projects
train.py 8.41 KiB
import torch
import tqdm
import numpy as np
import evaluation
import evaluate
import json
import random
import math
from tqdm.auto import tqdm
from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaModel, RobertaPreTrainedModel, RobertaConfig,  BertConfig, BertPreTrainedModel, PreTrainedModel, AutoModel, AutoTokenizer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AdamW, get_scheduler
from torch import nn
from torch.nn import CrossEntropyLoss
import matplotlib.pyplot as plt
import os
import pandas as pd
import sklearn

metric=evaluate.load("accuracy")
torch.cuda.empty_cache()

#with torch.autocast("cuda"):

def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer, train_dataset, test_dataset, num_epochs, learning_rate, batch_size, test_batch_size):
	"""Train loop for models. Iterates over epochs and batches and gives inputs to model. After training, call evaluation.py for evaluation of finetuned model."""
	model.train().to("cuda")
	train_sampler = RandomSampler(train_dataset)
	train_dataloader=DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
	num_training_steps=num_epochs*len(train_dataloader)

	optimizer=AdamW(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=0.1)
	lr_scheduler=get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=10, num_training_steps=num_training_steps)

	model.zero_grad()
	for epoch in range(num_epochs):
		index=0
		
		for batch in train_dataloader:
			print(len(batch))
			if name[0] == "b":
				if tmix==False:
					inputs = {'input_ids': batch[0],
							'attention_mask': batch[1],
							'token_type_ids': batch[2],
							'start_position': batch[3],
							'end_position': batch[4],
							'labels': batch[5]}
					labels=batch[5]
					start_positions=batch[3]
					end_positions=batch[4]
				if tmix==True:
					#print("Hello, tmix is set as true")
					if epoch == mixepoch:
						if imdb == False:
							print("this is miuxup epoch")
							#print(batch[5])
							#print("mixlayer: ", mixlayer)
							#print("lambda: ", lambda_value)
                        
							inputs={'input_ids': batch[0],
										'attention_mask': batch[1],
										'token_type_ids': batch[2],
										'start_position': batch[3],
										'end_position': batch[4],
										'labels': batch[5],
										'mixepoch': True,
										'mixlayer':mixlayer,
										'lambda_value':lambda_value}
						if imdb==True:
							print("this is a mixup epoch with imdb")
							inputs={'input_ids':batch[0],
									'attention_mask': batch[1],
									'token_type_ids': batch[2],
									'labels': batch[3],
									'mixepoch': True,
									'mixlayer': mixlayer,
									'lambda_value': lambda_value}
							
					else:
						if imdb == False:
							print("this is a non mixup epoch")
							#print(batch[5])
							inputs={'input_ids': batch[0],
										'attention_mask': batch[1],
										'token_type_ids': batch[2],
										'start_position': batch[3],
										'end_position': batch[4],
										'labels': batch[5],
										'mixepoch': False,
										'mixlayer':mixlayer,
										'lambda_value':lambda_value}
						elif imdb == True:
							print("non mixup epoch with imbd")
							inputs={'input_ids': batch[0],
									'attention_mask': batch[1],
									'token_type_ids': batch[2],
									'labels': batch[3],
									'mixepoch': False,
									'mixlayer': mixlayer,
									'lambda_value':lambda_value}
				

			if name[0] == "r":
				inputs = {'input_ids': batch[0],
						  'attention_mask': batch[1],
						  'start_position': batch[2],
						  'end_position': batch[3],
						  'labels': batch[4]}
				labels = batch[4]
				start_positions=batch[2]
				end_positions=batch[3]
			outputs=model(**inputs)
			loss=outputs[0]
			print("Loss: ", loss)
			loss.backward()
			optimizer.step()
			lr_scheduler.step()
			optimizer.zero_grad()
			model.zero_grad()

			if epoch==mixepoch:
				print("mixepoch")
				if mixup == True:
					#calculate new last hidden states and predictions(logits)
					new_matrix_batch, new_labels_batch = mixup_function(outputs[2], labels, lambda_value, threshold)
					new_matrix_batch.to("cuda")
					new_labels_batch.to("cuda")
					span_output=torch.randn(new_matrix_batch.shape[0], new_matrix_batch.shape[-1]).to("cuda")
					for i in range(new_matrix_batch.shape[0]):
						span_output[i]=new_matrix_batch[i][start_positions[i]:end_positions[i]].mean(dim=0)
					logits=model.classifier(span_output.detach())
					logits = logits.view(-1, 2).to("cuda")
					target = new_labels_batch.view(-1).to("cuda")
					loss_2 = cross_entropy(logits, target, lambda_value)
					
					#update entire model
					loss_2.backward()
					optimizer.step()
					lr_scheduler.step()
					optimizer.zero_grad()
					model.zero_grad()
	
	torch.save(model, "saved_models/bert_baseline.pth")

	#evaluate trained model
	evaluation_test = evaluation.evaluate_model(model, name,  test_dataset, test_batch_size, imdb)
	evaluation_train = evaluation.evaluate_model(model, name, train_dataset, test_batch_size, imdb)
 
	print("TEST: ", evaluation_test)
	print("TRAIN: ", evaluation_train)

	return evaluation_test, evaluation_train

def cross_entropy(logits, target, l):
	results = torch.tensor([], device='cuda')
	for i in range (logits.shape[0]):
		lg = logits[i:i+1,:] #comment to explain the process in this Code Line
		t = target[i]
		#makes the logits in log (base e) probabilities
		logprobs = torch.nn.functional.log_softmax(lg, dim=1)
		value = t.item() #gets Item (0. or 1.)
		if value == 1 or value == 0:
			one_hot = torch.tensor([1-value,value], device='cuda:0') #creating one-hot vector e.g. [0. ,1.]
			#class 1 and 2 mixed
			loss_clear_labels = -((one_hot[0] * logprobs[0][0]) + (one_hot[1] * logprobs[0][1]))
			results = torch.cat((loss_clear_labels.view(1), results), dim=0)
		else:
			value_r = round(value, 1) #to make it equal to lambda_value e.g. 0.4
			#Wert mit Flag
			mixed_vec = torch.tensor([value_r, 1-value_r])
			print("Mixed Vec: ", mixed_vec)
			logprobs = torch.nn.functional.log_softmax(lg, dim=1)
			print("Log:", logprobs)
        	#loss_mixed_labels = -torch.mul(mixed_vec, logprobs).sum()
			loss_mixed_labels = -((mixed_vec[0] * logprobs[0][0]) + (mixed_vec[1] * logprobs[0][1]))
			print("Loss Mixed Lables l: ", loss_mixed_labels)
			results = torch.cat((loss_mixed_labels.view(1), results), dim=0)
			print("Results Mixed 1: ", results)
	print("ALL BATCH Results: ", results)
	batch_loss = results.mean() #compute average
	#print("Batch Loss: ", batch_loss)
	return batch_loss



def mixup_function(batch_of_matrices, batch_of_labels, l):
	"""Function to perform mixup on a batch of matrices and labels with a given lambda
	"""
	runs = math.floor(batch_of_matrices.size()[0]/2)
	counter=0
	results=[]
	result_labels=[]
	for i in range(runs):
		#get matrices and labels out of batch
		matrix1=batch_of_matrices[counter]
		label1=batch_of_labels[counter]
		matrix2=batch_of_matrices[counter+1]
		label2=batch_of_labels[counter+1]

		#do interpolation
		new_matrix=matrix1*l + (1-l)*matrix2
		new_label=l*label1 + (1-l)*label2
		
		if new_matrix != None:
			results.append(new_matrix)
			result_labels.append(new_label)
		counter+=2
	results=torch.stack(results)
	result_labels= torch.stack(result_labels) #torch.LongTensor(result_labels)
	return results, result_labels

	
def train_salami(model, seed, train_set, test_set, batch_size, test_batch_size, learning_rate, epochs):
	"""Train loop of the salami group"""
	results=[]
	training_args = TrainingArguments(
		output_dir="./results",  # output directory
		num_train_epochs=epochs,  # total # of training epochs
		per_device_train_batch_size=batch_size,  # batch size per device during training
		per_device_eval_batch_size=test_batch_size,  # batch size for evaluation
		warmup_steps=10,  # number of warmup steps for learning rate scheduler
		weight_decay=0.1,  # strength of weight decay
		learning_rate=learning_rate,
		evaluation_strategy="no",  # evaluates never, per epoch, or every eval_steps
		eval_steps=10,
		logging_dir="./logs",  # directory for storing logs
		seed=seed,  # explicitly set seed
		save_strategy="no",  # do not save checkpoints
	)


	trainer=Trainer(
		model=model,
		train_dataset=train_set,
		eval_dataset=test_set,
		args=training_args,
		compute_metrics=evaluation.evaluate_model
		)

	trainer.train()
	test_set_results=trainer.evaluate()
	results.append(test_set_results)
	print(test_set_results)

	return results

import torch
import tqdm
import numpy as np