Skip to content
Snippets Groups Projects
Commit 99b95166 authored by kulcsar's avatar kulcsar
Browse files

main update

parent 41cf3ed2
No related branches found
No related tags found
No related merge requests found
......@@ -9,12 +9,14 @@ from typing import List
def run(raw_args):
print("parsing")
args=_parse_args(raw_args)
print("parsed arguments")
#load test and train dataset as well as tokenizers and models...
#Datasets
print("opened datasets...")
with open(args.train_dataset) as f:
data_train=json.loads(f.read())
......@@ -23,6 +25,7 @@ def run(raw_args):
#Tokenizers & Models
print("Loading tokenizers and initializing model...")
tokenizer=AutoTokenizer.from_pretrained(args.architecture)
models.set_seed(args.random_seed)
......@@ -40,6 +43,7 @@ def run(raw_args):
#preprocess...
print("preprocessing datasets...")
if args.tokenizer=="salami":
train_dataset=preprocess.salami_tokenizer(tokenizer, data_train, args.max_length, masked=args.masking) #no context implemented
test_dataset=preprocess.salami_tokenizer(tokenizer, data_test, args.max_length, masked=args.masking)
......@@ -55,6 +59,7 @@ def run(raw_args):
print("non eligible tokenizer selected")
#train...
print("training..")
if args.train_loop=="swp":
evaluation_test, evaluation_train = train.train(model, args.random_seed, train_dataset, test_dataset, args.epochs, args.learning_rate, args.batch_size, args.test_batch_size)
elif args.train_loop=="salami":
......@@ -144,5 +149,5 @@ def _parse_args(raw_args: List[str]) -> argparse.Namespace:
type=int,
default=64)
return parser.parse_args(raw_args)
......@@ -2,6 +2,7 @@ import torch
import tqdm
import numpy as np
import evaluation
import evaluate
import json
import random
import math
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment