Skip to content
Snippets Groups Projects
Commit bf8b2cca authored by kulcsar's avatar kulcsar
Browse files

add variables to mixup

parent cc68526f
No related branches found
No related tags found
No related merge requests found
......@@ -67,7 +67,7 @@ def run(raw_args):
#train...
print("training..")
if args.train_loop=="swp":
evaluation_test, evaluation_train = train.train(model, args.architecture,args.random_seed, args.gradient_accumulation_steps, args.mix_up, train_dataset, test_dataset, args.epochs, args.learning_rate, args.batch_size, args.test_batch_size)
evaluation_test, evaluation_train = train.train(model, args.architecture,args.random_seed, args.gradient_accumulation_steps, args.mix_up, args.threshold, args.lambda_value, train_dataset, test_dataset, args.epochs, args.learning_rate, args.batch_size, args.test_batch_size)
elif args.train_loop=="salami":
evaluation_test = train.train_salami(model,args.random_seed, train_dataset, test_dataset, args.batch_size, args.test_batch_size, args.learning_rate, args.epochs)
else:
......@@ -187,6 +187,20 @@ if __name__ == "__main__":
help="whether or not to apply mixup during training",
action="store_true")
parser.add_argument(
"-threshold",
"--threshold",
help="specifies the value for mixup threshold",
type=float,
default=0.05)
parser.add_argument(
"-lambda",
"--lambda_value",
help="speficies the lambda value for mixup",
type=float,
default=0.4)
#Test arguments
parser.add_argument(
......
......@@ -22,7 +22,7 @@ torch.cuda.empty_cache()
def train(model, name, seed,gradient_accumulation_steps,mixup, train_dataset, test_dataset, num_epochs, learning_rate, batch_size, test_batch_size):
def train(model, name, seed,gradient_accumulation_steps,mixup, threshold, lambda_value, train_dataset, test_dataset, num_epochs, learning_rate, batch_size, test_batch_size):
"""Write Train loop for model with certain train dataset"""
#set_seed(seed)
#if model_name[0] == "b":
......@@ -85,7 +85,7 @@ def train(model, name, seed,gradient_accumulation_steps,mixup, train_dataset, te
# # print("outputs {0}: {1}".format(i, outputs[i].size()))
if mixup == True:
#print("length of outputs: ", len(outputs))
new_matrix_batch, new_labels_batch = mixup_function(outputs[2], labels)
new_matrix_batch, new_labels_batch = mixup_function(outputs[2], labels, lambda_value, threshold)
#for matrix in new_matrix_batch
new_matrix_batch.to("cuda")
new_labels_batch.to("cuda")
......@@ -125,18 +125,18 @@ def train(model, name, seed,gradient_accumulation_steps,mixup, train_dataset, te
def mixup_function(batch_of_matrices, batch_of_labels):
def mixup_function(batch_of_matrices, batch_of_labels, l, t):
runs = math.floor(batch_of_matrices.size()[0]/2)
counter=0
results=[]
result_labels=[]
for i in range(runs):
print("doing interpolation...")
print("doing interpolation with lambda: {0} and threshold: {1}...".format(l, t))
matrix1=batch_of_matrices[counter]
label1=batch_of_labels[counter]
matrix2=batch_of_matrices[counter+1]
label2=batch_of_labels[counter+1]
new_matrix, new_label=interpolate(matrix1, label1, matrix2, label2, 0.4, 0.05)
new_matrix, new_label=interpolate(matrix1, label1, matrix2, label2, l, t)
if new_matrix != None:
results.append(new_matrix)
result_labels.append(new_label)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment