Commit fa52d202 authored by Max Ryabinin's avatar Max Ryabinin Committed by Facebook Github Bot
Browse files

Fix generation with --no-early-stop (#627)

Summary:
Because the size of `unfinalized_scores` is equal to current `bsz` and not initial batch size, we need to index it by `unfin_idx` instead of `sent` in `is_finished`.
Fixes #588.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/627

Differential Revision: D15034641

Pulled By: myleott

fbshipit-source-id: 2638e68e877ae01256cac7d8e69b5b7fec8f7017
parent d63477e1
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -178,7 +178,7 @@ class SequenceGenerator(object):
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
        def is_finished(sent, step, unfin_idx, unfinalized_scores=None):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
@@ -190,7 +190,7 @@ class SequenceGenerator(object):
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                best_unfinalized_score = unfinalized_scores[sent].max()
                best_unfinalized_score = unfinalized_scores[unfin_idx].max()
                if self.normalize_scores:
                    best_unfinalized_score /= max_len ** self.len_penalty
                if worst_finalized[sent]['score'] >= best_unfinalized_score:
@@ -287,7 +287,7 @@ class SequenceGenerator(object):
            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfinalized_scores):
                if not finished[sent] and is_finished(sent, step, unfin_idx, unfinalized_scores):
                    finished[sent] = True
                    newly_finished.append(unfin_idx)
            return newly_finished