From 40e914d976c2d7f80f4e68ca6e8c3e9c10985076 Mon Sep 17 00:00:00 2001
From: Simon Will <will@cl.uni-heidelberg.de>
Date: Mon, 29 Mar 2021 20:07:04 +0200
Subject: [PATCH] =?UTF-8?q?Revert=20"Don=E2=80=99t=20reset=20scheduler=20a?=
 =?UTF-8?q?nd=20optimizer=20if=20loading=20a=20checkpoint=20from=20online?=
 =?UTF-8?q?=20learning"?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This reverts commit 6f7602303ce757f0b60a4670515e6d83d6139f8a.
---
 joeynmt_server/joey_model.py | 29 ++++-------------------------
 1 file changed, 4 insertions(+), 25 deletions(-)

diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py
index c38b3f3..d7545f8 100644
--- a/joeynmt_server/joey_model.py
+++ b/joeynmt_server/joey_model.py
@@ -37,29 +37,6 @@ def make_config_absolute(config, root):
     return config
 
 
-def get_saver(ckpt):
-    saver = None
-    base_without_ext, _ = os.path.splitext(os.path.basename(ckpt))
-    saver_path = os.path.join(os.path.dirname(ckpt),
-                              '{}.saver'.format(base_without_ext))
-    if os.path.isfile(saver_path):
-        with open(saver_path) as f:
-            saver = f.read()
-    return saver
-
-
-def set_reset_config(training_config, ckpt):
-    saver = get_saver(ckpt)
-    if saver == 'nlmaps':
-        training_config['reset_best_ckpt'] = False
-        training_config['reset_scheduler'] = False
-        training_config['reset_optimizer'] = False
-    else:
-        training_config['reset_best_ckpt'] = False
-        training_config['reset_scheduler'] = True
-        training_config['reset_optimizer'] = True
-
-
 class JoeyModel:
 
     def __init__(self, config, model, train_manager, src_field, trg_field,
@@ -129,7 +106,9 @@ class JoeyModel:
         ckpt = get_latest_checkpoint(model_dir)
         ckpt_ctime = os.path.getctime(ckpt)
         config['training']['load_model'] = ckpt
-        set_reset_config(config['training'], ckpt)
+        config['training']['reset_best_ckpt'] = False
+        config['training']['reset_scheduler'] = True
+        config['training']['reset_optimizer'] = True
         train_manager = TrainManager(model=model, config=config)
 
         #model_checkpoint = load_checkpoint(ckpt, use_cuda=test_args['use_cuda'])
@@ -255,7 +234,7 @@ class JoeyModel:
                 batch_callback(batch)
         logging.info('Finished training after {} batches'.format(i + 1))
 
-        trainer._save_checkpoint(saver='nlmaps')
+        trainer._save_checkpoint()
 
         if dev_set:
             dev_results = self.validate(dev_set)
-- 
GitLab