Skip to content
Snippets Groups Projects
Commit 9661b84f authored by Simon Will's avatar Simon Will
Browse files

Improve training simulation script

parent 176c53cb
No related branches found
No related tags found
No related merge requests found
......@@ -58,7 +58,7 @@ def load_data(dataset_dir):
class NLMapsMT:
def __init__(self, base_url, model, user_id=1):
def __init__(self, base_url, model, user_id=None):
parsed = urllib.parse.urlparse(base_url)
self.scheme = parsed.scheme
self.netloc = parsed.netloc
......@@ -77,7 +77,10 @@ class NLMapsMT:
url = self._make_url('/save_feedback')
payload = {'model': self.model, 'nl': instance.nl,
'correct_lin': instance.lin, 'split': split,
'system_lin': '', 'user_id': self.user_id}
'system_lin': ''}
if self.user_id:
payload['user_id'] = self.user_id
logging.info('POST {} to {}'.format(payload, url))
resp = requests.post(url, json=payload)
if resp.status_code != 200:
......@@ -105,17 +108,22 @@ class NLMapsMT:
def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=1):
test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=None):
data = load_data(dataset_dir)
nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id)
for train_idx, (train, dev, test) in enumerate(
zip_longest(data['train'], data['dev'], data['test'])):
zip_longest(data['train'], data['dev'], data['test']), start=1):
if train:
nlmaps_mt.save_feedback(train, 'train')
time.sleep(wait_time)
if dev:
nlmaps_mt.save_feedback(dev, 'dev')
if test:
split = 'dev' if test_as_dev else 'test'
nlmaps_mt.save_feedback(test, split)
if train_idx % validation_freq == 0:
if validation_freq and train_idx % validation_freq == 0:
while nlmaps_mt.is_training():
time.sleep(1)
nlmaps_mt.validate('dev')
......@@ -126,12 +134,6 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
while nlmaps_mt.is_training():
time.sleep(1)
if dev:
nlmaps_mt.save_feedback(dev, 'dev')
if test:
split = 'dev' if test_as_dev else 'test'
nlmaps_mt.save_feedback(test, split)
def parse_args():
parser = argparse.ArgumentParser(description='Simulate online training')
......@@ -150,7 +152,7 @@ def parse_args():
parser.add_argument('--test-as-dev', default=False, action='store_true')
parser.add_argument('--base-url', default=NLMAPS_MT_BASE_URL,
help='Base_Url of the NLMaps MT server.')
parser.add_argument('--user-id', type=int, default=1,
parser.add_argument('--user-id', type=int,
help='User ID to use with the NLMaps MT server.')
args = parser.parse_args()
return args
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment