import os import torch import torchaudio import sox import json import argparse from examples.mms.data_prep.text_normalization import text_normalize from examples.mms.data_prep.align_utils import ( get_uroman_tokens, time_to_frame, load_model_dict, merge_repeats, get_spans, ) import torchaudio.functional as F SAMPLING_FREQ = 16000 EMISSION_INTERVAL = 30 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def generate_emissions(model, audio_file): waveform, _ = torchaudio.load(audio_file) # waveform: channels X T waveform = waveform.to(DEVICE) total_duration = sox.file_info.duration(audio_file) audio_sf = sox.file_info.sample_rate(audio_file) assert audio_sf == SAMPLING_FREQ emissions_arr = [] with torch.inference_mode(): i = 0 while i < total_duration: segment_start_time, segment_end_time = (i, i + EMISSION_INTERVAL) context = EMISSION_INTERVAL * 0.1 input_start_time = max(segment_start_time - context, 0) input_end_time = min(segment_end_time + context, total_duration) waveform_split = waveform[ :, int(SAMPLING_FREQ * input_start_time) : int( SAMPLING_FREQ * (input_end_time) ), ] model_outs, _ = model(waveform_split) emissions_ = model_outs[0] emission_start_frame = time_to_frame(segment_start_time) emission_end_frame = time_to_frame(segment_end_time) offset = time_to_frame(input_start_time) emissions_ = emissions_[ emission_start_frame - offset : emission_end_frame - offset, : ] emissions_arr.append(emissions_) i += EMISSION_INTERVAL emissions = torch.cat(emissions_arr, dim=0).squeeze() emissions = torch.log_softmax(emissions, dim=-1) stride = float(waveform.size(1) * 1000 / emissions.size(0) / SAMPLING_FREQ) return emissions, stride def get_alignments( audio_file, tokens, model, dictionary, use_star, ): # Generate emissions emissions, stride = generate_emissions(model, audio_file) T, N = emissions.size() if use_star: emissions = torch.cat([emissions, torch.zeros(T, 1).to(DEVICE)], dim=1) # Force Alignment if tokens: token_indices = [dictionary[c] for c in " ".join(tokens).split(" ") if c in dictionary] else: print(f"Empty transcript!!!!! for audio file {audio_file}") token_indices = [] blank = dictionary[""] targets = torch.tensor(token_indices, dtype=torch.int32).to(DEVICE) input_lengths = torch.tensor(emissions.shape[0]).unsqueeze(-1) target_lengths = torch.tensor(targets.shape[0]).unsqueeze(-1) path, _ = F.forced_align( emissions.unsqueeze(0), targets.unsqueeze(0), input_lengths, target_lengths, blank=blank ) path = path.squeeze().to("cpu").tolist() segments = merge_repeats(path, {v: k for k, v in dictionary.items()}) return segments, stride def main(args): assert not os.path.exists( args.outdir ), f"Error: Output path exists already {args.outdir}" transcripts = [] with open(args.text_filepath) as f: transcripts = [line.strip() for line in f] print("Read {} lines from {}".format(len(transcripts), args.text_filepath)) norm_transcripts = [text_normalize(line.strip(), args.lang) for line in transcripts] tokens = get_uroman_tokens(norm_transcripts, args.uroman_path, args.lang) model, dictionary = load_model_dict() model = model.to(DEVICE) if args.use_star: dictionary[""] = len(dictionary) tokens = [""] + tokens transcripts = [""] + transcripts norm_transcripts = [""] + norm_transcripts segments, stride = get_alignments( args.audio_filepath, tokens, model, dictionary, args.use_star, ) # Get spans of each line in input text file spans = get_spans(tokens, segments) os.makedirs(args.outdir) with open( f"{args.outdir}/manifest.json", "w") as f: for i, t in enumerate(transcripts): span = spans[i] seg_start_idx = span[0].start seg_end_idx = span[-1].end output_file = f"{args.outdir}/segment{i}.flac" audio_start_sec = seg_start_idx * stride / 1000 audio_end_sec = seg_end_idx * stride / 1000 tfm = sox.Transformer() tfm.trim(audio_start_sec , audio_end_sec) tfm.build_file(args.audio_filepath, output_file) sample = { "audio_start_sec": audio_start_sec, "audio_filepath": str(output_file), "duration": audio_end_sec - audio_start_sec, "text": t, "normalized_text":norm_transcripts[i], "uroman_tokens": tokens[i], } f.write(json.dumps(sample) + "\n") return segments, stride if __name__ == "__main__": parser = argparse.ArgumentParser(description="Align and segment long audio files") parser.add_argument( "-a", "--audio_filepath", type=str, help="Path to input audio file" ) parser.add_argument( "-t", "--text_filepath", type=str, help="Path to input text file " ) parser.add_argument( "-l", "--lang", type=str, default="eng", help="ISO code of the language" ) parser.add_argument( "-u", "--uroman_path", type=str, default="eng", help="Location to uroman/bin" ) parser.add_argument( "-s", "--use_star", action="store_true", help="Use star at the start of transcript", ) parser.add_argument( "-o", "--outdir", type=str, help="Output directory to store segmented audio files", ) print("Using torch version:", torch.__version__) print("Using torchaudio version:", torchaudio.__version__) print("Using device: ", DEVICE) args = parser.parse_args() main(args)