# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import unittest from tests.speech import TestFairseqSpeech from fairseq import utils S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/" class TestS2STransformer(TestFairseqSpeech): def setUp(self): self._set_up( "s2s", "speech_tests/s2s", [ "dev_shuf200.tsv", "src_feat.zip", "config_specaug_lb.yaml", "vocoder", "vocoder_config.json", ], ) def test_s2s_transformer_checkpoint(self): self.base_test( ckpt_name="s2u_transformer_reduced_fisher.pt", reference_score=38.3, dataset="dev_shuf200", arg_overrides={ "config_yaml": "config_specaug_lb.yaml", "multitask_config_yaml": None, "target_is_code": True, "target_code_size": 100, "eval_inference": False, }, score_type="bleu", strict=False, ) def postprocess_tokens(self, task, target, hypo_tokens): tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu() tgt_str = task.tgt_dict.string(tgt_tokens) hypo_str = task.tgt_dict.string(hypo_tokens) return tgt_str, hypo_str if __name__ == "__main__": unittest.main()