From 39477bceb4334c1fe189279e1df714543b793457 Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Thu, 22 Apr 2021 20:33:43 +0200 Subject: [PATCH] Add simulation option for saving one feedback at a time --- simulate_training.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/simulate_training.py b/simulate_training.py index 5539572..1431676 100644 --- a/simulate_training.py +++ b/simulate_training.py @@ -107,27 +107,31 @@ class NLMapsMT: logging.error('Response status code: {}'.format(resp.status_code)) -def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, +def main(dataset_dir, model, wait_time=None, validation_freq=10, dev2=False, test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=None): data = load_data(dataset_dir) nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id) - if validation_freq: + def wait_until_free(): while nlmaps_mt.is_training(): time.sleep(1) + + if validation_freq: + wait_until_free() nlmaps_mt.validate('dev') - while nlmaps_mt.is_training(): - time.sleep(1) + wait_until_free() if dev2: nlmaps_mt.validate('dev2') - while nlmaps_mt.is_training(): - time.sleep(1) + wait_until_free() for train_idx, (train, dev, test) in enumerate( zip_longest(data['train'], data['dev'], data['test']), start=1): if train: nlmaps_mt.save_feedback(train, 'train') - time.sleep(wait_time) + if wait_time is None: + wait_until_free() + else: + time.sleep(wait_time) if dev: nlmaps_mt.save_feedback(dev, 'dev') if test: @@ -135,15 +139,12 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, nlmaps_mt.save_feedback(test, split) if validation_freq and train_idx % validation_freq == 0: - while nlmaps_mt.is_training(): - time.sleep(1) + wait_until_free() nlmaps_mt.validate('dev') - while nlmaps_mt.is_training(): - time.sleep(1) + wait_until_free() if dev2: nlmaps_mt.validate('dev2') - while nlmaps_mt.is_training(): - time.sleep(1) + wait_until_free() def parse_args(): @@ -151,9 +152,10 @@ def parse_args(): parser.add_argument('dataset_dir', help='Dataset directory. Must contain' ' six files matching in *{train,dev,test}.{en,lin}') parser.add_argument('model', help='Path to model config yaml file') - parser.add_argument('--wait-time', type=int, default=3, + parser.add_argument('--wait-time', type=int, help='Number of seconds to wait between saving pieces' - ' of feedback.') + ' of feedback. Default is waiting until training is' + ' done.') parser.add_argument('--validation-freq', type=int, default=10, help='Validate on dev from model config after every N' ' pieces of feedback.') -- GitLab