diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py index cca698d8481525ec697d608912872f4ea64c5563..58844caf9bc254ea3d774ea425a0294da43a8dce 100644 --- a/joeynmt_server/joey_model.py +++ b/joeynmt_server/joey_model.py @@ -37,6 +37,29 @@ def make_config_absolute(config, root): return config +def get_saver(ckpt): + saver = None + base_without_ext, _ = os.path.splitext(os.path.basename(ckpt)) + saver_path = os.path.join(os.path.dirname(ckpt), + '{}.saver'.format(base_without_ext)) + if os.path.isfile(saver_path): + with open(saver_path) as f: + saver = f.read() + return saver + + +def set_reset_config(training_config, ckpt): + saver = get_saver(ckpt) + if saver == 'nlmaps': + training_config['reset_best_ckpt'] = False + training_config['reset_scheduler'] = False + training_config['reset_optimizer'] = False + else: + training_config['reset_best_ckpt'] = False + training_config['reset_scheduler'] = True + training_config['reset_optimizer'] = True + + class JoeyModel: def __init__(self, config, model, train_manager, src_field, trg_field, @@ -106,9 +129,7 @@ class JoeyModel: ckpt = get_latest_checkpoint(model_dir) ckpt_ctime = os.path.getctime(ckpt) config['training']['load_model'] = ckpt - config['training']['reset_best_ckpt'] = False - config['training']['reset_scheduler'] = True - config['training']['reset_optimizer'] = True + set_reset_config(config['training'], ckpt) train_manager = TrainManager(model=model, config=config) #model_checkpoint = load_checkpoint(ckpt, use_cuda=test_args['use_cuda']) @@ -243,7 +264,7 @@ class JoeyModel: batch_callback(batch) logging.info('Finished training after {} batches'.format(i + 1)) - trainer._save_checkpoint() + trainer._save_checkpoint(saver='nlmaps') if dev_set: dev_results = self.validate(dev_set)