From c1b2b3e582866c737ef38521388119fad1df158c Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Thu, 22 Apr 2021 20:39:44 +0200 Subject: [PATCH] Add config options for online setup --- joeynmt_server/config/default.py | 6 ++++++ joeynmt_server/trainer.py | 12 ++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/joeynmt_server/config/default.py b/joeynmt_server/config/default.py index 55f12d8..61aa457 100644 --- a/joeynmt_server/config/default.py +++ b/joeynmt_server/config/default.py @@ -17,3 +17,9 @@ USE_CUDA_TRANSLATE = True USE_CUDA_TRAIN = 'gpu' in socket.gethostname() TRAIN_AFTER_FEEDBACK = False + +ONLINE_IMMEDIATE_ITERATIONS = 5 +ONLINE_IMMEDIATE_BATCH_SIZE = 5 +ONLINE_MEMORY_BATCH_SIZE = 5 +ONLINE_OLD_DATA_BATCH_SIZE = 5 +ONLINE_RUNNING_VALIDATION = False diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py index 87599ec..2515f4b 100644 --- a/joeynmt_server/trainer.py +++ b/joeynmt_server/trainer.py @@ -29,10 +29,10 @@ def make_dataset_from_feedback(feedback, model): def train(config_basename, smallest_usage_count, segment_1, segment_2): - segment_1_threshold = 5 - segment_1_batch_size = 5 - segment_2_batch_size = 5 - segment_3_batch_size = 15 + segment_1_threshold = current_app.config.get('ONLINE_IMMEDIATE_ITERATIONS', 5) + segment_1_batch_size = current_app.config.get('ONLINE_IMMEDIATE_BATCH_SIZE', 5) + segment_2_batch_size = current_app.config.get('ONLINE_MEMORY_BATCH_SIZE', 5) + segment_3_batch_size = current_app.config.get('ONLINE_OLD_DATA_BATCH_SIZE', 5) joey_dir = current_app.config.get('JOEY_DIR') config_file = joey_dir / 'configs' / config_basename @@ -103,7 +103,7 @@ def train(config_basename, smallest_usage_count, segment_1, segment_2): def get_feedback_segments(config_basename): - segment_1_threshold = 5 + segment_1_threshold = current_app.config.get('ONLINE_IMMEDIATE_ITERATIONS', 5) feedback = Feedback.query.all() train_segment_1 = [] train_segment_2 = [] @@ -168,7 +168,7 @@ def train_n_rounds(config_basename, min_rounds=10): model = train(config_basename, smallest_usage_count, train1, train2) - if dev and current_app.config.get('RUNNING_VALIDATION', True): + if dev and current_app.config.get('ONLINE_RUNNING_VALIDATION', True): dev_set = make_dataset_from_feedback(dev, model) logging.info('Validating on {} feedback pieces.' .format(len(dev_set))) -- GitLab