From 6f58580d364a22ff84f587165bd4107b8dfecabd Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Wed, 24 Mar 2021 01:03:53 +0100 Subject: [PATCH] Fix a number of mistakes with simulated online training --- joeynmt-server-on-cluster.sh | 5 +++-- joeynmt_server/config/development.py | 4 ++-- joeynmt_server/trainer.py | 4 ++-- simulate_training.py | 11 +++++++++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/joeynmt-server-on-cluster.sh b/joeynmt-server-on-cluster.sh index a00aa67..fc6ad83 100755 --- a/joeynmt-server-on-cluster.sh +++ b/joeynmt-server-on-cluster.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash #SBATCH --job-name=joeynmt-server -#SBATCH --partition=compute -#SBATCH --nodelist=node40 +#SBATCH --partition=students +#SBATCH --gres=gpu:1 #SBATCH --nodes=1 #SBATCH --mem=5GB #SBATCH --time=3-00:00:00 @@ -14,6 +14,7 @@ CLUSTER_MAIN_NODE=node00 export FLASK_APP=joeynmt_server.fullapp:app export FLASK_ENV=development export FLASK_DEBUG=true +export JOEYNMT_SERVER_REPO="$HOME/ma/joeynmt-server" export ASSETS="$JOEYNMT_SERVER_REPO/dev-assets" if [ -z "$CONDA_DEFAULT_ENV" ]; then diff --git a/joeynmt_server/config/development.py b/joeynmt_server/config/development.py index f28e662..1bff0b0 100644 --- a/joeynmt_server/config/development.py +++ b/joeynmt_server/config/development.py @@ -18,7 +18,7 @@ SQLALCHEMY_DATABASE_URI = 'sqlite:///{}'.format( with open(ASSETS_DIR / 'secret_key.txt') as f: SECRET_KEY = f.read().strip() -TRAIN_AFTER_FEEDBACK = False +TRAIN_AFTER_FEEDBACK = True logging.config.dictConfig({ 'version': 1, @@ -39,7 +39,7 @@ logging.config.dictConfig({ }, }, 'root': { - 'level': 'DEBUG', + 'level': 'INFO', 'handlers': ['stdout', 'logfile'] } }) diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py index da613b1..e6da1e9 100644 --- a/joeynmt_server/trainer.py +++ b/joeynmt_server/trainer.py @@ -235,14 +235,14 @@ def validate(config_basename, dataset_name='dev'): logging.error(msg) raise ValueError(msg) - logging.info('Validating on dev set.') + logging.info('Validating on dataset {}.'.format(dataset_name) results = model.validate(dev_set) accuracy = results['score'] / 100 total = len(dev_set) correct = round(accuracy * total) logging.info('Got validation result: {}/{} = {}.' .format(correct, total, accuracy)) - evr = EvaluationResult(label='file_dev', correct=correct, + evr = EvaluationResult(label=dataset_name, correct=correct, total=total) db.session.add(evr) db.session.commit() diff --git a/simulate_training.py b/simulate_training.py index e95785a..e220aad 100644 --- a/simulate_training.py +++ b/simulate_training.py @@ -13,10 +13,15 @@ Instance = namedtuple('Instance', ('nl', 'lin')) NLMAPS_MT_BASE_URL = 'http://localhost:5050' +logging.basicConfig( + format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s', + level=logging.INFO +) + def find_and_read_file(dataset_dir, basenames, split, suffix): end = '{}.{}'.format(split, suffix) - suitable_basenames = [name.endswith(end) for name in basenames] + suitable_basenames = [name for name in basenames if name.endswith(end)] if len(suitable_basenames) == 0: logging.warning('Found no file matching *{}'.format(end)) return [] @@ -63,7 +68,9 @@ class NLMapsMT: def _make_url(self, path): parsed = urllib.parse.ParseResult( - scheme=self.scheme, netloc=self.netloc, path=path) + scheme=self.scheme, netloc=self.netloc, path=path, + params=None, query=None, fragment=None + ) return parsed.geturl() def save_feedback(self, instance: Instance, split): -- GitLab