Skip to content
Snippets Groups Projects
Commit 39477bce authored by Simon Will's avatar Simon Will
Browse files

Add simulation option for saving one feedback at a time

parent 99ab7156
No related branches found
No related tags found
No related merge requests found
...@@ -107,27 +107,31 @@ class NLMapsMT: ...@@ -107,27 +107,31 @@ class NLMapsMT:
logging.error('Response status code: {}'.format(resp.status_code)) logging.error('Response status code: {}'.format(resp.status_code))
def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, def main(dataset_dir, model, wait_time=None, validation_freq=10, dev2=False,
test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=None): test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=None):
data = load_data(dataset_dir) data = load_data(dataset_dir)
nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id) nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id)
if validation_freq: def wait_until_free():
while nlmaps_mt.is_training(): while nlmaps_mt.is_training():
time.sleep(1) time.sleep(1)
if validation_freq:
wait_until_free()
nlmaps_mt.validate('dev') nlmaps_mt.validate('dev')
while nlmaps_mt.is_training(): wait_until_free()
time.sleep(1)
if dev2: if dev2:
nlmaps_mt.validate('dev2') nlmaps_mt.validate('dev2')
while nlmaps_mt.is_training(): wait_until_free()
time.sleep(1)
for train_idx, (train, dev, test) in enumerate( for train_idx, (train, dev, test) in enumerate(
zip_longest(data['train'], data['dev'], data['test']), start=1): zip_longest(data['train'], data['dev'], data['test']), start=1):
if train: if train:
nlmaps_mt.save_feedback(train, 'train') nlmaps_mt.save_feedback(train, 'train')
time.sleep(wait_time) if wait_time is None:
wait_until_free()
else:
time.sleep(wait_time)
if dev: if dev:
nlmaps_mt.save_feedback(dev, 'dev') nlmaps_mt.save_feedback(dev, 'dev')
if test: if test:
...@@ -135,15 +139,12 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, ...@@ -135,15 +139,12 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
nlmaps_mt.save_feedback(test, split) nlmaps_mt.save_feedback(test, split)
if validation_freq and train_idx % validation_freq == 0: if validation_freq and train_idx % validation_freq == 0:
while nlmaps_mt.is_training(): wait_until_free()
time.sleep(1)
nlmaps_mt.validate('dev') nlmaps_mt.validate('dev')
while nlmaps_mt.is_training(): wait_until_free()
time.sleep(1)
if dev2: if dev2:
nlmaps_mt.validate('dev2') nlmaps_mt.validate('dev2')
while nlmaps_mt.is_training(): wait_until_free()
time.sleep(1)
def parse_args(): def parse_args():
...@@ -151,9 +152,10 @@ def parse_args(): ...@@ -151,9 +152,10 @@ def parse_args():
parser.add_argument('dataset_dir', help='Dataset directory. Must contain' parser.add_argument('dataset_dir', help='Dataset directory. Must contain'
' six files matching in *{train,dev,test}.{en,lin}') ' six files matching in *{train,dev,test}.{en,lin}')
parser.add_argument('model', help='Path to model config yaml file') parser.add_argument('model', help='Path to model config yaml file')
parser.add_argument('--wait-time', type=int, default=3, parser.add_argument('--wait-time', type=int,
help='Number of seconds to wait between saving pieces' help='Number of seconds to wait between saving pieces'
' of feedback.') ' of feedback. Default is waiting until training is'
' done.')
parser.add_argument('--validation-freq', type=int, default=10, parser.add_argument('--validation-freq', type=int, default=10,
help='Validate on dev from model config after every N' help='Validate on dev from model config after every N'
' pieces of feedback.') ' pieces of feedback.')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment