import re import os import torch import tempfile import math from dataclasses import dataclass from torchaudio.models import wav2vec2_model # iso codes with specialized rules in uroman special_isos_uroman = "ara, bel, bul, deu, ell, eng, fas, grc, ell, eng, heb, kaz, kir, lav, lit, mkd, mkd2, oss, pnt, pus, rus, srp, srp2, tur, uig, ukr, yid".split(",") special_isos_uroman = [i.strip() for i in special_isos_uroman] def normalize_uroman(text): text = text.lower() text = re.sub("([^a-z' ])", " ", text) text = re.sub(' +', ' ', text) return text.strip() def get_uroman_tokens(norm_transcripts, uroman_root_dir, iso = None): tf = tempfile.NamedTemporaryFile() tf2 = tempfile.NamedTemporaryFile() with open(tf.name, "w") as f: for t in norm_transcripts: f.write(t + "\n") assert os.path.exists(f"{uroman_root_dir}/uroman.pl"), "uroman not found" cmd = f"perl {uroman_root_dir}/uroman.pl" if iso in special_isos_uroman: cmd += f" -l {iso} " cmd += f" < {tf.name} > {tf2.name}" os.system(cmd) outtexts = [] with open(tf2.name) as f: for line in f: line = " ".join(line.strip()) line = re.sub(r"\s+", " ", line).strip() outtexts.append(line) assert len(outtexts) == len(norm_transcripts) uromans = [] for ot in outtexts: uromans.append(normalize_uroman(ot)) return uromans @dataclass class Segment: label: str start: int end: int def __repr__(self): return f"{self.label}: [{self.start:5d}, {self.end:5d})" @property def length(self): return self.end - self.start def merge_repeats(path, idx_to_token_map): i1, i2 = 0, 0 segments = [] while i1 < len(path): while i2 < len(path) and path[i1] == path[i2]: i2 += 1 segments.append(Segment(idx_to_token_map[path[i1]], i1, i2 - 1)) i1 = i2 return segments def time_to_frame(time): stride_msec = 20 frames_per_sec = 1000 / stride_msec return int(time * frames_per_sec) def load_model_dict(): model_path_name = "/tmp/ctc_alignment_mling_uroman_model.pt" print("Downloading model and dictionary...") if os.path.exists(model_path_name): print("Model path already exists. Skipping downloading....") else: torch.hub.download_url_to_file( "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt", model_path_name, ) assert os.path.exists(model_path_name) state_dict = torch.load(model_path_name, map_location="cpu") model = wav2vec2_model( extractor_mode="layer_norm", extractor_conv_layer_config=[ (512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2), ], extractor_conv_bias=True, encoder_embed_dim=1024, encoder_projection_dropout=0.0, encoder_pos_conv_kernel=128, encoder_pos_conv_groups=16, encoder_num_layers=24, encoder_num_heads=16, encoder_attention_dropout=0.0, encoder_ff_interm_features=4096, encoder_ff_interm_dropout=0.1, encoder_dropout=0.0, encoder_layer_norm_first=True, encoder_layer_drop=0.1, aux_num_out=31, ) model.load_state_dict(state_dict) model.eval() dict_path_name = "/tmp/ctc_alignment_mling_uroman_model.dict" if os.path.exists(dict_path_name): print("Dictionary path already exists. Skipping downloading....") else: torch.hub.download_url_to_file( "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/dictionary.txt", dict_path_name, ) assert os.path.exists(dict_path_name) dictionary = {} with open(dict_path_name) as f: dictionary = {l.strip(): i for i, l in enumerate(f.readlines())} return model, dictionary def get_spans(tokens, segments): ltr_idx = 0 tokens_idx = 0 intervals = [] start, end = (0, 0) sil = "" for (seg_idx, seg) in enumerate(segments): if(tokens_idx == len(tokens)): assert(seg_idx == len(segments) - 1) assert(seg.label == '') continue cur_token = tokens[tokens_idx].split(' ') ltr = cur_token[ltr_idx] if seg.label == "": continue assert(seg.label == ltr) if(ltr_idx) == 0: start = seg_idx if ltr_idx == len(cur_token) - 1: ltr_idx = 0 tokens_idx += 1 intervals.append((start, seg_idx)) while tokens_idx < len(tokens) and len(tokens[tokens_idx]) == 0: intervals.append((seg_idx, seg_idx)) tokens_idx += 1 else: ltr_idx += 1 spans = [] for (idx, (start, end)) in enumerate(intervals): span = segments[start:end + 1] if start > 0: prev_seg = segments[start - 1] if prev_seg.label == sil: pad_start = prev_seg.start if (idx == 0) else int((prev_seg.start + prev_seg.end)/2) span = [Segment(sil, pad_start, span[0].start)] + span if end+1 < len(segments): next_seg = segments[end+1] if next_seg.label == sil: pad_end = next_seg.end if (idx == len(intervals) - 1) else math.floor((next_seg.start + next_seg.end) / 2) span = span + [Segment(sil, span[-1].end, pad_end)] spans.append(span) return spans