|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
import torch |
|
from tests.speech import TestFairseqSpeech |
|
from fairseq.data.data_utils import post_process |
|
from fairseq import utils |
|
from omegaconf import open_dict |
|
|
|
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq" |
|
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") |
|
class TestWav2Vec2(TestFairseqSpeech): |
|
def setUp(self): |
|
self._set_up( |
|
"librispeech_w2v2", |
|
"conformer/wav2vec2/librispeech", |
|
[ |
|
"test_librispeech-other.ltr", |
|
"test_librispeech-other.tsv", |
|
"test_librispeech-other_small.ltr_100", |
|
"test_librispeech-other_small.tsv", |
|
"test-other.zip", |
|
"dict.ltr.txt", |
|
"dict.ltr_100.txt", |
|
], |
|
) |
|
self.unzip_files( |
|
"test-other.zip", |
|
) |
|
|
|
def test_transformer_w2v2(self): |
|
self.base_test( |
|
ckpt_name="transformer_oss_small_100h.pt", |
|
reference_score=38, |
|
score_delta=1, |
|
dataset="test_librispeech-other", |
|
max_tokens=1000000, |
|
max_positions=(700000, 1000), |
|
arg_overrides={ |
|
"task": "audio_finetuning", |
|
"labels": "ltr", |
|
"nbest": 1, |
|
"tpu": False, |
|
}, |
|
strict=False, |
|
) |
|
|
|
def test_conformer_w2v2(self): |
|
self.base_test( |
|
ckpt_name="conformer_LS_PT_LS_FT_rope.pt", |
|
reference_score=4.5, |
|
score_delta=1, |
|
dataset="test_librispeech-other_small", |
|
max_tokens=1000000, |
|
max_positions=(700000, 1000), |
|
arg_overrides={ |
|
"task": "audio_finetuning", |
|
"labels": "ltr_100", |
|
"nbest": 1, |
|
"tpu": False, |
|
}, |
|
strict=True, |
|
) |
|
|
|
def build_generator(self, task, models, cfg): |
|
try: |
|
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder |
|
except Exception: |
|
raise Exception("Cannot run this test without flashlight dependency") |
|
with open_dict(cfg): |
|
cfg.nbest = 1 |
|
return W2lViterbiDecoder(cfg, task.target_dictionary) |
|
|
|
def postprocess_tokens(self, task, target, hypo_tokens): |
|
tgt_tokens = utils.strip_pad(target, task.target_dictionary.pad()).int().cpu() |
|
tgt_str = task.target_dictionary.string(tgt_tokens) |
|
tgt_str = post_process(tgt_str, "letter") |
|
|
|
hypo_pieces = task.target_dictionary.string(hypo_tokens) |
|
hypo_str = post_process(hypo_pieces, "letter") |
|
return tgt_str, hypo_str |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|