PyTorch
ssl-aasist
custom_code
ssl-aasist / fairseq /tests /speech /test_s2s_transformer.py
ash56's picture
Add files using upload-large-folder tool
23b1952 verified
raw
history blame
1.54 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from tests.speech import TestFairseqSpeech
from fairseq import utils
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/"
class TestS2STransformer(TestFairseqSpeech):
def setUp(self):
self._set_up(
"s2s",
"speech_tests/s2s",
[
"dev_shuf200.tsv",
"src_feat.zip",
"config_specaug_lb.yaml",
"vocoder",
"vocoder_config.json",
],
)
def test_s2s_transformer_checkpoint(self):
self.base_test(
ckpt_name="s2u_transformer_reduced_fisher.pt",
reference_score=38.3,
dataset="dev_shuf200",
arg_overrides={
"config_yaml": "config_specaug_lb.yaml",
"multitask_config_yaml": None,
"target_is_code": True,
"target_code_size": 100,
"eval_inference": False,
},
score_type="bleu",
strict=False,
)
def postprocess_tokens(self, task, target, hypo_tokens):
tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
tgt_str = task.tgt_dict.string(tgt_tokens)
hypo_str = task.tgt_dict.string(hypo_tokens)
return tgt_str, hypo_str
if __name__ == "__main__":
unittest.main()