From c1b2b3e582866c737ef38521388119fad1df158c Mon Sep 17 00:00:00 2001
From: Simon Will <will@cl.uni-heidelberg.de>
Date: Thu, 22 Apr 2021 20:39:44 +0200
Subject: [PATCH] Add config options for online setup

---
 joeynmt_server/config/default.py |  6 ++++++
 joeynmt_server/trainer.py        | 12 ++++++------
 2 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/joeynmt_server/config/default.py b/joeynmt_server/config/default.py
index 55f12d8..61aa457 100644
--- a/joeynmt_server/config/default.py
+++ b/joeynmt_server/config/default.py
@@ -17,3 +17,9 @@ USE_CUDA_TRANSLATE = True
 USE_CUDA_TRAIN = 'gpu' in socket.gethostname()
 
 TRAIN_AFTER_FEEDBACK = False
+
+ONLINE_IMMEDIATE_ITERATIONS = 5
+ONLINE_IMMEDIATE_BATCH_SIZE = 5
+ONLINE_MEMORY_BATCH_SIZE = 5
+ONLINE_OLD_DATA_BATCH_SIZE = 5
+ONLINE_RUNNING_VALIDATION = False
diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py
index 87599ec..2515f4b 100644
--- a/joeynmt_server/trainer.py
+++ b/joeynmt_server/trainer.py
@@ -29,10 +29,10 @@ def make_dataset_from_feedback(feedback, model):
 
 
 def train(config_basename, smallest_usage_count, segment_1, segment_2):
-    segment_1_threshold = 5
-    segment_1_batch_size = 5
-    segment_2_batch_size = 5
-    segment_3_batch_size = 15
+    segment_1_threshold = current_app.config.get('ONLINE_IMMEDIATE_ITERATIONS', 5)
+    segment_1_batch_size = current_app.config.get('ONLINE_IMMEDIATE_BATCH_SIZE', 5)
+    segment_2_batch_size = current_app.config.get('ONLINE_MEMORY_BATCH_SIZE', 5)
+    segment_3_batch_size = current_app.config.get('ONLINE_OLD_DATA_BATCH_SIZE', 5)
 
     joey_dir = current_app.config.get('JOEY_DIR')
     config_file = joey_dir / 'configs' / config_basename
@@ -103,7 +103,7 @@ def train(config_basename, smallest_usage_count, segment_1, segment_2):
 
 
 def get_feedback_segments(config_basename):
-    segment_1_threshold = 5
+    segment_1_threshold = current_app.config.get('ONLINE_IMMEDIATE_ITERATIONS', 5)
     feedback = Feedback.query.all()
     train_segment_1 = []
     train_segment_2 = []
@@ -168,7 +168,7 @@ def train_n_rounds(config_basename, min_rounds=10):
             model = train(config_basename, smallest_usage_count, train1,
                           train2)
 
-            if dev and current_app.config.get('RUNNING_VALIDATION', True):
+            if dev and current_app.config.get('ONLINE_RUNNING_VALIDATION', True):
                 dev_set = make_dataset_from_feedback(dev, model)
                 logging.info('Validating on {} feedback pieces.'
                              .format(len(dev_set)))
-- 
GitLab