Skip to content
Snippets Groups Projects
Commit 9661b84f authored by Simon Will's avatar Simon Will
Browse files

Improve training simulation script

parent 176c53cb
No related branches found
No related tags found
No related merge requests found
...@@ -58,7 +58,7 @@ def load_data(dataset_dir): ...@@ -58,7 +58,7 @@ def load_data(dataset_dir):
class NLMapsMT: class NLMapsMT:
def __init__(self, base_url, model, user_id=1): def __init__(self, base_url, model, user_id=None):
parsed = urllib.parse.urlparse(base_url) parsed = urllib.parse.urlparse(base_url)
self.scheme = parsed.scheme self.scheme = parsed.scheme
self.netloc = parsed.netloc self.netloc = parsed.netloc
...@@ -77,7 +77,10 @@ class NLMapsMT: ...@@ -77,7 +77,10 @@ class NLMapsMT:
url = self._make_url('/save_feedback') url = self._make_url('/save_feedback')
payload = {'model': self.model, 'nl': instance.nl, payload = {'model': self.model, 'nl': instance.nl,
'correct_lin': instance.lin, 'split': split, 'correct_lin': instance.lin, 'split': split,
'system_lin': '', 'user_id': self.user_id} 'system_lin': ''}
if self.user_id:
payload['user_id'] = self.user_id
logging.info('POST {} to {}'.format(payload, url)) logging.info('POST {} to {}'.format(payload, url))
resp = requests.post(url, json=payload) resp = requests.post(url, json=payload)
if resp.status_code != 200: if resp.status_code != 200:
...@@ -105,17 +108,22 @@ class NLMapsMT: ...@@ -105,17 +108,22 @@ class NLMapsMT:
def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=1): test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=None):
data = load_data(dataset_dir) data = load_data(dataset_dir)
nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id) nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id)
for train_idx, (train, dev, test) in enumerate( for train_idx, (train, dev, test) in enumerate(
zip_longest(data['train'], data['dev'], data['test'])): zip_longest(data['train'], data['dev'], data['test']), start=1):
if train: if train:
nlmaps_mt.save_feedback(train, 'train') nlmaps_mt.save_feedback(train, 'train')
time.sleep(wait_time) time.sleep(wait_time)
if dev:
nlmaps_mt.save_feedback(dev, 'dev')
if test:
split = 'dev' if test_as_dev else 'test'
nlmaps_mt.save_feedback(test, split)
if train_idx % validation_freq == 0: if validation_freq and train_idx % validation_freq == 0:
while nlmaps_mt.is_training(): while nlmaps_mt.is_training():
time.sleep(1) time.sleep(1)
nlmaps_mt.validate('dev') nlmaps_mt.validate('dev')
...@@ -126,12 +134,6 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, ...@@ -126,12 +134,6 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
while nlmaps_mt.is_training(): while nlmaps_mt.is_training():
time.sleep(1) time.sleep(1)
if dev:
nlmaps_mt.save_feedback(dev, 'dev')
if test:
split = 'dev' if test_as_dev else 'test'
nlmaps_mt.save_feedback(test, split)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Simulate online training') parser = argparse.ArgumentParser(description='Simulate online training')
...@@ -150,7 +152,7 @@ def parse_args(): ...@@ -150,7 +152,7 @@ def parse_args():
parser.add_argument('--test-as-dev', default=False, action='store_true') parser.add_argument('--test-as-dev', default=False, action='store_true')
parser.add_argument('--base-url', default=NLMAPS_MT_BASE_URL, parser.add_argument('--base-url', default=NLMAPS_MT_BASE_URL,
help='Base_Url of the NLMaps MT server.') help='Base_Url of the NLMaps MT server.')
parser.add_argument('--user-id', type=int, default=1, parser.add_argument('--user-id', type=int,
help='User ID to use with the NLMaps MT server.') help='User ID to use with the NLMaps MT server.')
args = parser.parse_args() args = parser.parse_args()
return args return args
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment