From a9181cf61009441b21af4c975584264552a29df3 Mon Sep 17 00:00:00 2001
From: Simon Will <will@cl.uni-heidelberg.de>
Date: Mon, 22 Mar 2021 18:29:55 +0100
Subject: [PATCH] Use sequence_accuracy as eval_metric

---
 joeynmt_server/joey_model.py |  2 +-
 joeynmt_server/trainer.py    | 14 ++++++--------
 2 files changed, 7 insertions(+), 9 deletions(-)

diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py
index d541449..cf089d9 100644
--- a/joeynmt_server/joey_model.py
+++ b/joeynmt_server/joey_model.py
@@ -162,10 +162,10 @@ class JoeyModel:
         valid_kwargs = {k: v for k, v in self.test_args.items()
                         if k not in ['decoding_description', 'tokenizer_info',
                                      'tag_dict_file']}
-        valid_kwargs['eval_metric'] = ''
         valid_kwargs['compute_loss'] = False
         valid_kwargs['data'] = dataset
         valid_kwargs.update(kwargs)
+        valid_kwargs['eval_metric'] = 'sequence_accuracy'
 
         (score, loss, ppl, sources, sources_raw, references, hypotheses,
          hypotheses_raw, valid_attention_scores) = validate_on_data(
diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py
index 238b0da..12fa0da 100644
--- a/joeynmt_server/trainer.py
+++ b/joeynmt_server/trainer.py
@@ -170,10 +170,11 @@ def train_n_rounds(config_basename, min_rounds=10):
                 logging.info('Validating on {} feedback pieces.'
                              .format(len(dev_set)))
                 results = model.validate(dev_set)
+                accuracy = results['score'] / 100
                 total = len(dev_set)
-                correct = results['score'] * total
+                correct = round(accuracy * total)
                 logging.info('Got validation result: {}/{} = {}.'
-                             .format(correct, total, results['score']))
+                             .format(correct, total, accuracy))
                 evr = EvaluationResult(label='running_dev', correct=correct,
                                        total=total)
                 db.session.add(evr)
@@ -232,10 +233,11 @@ def validate(config_basename):
 
         logging.info('Validating on dev set.')
         results = model.validate(dev_set)
+        accuracy = results['score'] / 100
         total = len(dev_set)
-        correct = results['score'] * total
+        correct = round(accuracy * total)
         logging.info('Got validation result: {}/{} = {}.'
-                     .format(correct, total, results['score']))
+                     .format(correct, total, accuracy))
         evr = EvaluationResult(label='file_dev', correct=correct,
                                total=total)
         db.session.add(evr)
@@ -248,7 +250,3 @@ def validate(config_basename):
 
     db.session.delete(lock)
     db.session.commit()
-
-    correct = results['score']
-    total = len(dev_set)
-    EvaluationResult(label='changing_dev', correct=correct, total=total)
-- 
GitLab