From c1fe38c245d1a28026191fc10852aa79a247e024 Mon Sep 17 00:00:00 2001
From: Simon Will <will@cl.uni-heidelberg.de>
Date: Wed, 24 Mar 2021 00:17:21 +0100
Subject: [PATCH] Add script for simulating online training

---
 joeynmt_server/joey_model.py     |   4 +-
 joeynmt_server/trainer.py        |   8 +-
 joeynmt_server/views/feedback.py |  15 +--
 joeynmt_server/views/train.py    |   4 +-
 simulate_training.py             | 152 +++++++++++++++++++++++++++++++
 5 files changed, 173 insertions(+), 10 deletions(-)
 create mode 100644 simulate_training.py

diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py
index cf089d9..d7545f8 100644
--- a/joeynmt_server/joey_model.py
+++ b/joeynmt_server/joey_model.py
@@ -156,7 +156,9 @@ class JoeyModel:
         return dataset
 
     def get_config_dataset(self, name):
-        return self._load_train_dataset(self.config['data'][name])
+        if name in self.config['data']:
+            return self._load_train_dataset(self.config['data'][name])
+        return None
 
     def validate(self, dataset, **kwargs):
         valid_kwargs = {k: v for k, v in self.test_args.items()
diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py
index 12fa0da..da613b1 100644
--- a/joeynmt_server/trainer.py
+++ b/joeynmt_server/trainer.py
@@ -215,7 +215,7 @@ def train_until_finished(config_basename):
     return train_n_rounds(config_basename, min_rounds=None)
 
 
-def validate(config_basename):
+def validate(config_basename, dataset_name='dev'):
     joey_dir = current_app.config.get('JOEY_DIR')
     config_file = joey_dir / 'configs' / config_basename
     use_cuda_train = current_app.config.get('USE_CUDA_TRAIN', False)
@@ -229,7 +229,11 @@ def validate(config_basename):
     try:
         model = JoeyModel.from_config_file(config_file, joey_dir,
                                            use_cuda=use_cuda)
-        dev_set = model.get_config_dataset('dev')
+        dev_set = model.get_config_dataset(dataset_name)
+        if not dev_set:
+            msg = 'No such dataset: {}'.format(dataset_name)
+            logging.error(msg)
+            raise ValueError(msg)
 
         logging.info('Validating on dev set.')
         results = model.validate(dev_set)
diff --git a/joeynmt_server/views/feedback.py b/joeynmt_server/views/feedback.py
index bf61ba0..a2bf709 100644
--- a/joeynmt_server/views/feedback.py
+++ b/joeynmt_server/views/feedback.py
@@ -24,18 +24,21 @@ def save_feedback():
 
     if 'split' in data:
         data['split'] = data['split'][:50]
+        split_was_explicitly_set = True
     else:
         data['split'] = 'train'
+        split_was_explicitly_set = False
 
     fb = Feedback(**data)
     db.session.add(fb)
     db.session.commit()
 
-    if fb.id % 5 == 0:
-        fb.split = 'test'
-    elif fb.id % 5 == 4:
-        fb.split = 'dev'
-    db.session.commit()
+    if not split_was_explicitly_set:
+        if fb.id % 5 == 0:
+            fb.split = 'test'
+        elif fb.id % 5 == 4:
+            fb.split = 'dev'
+        db.session.commit()
 
     def train_in_thread():
         app = create_app()
@@ -49,7 +52,7 @@ def save_feedback():
     response = fb.json_ready_dict()
 
     if (current_app.config.get('TRAIN_AFTER_FEEDBACK') and config_basename
-            and fb.correct_lin):
+            and fb.correct_lin and fb.split == 'train'):
         thread = threading.Thread(target=train_in_thread)
         thread.start()
         time.sleep(0.1)
diff --git a/joeynmt_server/views/train.py b/joeynmt_server/views/train.py
index e68893c..b1b72d3 100644
--- a/joeynmt_server/views/train.py
+++ b/joeynmt_server/views/train.py
@@ -56,11 +56,13 @@ def validate():
     data = request.json
     config_basename = data.get('model')
 
+    dataset = data.get('dataset', 'dev')
+
     def validate_in_thread():
         app = create_app()
         with app.app_context():
             try:
-                validate_on_data(config_basename)
+                validate_on_data(config_basename, dataset)
             except:
                 logging.error('Training failed.')
                 logging.error(traceback.format_exc())
diff --git a/simulate_training.py b/simulate_training.py
new file mode 100644
index 0000000..8593b37
--- /dev/null
+++ b/simulate_training.py
@@ -0,0 +1,152 @@
+import argparse
+from collections import namedtuple
+from itertools import zip_longest
+import logging
+import os
+import sys
+import time
+import urllib
+
+import requests
+
+Instance = namedtuple('Instance', ('nl', 'lin'))
+
+NLMAPS_MT_BASE_URL = 'http://localhost:5050'
+
+
+def find_and_read_file(dataset_dir, basenames, split, suffix):
+    end = '{}.{}'.format(split, suffix)
+    suitable_basenames = [name.endswith(end) for name in basenames]
+    if len(suitable_basenames) == 0:
+        logging.warning('Found no file matching *{}'.format(end))
+        return []
+    if len(suitable_basenames) > 1:
+        logging.error('Found more than one file matching *{}'.format(end))
+        sys.exit(1)
+
+    path = os.path.join(dataset_dir, suitable_basenames[0])
+    with open(path) as f:
+        return [line.strip() for line in f]
+
+
+def load_data(dataset_dir):
+    basenames = os.listdir(dataset_dir)
+    data = {'train': [], 'dev': [], 'test': []}
+    for split in data:
+        en = find_and_read_file(dataset_dir, basenames, split, 'en')
+        lin = find_and_read_file(dataset_dir, basenames, split, 'en')
+
+        if len(en) != len(lin):
+            if not en:
+                logging.error('Could not load {} dataset. {} file missing'
+                              .format(split, 'en'))
+            if not lin:
+                logging.error('Could not load {} dataset. {} file missing'
+                              .format(split, 'lin'))
+            logging.error('Lengths of en and lin do not match ({} vs. {})'
+                          .format(en, lin))
+            sys.exit(2)
+
+        data[split] = [Instance(en_, lin_) for en_, lin_ in zip(en, lin)]
+    return data
+
+
+class NLMapsMT:
+
+    def __init__(self, base_url, model, user_id=1):
+        parsed = urllib.parse.urlparse(base_url)
+        self.scheme = parsed.scheme
+        self.netloc = parsed.netloc
+
+        self.model = model
+        self.user_id = user_id
+
+    def _make_url(self, path):
+        parsed = urllib.parse.ParseResult(
+            scheme=self.scheme, netloc=self.netloc, path=path)
+        return parsed.geturl()
+
+    def save_feedback(self, instance: Instance, split):
+        url = self._make_url('/save_feedback')
+        payload = {'model': self.model, 'nl': instance.nl,
+                   'correct_lin': instance.lin, 'split': split,
+                   'system_lin': '', 'user_id': self.user_id}
+        logging.info('POST {} to {}'.format(payload, url))
+        resp = requests.post(url, json=payload)
+        if resp.status_code != 200:
+            logging.error('Response status code: {}'.format(resp.status_code))
+
+    def is_training(self):
+        url = self._make_url('/train_status')
+        logging.info('GET to {}'.format(url))
+        resp = requests.get(url)
+        if resp.status_code != 200:
+            logging.error('Response status code: {}'.format(resp.status_code))
+            return True
+        data = resp.json()
+        return bool(data.get('still_training'))
+
+    def validate(self, dataset=None):
+        url = self._make_url('/validate')
+        payload = {'model': self.model}
+        if dataset:
+            payload['dataset'] = dataset
+        logging.info('POST {} to {}'.format(payload, url))
+        resp = requests.post(url, json=payload)
+        if resp.status_code != 200:
+            logging.error('Response status code: {}'.format(resp.status_code))
+
+
+def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
+         nlmaps_mt_base_url=NLMAPS_MT_BASE_URL, user_id=1):
+    data = load_data(dataset_dir)
+
+    nlmaps_mt = NLMapsMT(nlmaps_mt_base_url, model=model, user_id=user_id)
+    for train_idx, (train, dev, test) in enumerate(
+            zip_longest(data['train'], data['dev'], data['test'])):
+        if train:
+            nlmaps_mt.save_feedback(train, 'train')
+            time.sleep(wait_time)
+
+        if train_idx % validation_freq == 0:
+            while nlmaps_mt.is_training():
+                time.sleep(1)
+            nlmaps_mt.validate('dev')
+            while nlmaps_mt.is_training():
+                time.sleep(1)
+            if dev2:
+                nlmaps_mt.validate('dev2')
+                while nlmaps_mt.is_training():
+                    time.sleep(1)
+
+        if dev:
+            nlmaps_mt.save_feedback(dev, 'dev')
+        if test:
+            nlmaps_mt.save_feedback(test, 'test')
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Simulate online training')
+    parser.add_argument('dataset_dir', help='Dataset directory. Must contain'
+                        ' six files matching in *{train,dev,test}.{en,lin}')
+    parser.add_argument('model', help='Path to model config yaml file')
+    parser.add_argument('--wait-time', type=int, default=3,
+                        help='Number of seconds to wait between saving pieces'
+                        ' of feedback.')
+    parser.add_argument('--validation-freq', type=int, default=10,
+                        help='Validate on dev from model config after every N'
+                        ' pieces of feedback.')
+    parser.add_argument('--dev2', default=False, action='store_true',
+                        help='At validation time, additionally use dev2'
+                        ' from model config to validate on.')
+    parser.add_argument('--nlmaps-mt-base_url', default=NLMAPS_MT_BASE_URL,
+                        help='Base_Url of the NLMaps MT server.')
+    parser.add_argument('--user-id', type=int, default=1,
+                        help='User ID to use with the NLMaps MT server.')
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+    ARGS = parse_args()
+    main(**vars(ARGS))
-- 
GitLab