PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
9742bb8 verified
raw
history blame
2.52 kB
import os
import tempfile
import re
import librosa
import torch
import json
import numpy as np
import argparse
from tqdm import tqdm
import math
from transformers import Wav2Vec2ForCTC, AutoProcessor
from lib import falign_ext
parser = argparse.ArgumentParser()
parser.add_argument("--uroman_txt", type=str)
parser.add_argument("--wav", type=str)
parser.add_argument("--dst", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--n", type=int, default=10)
args = parser.parse_args()
ASR_SAMPLING_RATE = 16_000
MODEL_ID = "/upload/mms_zs"
processor = AutoProcessor.from_pretrained(args.model+MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(args.model+MODEL_ID)
token_file = args.model+"/upload/mms_zs/tokens.txt"
if __name__ == "__main__":
if not os.path.exists(args.dst):
os.makedirs(args.dst)
tokens = [x.strip() for x in open(token_file, "r").readlines()]
txts = [x.strip() for x in open(args.uroman_txt, "r").readlines()]
wavs = [x.strip() for x in open(args.wav, "r").readlines()]
assert len(txts) == args.n * len(wavs)
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
# clear it
with open(args.dst + "/uasr_score", "w") as f1:
pass
for i, w in tqdm(enumerate(wavs)):
assert isinstance(w, str)
audio_samples = librosa.load(w, sr=ASR_SAMPLING_RATE, mono=True)[0]
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
emissions = outputs.log_softmax(dim=-1).squeeze()
for j in range(args.n):
idx = (args.n * i) + j
chars = txts[idx].split()
token_sequence = [tokens.index(x) for x in chars]
try:
_, alphas, _ = falign_ext.falign(emissions, torch.tensor(token_sequence, device=device).int(), False)
aligned_alpha = max(alphas[-1]).item()
except:
aligned_alpha = math.log(0.000000001)
with open(args.dst + "/uasr_score", "a") as f1:
f1.write(str(aligned_alpha) + "\n")
f1.flush()