Skip to content
Snippets Groups Projects
Commit c1fe38c2 authored by Simon Will's avatar Simon Will
Browse files

Add script for simulating online training

parent a9181cf6
No related branches found
No related tags found
No related merge requests found
......@@ -156,7 +156,9 @@ class JoeyModel:
return dataset
def get_config_dataset(self, name):
return self._load_train_dataset(self.config['data'][name])
if name in self.config['data']:
return self._load_train_dataset(self.config['data'][name])
return None
def validate(self, dataset, **kwargs):
valid_kwargs = {k: v for k, v in self.test_args.items()
......
......@@ -215,7 +215,7 @@ def train_until_finished(config_basename):
return train_n_rounds(config_basename, min_rounds=None)
def validate(config_basename):
def validate(config_basename, dataset_name='dev'):
joey_dir = current_app.config.get('JOEY_DIR')
config_file = joey_dir / 'configs' / config_basename
use_cuda_train = current_app.config.get('USE_CUDA_TRAIN', False)
......@@ -229,7 +229,11 @@ def validate(config_basename):
try:
model = JoeyModel.from_config_file(config_file, joey_dir,
use_cuda=use_cuda)
dev_set = model.get_config_dataset('dev')
dev_set = model.get_config_dataset(dataset_name)
if not dev_set:
msg = 'No such dataset: {}'.format(dataset_name)
logging.error(msg)
raise ValueError(msg)
logging.info('Validating on dev set.')
results = model.validate(dev_set)
......
......@@ -24,18 +24,21 @@ def save_feedback():
if 'split' in data:
data['split'] = data['split'][:50]
split_was_explicitly_set = True
else:
data['split'] = 'train'
split_was_explicitly_set = False
fb = Feedback(**data)
db.session.add(fb)
db.session.commit()
if fb.id % 5 == 0:
fb.split = 'test'
elif fb.id % 5 == 4:
fb.split = 'dev'
db.session.commit()
if not split_was_explicitly_set:
if fb.id % 5 == 0:
fb.split = 'test'
elif fb.id % 5 == 4:
fb.split = 'dev'
db.session.commit()
def train_in_thread():
app = create_app()
......@@ -49,7 +52,7 @@ def save_feedback():
response = fb.json_ready_dict()
if (current_app.config.get('TRAIN_AFTER_FEEDBACK') and config_basename
and fb.correct_lin):
and fb.correct_lin and fb.split == 'train'):
thread = threading.Thread(target=train_in_thread)
thread.start()
time.sleep(0.1)
......
......@@ -56,11 +56,13 @@ def validate():
data = request.json
config_basename = data.get('model')
dataset = data.get('dataset', 'dev')
def validate_in_thread():
app = create_app()
with app.app_context():
try:
validate_on_data(config_basename)
validate_on_data(config_basename, dataset)
except:
logging.error('Training failed.')
logging.error(traceback.format_exc())
......
import argparse
from collections import namedtuple
from itertools import zip_longest
import logging
import os
import sys
import time
import urllib
import requests
Instance = namedtuple('Instance', ('nl', 'lin'))
NLMAPS_MT_BASE_URL = 'http://localhost:5050'
def find_and_read_file(dataset_dir, basenames, split, suffix):
end = '{}.{}'.format(split, suffix)
suitable_basenames = [name.endswith(end) for name in basenames]
if len(suitable_basenames) == 0:
logging.warning('Found no file matching *{}'.format(end))
return []
if len(suitable_basenames) > 1:
logging.error('Found more than one file matching *{}'.format(end))
sys.exit(1)
path = os.path.join(dataset_dir, suitable_basenames[0])
with open(path) as f:
return [line.strip() for line in f]
def load_data(dataset_dir):
basenames = os.listdir(dataset_dir)
data = {'train': [], 'dev': [], 'test': []}
for split in data:
en = find_and_read_file(dataset_dir, basenames, split, 'en')
lin = find_and_read_file(dataset_dir, basenames, split, 'en')
if len(en) != len(lin):
if not en:
logging.error('Could not load {} dataset. {} file missing'
.format(split, 'en'))
if not lin:
logging.error('Could not load {} dataset. {} file missing'
.format(split, 'lin'))
logging.error('Lengths of en and lin do not match ({} vs. {})'
.format(en, lin))
sys.exit(2)
data[split] = [Instance(en_, lin_) for en_, lin_ in zip(en, lin)]
return data
class NLMapsMT:
def __init__(self, base_url, model, user_id=1):
parsed = urllib.parse.urlparse(base_url)
self.scheme = parsed.scheme
self.netloc = parsed.netloc
self.model = model
self.user_id = user_id
def _make_url(self, path):
parsed = urllib.parse.ParseResult(
scheme=self.scheme, netloc=self.netloc, path=path)
return parsed.geturl()
def save_feedback(self, instance: Instance, split):
url = self._make_url('/save_feedback')
payload = {'model': self.model, 'nl': instance.nl,
'correct_lin': instance.lin, 'split': split,
'system_lin': '', 'user_id': self.user_id}
logging.info('POST {} to {}'.format(payload, url))
resp = requests.post(url, json=payload)
if resp.status_code != 200:
logging.error('Response status code: {}'.format(resp.status_code))
def is_training(self):
url = self._make_url('/train_status')
logging.info('GET to {}'.format(url))
resp = requests.get(url)
if resp.status_code != 200:
logging.error('Response status code: {}'.format(resp.status_code))
return True
data = resp.json()
return bool(data.get('still_training'))
def validate(self, dataset=None):
url = self._make_url('/validate')
payload = {'model': self.model}
if dataset:
payload['dataset'] = dataset
logging.info('POST {} to {}'.format(payload, url))
resp = requests.post(url, json=payload)
if resp.status_code != 200:
logging.error('Response status code: {}'.format(resp.status_code))
def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
nlmaps_mt_base_url=NLMAPS_MT_BASE_URL, user_id=1):
data = load_data(dataset_dir)
nlmaps_mt = NLMapsMT(nlmaps_mt_base_url, model=model, user_id=user_id)
for train_idx, (train, dev, test) in enumerate(
zip_longest(data['train'], data['dev'], data['test'])):
if train:
nlmaps_mt.save_feedback(train, 'train')
time.sleep(wait_time)
if train_idx % validation_freq == 0:
while nlmaps_mt.is_training():
time.sleep(1)
nlmaps_mt.validate('dev')
while nlmaps_mt.is_training():
time.sleep(1)
if dev2:
nlmaps_mt.validate('dev2')
while nlmaps_mt.is_training():
time.sleep(1)
if dev:
nlmaps_mt.save_feedback(dev, 'dev')
if test:
nlmaps_mt.save_feedback(test, 'test')
def parse_args():
parser = argparse.ArgumentParser(description='Simulate online training')
parser.add_argument('dataset_dir', help='Dataset directory. Must contain'
' six files matching in *{train,dev,test}.{en,lin}')
parser.add_argument('model', help='Path to model config yaml file')
parser.add_argument('--wait-time', type=int, default=3,
help='Number of seconds to wait between saving pieces'
' of feedback.')
parser.add_argument('--validation-freq', type=int, default=10,
help='Validate on dev from model config after every N'
' pieces of feedback.')
parser.add_argument('--dev2', default=False, action='store_true',
help='At validation time, additionally use dev2'
' from model config to validate on.')
parser.add_argument('--nlmaps-mt-base_url', default=NLMAPS_MT_BASE_URL,
help='Base_Url of the NLMaps MT server.')
parser.add_argument('--user-id', type=int, default=1,
help='User ID to use with the NLMaps MT server.')
args = parser.parse_args()
return args
if __name__ == '__main__':
ARGS = parse_args()
main(**vars(ARGS))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment