diff --git a/simulate_training.py b/simulate_training.py index ea9dce439fbb37d44418656cac8c0792c8d9e141..5539572de4a6f5748141e601e4295e72f9c40927 100644 --- a/simulate_training.py +++ b/simulate_training.py @@ -112,6 +112,17 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, data = load_data(dataset_dir) nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id) + if validation_freq: + while nlmaps_mt.is_training(): + time.sleep(1) + nlmaps_mt.validate('dev') + while nlmaps_mt.is_training(): + time.sleep(1) + if dev2: + nlmaps_mt.validate('dev2') + while nlmaps_mt.is_training(): + time.sleep(1) + for train_idx, (train, dev, test) in enumerate( zip_longest(data['train'], data['dev'], data['test']), start=1): if train: