From a4b4d13d2d72177b925fa0025f66d1f0a454f482 Mon Sep 17 00:00:00 2001
From: Marton Kulcsar <kulcsar@cl.uni-heidelberg.de>
Date: Fri, 24 Feb 2023 18:56:48 +0100
Subject: [PATCH] add stuff to models.py, train.py and main.py for tmix

---
 Code/models.py     |  57 ++++++------
 Code/preprocess.py | 213 ++++++++-------------------------------------
 Code/train.py      |  55 +++++++-----
 main.py            |  21 ++++-
 4 files changed, 114 insertions(+), 232 deletions(-)

diff --git a/Code/models.py b/Code/models.py
index c903d11..f488662 100644
--- a/Code/models.py
+++ b/Code/models.py
@@ -58,20 +58,26 @@ class WordClassificationModel(torch.nn.Module):
                  model and the computed loss value.
 
     """
-    def __init__(self, config_name, tmix=False, imdb=False): #mixlayer=-1, lambda_value=0.0):
+    def __init__(self, config_name, tmix=False, imdb=False, mlp_flag=False): #mixlayer=-1, lambda_value=0.0):
         super(WordClassificationModel, self).__init__()
         self.tmix=tmix
         self.imdb=imdb
+        self.mlp_flag=mlp_flag
         #self.mixlayer=mixlayer
         if tmix:
             print("initializing BertModelTMix")
-            self.embedding_model=BertModelTMix(config=AutoConfig.from_pretrained(config_name)).to(device)
+            self.embedding_model=BertModelTMix.from_pretrained(config_name, config=AutoConfig.from_pretrained(config_name)).to(device)
         else:
             self.embedding_model=AutoModel.from_pretrained(config_name, config=AutoConfig.from_pretrained(config_name)).to(device)
         
-
-        self.dropout=nn.Dropout(0.1)
-        self.classifier = nn.Linear(768, 2)  
+        if mlp_flag==False:
+           print("Using Linear Classifier")
+           self.classifier=nn.Linear(768, 2)
+        elif mlp_flag==True:
+           print("Using two layer Multi Layer Perceptron")
+           self.classifier=nn.Sequential(nn.Linear(768, 128), nn.Tanh(), nn.Linear(128, 2))
+
+        self.dropout=nn.Dropout(0.1)  
     def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, start_position=None, end_position=None, labels=None, mixepoch=False, mixlayer=None, lambda_value=None):
 
         if self.tmix==True:
@@ -82,9 +88,8 @@ class WordClassificationModel(torch.nn.Module):
                                     position_ids=position_ids,
                                     head_mask=head_mask,
                                     return_dict=False,
-                                        output_hidden_states=False,
-                                       labels=labels,
-                                    mixepoch=mixepoch, 
+                                    output_hidden_states=False,
+                                    labels=labels, 
                                     mixlayer=mixlayer,
                                     lambda_value=lambda_value)
         else:
@@ -107,14 +112,11 @@ class WordClassificationModel(torch.nn.Module):
             logits = self.classifier(span_output)
 
         else:
-            span_output=torch.randn(output.shape[0], output.shape[-1]).to(output.device)
-            for i in range(output.shape[0]):
-                span_output[i]=output[i].mean(dim=0)
+            span_output=torch.mean(output, 1)
             logits=self.classifier(span_output)
 
         if self.tmix==True and mixepoch == True:
             outputs = (logits,) + outputs[2:]
-
             loss = train.cross_entropy(logits[:math.floor((logits.size()[0]/2))], outputs[1][:math.floor((outputs[1].size()[0]/2))], lambda_value) #special CEL for soft labels 
             outputs = (loss,) + outputs
         
@@ -152,7 +154,7 @@ class BertForWordClassification(BertPreTrainedModel):
 
         self.bert=BertModel(config)
         self.dropout=nn.Dropout(config.hidden_dropout_prob)
-        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) #selbst machen!!
+        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
 
         self.init_weights()
 
@@ -167,7 +169,7 @@ class BertForWordClassification(BertPreTrainedModel):
             head_mask=head_mask)
         
         output = outputs[0] 
-        output = self.dropout(output) #apply droput
+        output = self.dropout(output)
         span_output = torch.randn(output.shape[0],output.shape[-1]).to(output.device)
         for i in range(output.shape[0]):
                 span_output[i] = output[i][start_position[i]:end_position[i]].mean(dim=0)
@@ -208,7 +210,7 @@ class RobertaForWordClassification(RobertaPreTrainedModel):
 
         self.roberta=RobertaModel(config)
         self.dropout=nn.Dropout(config.hidden_dropout_prob)
-        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) #selbst machen!!
+        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 
 
         self.init_weights()
 
@@ -220,8 +222,8 @@ class RobertaForWordClassification(RobertaPreTrainedModel):
                                 position_ids=position_ids,
                                 head_mask=head_mask)
 
-        output = outputs[0] #get outputs from bert
-        output = self.dropout(output) #apply droput
+        output = outputs[0]
+        output = self.dropout(output)
         span_output = torch.randn(output.shape[0],output.shape[-1]).to(output.device)
         for i in range(output.shape[0]):
             span_output[i] = output[i][start_position[i]:end_position[i]].mean(dim=0)
@@ -305,7 +307,6 @@ class BertModelTMix(BertPreTrainedModel):
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
         labels=None,
-        mixepoch=False,
         mixlayer=None,
         lambda_value=None
     ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
@@ -406,7 +407,6 @@ class BertModelTMix(BertPreTrainedModel):
             output_hidden_states=output_hidden_states,
             return_dict=return_dict,
             labels=labels,
-            mixepoch=mixepoch,
             mixlayer=mixlayer,
             lambda_value=lambda_value
         )
@@ -460,12 +460,10 @@ class BertTMixEncoder(torch.nn.Module):
         output_hidden_states: Optional[bool] = False,
         return_dict: Optional[bool] = True,
         mixlayer: int = None,
-        lambda_value: float=0.0,
-        mixepoch: bool = False) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+        lambda_value: float=0.0) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
         all_hidden_states = () if output_hidden_states else None
         all_self_attentions = () if output_attentions else None
         all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
-        all_labels=()
         next_decoder_cache = () if use_cache else None
         for i, layer_module in enumerate(self.layer):
 
@@ -509,8 +507,7 @@ class BertTMixEncoder(torch.nn.Module):
                     output_attentions,
                     lambda_value,
                     mixlayer=mixlayer,
-                    nowlayer=i,
-                    mixepoch=mixepoch
+                    nowlayer=i
                 )
 
             hidden_states = layer_outputs[0]
@@ -573,10 +570,9 @@ def forward_new(forward):
             return_dict: Optional[bool] = True,
             lambda_value: float=0.4,
             mixlayer: list=None,
-            nowlayer: int=0,
-            mixepoch: bool=False)-> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+            nowlayer: int=0)-> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
         new_matrices=[]
-        if nowlayer == mixlayer and mixepoch==True:
+        if nowlayer == mixlayer:
             runs = math.floor(hidden_states.size()[0]/2)
             counter=0
             new_attention_masks=[]
@@ -599,12 +595,10 @@ def forward_new(forward):
                 try: 
                     index1=((attention_mask_1[0][0]== -10000.).nonzero(as_tuple=False)[0]).item()
                 except IndexError:
-                    print(attention_mask_1.size())
                     index1=attention_mask_1.size()[0]
                 try: 
                     index2=((attention_mask_2[0][0]== -10000.).nonzero(as_tuple=False)[0]).item()
                 except IndexError:
-                    print(attention_mask_2.size())
                     index2=attention_mask_2.size()[0]
                 if index1>= index2:
                     selected_attention_mask=attention_mask_1
@@ -625,9 +619,8 @@ def forward_new(forward):
             new_attention_masks=torch.stack(new_attention_masks).to(device)
             new_labels=torch.Tensor(new_labels).to(device)
 
-            #when performing interpolation, pass back th new hidden states and labels
-            outputs=forward(self, hidden_states=new_matrices, head_mask=head_mask, attention_mask=new_attention_masks, encoder_hidden_states=encoder_hidden_states,
-                encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_values, output_attentions=output_attentions) #I"m a bit confused here... do we have to add self or rather not? 
+            #when performing interpolation, pass back the new hidden states and labels
+            outputs=[new_matrices, new_attention_masks]
             labels=copy.deepcopy(new_labels)
         else:
 
diff --git a/Code/preprocess.py b/Code/preprocess.py
index a3586e1..7c179ca 100644
--- a/Code/preprocess.py
+++ b/Code/preprocess.py
@@ -16,16 +16,16 @@ import os
 import pandas as pd
 import sklearn
 
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 #metric=evaluate.load("accuracy")
 torch.cuda.empty_cache()
 
 
 def reposition(dp, old_dataset=False):
-	"""Reposition fucntion to find the character level indices of the metonymy (to map back in tokenier_new
+	"""Reposition function to find the character level indices of the metonymy (to map back in tokenizer_new
 	function by char_to_tokens)
 	params:
-		dp -> json readin of li et al shaped dataset
+		dp -> json readin of li et al shaped dataset(or original dataset)
+		old_dataset: bool -> Wheter or not Li et al Datasets are used (True:no, False:yes)
 
 	returns:
 		new_start -> int: new start position of metonymy on character level (including whitespaces)
@@ -34,7 +34,6 @@ def reposition(dp, old_dataset=False):
 	new_end=0
 	if old_dataset ==False:
 		new_dp= " ".join(dp["sentence"]).lower()
-	
 		if dp["pos"][0]==0:
 			new_start=len(" ".join(dp["sentence"][:dp["pos"][0]]))
 		else:
@@ -86,62 +85,35 @@ def tokenizer_new(tokenizer, input, max_length, masked=False, old_dataset=False,
 			dp["sentence"][dp["pos"][0] : dp["pos"][1]] == "<mask>" 
 	
 		
-		#if old_dataset == False:
-		#find new char-pos for metonymic word
-		#new_start_pos, new_end_pos = reposition(dp)
-		#old_target=" ".join(dp["sentence"][dp["pos"][0]:dp["pos"][1]]).lower()
-
-		#### implement rest: tokenize metonymic sentence and encode the rest and pad with it
-
-
-		#old_target=new_dp[dp["pos"][0]: dp["pos"][1]].lower() #old target already on character level
-
-		#assert new_dp[new_start_pos:new_end_pos].strip() == "".join(dp["sentence"][dp["pos"][0]: dp["pos"][1]]).lower()
 
 		if old_dataset == False:
+			#if Li et al dataset: reposition positions on character level, encode sentence and extract target 
 			new_start_pos, new_end_pos = reposition(dp, old_dataset=False)
 			new_dp= " ".join(dp["sentence"]).lower()
 			encoded_inp=tokenizer.encode_plus(new_dp, add_special_tokens=True, max_length=max_length, padding="max_length", truncation=True)
-			#tf_tokens=tokenizer.convert_ids_to_tokens(encoded_inp["input_ids"])
-			#print(tf_tokens)
-			#print(typ(encoded_inp))
-			#print("start pos: ", new_start_pos)
-			#print("end pos: ", new_end_pos)
-			#print("sentence: ", new_dp)
 			old_target="".join(dp["sentence"][dp["pos"][0]:dp["pos"][1]]).lower()
 		else:
-			#print("new dataset")
+			#If old dataset(original Markert semeval): reposition and add context
 			new_start_pos, new_end_pos = reposition(dp, old_dataset=True)
 			new_dp= " ".join(dp["sentence"][1]).lower()
 			encoded_inp=tokenizer.encode_plus(new_dp, add_special_tokens=True) #dont add max length and padding so we can do it manually
 			length_metonymies = len(encoded_inp["input_ids"])
 			context_len=max_length - length_metonymies #length of how much context tokens we can add. We add from left to right
-			#print("metonymy sentece: ", new_dp)
-			#print("metonymy sentence inputs: ", encoded_inp)
-			#print("context length: ", context_len)
 			inp_before=" ".join(dp["sentence"][0]).lower()
-			#print("input before: ", inp_before)
 			encoded_inp_before=tokenizer.encode_plus(inp_before, add_special_tokens=True) #encode before and after context
-
-			#print("encoded inputs before: ", encoded_inp_before)
-			#print("\n")
 			inp_after=" ".join(dp["sentence"][2]).lower()
-			#print("input after: ", inp_after)
 			encoded_inp_after=tokenizer.encode_plus(inp_after , add_special_tokens=True)
-			#print("encoded inputs after: ", encoded_inp_after)
-			#print("\n")
-			#print("\n")
 			
 			#Preprare input for new dictionary with context
 			context_input_ids=[] 
 			context_attention_masks=[]
-			if tokenizer.name_or_path[0] == "b": #BER Tokenizer has token type ids too
+			if tokenizer.name_or_path[0] == "b": #BERT Tokenizer has token type ids too
 				context_token_type_ids=[]
 
 			length_before=len(encoded_inp_before["input_ids"])
 			length_after=len(encoded_inp_after["input_ids"])
 
-
+			#Pad before
 			if length_before>=context_len/2 and length_after>=context_len/2:
 				index_before=int(context_len/2)
 				index_after=int(context_len/2)
@@ -155,6 +127,7 @@ def tokenizer_new(tokenizer, input, max_length, masked=False, old_dataset=False,
 				else:
 					index_after=int(math.ceil(wanted_from_after))
 
+			#Pad after
 			elif length_after<context_len/2 and length_before>=context_len/2:
 				index_after=length_after
 				difference_after=(context_len/2)-length_after
@@ -168,108 +141,37 @@ def tokenizer_new(tokenizer, input, max_length, masked=False, old_dataset=False,
 				index_before=length_before
 				index_after=length_after
 
-			#print("len before: ", length_before)
-			#print("len after: ", length_after)
-			#print("index_before: ", index_before)
-			#print("index_after: ", index_after)
-			#print("not used: ", context_len-index_before-index_after)
-
-			#Use the calculated indices to append the right tokens and pad to 512 if needed, recalculate metonymy position and prepare for decoding metonymy
-
-			#before_decoded="".join(tokenizer.decode(encoded_inp_before["input_ids"][length_before-index_before:length_before]))
-			#if tokenizer.name_or_path[0]=="b":
-			#	before_decoded.replace("[CLS]", "") #.replace(" [SEP]", "")
-			#print(before_decoded)
-
 			context_input_ids=context_input_ids + encoded_inp_before["input_ids"][length_before-index_before:length_before]
 			context_input_ids=context_input_ids + encoded_inp["input_ids"]
 			context_input_ids=context_input_ids + encoded_inp_after["input_ids"][length_after-index_after:length_after] 
 			context_input_ids=context_input_ids+([0]*(512-len(context_input_ids))) #pad
-			#print("new input ids: ", len(context_input_ids))
 
 
 			context_attention_masks= context_attention_masks+encoded_inp_before["attention_mask"][length_before-index_before:length_before]
 			context_attention_masks=context_attention_masks+encoded_inp["attention_mask"]
 			context_attention_masks=context_attention_masks+encoded_inp_after["attention_mask"][length_after-index_after:length_after]
 			context_attention_masks=context_attention_masks+([0]* (512-len(context_attention_masks))) #pad
-			#print("new attention maks: ", len(context_attention_masks))
 
-			if tokenizer.name_or_path[0] == "b": #BER Tokenizer has token type ids too
+			if tokenizer.name_or_path[0] == "b": #BERT Tokenizer has token type ids too
 				context_token_type_ids=context_token_type_ids + encoded_inp_before["token_type_ids"][length_before-index_before:length_before]
 				context_token_type_ids=context_token_type_ids +encoded_inp["token_type_ids"]
 				context_token_type_ids=context_token_type_ids +encoded_inp_after["token_type_ids"][length_after-index_after:length_after]
 				context_token_type_ids=context_token_type_ids+([0]*(512-len(context_token_type_ids)))
-				#print("new token type ids: ", len(context_token_type_ids))
 				assert len(context_token_type_ids) == 512
-
+			
+			#make sure we pad to maximum
 			assert len(context_input_ids) == 512 and len(context_attention_masks) == 512
-			print(len(context_input_ids))
-
-			#get tokeniized words for before sentence and the metonymy sentence
+			
 
+			#get tokenized words for before sentence and the metonymy sentence
 			tokenized_before=[]
 			for i in range(len(" ".join(dp["sentence"][0]).lower())):
 				tokenized_before.append((encoded_inp.char_to_token(i, sequence_index=0)))
-			#print(tokenized_before)
-
-			#tokenized_words = []
-			#for i in range(len(new_dp)): #range(len(new_dp))
-			#	tokenized_words.append((encoded_inp.char_to_token(i, sequence_index=0)))
-			#print(tokenized_words)
-
-			#span=[]
-
-			#for i in tokenized_words[new_start_pos:new_end_pos]:
-			#	if i is not None:
-			#		span.append(i+len(encoded_inp_before["input_ids"]))
-
-			#new_start_pos=new_start_pos+len(encoded_inp_before["input_ids"]) #update inces by adding the number of tokens that are in before sentence
-			#new_end_pos=new_end_pos+len(encoded_inp_before["input_ids"])
-			#print(span)
-			#indices_to_tokens=list(set(span))
-			#indices_to_tokens.sort()
-			#print(indices_to_tokens)
-			#if len(indices_to_tokens)==1:
-			#	print("decoding 1")
-			#	decoded="".join(tokenizer.decode(context_input_ids[indices_to_tokens[0]])).strip().replace(" ", "")
-			#else:
-			#	print("decoding 2")
-			#	#print("indices_to_tokens: ", indices_to_tokens)
-			#	decoded="".join(tokenizer.decode(context_input_ids[indices_to_tokens[0]:indices_to_tokens[-1]+1])).strip().replace(" ", "")
-			#print(decoded)
 
 			old_target="".join(dp["sentence"][1][dp["pos"][0]:dp["pos"][1]]).lower()
-			#print("old_target: ", old_target)
-			#make an encoded_inp dictionary -> not needed, because we use lists directly
-			#encoded_inp={"input_ids": context_input_ids, "attention_mask": context_attention_masks}
-
-			#if tokenizer.name_or_path[0] =="b":
-			#	encoded_inp["token_type_ids"]=context_token_type_ids
-			#print(encoded_inp)
-
-
-		#print(len(encoded_inp["input_ids"]))
-
-		#li et al approach
-		"""
-		if old_dataset==False:
-			orig_to_tok_index2=[]
-			all_tokens2 = ['[CLS]'] 
-			for (i, token) in enumerate(dp["sentence"]):
-				orig_to_tok_index2.append(len(all_tokens2))
-				sub_tokens = tokenizer#.tokenize(token)
-				for sub_token in sub_tokens:
-					all_tokens2.append(sub_token)
-			orig_to_tok_index2.append(len(all_tokens2))
-			new_target="".join(tf_tokens[orig_to_tok_index2[dp["pos"][0]]:orig_to_tok_index2[dp["pos"][1]]]).replace("##", "").lower()
-			print("orig to tok index: ", [orig_to_tok_index2[dp["pos"][0]], orig_to_tok_index2[dp["pos"][1]]])
-			print("new_target: ", repr(new_target))
-		"""
 
 		tokenized_words = []
-		for i in range(len(new_dp)): #range(len(new_dp))
-			#if(new_dp[i])==" ":
-			#	continue #spaces are connected with the words with the roberta tokenizer and are thus always mapped to None
+		for i in range(len(new_dp)): 
 			tokenized_words.append((encoded_inp.char_to_token(i, sequence_index=0)))
 
 		span=[]
@@ -279,68 +181,35 @@ def tokenizer_new(tokenizer, input, max_length, masked=False, old_dataset=False,
 					span.append(i+index_before)
 				else:
 					span.append(i)
-		#if old_dataset==True:
-		#	new_start_pos=new_start_pos+len(encoded_inp_before["input_ids"]) #update inces by adding the number of tokens that are in before sentence
-		#	new_end_pos=new_end_pos+len(encoded_inp_before["input_ids"])
 
 		indices_to_tokens=list(set(span))
 		indices_to_tokens.sort()
-		#print(indices_to_tokens)
-		#print("indices to tokens: ", indices_to_tokens)
 
+		#decode new positioned tokens to check for false mapping
 		if old_dataset==False: 
 			if len(indices_to_tokens)==1:
-				#print("decoding 1")
 				decoded="".join(tokenizer.decode(encoded_inp["input_ids"][indices_to_tokens[0]])).strip().replace(" ", "")
 			else:
-				#print("decoding 2")
-				#print("indices_to_tokens: ", indices_to_tokens)
 				decoded="".join(tokenizer.decode(encoded_inp["input_ids"][indices_to_tokens[0]:indices_to_tokens[-1]+1])).strip().replace(" ", "")
 		else:
 			if len(indices_to_tokens)==1:
-				#print("decoding 1")
 				decoded="".join(tokenizer.decode(context_input_ids[indices_to_tokens[0]])).strip().replace(" ", "")
 			else:
-				#print("decoding 2")
-				#print("indices_to_tokens: ", indices_to_tokens)
 				decoded="".join(tokenizer.decode(context_input_ids[indices_to_tokens[0]:indices_to_tokens[-1]+1])).strip().replace(" ", "")
-			
-			#print("newly_decoded: ", decoded)
-		
-		#old_dp=" ".join(dp["sentence"]).lower()
-		#print(old_dp)
-		#old_target="".join(old_dp[dp["pos"][0]: dp["pos"][1]]).lower()
-		#old_target="".join(dp["sentence"][dp["pos"][0]:dp["pos"][1]]).lower()
+
 		if old_target!=decoded:
 			print("wrong mapping")
-			if old_dataset == True:
-				print("new_start_pos: ", new_start_pos)
-				print("lenght of before: ", len(encoded_inp_before["input_ids"]))
-				print("lengh of after: ", len(encoded_inp_after["input_ids"]))
-				print("after input ids: ", encoded_inp_after["input_ids"])
-				print("Used from before: ", index_before)
-				print("Used from after: ", index_after)
-				print("metonomy sentence length: ", len(encoded_inp["input_ids"]))
-				print("left for filling: ", context_len)
-			print("indices to tokens: ", indices_to_tokens)
-			print("decoded: ", decoded)
-			print("old target: ", old_target)
-			print(dp)
-			#print(old_dp)
-			#mapping_counter+=1
 			continue
 
-
+		
 		all_start_positions.append(indices_to_tokens[0])
 		all_end_positions.append(indices_to_tokens[-1]+1)
 		all_labels.append(dp["label"])
 		if old_dataset==False:
 			all_input_ids.append(encoded_inp["input_ids"])
-			#print("len input ids: ", len(all_input_ids))
 			all_attention_masks.append(encoded_inp["attention_mask"])
 		else:
 			all_input_ids.append(context_input_ids)
-			#print("len input ids: ", len(all_input_ids))
 			all_attention_masks.append(context_attention_masks)
 
 		if tokenizer.name_or_path[0] == "b":
@@ -349,37 +218,33 @@ def tokenizer_new(tokenizer, input, max_length, masked=False, old_dataset=False,
 			else:
 				all_token_type_ids.append(context_token_type_ids)
 
-	#if tokenizer.name_or_path[0] == "b":
-	#	print(len(all_start_positions))
 
-	
-	#print("len end pos: ", len(all_end_positions))
-	#print("len all labels: ", len(all_labels))
-	#print("len attention masks: ", len(all_attention_masks[0]))
-	#print("len start pos: ", len(all_start_positions))
-	#print("len toke type ids: ", len(all_token_type_ids[0]))
-	if tokenizer.name_or_path[0] == "r": #if tokenizer is roberta we dont have token_type ids
 		print("roberta tokenizer")
-		dataset=TensorDataset(torch.tensor(all_input_ids, dtype=torch.long).to(device) , 
-							torch.tensor(all_attention_masks, dtype=torch.long).to(device) ,
-							torch.tensor(all_start_positions,dtype=torch.long).to(device),
-							torch.tensor(all_end_positions, dtype=torch.long).to(device),
-							torch.tensor(all_labels,dtype=torch.long).to(device))
+		dataset=TensorDataset(torch.tensor(all_input_ids, dtype=torch.long).to("cuda") , 
+							torch.tensor(all_attention_masks, dtype=torch.long).to("cuda") ,
+							torch.tensor(all_start_positions,dtype=torch.long).to("cuda"),
+							torch.tensor(all_end_positions, dtype=torch.long).to("cuda"),
+							torch.tensor(all_labels,dtype=torch.long).to("cuda"))
 
 	if tokenizer.name_or_path[0] =="b":
 		print("bert tokenizer")
-		dataset=TensorDataset(torch.tensor(all_input_ids, dtype=torch.long).to(device), 
-					torch.tensor(all_attention_masks, dtype=torch.long).to(device),
-					torch.tensor(all_token_type_ids, dtype=torch.long).to(device),
-					torch.tensor(all_start_positions,dtype=torch.long).to(device),
-					torch.tensor(all_end_positions, dtype=torch.long).to(device),
-					torch.tensor(all_labels,dtype=torch.long).to(device))
+		dataset=TensorDataset(torch.tensor(all_input_ids, dtype=torch.long).to("cuda"), 
+					torch.tensor(all_attention_masks, dtype=torch.long).to("cuda"),
+					torch.tensor(all_token_type_ids, dtype=torch.long).to("cuda"),
+					torch.tensor(all_start_positions,dtype=torch.long).to("cuda"),
+					torch.tensor(all_end_positions, dtype=torch.long).to("cuda"),
+					torch.tensor(all_labels,dtype=torch.long).to("cuda"))
 	print("created dataset")
-	#print(mapping_counter)
 
 	return dataset
 
 def tokenizer_imdb(tokenizer, dataset, max_length):
+	"""Tokenizer for imdb dataset (for validation of our tmix implementation.
+	
+	Params: 
+	tokenizer: AutoTokenizer -> Tokenizer (in out case BERT base uncased) 
+	dataset: list of dicts   -> dataset (imdb from huggingface) to be preprocessed
+	max_length: int 		 -> maximum length for padding/truncation"""
 	all_input_ids=[]
 	all_attention_masks=[]
 	all_token_type_ids=[]
@@ -387,17 +252,12 @@ def tokenizer_imdb(tokenizer, dataset, max_length):
 
 	for dp in dataset:
 		encoded_inp=tokenizer.encode_plus(dp["text"], add_special_tokens=True, max_length=max_length, truncation=True, padding="max_length")
-		#print("encoded input:",encoded_inp)
 		all_labels.append(dp["label"])
 		all_input_ids.append(encoded_inp["input_ids"])
 		all_attention_masks.append(encoded_inp["attention_mask"])
 		all_token_type_ids.append(encoded_inp["token_type_ids"])
 	
-	print("labels: ", len(all_labels))
-	print("input_ids: ", len(all_input_ids))
-	print("token_type_ids: ", len(all_token_type_ids))
-	print("attention_masks: ", len(all_attention_masks))
-	dataset=TensorDataset(torch.tensor(all_input_ids, dtype=torch.long).to(device), torch.tensor(all_attention_masks, dtype=torch.long).to(device), torch.tensor(all_token_type_ids, dtype=torch.long).to(device), torch.tensor(all_labels, dtype=torch.long).to(device))
+	dataset=TensorDataset(torch.tensor(all_input_ids, dtype=torch.long).to("cuda"), torch.tensor(all_attention_masks, dtype=torch.long).to("cuda"), torch.tensor(all_token_type_ids, dtype=torch.long).to("cuda"), torch.tensor(all_labels, dtype=torch.long).to("cuda"))
 	print("created imdb dataset")
 	return dataset
 
@@ -405,6 +265,7 @@ def tokenizer_imdb(tokenizer, dataset, max_length):
 
 class EncodedTokenDataset(torch.utils.data.Dataset):
     """
+	Salami Dataset Creator
     A dataset, containing encoded sentences, integer labels and
     the starting and ending position of the target word.
     """
@@ -431,7 +292,7 @@ class EncodedTokenDataset(torch.utils.data.Dataset):
 
 
 def salami_tokenizer(tokenizer, input, max_length, masked=False):
-	
+	"""Salami tokenizer for input sentences (Used together with EncodedTokenDataset)"""
 	print("salami tokenizer")
 	bots_token, eots_token = "[bots]", "[eots]"
 	tokenizer.add_tokens([bots_token, eots_token])
diff --git a/Code/train.py b/Code/train.py
index 919fb06..a215b15 100644
--- a/Code/train.py
+++ b/Code/train.py
@@ -11,7 +11,7 @@ from transformers import BertTokenizer, RobertaTokenizer, BertModel, RobertaMode
 from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
 from transformers import AdamW, get_scheduler
 from torch import nn
-from torch.nn import CrossEntropyLoss
+from torch.nn import CrossEntropyLoss, DataParallel
 import matplotlib.pyplot as plt
 import os
 import pandas as pd
@@ -23,49 +23,64 @@ torch.cuda.empty_cache()
 
 #with torch.autocast("cuda"):
 
-def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer, train_dataset, test_dataset, num_epochs, learning_rate, batch_size, test_batch_size, model_save_path=None):
+def train(model, name,train_dataset, test_dataset, seed, batch_size, test_batch_size,num_epochs,imdb=False, mixup =False,lambda_value=None, mixepoch=None, tmix=False, mixlayer=None,  learning_rate=None, mlp_learning_rate=None,  model_save_path=None):
 	"""Train loop for models. Iterates over epochs and batches and gives inputs to model. After training, call evaluation.py for evaluation of finetuned model.
 	
 	Params:
 	model: model out of models.py
 	name: str
-	imdb: bool
+	train_dataset: Dataset 
+	test_dataset: Dataset
 	seed: int
+	batch_size: 
+	test_batch_size:
+	num_epochs: int
+	imdb: bool
 	mixup: bool
 	lambda_value: float
 	mixepoch:int
 	tmix: bool
 	mixlayer: int in {0, 11}
-	train_dataset: Dataset 
-	test_dataset: Dataset
-	num_epochs: int
 	learning_rate: float
-	batch_size: 
-	test_batch_size:
+	mlp_leaning_rate:float
+	
 	
-	Returns:"""
+	Returns: Evaluation Results for train and test dataset in Accuracy, F1, Precision and Recall"""
 	model.train().to(device)
 	train_sampler = RandomSampler(train_dataset)
-	train_dataloader=DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
+	train_dataloader=DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, shuffle=True)
 	num_training_steps=num_epochs*len(train_dataloader)
 
-	optimizer=AdamW(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=0.1)
+	if mlp_learning_rate==None:
+		print("initializing one learning rate")
+		optimizer=AdamW(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=0.1)
+	else:
+		print("initializing separate learning rates")
+		model=nn.DataParallel(model)
+		optimizer=AdamW([
+			{'params': model.module.embedding_model.parameters(), 'lr': learning_rate},
+			{'params': model.module.classifier.parameters(), 'lr': mlp_learning_rate}
+		])
 	lr_scheduler=get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=10, num_training_steps=num_training_steps)
 
 	model.zero_grad()
 	for epoch in range(num_epochs):
-		index=0
-		
 		for batch in train_dataloader:
 			print(len(batch))
 			if name[0] == "b":
 				if tmix==False:
-					inputs = {'input_ids': batch[0],
-							'attention_mask': batch[1],
-							'token_type_ids': batch[2],
-							'start_position': batch[3],
-							'end_position': batch[4],
-							'labels': batch[5]}
+					if imdb==False:
+						inputs = {'input_ids': batch[0],
+								'attention_mask': batch[1],
+								'token_type_ids': batch[2],
+								'start_position': batch[3],
+								'end_position': batch[4],
+								'labels': batch[5]}
+					if imdb==True:
+						inputs={'input_ids':batch[0],
+								'attention_mask': batch[1],
+								'token_type_ids': batch[2],
+								'labels': batch[3]}
 				if tmix==True:
 					if imdb == False:
 						print("this is mixup epoch")
@@ -100,7 +115,7 @@ def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer,
 				end_positions=batch[3]
 			outputs=model(**inputs)
 			loss=outputs[0]
-			print("Loss: ", loss)
+			print("Epoch: {0} Loss: {1}".format(epoch, loss))
 			loss.backward()
 			optimizer.step()
 			lr_scheduler.step()
diff --git a/main.py b/main.py
index f969a35..bdf7daa 100644
--- a/main.py
+++ b/main.py
@@ -51,8 +51,7 @@ def run(raw_args):
 		test_dataset=Code.preprocess.salami_tokenizer(tokenizer, data_test, args.max_length, masked=args.masking)
 	
 	elif args.tokenizer=="swp":
