Skip to content
Snippets Groups Projects
Commit 99ab7156 authored by Simon Will's avatar Simon Will
Browse files

Don’t reset scheduler and optimizer if loading a checkpoint from online learning

parent 76c2a6ce
No related branches found
No related tags found
No related merge requests found
......@@ -37,6 +37,29 @@ def make_config_absolute(config, root):
return config
def get_saver(ckpt):
saver = None
base_without_ext, _ = os.path.splitext(os.path.basename(ckpt))
saver_path = os.path.join(os.path.dirname(ckpt),
'{}.saver'.format(base_without_ext))
if os.path.isfile(saver_path):
with open(saver_path) as f:
saver = f.read()
return saver
def set_reset_config(training_config, ckpt):
saver = get_saver(ckpt)
if saver == 'nlmaps':
training_config['reset_best_ckpt'] = False
training_config['reset_scheduler'] = False
training_config['reset_optimizer'] = False
else:
training_config['reset_best_ckpt'] = False
training_config['reset_scheduler'] = True
training_config['reset_optimizer'] = True
class JoeyModel:
def __init__(self, config, model, train_manager, src_field, trg_field,
......@@ -106,9 +129,7 @@ class JoeyModel:
ckpt = get_latest_checkpoint(model_dir)
ckpt_ctime = os.path.getctime(ckpt)
config['training']['load_model'] = ckpt
config['training']['reset_best_ckpt'] = False
config['training']['reset_scheduler'] = True
config['training']['reset_optimizer'] = True
set_reset_config(config['training'], ckpt)
train_manager = TrainManager(model=model, config=config)
#model_checkpoint = load_checkpoint(ckpt, use_cuda=test_args['use_cuda'])
......@@ -243,7 +264,7 @@ class JoeyModel:
batch_callback(batch)
logging.info('Finished training after {} batches'.format(i + 1))
trainer._save_checkpoint()
trainer._save_checkpoint(saver='nlmaps')
if dev_set:
dev_results = self.validate(dev_set)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment