Commit 40ac340b authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Eval and log on a subset of directions for multimodel training (#605)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/605

Eval and log on a subset of directions for multimodel training

This reduces code duplication in PyTorch Translate's semi_supervised task and will enable clean multitask setups in the future.

Reviewed By: pipibjc, dpacgopinath

Differential Revision: D14672779

fbshipit-source-id: 1342c71781f0824cc56a38ad1c1822e34eaef337
parent f492db25
Loading
Loading
Loading
Loading
+8 −2
Original line number Diff line number Diff line
@@ -77,6 +77,12 @@ class MultilingualTranslationTask(FairseqTask):
        super().__init__(args)
        self.dicts = dicts
        self.lang_pairs = args.lang_pairs
        # eval_lang_pairs for multilingual translation is usually all of the
        # lang_pairs. However for other multitask settings or when we want to
        # optimize for certain languages we want to use a different subset. Thus
        # the eval_lang_pairs class variable is provided for classes that extend
        # this class.
        self.eval_lang_pairs = args.lang_pairs
        self.langs = list(dicts.keys())
        self.training = training

@@ -205,7 +211,7 @@ class MultilingualTranslationTask(FairseqTask):
        model.eval()
        with torch.no_grad():
            agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
            for lang_pair in self.args.lang_pairs:
            for lang_pair in self.eval_lang_pairs:
                if sample[lang_pair] is None or len(sample[lang_pair]) == 0:
                    continue
                loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
@@ -236,7 +242,7 @@ class MultilingualTranslationTask(FairseqTask):
            lang_pair: criterion.__class__.aggregate_logging_outputs([
                logging_output.get(lang_pair, {}) for logging_output in logging_outputs
            ])
            for lang_pair in self.args.lang_pairs
            for lang_pair in self.eval_lang_pairs
        }

        def sum_over_languages(key):