diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py index f944aa8cf013ad4b619eb35297d6da63dad21ec8..cca698d8481525ec697d608912872f4ea64c5563 100644 --- a/joeynmt_server/joey_model.py +++ b/joeynmt_server/joey_model.py @@ -125,6 +125,12 @@ class JoeyModel: self._train_dataset = self._load_train_dataset() return self._train_dataset + @property + def steps(self): + if self.train_manager: + return self.train_manager.stats.steps + return None + def is_still_latest(self): train_config = self.config['training'] latest_ckpt = get_latest_checkpoint(train_config['model_dir']) diff --git a/joeynmt_server/models/evaluation_results.py b/joeynmt_server/models/evaluation_results.py index 799e948363432fcdf4c49e471e16e28b1924758b..e971b1d6fe39849b06f33a904daf04ac9089151a 100644 --- a/joeynmt_server/models/evaluation_results.py +++ b/joeynmt_server/models/evaluation_results.py @@ -8,6 +8,7 @@ class EvaluationResult(BaseModel): label = db.Column(db.Unicode(50), nullable=False) model = db.Column(db.Unicode(500), nullable=False) + steps = db.Column(db.Integer, nullable=True) correct = db.Column(db.Integer, nullable=False) total = db.Column(db.Integer, nullable=False) diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py index f1df7ee0bb9b8cd9be578b2470532e0e113a4410..87599ec4a554cc4ed7eeac6e7e720edf4d2a1402 100644 --- a/joeynmt_server/trainer.py +++ b/joeynmt_server/trainer.py @@ -168,7 +168,7 @@ def train_n_rounds(config_basename, min_rounds=10): model = train(config_basename, smallest_usage_count, train1, train2) - if dev: + if dev and current_app.config.get('RUNNING_VALIDATION', True): dev_set = make_dataset_from_feedback(dev, model) logging.info('Validating on {} feedback pieces.' .format(len(dev_set))) @@ -179,8 +179,8 @@ def train_n_rounds(config_basename, min_rounds=10): logging.info('Got validation result: {}/{} = {}.' .format(correct, total, accuracy)) evr = EvaluationResult( - label='running_dev', model=config_basename, - correct=correct, total=total + label='running_dev', steps=model.steps, + model=config_basename, correct=correct, total=total ) db.session.add(evr) db.session.commit() @@ -247,8 +247,9 @@ def validate(config_basename, dataset_name='dev'): correct = round(accuracy * total) logging.info('Got validation result: {}/{} = {}.' .format(correct, total, accuracy)) - evr = EvaluationResult(label=dataset_name, model=config_basename, - correct=correct, total=total) + evr = EvaluationResult(label=dataset_name, steps=model.steps, + model=config_basename, correct=correct, + total=total) db.session.add(evr) db.session.commit() except: