Skip to content
Snippets Groups Projects
Commit 152b69bc authored by kulcsar's avatar kulcsar
Browse files

add changes for documentation of code

parent 8f6bc356
No related branches found
No related tags found
No related merge requests found
......@@ -23,7 +23,26 @@ torch.cuda.empty_cache()
#with torch.autocast("cuda"):
def train(model, name, imdb, seed,mixup,lambda_value, mixepoch, tmix, mixlayer, train_dataset, test_dataset, num_epochs, learning_rate, batch_size, test_batch_size):
"""Train loop for models. Iterates over epochs and batches and gives inputs to model. After training, call evaluation.py for evaluation of finetuned model."""
"""Train loop for models. Iterates over epochs and batches and gives inputs to model. After training, call evaluation.py for evaluation of finetuned model.
Params:
model: model out of models.py
name: str
imdb: bool
seed: int
mixup: bool
lambda_value: float
mixepoch:int
tmix: bool
mixlayer: int in {0, 11}
train_dataset: Dataset
test_dataset: Dataset
num_epochs: int
learning_rate: float
batch_size:
test_batch_size:
Returns:"""
model.train().to("cuda")
train_sampler = RandomSampler(train_dataset)
train_dataloader=DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
......
......@@ -3,3 +3,8 @@ numpy==1.23.5
pandas==1.5.2
torch==1.13.0+cu116
tqdm==4.64.1
evaluate ==0.3.0
matplotlib==3.5.2
scikit_lean==1.2.1
transformers==4.26.1
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment