mrfakename's picture
Sync from GitHub repo
0bebc31 verified
raw
history blame
21.8 kB
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
"""
import argparse
import json
import os
import time
from typing import Dict, List, Union
import datasets
import jieba
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from datasets import load_dataset
from f5_tts_trtllm import F5TTS
from huggingface_hub import hf_hub_download
from pypinyin import Style, lazy_pinyin
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
torch.manual_seed(0)
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
parser.add_argument(
"--vocab-file",
required=True,
type=str,
help="vocab file",
)
parser.add_argument(
"--model-path",
required=True,
type=str,
help="model path, to load text embedding",
)
parser.add_argument(
"--tllm-model-dir",
required=True,
type=str,
help="tllm model dir",
)
parser.add_argument(
"--batch-size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
parser.add_argument(
"--vocoder",
default="vocos",
type=str,
help="vocoder name",
)
parser.add_argument(
"--vocoder-trt-engine-path",
default=None,
type=str,
help="vocoder trt engine path",
)
parser.add_argument("--enable-warmup", action="store_true")
parser.add_argument("--remove-input-padding", action="store_true")
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
args = parser.parse_args()
return args
def padded_mel_batch(ref_mels, max_seq_len):
padded_ref_mels = []
for mel in ref_mels:
# pad along the last dimension
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
return padded_ref_mels
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000
target_rms = 0.1
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
[],
[],
[],
[],
[],
)
for i, item in enumerate(batch):
item_id, prompt_text, target_text = (
item["id"],
item["prompt_text"],
item["target_text"],
)
ids.append(item_id)
reference_target_texts_list.append(prompt_text + target_text)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
if use_perf:
torch.cuda.nvtx.range_pop()
ref_mel = ref_mel.squeeze()
ref_mel_len = ref_mel.shape[0]
assert ref_mel.shape[1] == 100
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append(
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
)
max_seq_len = max(estimated_reference_target_mel_len)
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if use_perf:
torch.cuda.nvtx.range_pop()
return {
"ids": ids,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence,
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
# Initialize process group with explicit device IDs
dist.init_process_group(
"nccl",
)
return world_size, local_rank, rank
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: Union[List[str], List[List[str]]],
vocab_char_map: Dict[str, int], # {char: idx}
padding_value=-1,
):
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return list_idx_tensors
def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
):
if vocoder_name == "vocos":
if vocoder_trt_engine_path is not None:
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
else:
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
from vocos.feature_extractors import EncodecFeatures
if isinstance(vocoder.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not implemented yet")
return vocoder
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
if vocoder == "vocos":
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=24000,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=100,
power=1,
center=True,
normalized=False,
norm=None,
).to(device)
mel = mel_stft(waveform.to(device))
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vae engine from {engine_path}")
self.engine_path = engine_path
with open(engine_path, "rb") as f:
engine_buffer = f.read()
self.session = Session.from_serialized_engine(engine_buffer)
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
def decode(self, mels):
mels = mels.contiguous()
inputs = {"mel": mels}
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
outputs = {
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
}
ok = self.session.run(inputs, outputs, self.stream)
assert ok, "Runtime execution failed for vae session"
samples = outputs["waveform"]
return samples
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
tllm_model_dir = args.tllm_model_dir
config_file = os.path.join(tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
if args.backend_type == "trt":
model = F5TTS(
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
)
elif args.backend_type == "pytorch":
import sys
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
F5TTS_model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
pe_attn_head=1,
text_mask_padding=False,
)
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
)
dataset = load_dataset(
"yuekai/seed_tts",
split=args.split_name,
trust_remote_code=True,
)
def add_estimated_duration(example):
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
estimated_duration = prompt_audio_len * scale_factor
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
return example
dataset = dataset.map(add_estimated_duration)
dataset = dataset.sort("estimated_duration", reverse=True)
if args.use_perf:
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
dataset = datasets.concatenate_datasets(dataset_list_short)
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
else:
# This would disable shuffling
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
)
total_steps = len(dataset)
if args.enable_warmup:
for batch in dataloader:
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
if args.backend_type == "trt":
_ = model.sample(
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
)
elif args.backend_type == "pytorch":
with torch.inference_mode():
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
total_mel_lens = torch.tensor(total_mel_lens, device=device)
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
decoding_time = 0
vocoder_time = 0
total_duration = 0
if args.use_perf:
torch.cuda.cudart().cudaProfilerStart()
total_decoding_time = time.time()
for batch in dataloader:
if args.use_perf:
torch.cuda.nvtx.range_push("data sample")
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
if args.use_perf:
torch.cuda.nvtx.range_pop()
if args.backend_type == "trt":
generated, cost_time = model.sample(
text_pad_seq,
ref_mels,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
use_perf=args.use_perf,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
start_time = time.time()
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
cost_time = time.time() - start_time
decoding_time += cost_time
vocoder_start_time = time.time()
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if args.vocoder == "vocos":
if args.use_perf:
torch.cuda.nvtx.range_push("vocoder decode")
generated_wave = vocoder.decode(gen_mel_spec).cpu()
if args.use_perf:
torch.cuda.nvtx.range_pop()
else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
# if ref_rms_list[i] < target_rms:
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
if rms < target_rms:
generated_wave = generated_wave * target_rms / rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",
generated_wave,
target_sample_rate,
)
total_duration += generated_wave.shape[1] / target_sample_rate
vocoder_time += time.time() - vocoder_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
total_decoding_time = time.time() - total_decoding_time
if rank == 0:
progress_bar.close()
rtf = total_decoding_time / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
s += f"batch size: {args.batch_size}\n"
print(s)
with open(f"{args.output_dir}/rtf.txt", "w") as f:
f.write(s)
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()