File size: 2,914 Bytes
23b1952 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from tests.speech import TestFairseqSpeech
from fairseq.data.data_utils import post_process
from fairseq import utils
from omegaconf import open_dict
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestWav2Vec2(TestFairseqSpeech):
def setUp(self):
self._set_up(
"librispeech_w2v2",
"conformer/wav2vec2/librispeech",
[
"test_librispeech-other.ltr",
"test_librispeech-other.tsv",
"test_librispeech-other_small.ltr_100",
"test_librispeech-other_small.tsv",
"test-other.zip",
"dict.ltr.txt",
"dict.ltr_100.txt",
],
)
self.unzip_files(
"test-other.zip",
)
def test_transformer_w2v2(self):
self.base_test(
ckpt_name="transformer_oss_small_100h.pt",
reference_score=38,
score_delta=1,
dataset="test_librispeech-other",
max_tokens=1000000,
max_positions=(700000, 1000),
arg_overrides={
"task": "audio_finetuning",
"labels": "ltr",
"nbest": 1,
"tpu": False,
},
strict=False,
)
def test_conformer_w2v2(self):
self.base_test(
ckpt_name="conformer_LS_PT_LS_FT_rope.pt",
reference_score=4.5,
score_delta=1,
dataset="test_librispeech-other_small",
max_tokens=1000000,
max_positions=(700000, 1000),
arg_overrides={
"task": "audio_finetuning",
"labels": "ltr_100",
"nbest": 1,
"tpu": False,
},
strict=True,
)
def build_generator(self, task, models, cfg):
try:
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
except Exception:
raise Exception("Cannot run this test without flashlight dependency")
with open_dict(cfg):
cfg.nbest = 1
return W2lViterbiDecoder(cfg, task.target_dictionary)
def postprocess_tokens(self, task, target, hypo_tokens):
tgt_tokens = utils.strip_pad(target, task.target_dictionary.pad()).int().cpu()
tgt_str = task.target_dictionary.string(tgt_tokens)
tgt_str = post_process(tgt_str, "letter")
hypo_pieces = task.target_dictionary.string(hypo_tokens)
hypo_str = post_process(hypo_pieces, "letter")
return tgt_str, hypo_str
if __name__ == "__main__":
unittest.main()
|