diff --git a/joeynmt_server/config/default.py b/joeynmt_server/config/default.py index 55f12d81cbf98b1be3a8f2bed3b0f300f19a4b89..61aa4578c1dfcd68565ef035d149603e05e6fbab 100644 --- a/joeynmt_server/config/default.py +++ b/joeynmt_server/config/default.py @@ -17,3 +17,9 @@ USE_CUDA_TRANSLATE = True USE_CUDA_TRAIN = 'gpu' in socket.gethostname() TRAIN_AFTER_FEEDBACK = False + +ONLINE_IMMEDIATE_ITERATIONS = 5 +ONLINE_IMMEDIATE_BATCH_SIZE = 5 +ONLINE_MEMORY_BATCH_SIZE = 5 +ONLINE_OLD_DATA_BATCH_SIZE = 5 +ONLINE_RUNNING_VALIDATION = False diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py index 87599ec4a554cc4ed7eeac6e7e720edf4d2a1402..2515f4b4a2b76c1a3486bcbaa2f46a8925a1cc81 100644 --- a/joeynmt_server/trainer.py +++ b/joeynmt_server/trainer.py @@ -29,10 +29,10 @@ def make_dataset_from_feedback(feedback, model): def train(config_basename, smallest_usage_count, segment_1, segment_2): - segment_1_threshold = 5 - segment_1_batch_size = 5 - segment_2_batch_size = 5 - segment_3_batch_size = 15 + segment_1_threshold = current_app.config.get('ONLINE_IMMEDIATE_ITERATIONS', 5) + segment_1_batch_size = current_app.config.get('ONLINE_IMMEDIATE_BATCH_SIZE', 5) + segment_2_batch_size = current_app.config.get('ONLINE_MEMORY_BATCH_SIZE', 5) + segment_3_batch_size = current_app.config.get('ONLINE_OLD_DATA_BATCH_SIZE', 5) joey_dir = current_app.config.get('JOEY_DIR') config_file = joey_dir / 'configs' / config_basename @@ -103,7 +103,7 @@ def train(config_basename, smallest_usage_count, segment_1, segment_2): def get_feedback_segments(config_basename): - segment_1_threshold = 5 + segment_1_threshold = current_app.config.get('ONLINE_IMMEDIATE_ITERATIONS', 5) feedback = Feedback.query.all() train_segment_1 = [] train_segment_2 = [] @@ -168,7 +168,7 @@ def train_n_rounds(config_basename, min_rounds=10): model = train(config_basename, smallest_usage_count, train1, train2) - if dev and current_app.config.get('RUNNING_VALIDATION', True): + if dev and current_app.config.get('ONLINE_RUNNING_VALIDATION', True): dev_set = make_dataset_from_feedback(dev, model) logging.info('Validating on {} feedback pieces.' .format(len(dev_set)))