Skip to content
Snippets Groups Projects
Commit c1b2b3e5 authored by Simon Will's avatar Simon Will
Browse files

Add config options for online setup

parent 39477bce
No related branches found
No related tags found
No related merge requests found
...@@ -17,3 +17,9 @@ USE_CUDA_TRANSLATE = True ...@@ -17,3 +17,9 @@ USE_CUDA_TRANSLATE = True
USE_CUDA_TRAIN = 'gpu' in socket.gethostname() USE_CUDA_TRAIN = 'gpu' in socket.gethostname()
TRAIN_AFTER_FEEDBACK = False TRAIN_AFTER_FEEDBACK = False
ONLINE_IMMEDIATE_ITERATIONS = 5
ONLINE_IMMEDIATE_BATCH_SIZE = 5
ONLINE_MEMORY_BATCH_SIZE = 5
ONLINE_OLD_DATA_BATCH_SIZE = 5
ONLINE_RUNNING_VALIDATION = False
...@@ -29,10 +29,10 @@ def make_dataset_from_feedback(feedback, model): ...@@ -29,10 +29,10 @@ def make_dataset_from_feedback(feedback, model):
def train(config_basename, smallest_usage_count, segment_1, segment_2): def train(config_basename, smallest_usage_count, segment_1, segment_2):
segment_1_threshold = 5 segment_1_threshold = current_app.config.get('ONLINE_IMMEDIATE_ITERATIONS', 5)
segment_1_batch_size = 5 segment_1_batch_size = current_app.config.get('ONLINE_IMMEDIATE_BATCH_SIZE', 5)
segment_2_batch_size = 5 segment_2_batch_size = current_app.config.get('ONLINE_MEMORY_BATCH_SIZE', 5)
segment_3_batch_size = 15 segment_3_batch_size = current_app.config.get('ONLINE_OLD_DATA_BATCH_SIZE', 5)
joey_dir = current_app.config.get('JOEY_DIR') joey_dir = current_app.config.get('JOEY_DIR')
config_file = joey_dir / 'configs' / config_basename config_file = joey_dir / 'configs' / config_basename
...@@ -103,7 +103,7 @@ def train(config_basename, smallest_usage_count, segment_1, segment_2): ...@@ -103,7 +103,7 @@ def train(config_basename, smallest_usage_count, segment_1, segment_2):
def get_feedback_segments(config_basename): def get_feedback_segments(config_basename):
segment_1_threshold = 5 segment_1_threshold = current_app.config.get('ONLINE_IMMEDIATE_ITERATIONS', 5)
feedback = Feedback.query.all() feedback = Feedback.query.all()
train_segment_1 = [] train_segment_1 = []
train_segment_2 = [] train_segment_2 = []
...@@ -168,7 +168,7 @@ def train_n_rounds(config_basename, min_rounds=10): ...@@ -168,7 +168,7 @@ def train_n_rounds(config_basename, min_rounds=10):
model = train(config_basename, smallest_usage_count, train1, model = train(config_basename, smallest_usage_count, train1,
train2) train2)
if dev and current_app.config.get('RUNNING_VALIDATION', True): if dev and current_app.config.get('ONLINE_RUNNING_VALIDATION', True):
dev_set = make_dataset_from_feedback(dev, model) dev_set = make_dataset_from_feedback(dev, model)
logging.info('Validating on {} feedback pieces.' logging.info('Validating on {} feedback pieces.'
.format(len(dev_set))) .format(len(dev_set)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment