Skip to content
Snippets Groups Projects
Commit fa52d202 authored by Max Ryabinin's avatar Max Ryabinin Committed by Facebook Github Bot
Browse files

Fix generation with --no-early-stop (#627)

Summary:
Because the size of `unfinalized_scores` is equal to current `bsz` and not initial batch size, we need to index it by `unfin_idx` instead of `sent` in `is_finished`.
Fixes #588.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/627

Differential Revision: D15034641

Pulled By: myleott

fbshipit-source-id: 2638e68e877ae01256cac7d8e69b5b7fec8f7017
parent d63477e1
No related branches found
No related tags found
No related merge requests found
......@@ -178,7 +178,7 @@ class SequenceGenerator(object):
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent, step, unfinalized_scores=None):
def is_finished(sent, step, unfin_idx, unfinalized_scores=None):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
......@@ -190,7 +190,7 @@ class SequenceGenerator(object):
return True
# stop if the best unfinalized score is worse than the worst
# finalized one
best_unfinalized_score = unfinalized_scores[sent].max()
best_unfinalized_score = unfinalized_scores[unfin_idx].max()
if self.normalize_scores:
best_unfinalized_score /= max_len ** self.len_penalty
if worst_finalized[sent]['score'] >= best_unfinalized_score:
......@@ -287,7 +287,7 @@ class SequenceGenerator(object):
newly_finished = []
for sent, unfin_idx in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores):
if not finished[sent] and is_finished(sent, step, unfin_idx, unfinalized_scores):
finished[sent] = True
newly_finished.append(unfin_idx)
return newly_finished
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment