From 40e914d976c2d7f80f4e68ca6e8c3e9c10985076 Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Mon, 29 Mar 2021 20:07:04 +0200 Subject: [PATCH] =?UTF-8?q?Revert=20"Don=E2=80=99t=20reset=20scheduler=20a?= =?UTF-8?q?nd=20optimizer=20if=20loading=20a=20checkpoint=20from=20online?= =?UTF-8?q?=20learning"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 6f7602303ce757f0b60a4670515e6d83d6139f8a. --- joeynmt_server/joey_model.py | 29 ++++------------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py index c38b3f3..d7545f8 100644 --- a/joeynmt_server/joey_model.py +++ b/joeynmt_server/joey_model.py @@ -37,29 +37,6 @@ def make_config_absolute(config, root): return config -def get_saver(ckpt): - saver = None - base_without_ext, _ = os.path.splitext(os.path.basename(ckpt)) - saver_path = os.path.join(os.path.dirname(ckpt), - '{}.saver'.format(base_without_ext)) - if os.path.isfile(saver_path): - with open(saver_path) as f: - saver = f.read() - return saver - - -def set_reset_config(training_config, ckpt): - saver = get_saver(ckpt) - if saver == 'nlmaps': - training_config['reset_best_ckpt'] = False - training_config['reset_scheduler'] = False - training_config['reset_optimizer'] = False - else: - training_config['reset_best_ckpt'] = False - training_config['reset_scheduler'] = True - training_config['reset_optimizer'] = True - - class JoeyModel: def __init__(self, config, model, train_manager, src_field, trg_field, @@ -129,7 +106,9 @@ class JoeyModel: ckpt = get_latest_checkpoint(model_dir) ckpt_ctime = os.path.getctime(ckpt) config['training']['load_model'] = ckpt - set_reset_config(config['training'], ckpt) + config['training']['reset_best_ckpt'] = False + config['training']['reset_scheduler'] = True + config['training']['reset_optimizer'] = True train_manager = TrainManager(model=model, config=config) #model_checkpoint = load_checkpoint(ckpt, use_cuda=test_args['use_cuda']) @@ -255,7 +234,7 @@ class JoeyModel: batch_callback(batch) logging.info('Finished training after {} batches'.format(i + 1)) - trainer._save_checkpoint(saver='nlmaps') + trainer._save_checkpoint() if dev_set: dev_results = self.validate(dev_set) -- GitLab