From a9181cf61009441b21af4c975584264552a29df3 Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Mon, 22 Mar 2021 18:29:55 +0100 Subject: [PATCH] Use sequence_accuracy as eval_metric --- joeynmt_server/joey_model.py | 2 +- joeynmt_server/trainer.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py index d541449..cf089d9 100644 --- a/joeynmt_server/joey_model.py +++ b/joeynmt_server/joey_model.py @@ -162,10 +162,10 @@ class JoeyModel: valid_kwargs = {k: v for k, v in self.test_args.items() if k not in ['decoding_description', 'tokenizer_info', 'tag_dict_file']} - valid_kwargs['eval_metric'] = '' valid_kwargs['compute_loss'] = False valid_kwargs['data'] = dataset valid_kwargs.update(kwargs) + valid_kwargs['eval_metric'] = 'sequence_accuracy' (score, loss, ppl, sources, sources_raw, references, hypotheses, hypotheses_raw, valid_attention_scores) = validate_on_data( diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py index 238b0da..12fa0da 100644 --- a/joeynmt_server/trainer.py +++ b/joeynmt_server/trainer.py @@ -170,10 +170,11 @@ def train_n_rounds(config_basename, min_rounds=10): logging.info('Validating on {} feedback pieces.' .format(len(dev_set))) results = model.validate(dev_set) + accuracy = results['score'] / 100 total = len(dev_set) - correct = results['score'] * total + correct = round(accuracy * total) logging.info('Got validation result: {}/{} = {}.' - .format(correct, total, results['score'])) + .format(correct, total, accuracy)) evr = EvaluationResult(label='running_dev', correct=correct, total=total) db.session.add(evr) @@ -232,10 +233,11 @@ def validate(config_basename): logging.info('Validating on dev set.') results = model.validate(dev_set) + accuracy = results['score'] / 100 total = len(dev_set) - correct = results['score'] * total + correct = round(accuracy * total) logging.info('Got validation result: {}/{} = {}.' - .format(correct, total, results['score'])) + .format(correct, total, accuracy)) evr = EvaluationResult(label='file_dev', correct=correct, total=total) db.session.add(evr) @@ -248,7 +250,3 @@ def validate(config_basename): db.session.delete(lock) db.session.commit() - - correct = results['score'] - total = len(dev_set) - EvaluationResult(label='changing_dev', correct=correct, total=total) -- GitLab