From 6f58580d364a22ff84f587165bd4107b8dfecabd Mon Sep 17 00:00:00 2001
From: Simon Will <will@cl.uni-heidelberg.de>
Date: Wed, 24 Mar 2021 01:03:53 +0100
Subject: [PATCH] Fix a number of mistakes with simulated online training

---
 joeynmt-server-on-cluster.sh         |  5 +++--
 joeynmt_server/config/development.py |  4 ++--
 joeynmt_server/trainer.py            |  4 ++--
 simulate_training.py                 | 11 +++++++++--
 4 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/joeynmt-server-on-cluster.sh b/joeynmt-server-on-cluster.sh
index a00aa67..fc6ad83 100755
--- a/joeynmt-server-on-cluster.sh
+++ b/joeynmt-server-on-cluster.sh
@@ -1,7 +1,7 @@
 #!/usr/bin/env bash
 #SBATCH --job-name=joeynmt-server
-#SBATCH --partition=compute
-#SBATCH --nodelist=node40
+#SBATCH --partition=students
+#SBATCH --gres=gpu:1
 #SBATCH --nodes=1
 #SBATCH --mem=5GB
 #SBATCH --time=3-00:00:00
@@ -14,6 +14,7 @@ CLUSTER_MAIN_NODE=node00
 export FLASK_APP=joeynmt_server.fullapp:app
 export FLASK_ENV=development
 export FLASK_DEBUG=true
+export JOEYNMT_SERVER_REPO="$HOME/ma/joeynmt-server"
 export ASSETS="$JOEYNMT_SERVER_REPO/dev-assets"
 
 if [ -z "$CONDA_DEFAULT_ENV" ]; then
diff --git a/joeynmt_server/config/development.py b/joeynmt_server/config/development.py
index f28e662..1bff0b0 100644
--- a/joeynmt_server/config/development.py
+++ b/joeynmt_server/config/development.py
@@ -18,7 +18,7 @@ SQLALCHEMY_DATABASE_URI = 'sqlite:///{}'.format(
 with open(ASSETS_DIR / 'secret_key.txt') as f:
     SECRET_KEY = f.read().strip()
 
-TRAIN_AFTER_FEEDBACK = False
+TRAIN_AFTER_FEEDBACK = True
 
 logging.config.dictConfig({
     'version': 1,
@@ -39,7 +39,7 @@ logging.config.dictConfig({
         },
     },
     'root': {
-        'level': 'DEBUG',
+        'level': 'INFO',
         'handlers': ['stdout', 'logfile']
     }
 })
diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py
index da613b1..e6da1e9 100644
--- a/joeynmt_server/trainer.py
+++ b/joeynmt_server/trainer.py
@@ -235,14 +235,14 @@ def validate(config_basename, dataset_name='dev'):
             logging.error(msg)
             raise ValueError(msg)
 
-        logging.info('Validating on dev set.')
+        logging.info('Validating on dataset {}.'.format(dataset_name)
         results = model.validate(dev_set)
         accuracy = results['score'] / 100
         total = len(dev_set)
         correct = round(accuracy * total)
         logging.info('Got validation result: {}/{} = {}.'
                      .format(correct, total, accuracy))
-        evr = EvaluationResult(label='file_dev', correct=correct,
+        evr = EvaluationResult(label=dataset_name, correct=correct,
                                total=total)
         db.session.add(evr)
         db.session.commit()
diff --git a/simulate_training.py b/simulate_training.py
index e95785a..e220aad 100644
--- a/simulate_training.py
+++ b/simulate_training.py
@@ -13,10 +13,15 @@ Instance = namedtuple('Instance', ('nl', 'lin'))
 
 NLMAPS_MT_BASE_URL = 'http://localhost:5050'
 
+logging.basicConfig(
+    format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
+    level=logging.INFO
+)
+
 
 def find_and_read_file(dataset_dir, basenames, split, suffix):
     end = '{}.{}'.format(split, suffix)
-    suitable_basenames = [name.endswith(end) for name in basenames]
+    suitable_basenames = [name for name in basenames if name.endswith(end)]
     if len(suitable_basenames) == 0:
         logging.warning('Found no file matching *{}'.format(end))
         return []
@@ -63,7 +68,9 @@ class NLMapsMT:
 
     def _make_url(self, path):
         parsed = urllib.parse.ParseResult(
-            scheme=self.scheme, netloc=self.netloc, path=path)
+            scheme=self.scheme, netloc=self.netloc, path=path,
+            params=None, query=None, fragment=None
+        )
         return parsed.geturl()
 
     def save_feedback(self, instance: Instance, split):
-- 
GitLab