|
import os |
|
import tempfile |
|
import re |
|
import librosa |
|
import torch |
|
import json |
|
import numpy as np |
|
import argparse |
|
from tqdm import tqdm |
|
import math |
|
|
|
from transformers import Wav2Vec2ForCTC, AutoProcessor |
|
|
|
from lib import falign_ext |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--uroman_txt", type=str) |
|
parser.add_argument("--wav", type=str) |
|
parser.add_argument("--dst", type=str) |
|
parser.add_argument("--model", type=str) |
|
parser.add_argument("--n", type=int, default=10) |
|
args = parser.parse_args() |
|
|
|
ASR_SAMPLING_RATE = 16_000 |
|
|
|
MODEL_ID = "/upload/mms_zs" |
|
|
|
processor = AutoProcessor.from_pretrained(args.model+MODEL_ID) |
|
model = Wav2Vec2ForCTC.from_pretrained(args.model+MODEL_ID) |
|
|
|
token_file = args.model+"/upload/mms_zs/tokens.txt" |
|
|
|
if __name__ == "__main__": |
|
if not os.path.exists(args.dst): |
|
os.makedirs(args.dst) |
|
|
|
tokens = [x.strip() for x in open(token_file, "r").readlines()] |
|
|
|
txts = [x.strip() for x in open(args.uroman_txt, "r").readlines()] |
|
wavs = [x.strip() for x in open(args.wav, "r").readlines()] |
|
assert len(txts) == args.n * len(wavs) |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif ( |
|
hasattr(torch.backends, "mps") |
|
and torch.backends.mps.is_available() |
|
and torch.backends.mps.is_built() |
|
): |
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
model.to(device) |
|
|
|
|
|
with open(args.dst + "/uasr_score", "w") as f1: |
|
pass |
|
|
|
for i, w in tqdm(enumerate(wavs)): |
|
assert isinstance(w, str) |
|
audio_samples = librosa.load(w, sr=ASR_SAMPLING_RATE, mono=True)[0] |
|
|
|
inputs = processor( |
|
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt" |
|
) |
|
inputs = inputs.to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs).logits |
|
|
|
emissions = outputs.log_softmax(dim=-1).squeeze() |
|
|
|
for j in range(args.n): |
|
idx = (args.n * i) + j |
|
chars = txts[idx].split() |
|
token_sequence = [tokens.index(x) for x in chars] |
|
|
|
try: |
|
_, alphas, _ = falign_ext.falign(emissions, torch.tensor(token_sequence, device=device).int(), False) |
|
aligned_alpha = max(alphas[-1]).item() |
|
except: |
|
aligned_alpha = math.log(0.000000001) |
|
|
|
with open(args.dst + "/uasr_score", "a") as f1: |
|
f1.write(str(aligned_alpha) + "\n") |
|
f1.flush() |