From 39477bceb4334c1fe189279e1df714543b793457 Mon Sep 17 00:00:00 2001
From: Simon Will <will@cl.uni-heidelberg.de>
Date: Thu, 22 Apr 2021 20:33:43 +0200
Subject: [PATCH] Add simulation option for saving one feedback at a time

---
 simulate_training.py | 32 +++++++++++++++++---------------
 1 file changed, 17 insertions(+), 15 deletions(-)

diff --git a/simulate_training.py b/simulate_training.py
index 5539572..1431676 100644
--- a/simulate_training.py
+++ b/simulate_training.py
@@ -107,27 +107,31 @@ class NLMapsMT:
             logging.error('Response status code: {}'.format(resp.status_code))
 
 
-def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
+def main(dataset_dir, model, wait_time=None, validation_freq=10, dev2=False,
          test_as_dev=False, base_url=NLMAPS_MT_BASE_URL, user_id=None):
     data = load_data(dataset_dir)
     nlmaps_mt = NLMapsMT(base_url, model=model, user_id=user_id)
 
-    if validation_freq:
+    def wait_until_free():
         while nlmaps_mt.is_training():
             time.sleep(1)
+
+    if validation_freq:
+        wait_until_free()
         nlmaps_mt.validate('dev')
-        while nlmaps_mt.is_training():
-            time.sleep(1)
+        wait_until_free()
         if dev2:
             nlmaps_mt.validate('dev2')
-            while nlmaps_mt.is_training():
-                time.sleep(1)
+            wait_until_free()
 
     for train_idx, (train, dev, test) in enumerate(
             zip_longest(data['train'], data['dev'], data['test']), start=1):
         if train:
             nlmaps_mt.save_feedback(train, 'train')
-            time.sleep(wait_time)
+            if wait_time is None:
+                wait_until_free()
+            else:
+                time.sleep(wait_time)
         if dev:
             nlmaps_mt.save_feedback(dev, 'dev')
         if test:
@@ -135,15 +139,12 @@ def main(dataset_dir, model, wait_time=3, validation_freq=10, dev2=False,
             nlmaps_mt.save_feedback(test, split)
 
         if validation_freq and train_idx % validation_freq == 0:
-            while nlmaps_mt.is_training():
-                time.sleep(1)
+            wait_until_free()
             nlmaps_mt.validate('dev')
-            while nlmaps_mt.is_training():
-                time.sleep(1)
+            wait_until_free()
             if dev2:
                 nlmaps_mt.validate('dev2')
-                while nlmaps_mt.is_training():
-                    time.sleep(1)
+                wait_until_free()
 
 
 def parse_args():
@@ -151,9 +152,10 @@ def parse_args():
     parser.add_argument('dataset_dir', help='Dataset directory. Must contain'
                         ' six files matching in *{train,dev,test}.{en,lin}')
     parser.add_argument('model', help='Path to model config yaml file')
-    parser.add_argument('--wait-time', type=int, default=3,
+    parser.add_argument('--wait-time', type=int,
                         help='Number of seconds to wait between saving pieces'
-                        ' of feedback.')
+                        ' of feedback. Default is waiting until training is'
+                        ' done.')
     parser.add_argument('--validation-freq', type=int, default=10,
                         help='Validate on dev from model config after every N'
                         ' pieces of feedback.')
-- 
GitLab