PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
9742bb8 verified
raw
history blame
5.66 kB
import re
import os
import torch
import tempfile
import math
from dataclasses import dataclass
from torchaudio.models import wav2vec2_model
# iso codes with specialized rules in uroman
special_isos_uroman = "ara, bel, bul, deu, ell, eng, fas, grc, ell, eng, heb, kaz, kir, lav, lit, mkd, mkd2, oss, pnt, pus, rus, srp, srp2, tur, uig, ukr, yid".split(",")
special_isos_uroman = [i.strip() for i in special_isos_uroman]
def normalize_uroman(text):
text = text.lower()
text = re.sub("([^a-z' ])", " ", text)
text = re.sub(' +', ' ', text)
return text.strip()
def get_uroman_tokens(norm_transcripts, uroman_root_dir, iso = None):
tf = tempfile.NamedTemporaryFile()
tf2 = tempfile.NamedTemporaryFile()
with open(tf.name, "w") as f:
for t in norm_transcripts:
f.write(t + "\n")
assert os.path.exists(f"{uroman_root_dir}/uroman.pl"), "uroman not found"
cmd = f"perl {uroman_root_dir}/uroman.pl"
if iso in special_isos_uroman:
cmd += f" -l {iso} "
cmd += f" < {tf.name} > {tf2.name}"
os.system(cmd)
outtexts = []
with open(tf2.name) as f:
for line in f:
line = " ".join(line.strip())
line = re.sub(r"\s+", " ", line).strip()
outtexts.append(line)
assert len(outtexts) == len(norm_transcripts)
uromans = []
for ot in outtexts:
uromans.append(normalize_uroman(ot))
return uromans
@dataclass
class Segment:
label: str
start: int
end: int
def __repr__(self):
return f"{self.label}: [{self.start:5d}, {self.end:5d})"
@property
def length(self):
return self.end - self.start
def merge_repeats(path, idx_to_token_map):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1] == path[i2]:
i2 += 1
segments.append(Segment(idx_to_token_map[path[i1]], i1, i2 - 1))
i1 = i2
return segments
def time_to_frame(time):
stride_msec = 20
frames_per_sec = 1000 / stride_msec
return int(time * frames_per_sec)
def load_model_dict():
model_path_name = "/tmp/ctc_alignment_mling_uroman_model.pt"
print("Downloading model and dictionary...")
if os.path.exists(model_path_name):
print("Model path already exists. Skipping downloading....")
else:
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
model_path_name,
)
assert os.path.exists(model_path_name)
state_dict = torch.load(model_path_name, map_location="cpu")
model = wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=[
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=0.0,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.0,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
aux_num_out=31,
)
model.load_state_dict(state_dict)
model.eval()
dict_path_name = "/tmp/ctc_alignment_mling_uroman_model.dict"
if os.path.exists(dict_path_name):
print("Dictionary path already exists. Skipping downloading....")
else:
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/dictionary.txt",
dict_path_name,
)
assert os.path.exists(dict_path_name)
dictionary = {}
with open(dict_path_name) as f:
dictionary = {l.strip(): i for i, l in enumerate(f.readlines())}
return model, dictionary
def get_spans(tokens, segments):
ltr_idx = 0
tokens_idx = 0
intervals = []
start, end = (0, 0)
sil = "<blank>"
for (seg_idx, seg) in enumerate(segments):
if(tokens_idx == len(tokens)):
assert(seg_idx == len(segments) - 1)
assert(seg.label == '<blank>')
continue
cur_token = tokens[tokens_idx].split(' ')
ltr = cur_token[ltr_idx]
if seg.label == "<blank>": continue
assert(seg.label == ltr)
if(ltr_idx) == 0: start = seg_idx
if ltr_idx == len(cur_token) - 1:
ltr_idx = 0
tokens_idx += 1
intervals.append((start, seg_idx))
while tokens_idx < len(tokens) and len(tokens[tokens_idx]) == 0:
intervals.append((seg_idx, seg_idx))
tokens_idx += 1
else:
ltr_idx += 1
spans = []
for (idx, (start, end)) in enumerate(intervals):
span = segments[start:end + 1]
if start > 0:
prev_seg = segments[start - 1]
if prev_seg.label == sil:
pad_start = prev_seg.start if (idx == 0) else int((prev_seg.start + prev_seg.end)/2)
span = [Segment(sil, pad_start, span[0].start)] + span
if end+1 < len(segments):
next_seg = segments[end+1]
if next_seg.label == sil:
pad_end = next_seg.end if (idx == len(intervals) - 1) else math.floor((next_seg.start + next_seg.end) / 2)
span = span + [Segment(sil, span[-1].end, pad_end)]
spans.append(span)
return spans