Skip to content
Snippets Groups Projects
Commit 6834a6e8 authored by Simon Will's avatar Simon Will
Browse files

Fix a few issues with online training simulation

parent 0919708d
No related branches found
No related tags found
No related merge requests found
...@@ -39,7 +39,7 @@ def load_data(dataset_dir): ...@@ -39,7 +39,7 @@ def load_data(dataset_dir):
data = {'train': [], 'dev': [], 'test': []} data = {'train': [], 'dev': [], 'test': []}
for split in data: for split in data:
en = find_and_read_file(dataset_dir, basenames, split, 'en') en = find_and_read_file(dataset_dir, basenames, split, 'en')
lin = find_and_read_file(dataset_dir, basenames, split, 'en') lin = find_and_read_file(dataset_dir, basenames, split, 'lin')
if len(en) != len(lin): if len(en) != len(lin):
if not en: if not en:
...@@ -85,7 +85,7 @@ class NLMapsMT: ...@@ -85,7 +85,7 @@ class NLMapsMT:
def is_training(self): def is_training(self):
url = self._make_url('/train_status') url = self._make_url('/train_status')
logging.info('GET to {}'.format(url)) logging.debug('GET to {}'.format(url))
resp = requests.get(url) resp = requests.get(url)
if resp.status_code != 200: if resp.status_code != 200:
logging.error('Response status code: {}'.format(resp.status_code)) logging.error('Response status code: {}'.format(resp.status_code))
...@@ -105,7 +105,7 @@ class NLMapsMT: ...@@ -105,7 +105,7 @@ class NLMapsMT:
def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
base_url=NLMAPS_MT_BASE_URL, user_id=1): test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=1):
data = load_data(dataset_dir) data = load_data(dataset_dir)
nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id) nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id)
...@@ -129,7 +129,8 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False, ...@@ -129,7 +129,8 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
if dev: if dev:
nlmaps_mt.save_feedback(dev, 'dev') nlmaps_mt.save_feedback(dev, 'dev')
if test: if test:
nlmaps_mt.save_feedback(test, 'test') split = 'dev' if test_as_dev else 'test'
nlmaps_mt.save_feedback(test, split)
def parse_args(): def parse_args():
...@@ -146,6 +147,7 @@ def parse_args(): ...@@ -146,6 +147,7 @@ def parse_args():
parser.add_argument('--dev2', default=False, action='store_true', parser.add_argument('--dev2', default=False, action='store_true',
help='At validation time, additionally use dev2' help='At validation time, additionally use dev2'
' from model config to validate on.') ' from model config to validate on.')
parser.add_argument('--test-as-dev', default=False, action='store_true')
parser.add_argument('--base-url', default=NLMAPS_MT_BASE_URL, parser.add_argument('--base-url', default=NLMAPS_MT_BASE_URL,
help='Base_Url of the NLMaps MT server.') help='Base_Url of the NLMaps MT server.')
parser.add_argument('--user-id', type=int, default=1, parser.add_argument('--user-id', type=int, default=1,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment