From c1fe38c245d1a28026191fc10852aa79a247e024 Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Wed, 24 Mar 2021 00:17:21 +0100 Subject: [PATCH] Add script for simulating online training --- joeynmt_server/joey_model.py | 4 +- joeynmt_server/trainer.py | 8 +- joeynmt_server/views/feedback.py | 15 +-- joeynmt_server/views/train.py | 4 +- simulate_training.py | 152 +++++++++++++++++++++++++++++++ 5 files changed, 173 insertions(+), 10 deletions(-) create mode 100644 simulate_training.py diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py index cf089d9..d7545f8 100644 --- a/joeynmt_server/joey_model.py +++ b/joeynmt_server/joey_model.py @@ -156,7 +156,9 @@ class JoeyModel: return dataset def get_config_dataset(self, name): - return self._load_train_dataset(self.config['data'][name]) + if name in self.config['data']: + return self._load_train_dataset(self.config['data'][name]) + return None def validate(self, dataset, **kwargs): valid_kwargs = {k: v for k, v in self.test_args.items() diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py index 12fa0da..da613b1 100644 --- a/joeynmt_server/trainer.py +++ b/joeynmt_server/trainer.py @@ -215,7 +215,7 @@ def train_until_finished(config_basename): return train_n_rounds(config_basename, min_rounds=None) -def validate(config_basename): +def validate(config_basename, dataset_name='dev'): joey_dir = current_app.config.get('JOEY_DIR') config_file = joey_dir / 'configs' / config_basename use_cuda_train = current_app.config.get('USE_CUDA_TRAIN', False) @@ -229,7 +229,11 @@ def validate(config_basename): try: model = JoeyModel.from_config_file(config_file, joey_dir, use_cuda=use_cuda) - dev_set = model.get_config_dataset('dev') + dev_set = model.get_config_dataset(dataset_name) + if not dev_set: + msg = 'No such dataset: {}'.format(dataset_name) + logging.error(msg) + raise ValueError(msg) logging.info('Validating on dev set.') results = model.validate(dev_set) diff --git a/joeynmt_server/views/feedback.py b/joeynmt_server/views/feedback.py index bf61ba0..a2bf709 100644 --- a/joeynmt_server/views/feedback.py +++ b/joeynmt_server/views/feedback.py @@ -24,18 +24,21 @@ def save_feedback(): if 'split' in data: data['split'] = data['split'][:50] + split_was_explicitly_set = True else: data['split'] = 'train' + split_was_explicitly_set = False fb = Feedback(**data) db.session.add(fb) db.session.commit() - if fb.id % 5 == 0: - fb.split = 'test' - elif fb.id % 5 == 4: - fb.split = 'dev' - db.session.commit() + if not split_was_explicitly_set: + if fb.id % 5 == 0: + fb.split = 'test' + elif fb.id % 5 == 4: + fb.split = 'dev' + db.session.commit() def train_in_thread(): app = create_app() @@ -49,7 +52,7 @@ def save_feedback(): response = fb.json_ready_dict() if (current_app.config.get('TRAIN_AFTER_FEEDBACK') and config_basename - and fb.correct_lin): + and fb.correct_lin and fb.split == 'train'): thread = threading.Thread(target=train_in_thread) thread.start() time.sleep(0.1) diff --git a/joeynmt_server/views/train.py b/joeynmt_server/views/train.py index e68893c..b1b72d3 100644 --- a/joeynmt_server/views/train.py +++ b/joeynmt_server/views/train.py @@ -56,11 +56,13 @@ def validate(): data = request.json config_basename = data.get('model') + dataset = data.get('dataset', 'dev') + def validate_in_thread(): app = create_app() with app.app_context(): try: - validate_on_data(config_basename) + validate_on_data(config_basename, dataset) except: logging.error('Training failed.') logging.error(traceback.format_exc()) diff --git a/simulate_training.py b/simulate_training.py new file mode 100644 index 0000000..8593b37 --- /dev/null +++ b/simulate_training.py @@ -0,0 +1,152 @@ +import argparse +from collections import namedtuple +from itertools import zip_longest +import logging +import os +import sys +import time +import urllib + +import requests + +Instance = namedtuple('Instance', ('nl', 'lin')) + +NLMAPS_MT_BASE_URL = 'http://localhost:5050' + + +def find_and_read_file(dataset_dir, basenames, split, suffix): + end = '{}.{}'.format(split, suffix) + suitable_basenames = [name.endswith(end) for name in basenames] + if len(suitable_basenames) == 0: + logging.warning('Found no file matching *{}'.format(end)) + return [] + if len(suitable_basenames) > 1: + logging.error('Found more than one file matching *{}'.format(end)) + sys.exit(1) + + path = os.path.join(dataset_dir, suitable_basenames[0]) + with open(path) as f: + return [line.strip() for line in f] + + +def load_data(dataset_dir): + basenames = os.listdir(dataset_dir) + data = {'train': [], 'dev': [], 'test': []} + for split in data: + en = find_and_read_file(dataset_dir, basenames, split, 'en') + lin = find_and_read_file(dataset_dir, basenames, split, 'en') + + if len(en) != len(lin): + if not en: + logging.error('Could not load {} dataset. {} file missing' + .format(split, 'en')) + if not lin: + logging.error('Could not load {} dataset. {} file missing' + .format(split, 'lin')) + logging.error('Lengths of en and lin do not match ({} vs. {})' + .format(en, lin)) + sys.exit(2) + + data[split] = [Instance(en_, lin_) for en_, lin_ in zip(en, lin)] + return data + + +class NLMapsMT: + + def __init__(self, base_url, model, user_id=1): + parsed = urllib.parse.urlparse(base_url) + self.scheme = parsed.scheme + self.netloc = parsed.netloc + + self.model = model + self.user_id = user_id + + def _make_url(self, path): + parsed = urllib.parse.ParseResult( + scheme=self.scheme, netloc=self.netloc, path=path) + return parsed.geturl() + + def save_feedback(self, instance: Instance, split): + url = self._make_url('/save_feedback') + payload = {'model': self.model, 'nl': instance.nl, + 'correct_lin': instance.lin, 'split': split, + 'system_lin': '', 'user_id': self.user_id} + logging.info('POST {} to {}'.format(payload, url)) + resp = requests.post(url, json=payload) + if resp.status_code != 200: + logging.error('Response status code: {}'.format(resp.status_code)) + + def is_training(self): + url = self._make_url('/train_status') + logging.info('GET to {}'.format(url)) + resp = requests.get(url) + if resp.status_code != 200: + logging.error('Response status code: {}'.format(resp.status_code)) + return True + data = resp.json() + return bool(data.get('still_training')) + + def validate(self, dataset=None): + url = self._make_url('/validate') + payload = {'model': self.model} + if dataset: + payload['dataset'] = dataset + logging.info('POST {} to {}'.format(payload, url)) + resp = requests.post(url, json=payload) + if resp.status_code != 200: + logging.error('Response status code: {}'.format(resp.status_code)) + + +def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, + nlmaps_mt_base_url=NLMAPS_MT_BASE_URL, user_id=1): + data = load_data(dataset_dir) + + nlmaps_mt = NLMapsMT(nlmaps_mt_base_url, model=model, user_id=user_id) + for train_idx, (train, dev, test) in enumerate( + zip_longest(data['train'], data['dev'], data['test'])): + if train: + nlmaps_mt.save_feedback(train, 'train') + time.sleep(wait_time) + + if train_idx % validation_freq == 0: + while nlmaps_mt.is_training(): + time.sleep(1) + nlmaps_mt.validate('dev') + while nlmaps_mt.is_training(): + time.sleep(1) + if dev2: + nlmaps_mt.validate('dev2') + while nlmaps_mt.is_training(): + time.sleep(1) + + if dev: + nlmaps_mt.save_feedback(dev, 'dev') + if test: + nlmaps_mt.save_feedback(test, 'test') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Simulate online training') + parser.add_argument('dataset_dir', help='Dataset directory. Must contain' + ' six files matching in *{train,dev,test}.{en,lin}') + parser.add_argument('model', help='Path to model config yaml file') + parser.add_argument('--wait-time', type=int, default=3, + help='Number of seconds to wait between saving pieces' + ' of feedback.') + parser.add_argument('--validation-freq', type=int, default=10, + help='Validate on dev from model config after every N' + ' pieces of feedback.') + parser.add_argument('--dev2', default=False, action='store_true', + help='At validation time, additionally use dev2' + ' from model config to validate on.') + parser.add_argument('--nlmaps-mt-base_url', default=NLMAPS_MT_BASE_URL, + help='Base_Url of the NLMaps MT server.') + parser.add_argument('--user-id', type=int, default=1, + help='User ID to use with the NLMaps MT server.') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + ARGS = parse_args() + main(**vars(ARGS)) -- GitLab