From 21be4a378eb03c7545f6f58f1b6daba313cc8555 Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Tue, 6 Apr 2021 13:07:29 +0200 Subject: [PATCH] Make training work by using non-determinstic random shuffler --- joeynmt_server/joey_model.py | 5 ++++- joeynmt_server/trainer.py | 3 +++ joeynmt_server/utils/batching.py | 13 ++++++++++++- joeynmt_server/views/errors.py | 1 - joeynmt_server/views/feedback.py | 4 ++-- 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/joeynmt_server/joey_model.py b/joeynmt_server/joey_model.py index d7545f8..f944aa8 100644 --- a/joeynmt_server/joey_model.py +++ b/joeynmt_server/joey_model.py @@ -217,7 +217,10 @@ class JoeyModel: for i, batch in enumerate(batches): logging.info('Training on batch {}.'.format(i + 1)) - logging.info('IDs in batch: {}'.format(batch.id)) + logging.info('Size of batch: {}'.format(len(batch))) + if hasattr(batch, 'id'): + logging.info('IDs in batch: {}'.format(batch.id)) + joey_batch = Batch(batch, self.model.pad_index, use_cuda=trainer.use_cuda) trainer._train_step(joey_batch) diff --git a/joeynmt_server/trainer.py b/joeynmt_server/trainer.py index ee08747..f1df7ee 100644 --- a/joeynmt_server/trainer.py +++ b/joeynmt_server/trainer.py @@ -74,6 +74,9 @@ def train(config_basename, smallest_usage_count, segment_1, segment_2): train_iterator = merge_iterators(*iterators) def increment_train_usages(batch): + if not hasattr(batch, 'id'): + return + ids = {int(id_) for id_ in batch.id if id_ >= 0} existing_usages = { usage.feedback_id: usage diff --git a/joeynmt_server/utils/batching.py b/joeynmt_server/utils/batching.py index 8595f2f..dcc333f 100644 --- a/joeynmt_server/utils/batching.py +++ b/joeynmt_server/utils/batching.py @@ -1,6 +1,8 @@ import itertools +import random from torchtext.data import Batch, BucketIterator, Dataset, Example, Field +from torchtext.data.utils import RandomShuffler ID_FIELD = Field(sequential=False, use_vocab=False, batch_first=True) @@ -21,12 +23,21 @@ def make_dataset(src, src_field, trg_field=None, trg=None, ids=None): return dataset +class MyRandomShuffler(RandomShuffler): + + def __call__(self, data): + return random.sample(data, len(data)) + + class MyBucketIterator(BucketIterator): - def __init__(self, *args, max_epochs=None, **kwargs): + def __init__(self, *args, max_epochs=None, deterministic=False, **kwargs): self.max_epochs = max_epochs super().__init__(*args, **kwargs) + if not deterministic: + self.random_shuffler = MyRandomShuffler() + def iter_batches_as_lists(self): """Mostly a copy of BucketIterator.__iter__, but yielding the batches as lists of sentences instead of Batch objects. diff --git a/joeynmt_server/views/errors.py b/joeynmt_server/views/errors.py index f8a78f8..434d489 100644 --- a/joeynmt_server/views/errors.py +++ b/joeynmt_server/views/errors.py @@ -1,7 +1,6 @@ import traceback from flask import current_app, jsonify, render_template, request -from flask_login import current_user @current_app.errorhandler(403) diff --git a/joeynmt_server/views/feedback.py b/joeynmt_server/views/feedback.py index 50b6d51..06e7a40 100644 --- a/joeynmt_server/views/feedback.py +++ b/joeynmt_server/views/feedback.py @@ -109,7 +109,7 @@ def query_feedback(): 'id': piece.id, 'created': piece.created.isoformat(timespec='seconds'), 'nl': piece.nl, 'correct_lin': piece.correct_lin, 'original_model': piece.model, 'original_lin': piece.system_lin, - 'parent_id': piece.parent_id, + 'parent_id': piece.parent_id, 'split': piece.split, 'model': model, 'model_lin': model_lin } joined_feedback.append(joined) @@ -130,7 +130,7 @@ def edit_feedback(): elif not isinstance(data['id'], int): return jsonify({'error': 'Feedback id is not an int'}), 400 elif not {'editor_id', 'id', 'nl', 'system_lin', 'correct_lin', - 'model'}.issuperset(data.keys()): + 'model', 'split'}.issuperset(data.keys()): return jsonify({'error': 'Illegal keys given'}), 400 id = data.pop('id') -- GitLab