Skip to content
Snippets Groups Projects
Commit 09bd9def authored by Simon Will's avatar Simon Will
Browse files

Add steps column to evaluation_results table (no migration yet)

parent f9a60cac
No related branches found
No related tags found
No related merge requests found
...@@ -125,6 +125,12 @@ class JoeyModel: ...@@ -125,6 +125,12 @@ class JoeyModel:
self._train_dataset = self._load_train_dataset() self._train_dataset = self._load_train_dataset()
return self._train_dataset return self._train_dataset
@property
def steps(self):
if self.train_manager:
return self.train_manager.stats.steps
return None
def is_still_latest(self): def is_still_latest(self):
train_config = self.config['training'] train_config = self.config['training']
latest_ckpt = get_latest_checkpoint(train_config['model_dir']) latest_ckpt = get_latest_checkpoint(train_config['model_dir'])
......
...@@ -8,6 +8,7 @@ class EvaluationResult(BaseModel): ...@@ -8,6 +8,7 @@ class EvaluationResult(BaseModel):
label = db.Column(db.Unicode(50), nullable=False) label = db.Column(db.Unicode(50), nullable=False)
model = db.Column(db.Unicode(500), nullable=False) model = db.Column(db.Unicode(500), nullable=False)
steps = db.Column(db.Integer, nullable=True)
correct = db.Column(db.Integer, nullable=False) correct = db.Column(db.Integer, nullable=False)
total = db.Column(db.Integer, nullable=False) total = db.Column(db.Integer, nullable=False)
......
...@@ -168,7 +168,7 @@ def train_n_rounds(config_basename, min_rounds=10): ...@@ -168,7 +168,7 @@ def train_n_rounds(config_basename, min_rounds=10):
model = train(config_basename, smallest_usage_count, train1, model = train(config_basename, smallest_usage_count, train1,
train2) train2)
if dev: if dev and current_app.config.get('RUNNING_VALIDATION', True):
dev_set = make_dataset_from_feedback(dev, model) dev_set = make_dataset_from_feedback(dev, model)
logging.info('Validating on {} feedback pieces.' logging.info('Validating on {} feedback pieces.'
.format(len(dev_set))) .format(len(dev_set)))
...@@ -179,8 +179,8 @@ def train_n_rounds(config_basename, min_rounds=10): ...@@ -179,8 +179,8 @@ def train_n_rounds(config_basename, min_rounds=10):
logging.info('Got validation result: {}/{} = {}.' logging.info('Got validation result: {}/{} = {}.'
.format(correct, total, accuracy)) .format(correct, total, accuracy))
evr = EvaluationResult( evr = EvaluationResult(
label='running_dev', model=config_basename, label='running_dev', steps=model.steps,
correct=correct, total=total model=config_basename, correct=correct, total=total
) )
db.session.add(evr) db.session.add(evr)
db.session.commit() db.session.commit()
...@@ -247,8 +247,9 @@ def validate(config_basename, dataset_name='dev'): ...@@ -247,8 +247,9 @@ def validate(config_basename, dataset_name='dev'):
correct = round(accuracy * total) correct = round(accuracy * total)
logging.info('Got validation result: {}/{} = {}.' logging.info('Got validation result: {}/{} = {}.'
.format(correct, total, accuracy)) .format(correct, total, accuracy))
evr = EvaluationResult(label=dataset_name, model=config_basename, evr = EvaluationResult(label=dataset_name, steps=model.steps,
correct=correct, total=total) model=config_basename, correct=correct,
total=total)
db.session.add(evr) db.session.add(evr)
db.session.commit() db.session.commit()
except: except:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment