File size: 1,537 Bytes
23b1952 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from tests.speech import TestFairseqSpeech
from fairseq import utils
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/"
class TestS2STransformer(TestFairseqSpeech):
def setUp(self):
self._set_up(
"s2s",
"speech_tests/s2s",
[
"dev_shuf200.tsv",
"src_feat.zip",
"config_specaug_lb.yaml",
"vocoder",
"vocoder_config.json",
],
)
def test_s2s_transformer_checkpoint(self):
self.base_test(
ckpt_name="s2u_transformer_reduced_fisher.pt",
reference_score=38.3,
dataset="dev_shuf200",
arg_overrides={
"config_yaml": "config_specaug_lb.yaml",
"multitask_config_yaml": None,
"target_is_code": True,
"target_code_size": 100,
"eval_inference": False,
},
score_type="bleu",
strict=False,
)
def postprocess_tokens(self, task, target, hypo_tokens):
tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
tgt_str = task.tgt_dict.string(tgt_tokens)
hypo_str = task.tgt_dict.string(hypo_tokens)
return tgt_str, hypo_str
if __name__ == "__main__":
unittest.main()
|