From 6f7602303ce757f0b60a4670515e6d83d6139f8a Mon Sep 17 00:00:00 2001
From: Simon Will <will@cl.uni-heidelberg.de>
Date: Mon, 29 Mar 2021 17:48:26 +0200
Subject: [PATCH] =?UTF-8?q?Don=E2=80=99t=20reset=20scheduler=20and=20optim?=
 =?UTF-8?q?izer=20if=20loading=20a=20checkpoint=20from=20online=20learning?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 joeynmt_server/joey_model.py | 29 +++++++++++++++++++++++++----
 1 file changed, 25 insertions(+), 4 deletions(-)

diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py
index d7545f8..c38b3f3 100644
--- a/joeynmt_server/joey_model.py
+++ b/joeynmt_server/joey_model.py
@@ -37,6 +37,29 @@ def make_config_absolute(config, root):
     return config
 
 
+def get_saver(ckpt):
+    saver = None
+    base_without_ext, _ = os.path.splitext(os.path.basename(ckpt))
+    saver_path = os.path.join(os.path.dirname(ckpt),
+                              '{}.saver'.format(base_without_ext))
+    if os.path.isfile(saver_path):
+        with open(saver_path) as f:
+            saver = f.read()
+    return saver
+
+
+def set_reset_config(training_config, ckpt):
+    saver = get_saver(ckpt)
+    if saver == 'nlmaps':
+        training_config['reset_best_ckpt'] = False
+        training_config['reset_scheduler'] = False
+        training_config['reset_optimizer'] = False
+    else:
+        training_config['reset_best_ckpt'] = False
+        training_config['reset_scheduler'] = True
+        training_config['reset_optimizer'] = True
+
+
 class JoeyModel:
 
     def __init__(self, config, model, train_manager, src_field, trg_field,
@@ -106,9 +129,7 @@ class JoeyModel:
         ckpt = get_latest_checkpoint(model_dir)
         ckpt_ctime = os.path.getctime(ckpt)
         config['training']['load_model'] = ckpt
-        config['training']['reset_best_ckpt'] = False
-        config['training']['reset_scheduler'] = True
-        config['training']['reset_optimizer'] = True
+        set_reset_config(config['training'], ckpt)
         train_manager = TrainManager(model=model, config=config)
 
         #model_checkpoint = load_checkpoint(ckpt, use_cuda=test_args['use_cuda'])
@@ -234,7 +255,7 @@ class JoeyModel:
                 batch_callback(batch)
         logging.info('Finished training after {} batches'.format(i + 1))
 
-        trainer._save_checkpoint()
+        trainer._save_checkpoint(saver='nlmaps')
 
         if dev_set:
             dev_results = self.validate(dev_set)
-- 
GitLab