Skip to content
Snippets Groups Projects
Commit 21be4a37 authored by Simon Will's avatar Simon Will
Browse files

Make training work by using non-determinstic random shuffler

parent 64d115a6
No related branches found
No related tags found
No related merge requests found
......@@ -217,7 +217,10 @@ class JoeyModel:
for i, batch in enumerate(batches):
logging.info('Training on batch {}.'.format(i + 1))
logging.info('IDs in batch: {}'.format(batch.id))
logging.info('Size of batch: {}'.format(len(batch)))
if hasattr(batch, 'id'):
logging.info('IDs in batch: {}'.format(batch.id))
joey_batch = Batch(batch, self.model.pad_index,
use_cuda=trainer.use_cuda)
trainer._train_step(joey_batch)
......
......@@ -74,6 +74,9 @@ def train(config_basename, smallest_usage_count, segment_1, segment_2):
train_iterator = merge_iterators(*iterators)
def increment_train_usages(batch):
if not hasattr(batch, 'id'):
return
ids = {int(id_) for id_ in batch.id if id_ >= 0}
existing_usages = {
usage.feedback_id: usage
......
import itertools
import random
from torchtext.data import Batch, BucketIterator, Dataset, Example, Field
from torchtext.data.utils import RandomShuffler
ID_FIELD = Field(sequential=False, use_vocab=False, batch_first=True)
......@@ -21,12 +23,21 @@ def make_dataset(src, src_field, trg_field=None, trg=None, ids=None):
return dataset
class MyRandomShuffler(RandomShuffler):
def __call__(self, data):
return random.sample(data, len(data))
class MyBucketIterator(BucketIterator):
def __init__(self, *args, max_epochs=None, **kwargs):
def __init__(self, *args, max_epochs=None, deterministic=False, **kwargs):
self.max_epochs = max_epochs
super().__init__(*args, **kwargs)
if not deterministic:
self.random_shuffler = MyRandomShuffler()
def iter_batches_as_lists(self):
"""Mostly a copy of BucketIterator.__iter__, but yielding the
batches as lists of sentences instead of Batch objects.
......
import traceback
from flask import current_app, jsonify, render_template, request
from flask_login import current_user
@current_app.errorhandler(403)
......
......@@ -109,7 +109,7 @@ def query_feedback():
'id': piece.id, 'created': piece.created.isoformat(timespec='seconds'),
'nl': piece.nl, 'correct_lin': piece.correct_lin,
'original_model': piece.model, 'original_lin': piece.system_lin,
'parent_id': piece.parent_id,
'parent_id': piece.parent_id, 'split': piece.split,
'model': model, 'model_lin': model_lin
}
joined_feedback.append(joined)
......@@ -130,7 +130,7 @@ def edit_feedback():
elif not isinstance(data['id'], int):
return jsonify({'error': 'Feedback id is not an int'}), 400
elif not {'editor_id', 'id', 'nl', 'system_lin', 'correct_lin',
'model'}.issuperset(data.keys()):
'model', 'split'}.issuperset(data.keys()):
return jsonify({'error': 'Illegal keys given'}), 400
id = data.pop('id')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment