PyTorch
ssl-aasist
custom_code
ssl-aasist / fairseq /examples /mms /data_prep /align_and_segment.py
ash56's picture
Add files using upload-large-folder tool
9742bb8 verified
raw
history blame
6.17 kB
import os
import torch
import torchaudio
import sox
import json
import argparse
from examples.mms.data_prep.text_normalization import text_normalize
from examples.mms.data_prep.align_utils import (
get_uroman_tokens,
time_to_frame,
load_model_dict,
merge_repeats,
get_spans,
)
import torchaudio.functional as F
SAMPLING_FREQ = 16000
EMISSION_INTERVAL = 30
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def generate_emissions(model, audio_file):
waveform, _ = torchaudio.load(audio_file) # waveform: channels X T
waveform = waveform.to(DEVICE)
total_duration = sox.file_info.duration(audio_file)
audio_sf = sox.file_info.sample_rate(audio_file)
assert audio_sf == SAMPLING_FREQ
emissions_arr = []
with torch.inference_mode():
i = 0
while i < total_duration:
segment_start_time, segment_end_time = (i, i + EMISSION_INTERVAL)
context = EMISSION_INTERVAL * 0.1
input_start_time = max(segment_start_time - context, 0)
input_end_time = min(segment_end_time + context, total_duration)
waveform_split = waveform[
:,
int(SAMPLING_FREQ * input_start_time) : int(
SAMPLING_FREQ * (input_end_time)
),
]
model_outs, _ = model(waveform_split)
emissions_ = model_outs[0]
emission_start_frame = time_to_frame(segment_start_time)
emission_end_frame = time_to_frame(segment_end_time)
offset = time_to_frame(input_start_time)
emissions_ = emissions_[
emission_start_frame - offset : emission_end_frame - offset, :
]
emissions_arr.append(emissions_)
i += EMISSION_INTERVAL
emissions = torch.cat(emissions_arr, dim=0).squeeze()
emissions = torch.log_softmax(emissions, dim=-1)
stride = float(waveform.size(1) * 1000 / emissions.size(0) / SAMPLING_FREQ)
return emissions, stride
def get_alignments(
audio_file,
tokens,
model,
dictionary,
use_star,
):
# Generate emissions
emissions, stride = generate_emissions(model, audio_file)
T, N = emissions.size()
if use_star:
emissions = torch.cat([emissions, torch.zeros(T, 1).to(DEVICE)], dim=1)
# Force Alignment
if tokens:
token_indices = [dictionary[c] for c in " ".join(tokens).split(" ") if c in dictionary]
else:
print(f"Empty transcript!!!!! for audio file {audio_file}")
token_indices = []
blank = dictionary["<blank>"]
targets = torch.tensor(token_indices, dtype=torch.int32).to(DEVICE)
input_lengths = torch.tensor(emissions.shape[0]).unsqueeze(-1)
target_lengths = torch.tensor(targets.shape[0]).unsqueeze(-1)
path, _ = F.forced_align(
emissions.unsqueeze(0), targets.unsqueeze(0), input_lengths, target_lengths, blank=blank
)
path = path.squeeze().to("cpu").tolist()
segments = merge_repeats(path, {v: k for k, v in dictionary.items()})
return segments, stride
def main(args):
assert not os.path.exists(
args.outdir
), f"Error: Output path exists already {args.outdir}"
transcripts = []
with open(args.text_filepath) as f:
transcripts = [line.strip() for line in f]
print("Read {} lines from {}".format(len(transcripts), args.text_filepath))
norm_transcripts = [text_normalize(line.strip(), args.lang) for line in transcripts]
tokens = get_uroman_tokens(norm_transcripts, args.uroman_path, args.lang)
model, dictionary = load_model_dict()
model = model.to(DEVICE)
if args.use_star:
dictionary["<star>"] = len(dictionary)
tokens = ["<star>"] + tokens
transcripts = ["<star>"] + transcripts
norm_transcripts = ["<star>"] + norm_transcripts
segments, stride = get_alignments(
args.audio_filepath,
tokens,
model,
dictionary,
args.use_star,
)
# Get spans of each line in input text file
spans = get_spans(tokens, segments)
os.makedirs(args.outdir)
with open( f"{args.outdir}/manifest.json", "w") as f:
for i, t in enumerate(transcripts):
span = spans[i]
seg_start_idx = span[0].start
seg_end_idx = span[-1].end
output_file = f"{args.outdir}/segment{i}.flac"
audio_start_sec = seg_start_idx * stride / 1000
audio_end_sec = seg_end_idx * stride / 1000
tfm = sox.Transformer()
tfm.trim(audio_start_sec , audio_end_sec)
tfm.build_file(args.audio_filepath, output_file)
sample = {
"audio_start_sec": audio_start_sec,
"audio_filepath": str(output_file),
"duration": audio_end_sec - audio_start_sec,
"text": t,
"normalized_text":norm_transcripts[i],
"uroman_tokens": tokens[i],
}
f.write(json.dumps(sample) + "\n")
return segments, stride
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Align and segment long audio files")
parser.add_argument(
"-a", "--audio_filepath", type=str, help="Path to input audio file"
)
parser.add_argument(
"-t", "--text_filepath", type=str, help="Path to input text file "
)
parser.add_argument(
"-l", "--lang", type=str, default="eng", help="ISO code of the language"
)
parser.add_argument(
"-u", "--uroman_path", type=str, default="eng", help="Location to uroman/bin"
)
parser.add_argument(
"-s",
"--use_star",
action="store_true",
help="Use star at the start of transcript",
)
parser.add_argument(
"-o",
"--outdir",
type=str,
help="Output directory to store segmented audio files",
)
print("Using torch version:", torch.__version__)
print("Using torchaudio version:", torchaudio.__version__)
print("Using device: ", DEVICE)
args = parser.parse_args()
main(args)