-		print("train dataset preprocessing ")        
-		print(args.tcontext)
+		print("train dataset preprocessing ")
 		train_dataset=Code.preprocess.tokenizer_new(tokenizer, data_train, args.max_length, masked=args.masking, old_dataset=args.tcontext)
 		test_dataset=Code.preprocess.tokenizer_new(tokenizer, data_test, args.max_length, masked=args.masking, old_dataset=False) 
 	
@@ -66,7 +65,7 @@ def run(raw_args):
 	#train&evaluate...
 	print("training..")
 	if args.train_loop=="swp":
-		evaluation_test, evaluation_train = Code.train.train(model, args.architecture, args.imdb, args.random_seed, args.mix_up, args.lambda_value, args.mixepoch, args.tmix, args.mixlayer, train_dataset, test_dataset, args.epochs, args.learning_rate, args.batch_size, args.test_batch_size, args.model_save_path)
+		evaluation_test, evaluation_train = Code.train.train(model, args.architecture, train_dataset, test_dataset, args.random_seed,args.batch_size, args.test_batch_size,args.epochs,args.imdb,  args.mix_up, args.lambda_value, args.mixepoch, args.tmix, args.mixlayer,   args.learning_rate, args.second_learning_rate, args.model_save_path)
 	elif args.train_loop=="salami":
 		evaluation_test = Code.train.train_salami(model,args.random_seed, train_dataset, test_dataset, args.batch_size, args.test_batch_size, args.learning_rate, args.epochs)
 	else:
@@ -111,6 +110,12 @@ if __name__ == "__main__":
 		action="store_true"
 	)
 
+	parser.add_argument(
+		"--mlp", 
+		help="use two layer multi layer perceptron at the end? (if no, linear classifier)",
+		action="store_true"
+	)
+
 	#Datasets
 	parser.add_argument(
 		"-t",
@@ -150,7 +155,7 @@ if __name__ == "__main__":
 		"-max",
 		"--max_length",
         type=int, 
-		help="How big is max length when tokenizing the sentences?")	
+		help="Max sequence length when tokenizing the sentences?")	
 
 
 	#Train arguments
@@ -170,6 +175,14 @@ if __name__ == "__main__":
 		"--learning_rate",
 		type=float,
 		help="Learning rate for training")
+	
+	parser.add_argument(
+		"-lrtwo", 
+		"--second_learning_rate",
+		type=float,
+		help="Separate learning rate for multi layer perceptron", 
+		default=None
+	)
 
 	parser.add_argument(
 		"-rs",
-- 
GitLab