PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
9742bb8 verified
raw
history blame
2.26 kB
#!/usr/bin/env python3
# -*- encoding: utf8 -*-
import argparse
import itertools
import os
import re
import sys
from pathlib import Path
import whisper
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument("--wavs", type=str)
parser.add_argument("--lids", type=str)
parser.add_argument("--dst", type=str)
parser.add_argument("--beam_size", type=int, default=1)
parser.add_argument("--model", type=str)
parser.add_argument("--mapping", type=str, default="whisper/lid_mapping.txt")
parser.add_argument("--n", type=int, default=10)
args = parser.parse_args()
if __name__ == "__main__":
model = whisper.load_model(args.model)
print(args)
wavs = [y for y in [x.strip() for x in open(args.wavs, "r").readlines()] for _ in range(args.n)]
lids = [x.strip() for x in open(args.lids, "r").readlines()]
assert len(wavs) == len(lids)
if args.mapping is not None:
# mms_lid_code:whisper_lid_code
mapping = {x[1]:x[0] for x in [l.strip().split(";", 1) for l in open(args.mapping, "r").readlines()]}
else:
mapping = None
if not os.path.exists(args.dst):
os.makedirs(args.dst)
# clear it
with open(args.dst + "/nbest_asr_hyp", "w") as f1, open(args.dst + "/asr_score", "w") as f2:
pass
for wav, lang in tqdm(zip(wavs, lids)):
# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio(wav)
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
if mapping is not None and lang in mapping.keys():
lang_code = mapping[lang]
else:
lang_code = lang
# decode the audio
options = whisper.DecodingOptions(beam_size=args.beam_size, language=lang_code)
output = whisper.decode(model, mel, options)
result = output.text
length = len(output.tokens)
score = output.avg_logprob * length
with open(args.dst + "/nbest_asr_hyp", "a") as f1, open(args.dst + "/asr_score", "a") as f2:
f1.write(result + "\n")
f2.write(str(score) + "\n")
f1.flush()
f2.flush()