From 9661b84fb485648041923939485b51179e2383ff Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Fri, 2 Apr 2021 22:39:24 +0200 Subject: [PATCH] Improve training simulation script --- simulate_training.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/simulate_training.py b/simulate_training.py index a6c3053..ea9dce4 100644 --- a/simulate_training.py +++ b/simulate_training.py @@ -58,7 +58,7 @@ def load_data(dataset_dir): class NLMapsMT: - def __init__(self, base_url, model, user_id=1): + def __init__(self, base_url, model, user_id=None): parsed = urllib.parse.urlparse(base_url) self.scheme = parsed.scheme self.netloc = parsed.netloc @@ -77,7 +77,10 @@ class NLMapsMT: url = self._make_url('/save_feedback') payload = {'model': self.model, 'nl': instance.nl, 'correct_lin': instance.lin, 'split': split, - 'system_lin': '', 'user_id': self.user_id} + 'system_lin': ''} + if self.user_id: + payload['user_id'] = self.user_id + logging.info('POST {} to {}'.format(payload, url)) resp = requests.post(url, json=payload) if resp.status_code != 200: @@ -105,17 +108,22 @@ class NLMapsMT: def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, - test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=1): + test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=None): data = load_data(dataset_dir) - nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id) + for train_idx, (train, dev, test) in enumerate( - zip_longest(data['train'], data['dev'], data['test'])): + zip_longest(data['train'], data['dev'], data['test']), start=1): if train: nlmaps_mt.save_feedback(train, 'train') time.sleep(wait_time) + if dev: + nlmaps_mt.save_feedback(dev, 'dev') + if test: + split = 'dev' if test_as_dev else 'test' + nlmaps_mt.save_feedback(test, split) - if train_idx % validation_freq == 0: + if validation_freq and train_idx % validation_freq == 0: while nlmaps_mt.is_training(): time.sleep(1) nlmaps_mt.validate('dev') @@ -126,12 +134,6 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, while nlmaps_mt.is_training(): time.sleep(1) - if dev: - nlmaps_mt.save_feedback(dev, 'dev') - if test: - split = 'dev' if test_as_dev else 'test' - nlmaps_mt.save_feedback(test, split) - def parse_args(): parser = argparse.ArgumentParser(description='Simulate online training') @@ -150,7 +152,7 @@ def parse_args(): parser.add_argument('--test-as-dev', default=False, action='store_true') parser.add_argument('--base-url', default=NLMAPS_MT_BASE_URL, help='Base_Url of the NLMaps MT server.') - parser.add_argument('--user-id', type=int, default=1, + parser.add_argument('--user-id', type=int, help='User ID to use with the NLMaps MT server.') args = parser.parse_args() return args -- GitLab