Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/runtime/triton_trtllm/Dockerfile.server +3 -0
- src/f5_tts/runtime/triton_trtllm/README.md +46 -0
- src/f5_tts/runtime/triton_trtllm/client_grpc.py +470 -0
- src/f5_tts/runtime/triton_trtllm/client_http.py +142 -0
- src/f5_tts/runtime/triton_trtllm/docker-compose.yml +20 -0
- src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py +431 -0
- src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py +275 -0
- src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt +81 -0
- src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep +0 -0
- src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt +32 -0
- src/f5_tts/runtime/triton_trtllm/patch/__init__.py +196 -0
- src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +225 -0
- src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +410 -0
- src/f5_tts/runtime/triton_trtllm/run.sh +70 -0
- src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py +247 -0
- src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py +359 -0
- src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py +137 -0
- src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh +43 -0
- src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py +36 -0
src/f5_tts/runtime/triton_trtllm/Dockerfile.server
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/tritonserver:24.12-py3
|
2 |
+
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
|
3 |
+
WORKDIR /workspace
|
src/f5_tts/runtime/triton_trtllm/README.md
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Triton Inference Serving Best Practice for F5 TTS
|
2 |
+
|
3 |
+
### Quick Start
|
4 |
+
Directly launch the service using docker compose.
|
5 |
+
```sh
|
6 |
+
# TODO: support F5TTS_v1_Base
|
7 |
+
MODEL=F5TTS_Base docker compose up
|
8 |
+
```
|
9 |
+
|
10 |
+
### Build Image
|
11 |
+
Build the docker image from scratch.
|
12 |
+
```sh
|
13 |
+
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
|
14 |
+
```
|
15 |
+
|
16 |
+
### Create Docker Container
|
17 |
+
```sh
|
18 |
+
your_mount_dir=/mnt:/mnt
|
19 |
+
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
|
20 |
+
```
|
21 |
+
|
22 |
+
### Export Models to TensorRT-LLM and Launch Server
|
23 |
+
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
|
24 |
+
|
25 |
+
```sh
|
26 |
+
bash run.sh 0 4 F5TTS_Base
|
27 |
+
```
|
28 |
+
### HTTP Client
|
29 |
+
```sh
|
30 |
+
python3 client_http.py
|
31 |
+
```
|
32 |
+
### Benchmark using Dataset
|
33 |
+
```sh
|
34 |
+
num_task=2
|
35 |
+
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
36 |
+
```
|
37 |
+
|
38 |
+
### Benchmark Results
|
39 |
+
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
|
40 |
+
|
41 |
+
| Model | Concurrency | Avg Latency | RTF |
|
42 |
+
|-------|-------------|-----------------|--|
|
43 |
+
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
|
44 |
+
|
45 |
+
### Credits
|
46 |
+
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
src/f5_tts/runtime/triton_trtllm/client_grpc.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
3 |
+
# 2023 Nvidia (authors: Yuekai Zhang)
|
4 |
+
# 2023 Recurrent.ai (authors: Songtao Shi)
|
5 |
+
# See LICENSE for clarification regarding multiple authors
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
"""
|
19 |
+
This script supports to load dataset from huggingface and sends it to the server
|
20 |
+
for decoding, in parallel.
|
21 |
+
|
22 |
+
Usage:
|
23 |
+
num_task=2
|
24 |
+
|
25 |
+
# For offline F5-TTS
|
26 |
+
python3 client_grpc.py \
|
27 |
+
--server-addr localhost \
|
28 |
+
--model-name f5_tts \
|
29 |
+
--num-tasks $num_task \
|
30 |
+
--huggingface-dataset yuekai/seed_tts \
|
31 |
+
--split-name test_zh \
|
32 |
+
--log-dir ./log_concurrent_tasks_${num_task}
|
33 |
+
|
34 |
+
# For offline Spark-TTS-0.5B
|
35 |
+
python3 client_grpc.py \
|
36 |
+
--server-addr localhost \
|
37 |
+
--model-name spark_tts \
|
38 |
+
--num-tasks $num_task \
|
39 |
+
--huggingface-dataset yuekai/seed_tts \
|
40 |
+
--split-name wenetspeech4tts \
|
41 |
+
--log-dir ./log_concurrent_tasks_${num_task}
|
42 |
+
"""
|
43 |
+
|
44 |
+
import argparse
|
45 |
+
import asyncio
|
46 |
+
import json
|
47 |
+
|
48 |
+
import os
|
49 |
+
import time
|
50 |
+
import types
|
51 |
+
from pathlib import Path
|
52 |
+
|
53 |
+
import numpy as np
|
54 |
+
import soundfile as sf
|
55 |
+
import tritonclient
|
56 |
+
import tritonclient.grpc.aio as grpcclient
|
57 |
+
from tritonclient.utils import np_to_triton_dtype
|
58 |
+
|
59 |
+
|
60 |
+
def write_triton_stats(stats, summary_file):
|
61 |
+
with open(summary_file, "w") as summary_f:
|
62 |
+
model_stats = stats["model_stats"]
|
63 |
+
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
|
64 |
+
summary_f.write(
|
65 |
+
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
66 |
+
)
|
67 |
+
summary_f.write("To learn more about the log, please refer to: \n")
|
68 |
+
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
|
69 |
+
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
|
70 |
+
summary_f.write(
|
71 |
+
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
72 |
+
)
|
73 |
+
summary_f.write(
|
74 |
+
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
|
75 |
+
)
|
76 |
+
summary_f.write(
|
77 |
+
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
|
78 |
+
)
|
79 |
+
summary_f.write(
|
80 |
+
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
|
81 |
+
)
|
82 |
+
for model_state in model_stats:
|
83 |
+
if "last_inference" not in model_state:
|
84 |
+
continue
|
85 |
+
summary_f.write(f"model name is {model_state['name']} \n")
|
86 |
+
model_inference_stats = model_state["inference_stats"]
|
87 |
+
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
88 |
+
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
89 |
+
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
90 |
+
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
91 |
+
summary_f.write(
|
92 |
+
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
93 |
+
)
|
94 |
+
model_batch_stats = model_state["batch_stats"]
|
95 |
+
for batch in model_batch_stats:
|
96 |
+
batch_size = int(batch["batch_size"])
|
97 |
+
compute_input = batch["compute_input"]
|
98 |
+
compute_output = batch["compute_output"]
|
99 |
+
compute_infer = batch["compute_infer"]
|
100 |
+
batch_count = int(compute_infer["count"])
|
101 |
+
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
|
102 |
+
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
103 |
+
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
104 |
+
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
105 |
+
summary_f.write(
|
106 |
+
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
|
107 |
+
)
|
108 |
+
summary_f.write(
|
109 |
+
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
|
110 |
+
)
|
111 |
+
summary_f.write(
|
112 |
+
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
def get_args():
|
117 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
118 |
+
|
119 |
+
parser.add_argument(
|
120 |
+
"--server-addr",
|
121 |
+
type=str,
|
122 |
+
default="localhost",
|
123 |
+
help="Address of the server",
|
124 |
+
)
|
125 |
+
|
126 |
+
parser.add_argument(
|
127 |
+
"--server-port",
|
128 |
+
type=int,
|
129 |
+
default=8001,
|
130 |
+
help="Grpc port of the triton server, default is 8001",
|
131 |
+
)
|
132 |
+
|
133 |
+
parser.add_argument(
|
134 |
+
"--reference-audio",
|
135 |
+
type=str,
|
136 |
+
default=None,
|
137 |
+
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
138 |
+
)
|
139 |
+
|
140 |
+
parser.add_argument(
|
141 |
+
"--reference-text",
|
142 |
+
type=str,
|
143 |
+
default="",
|
144 |
+
help="",
|
145 |
+
)
|
146 |
+
|
147 |
+
parser.add_argument(
|
148 |
+
"--target-text",
|
149 |
+
type=str,
|
150 |
+
default="",
|
151 |
+
help="",
|
152 |
+
)
|
153 |
+
|
154 |
+
parser.add_argument(
|
155 |
+
"--huggingface-dataset",
|
156 |
+
type=str,
|
157 |
+
default="yuekai/seed_tts",
|
158 |
+
help="dataset name in huggingface dataset hub",
|
159 |
+
)
|
160 |
+
|
161 |
+
parser.add_argument(
|
162 |
+
"--split-name",
|
163 |
+
type=str,
|
164 |
+
default="wenetspeech4tts",
|
165 |
+
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
166 |
+
help="dataset split name, default is 'test'",
|
167 |
+
)
|
168 |
+
|
169 |
+
parser.add_argument(
|
170 |
+
"--manifest-path",
|
171 |
+
type=str,
|
172 |
+
default=None,
|
173 |
+
help="Path to the manifest dir which includes wav.scp trans.txt files.",
|
174 |
+
)
|
175 |
+
|
176 |
+
parser.add_argument(
|
177 |
+
"--model-name",
|
178 |
+
type=str,
|
179 |
+
default="f5_tts",
|
180 |
+
choices=["f5_tts", "spark_tts"],
|
181 |
+
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
182 |
+
)
|
183 |
+
|
184 |
+
parser.add_argument(
|
185 |
+
"--num-tasks",
|
186 |
+
type=int,
|
187 |
+
default=1,
|
188 |
+
help="Number of concurrent tasks for sending",
|
189 |
+
)
|
190 |
+
|
191 |
+
parser.add_argument(
|
192 |
+
"--log-interval",
|
193 |
+
type=int,
|
194 |
+
default=5,
|
195 |
+
help="Controls how frequently we print the log.",
|
196 |
+
)
|
197 |
+
|
198 |
+
parser.add_argument(
|
199 |
+
"--compute-wer",
|
200 |
+
action="store_true",
|
201 |
+
default=False,
|
202 |
+
help="""True to compute WER.
|
203 |
+
""",
|
204 |
+
)
|
205 |
+
|
206 |
+
parser.add_argument(
|
207 |
+
"--log-dir",
|
208 |
+
type=str,
|
209 |
+
required=False,
|
210 |
+
default="./tmp",
|
211 |
+
help="log directory",
|
212 |
+
)
|
213 |
+
|
214 |
+
parser.add_argument(
|
215 |
+
"--batch-size",
|
216 |
+
type=int,
|
217 |
+
default=1,
|
218 |
+
help="Inference batch_size per request for offline mode.",
|
219 |
+
)
|
220 |
+
|
221 |
+
return parser.parse_args()
|
222 |
+
|
223 |
+
|
224 |
+
def load_audio(wav_path, target_sample_rate=16000):
|
225 |
+
assert target_sample_rate == 16000, "hard coding in server"
|
226 |
+
if isinstance(wav_path, dict):
|
227 |
+
waveform = wav_path["array"]
|
228 |
+
sample_rate = wav_path["sampling_rate"]
|
229 |
+
else:
|
230 |
+
waveform, sample_rate = sf.read(wav_path)
|
231 |
+
if sample_rate != target_sample_rate:
|
232 |
+
from scipy.signal import resample
|
233 |
+
|
234 |
+
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
235 |
+
waveform = resample(waveform, num_samples)
|
236 |
+
return waveform, target_sample_rate
|
237 |
+
|
238 |
+
|
239 |
+
async def send(
|
240 |
+
manifest_item_list: list,
|
241 |
+
name: str,
|
242 |
+
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
243 |
+
protocol_client: types.ModuleType,
|
244 |
+
log_interval: int,
|
245 |
+
model_name: str,
|
246 |
+
padding_duration: int = None,
|
247 |
+
audio_save_dir: str = "./",
|
248 |
+
save_sample_rate: int = 16000,
|
249 |
+
):
|
250 |
+
total_duration = 0.0
|
251 |
+
latency_data = []
|
252 |
+
task_id = int(name[5:])
|
253 |
+
|
254 |
+
print(f"manifest_item_list: {manifest_item_list}")
|
255 |
+
for i, item in enumerate(manifest_item_list):
|
256 |
+
if i % log_interval == 0:
|
257 |
+
print(f"{name}: {i}/{len(manifest_item_list)}")
|
258 |
+
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
259 |
+
duration = len(waveform) / sample_rate
|
260 |
+
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
261 |
+
|
262 |
+
reference_text, target_text = item["reference_text"], item["target_text"]
|
263 |
+
|
264 |
+
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
265 |
+
|
266 |
+
if padding_duration:
|
267 |
+
# padding to nearset 10 seconds
|
268 |
+
samples = np.zeros(
|
269 |
+
(
|
270 |
+
1,
|
271 |
+
padding_duration
|
272 |
+
* sample_rate
|
273 |
+
* ((int(estimated_target_duration + duration) // padding_duration) + 1),
|
274 |
+
),
|
275 |
+
dtype=np.float32,
|
276 |
+
)
|
277 |
+
|
278 |
+
samples[0, : len(waveform)] = waveform
|
279 |
+
else:
|
280 |
+
samples = waveform
|
281 |
+
|
282 |
+
samples = samples.reshape(1, -1).astype(np.float32)
|
283 |
+
|
284 |
+
inputs = [
|
285 |
+
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
286 |
+
protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)),
|
287 |
+
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
288 |
+
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
|
289 |
+
]
|
290 |
+
inputs[0].set_data_from_numpy(samples)
|
291 |
+
inputs[1].set_data_from_numpy(lengths)
|
292 |
+
|
293 |
+
input_data_numpy = np.array([reference_text], dtype=object)
|
294 |
+
input_data_numpy = input_data_numpy.reshape((1, 1))
|
295 |
+
inputs[2].set_data_from_numpy(input_data_numpy)
|
296 |
+
|
297 |
+
input_data_numpy = np.array([target_text], dtype=object)
|
298 |
+
input_data_numpy = input_data_numpy.reshape((1, 1))
|
299 |
+
inputs[3].set_data_from_numpy(input_data_numpy)
|
300 |
+
|
301 |
+
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
302 |
+
|
303 |
+
sequence_id = 100000000 + i + task_id * 10
|
304 |
+
start = time.time()
|
305 |
+
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
|
306 |
+
|
307 |
+
audio = response.as_numpy("waveform").reshape(-1)
|
308 |
+
|
309 |
+
end = time.time() - start
|
310 |
+
|
311 |
+
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
312 |
+
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
313 |
+
|
314 |
+
latency_data.append((end, estimated_target_duration))
|
315 |
+
total_duration += estimated_target_duration
|
316 |
+
|
317 |
+
return total_duration, latency_data
|
318 |
+
|
319 |
+
|
320 |
+
def load_manifests(manifest_path):
|
321 |
+
with open(manifest_path, "r") as f:
|
322 |
+
manifest_list = []
|
323 |
+
for line in f:
|
324 |
+
assert len(line.strip().split("|")) == 4
|
325 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
326 |
+
utt = Path(utt).stem
|
327 |
+
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
|
328 |
+
if not os.path.isabs(prompt_wav):
|
329 |
+
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
330 |
+
manifest_list.append(
|
331 |
+
{
|
332 |
+
"audio_filepath": prompt_wav,
|
333 |
+
"reference_text": prompt_text,
|
334 |
+
"target_text": gt_text,
|
335 |
+
"target_audio_path": utt,
|
336 |
+
}
|
337 |
+
)
|
338 |
+
return manifest_list
|
339 |
+
|
340 |
+
|
341 |
+
def split_data(data, k):
|
342 |
+
n = len(data)
|
343 |
+
if n < k:
|
344 |
+
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
|
345 |
+
k = n
|
346 |
+
|
347 |
+
quotient = n // k
|
348 |
+
remainder = n % k
|
349 |
+
|
350 |
+
result = []
|
351 |
+
start = 0
|
352 |
+
for i in range(k):
|
353 |
+
if i < remainder:
|
354 |
+
end = start + quotient + 1
|
355 |
+
else:
|
356 |
+
end = start + quotient
|
357 |
+
|
358 |
+
result.append(data[start:end])
|
359 |
+
start = end
|
360 |
+
|
361 |
+
return result
|
362 |
+
|
363 |
+
|
364 |
+
async def main():
|
365 |
+
args = get_args()
|
366 |
+
url = f"{args.server_addr}:{args.server_port}"
|
367 |
+
|
368 |
+
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
|
369 |
+
protocol_client = grpcclient
|
370 |
+
|
371 |
+
if args.reference_audio:
|
372 |
+
args.num_tasks = 1
|
373 |
+
args.log_interval = 1
|
374 |
+
manifest_item_list = [
|
375 |
+
{
|
376 |
+
"reference_text": args.reference_text,
|
377 |
+
"target_text": args.target_text,
|
378 |
+
"audio_filepath": args.reference_audio,
|
379 |
+
"target_audio_path": "test",
|
380 |
+
}
|
381 |
+
]
|
382 |
+
elif args.huggingface_dataset:
|
383 |
+
import datasets
|
384 |
+
|
385 |
+
dataset = datasets.load_dataset(
|
386 |
+
args.huggingface_dataset,
|
387 |
+
split=args.split_name,
|
388 |
+
trust_remote_code=True,
|
389 |
+
)
|
390 |
+
manifest_item_list = []
|
391 |
+
for i in range(len(dataset)):
|
392 |
+
manifest_item_list.append(
|
393 |
+
{
|
394 |
+
"audio_filepath": dataset[i]["prompt_audio"],
|
395 |
+
"reference_text": dataset[i]["prompt_text"],
|
396 |
+
"target_audio_path": dataset[i]["id"],
|
397 |
+
"target_text": dataset[i]["target_text"],
|
398 |
+
}
|
399 |
+
)
|
400 |
+
else:
|
401 |
+
manifest_item_list = load_manifests(args.manifest_path)
|
402 |
+
|
403 |
+
args.num_tasks = min(args.num_tasks, len(manifest_item_list))
|
404 |
+
manifest_item_list = split_data(manifest_item_list, args.num_tasks)
|
405 |
+
|
406 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
407 |
+
tasks = []
|
408 |
+
start_time = time.time()
|
409 |
+
for i in range(args.num_tasks):
|
410 |
+
task = asyncio.create_task(
|
411 |
+
send(
|
412 |
+
manifest_item_list[i],
|
413 |
+
name=f"task-{i}",
|
414 |
+
triton_client=triton_client,
|
415 |
+
protocol_client=protocol_client,
|
416 |
+
log_interval=args.log_interval,
|
417 |
+
model_name=args.model_name,
|
418 |
+
audio_save_dir=args.log_dir,
|
419 |
+
padding_duration=1,
|
420 |
+
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
421 |
+
)
|
422 |
+
)
|
423 |
+
tasks.append(task)
|
424 |
+
|
425 |
+
ans_list = await asyncio.gather(*tasks)
|
426 |
+
|
427 |
+
end_time = time.time()
|
428 |
+
elapsed = end_time - start_time
|
429 |
+
|
430 |
+
total_duration = 0.0
|
431 |
+
latency_data = []
|
432 |
+
for ans in ans_list:
|
433 |
+
total_duration += ans[0]
|
434 |
+
latency_data += ans[1]
|
435 |
+
|
436 |
+
rtf = elapsed / total_duration
|
437 |
+
|
438 |
+
s = f"RTF: {rtf:.4f}\n"
|
439 |
+
s += f"total_duration: {total_duration:.3f} seconds\n"
|
440 |
+
s += f"({total_duration / 3600:.2f} hours)\n"
|
441 |
+
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
442 |
+
|
443 |
+
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
444 |
+
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
445 |
+
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
|
446 |
+
s += f"latency_variance: {latency_variance:.2f}\n"
|
447 |
+
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
|
448 |
+
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
|
449 |
+
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
|
450 |
+
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
|
451 |
+
s += f"average_latency_ms: {latency_ms:.2f}\n"
|
452 |
+
|
453 |
+
print(s)
|
454 |
+
if args.manifest_path:
|
455 |
+
name = Path(args.manifest_path).stem
|
456 |
+
elif args.split_name:
|
457 |
+
name = args.split_name
|
458 |
+
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
459 |
+
f.write(s)
|
460 |
+
|
461 |
+
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
|
462 |
+
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
463 |
+
|
464 |
+
metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
|
465 |
+
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
466 |
+
json.dump(metadata, f, indent=4)
|
467 |
+
|
468 |
+
|
469 |
+
if __name__ == "__main__":
|
470 |
+
asyncio.run(main())
|
src/f5_tts/runtime/triton_trtllm/client_http.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
import requests
|
27 |
+
import soundfile as sf
|
28 |
+
import numpy as np
|
29 |
+
import argparse
|
30 |
+
|
31 |
+
|
32 |
+
def get_args():
|
33 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
34 |
+
|
35 |
+
parser.add_argument(
|
36 |
+
"--server-url",
|
37 |
+
type=str,
|
38 |
+
default="localhost:8000",
|
39 |
+
help="Address of the server",
|
40 |
+
)
|
41 |
+
|
42 |
+
parser.add_argument(
|
43 |
+
"--reference-audio",
|
44 |
+
type=str,
|
45 |
+
default="../../infer/examples/basic/basic_ref_en.wav",
|
46 |
+
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
47 |
+
)
|
48 |
+
|
49 |
+
parser.add_argument(
|
50 |
+
"--reference-text",
|
51 |
+
type=str,
|
52 |
+
default="Some call me nature, others call me mother nature.",
|
53 |
+
help="",
|
54 |
+
)
|
55 |
+
|
56 |
+
parser.add_argument(
|
57 |
+
"--target-text",
|
58 |
+
type=str,
|
59 |
+
default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
|
60 |
+
help="",
|
61 |
+
)
|
62 |
+
|
63 |
+
parser.add_argument(
|
64 |
+
"--model-name",
|
65 |
+
type=str,
|
66 |
+
default="f5_tts",
|
67 |
+
choices=["f5_tts", "spark_tts"],
|
68 |
+
help="triton model_repo module name to request",
|
69 |
+
)
|
70 |
+
|
71 |
+
parser.add_argument(
|
72 |
+
"--output-audio",
|
73 |
+
type=str,
|
74 |
+
default="output.wav",
|
75 |
+
help="Path to save the output audio",
|
76 |
+
)
|
77 |
+
return parser.parse_args()
|
78 |
+
|
79 |
+
|
80 |
+
def prepare_request(
|
81 |
+
samples,
|
82 |
+
reference_text,
|
83 |
+
target_text,
|
84 |
+
sample_rate=16000,
|
85 |
+
audio_save_dir: str = "./",
|
86 |
+
):
|
87 |
+
assert len(samples.shape) == 1, "samples should be 1D"
|
88 |
+
lengths = np.array([[len(samples)]], dtype=np.int32)
|
89 |
+
samples = samples.reshape(1, -1).astype(np.float32)
|
90 |
+
|
91 |
+
data = {
|
92 |
+
"inputs": [
|
93 |
+
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
|
94 |
+
{
|
95 |
+
"name": "reference_wav_len",
|
96 |
+
"shape": lengths.shape,
|
97 |
+
"datatype": "INT32",
|
98 |
+
"data": lengths.tolist(),
|
99 |
+
},
|
100 |
+
{"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
|
101 |
+
{"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
|
102 |
+
]
|
103 |
+
}
|
104 |
+
|
105 |
+
return data
|
106 |
+
|
107 |
+
|
108 |
+
def load_audio(wav_path, target_sample_rate=16000):
|
109 |
+
assert target_sample_rate == 16000, "hard coding in server"
|
110 |
+
if isinstance(wav_path, dict):
|
111 |
+
samples = wav_path["array"]
|
112 |
+
sample_rate = wav_path["sampling_rate"]
|
113 |
+
else:
|
114 |
+
samples, sample_rate = sf.read(wav_path)
|
115 |
+
if sample_rate != target_sample_rate:
|
116 |
+
from scipy.signal import resample
|
117 |
+
|
118 |
+
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
|
119 |
+
samples = resample(samples, num_samples)
|
120 |
+
return samples, target_sample_rate
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
args = get_args()
|
125 |
+
server_url = args.server_url
|
126 |
+
if not server_url.startswith(("http://", "https://")):
|
127 |
+
server_url = f"http://{server_url}"
|
128 |
+
|
129 |
+
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
130 |
+
samples, sr = load_audio(args.reference_audio)
|
131 |
+
assert sr == 16000, "sample rate hardcoded in server"
|
132 |
+
|
133 |
+
samples = np.array(samples, dtype=np.float32)
|
134 |
+
data = prepare_request(samples, args.reference_text, args.target_text)
|
135 |
+
|
136 |
+
rsp = requests.post(
|
137 |
+
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
|
138 |
+
)
|
139 |
+
result = rsp.json()
|
140 |
+
audio = result["outputs"][0]["data"]
|
141 |
+
audio = np.array(audio, dtype=np.float32)
|
142 |
+
sf.write(args.output_audio, audio, 24000, "PCM_16")
|
src/f5_tts/runtime/triton_trtllm/docker-compose.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
tts:
|
3 |
+
image: soar97/triton-f5-tts:24.12
|
4 |
+
shm_size: '1gb'
|
5 |
+
ports:
|
6 |
+
- "8000:8000"
|
7 |
+
- "8001:8001"
|
8 |
+
- "8002:8002"
|
9 |
+
environment:
|
10 |
+
- PYTHONIOENCODING=utf-8
|
11 |
+
- MODEL_ID=${MODEL_ID}
|
12 |
+
deploy:
|
13 |
+
resources:
|
14 |
+
reservations:
|
15 |
+
devices:
|
16 |
+
- driver: nvidia
|
17 |
+
device_ids: ['0']
|
18 |
+
capabilities: [gpu]
|
19 |
+
command: >
|
20 |
+
/bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL"
|
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorrt as trt
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
from typing import List, Optional
|
6 |
+
from functools import wraps
|
7 |
+
|
8 |
+
import tensorrt_llm
|
9 |
+
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
10 |
+
from tensorrt_llm.logger import logger
|
11 |
+
from tensorrt_llm.runtime.session import Session
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
|
18 |
+
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
19 |
+
# Audio tensor case: batch, seq_len, feature_len
|
20 |
+
# position_ids case: batch, seq_len
|
21 |
+
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
|
22 |
+
|
23 |
+
# Initialize a list to collect valid sequences
|
24 |
+
valid_sequences = []
|
25 |
+
|
26 |
+
for i in range(input_tensor.shape[0]):
|
27 |
+
valid_length = input_tensor_lengths[i]
|
28 |
+
valid_sequences.append(input_tensor[i, :valid_length])
|
29 |
+
|
30 |
+
# Concatenate all valid sequences along the batch dimension
|
31 |
+
output_tensor = torch.cat(valid_sequences, dim=0).contiguous()
|
32 |
+
return output_tensor
|
33 |
+
|
34 |
+
|
35 |
+
class TextEmbedding(nn.Module):
|
36 |
+
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
|
37 |
+
super().__init__()
|
38 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
39 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
40 |
+
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
41 |
+
|
42 |
+
def forward(self, text):
|
43 |
+
# only keep tensors with value not -1
|
44 |
+
text_mask = text != -1
|
45 |
+
text_pad_cut_off_index = text_mask.sum(dim=1).max()
|
46 |
+
|
47 |
+
text = text[:, :text_pad_cut_off_index]
|
48 |
+
text = self.text_embed(text)
|
49 |
+
text = text + self.freqs_cis[: text.shape[1], :]
|
50 |
+
for block in self.text_blocks:
|
51 |
+
text = block(text)
|
52 |
+
# padding text to the original length
|
53 |
+
# text shape: B,seq_len,C
|
54 |
+
# pad at the second dimension
|
55 |
+
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
|
56 |
+
return text
|
57 |
+
|
58 |
+
|
59 |
+
class GRN(nn.Module):
|
60 |
+
def __init__(self, dim):
|
61 |
+
super().__init__()
|
62 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
63 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
67 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
68 |
+
return self.gamma * (x * Nx) + self.beta + x
|
69 |
+
|
70 |
+
|
71 |
+
class ConvNeXtV2Block(nn.Module):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
dim: int,
|
75 |
+
intermediate_dim: int,
|
76 |
+
dilation: int = 1,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
padding = (dilation * (7 - 1)) // 2
|
80 |
+
self.dwconv = nn.Conv1d(
|
81 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
82 |
+
) # depthwise conv
|
83 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
84 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
85 |
+
self.act = nn.GELU()
|
86 |
+
self.grn = GRN(intermediate_dim)
|
87 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
88 |
+
|
89 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
90 |
+
residual = x
|
91 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
92 |
+
x = self.dwconv(x)
|
93 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
94 |
+
x = self.norm(x)
|
95 |
+
x = self.pwconv1(x)
|
96 |
+
x = self.act(x)
|
97 |
+
x = self.grn(x)
|
98 |
+
x = self.pwconv2(x)
|
99 |
+
return residual + x
|
100 |
+
|
101 |
+
|
102 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
103 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
104 |
+
# has some connection to NTK literature
|
105 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
106 |
+
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
107 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
108 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
109 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
110 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
111 |
+
freqs_cos = torch.cos(freqs) # real part
|
112 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
113 |
+
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
114 |
+
|
115 |
+
|
116 |
+
def load_checkpoint(ckpt_path, use_ema=True):
|
117 |
+
checkpoint = torch.load(ckpt_path, weights_only=True)
|
118 |
+
if use_ema:
|
119 |
+
checkpoint["model_state_dict"] = {
|
120 |
+
k.replace("ema_model.", ""): v
|
121 |
+
for k, v in checkpoint["ema_model_state_dict"].items()
|
122 |
+
if k not in ["initted", "step"]
|
123 |
+
}
|
124 |
+
dict_state = checkpoint["model_state_dict"]
|
125 |
+
text_embed_dict = {}
|
126 |
+
for key in dict_state.keys():
|
127 |
+
# transformer.text_embed.text_embed.weight -> text_embed.weight
|
128 |
+
if "text_embed" in key:
|
129 |
+
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
|
130 |
+
return text_embed_dict
|
131 |
+
|
132 |
+
|
133 |
+
class F5TTS(object):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
config,
|
137 |
+
debug_mode=True,
|
138 |
+
stream: Optional[torch.cuda.Stream] = None,
|
139 |
+
tllm_model_dir: Optional[str] = None,
|
140 |
+
model_path: Optional[str] = None,
|
141 |
+
vocab_size: Optional[int] = None,
|
142 |
+
):
|
143 |
+
self.dtype = config["pretrained_config"]["dtype"]
|
144 |
+
|
145 |
+
rank = tensorrt_llm.mpi_rank()
|
146 |
+
world_size = config["pretrained_config"]["mapping"]["world_size"]
|
147 |
+
cp_size = config["pretrained_config"]["mapping"]["cp_size"]
|
148 |
+
tp_size = config["pretrained_config"]["mapping"]["tp_size"]
|
149 |
+
pp_size = config["pretrained_config"]["mapping"]["pp_size"]
|
150 |
+
assert pp_size == 1
|
151 |
+
self.mapping = tensorrt_llm.Mapping(
|
152 |
+
world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1
|
153 |
+
)
|
154 |
+
|
155 |
+
local_rank = rank % self.mapping.gpus_per_node
|
156 |
+
self.device = torch.device(f"cuda:{local_rank}")
|
157 |
+
|
158 |
+
torch.cuda.set_device(self.device)
|
159 |
+
|
160 |
+
self.stream = stream
|
161 |
+
if self.stream is None:
|
162 |
+
self.stream = torch.cuda.Stream(self.device)
|
163 |
+
torch.cuda.set_stream(self.stream)
|
164 |
+
|
165 |
+
engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine")
|
166 |
+
logger.info(f"Loading engine from {engine_file}")
|
167 |
+
with open(engine_file, "rb") as f:
|
168 |
+
engine_buffer = f.read()
|
169 |
+
|
170 |
+
assert engine_buffer is not None
|
171 |
+
|
172 |
+
self.session = Session.from_serialized_engine(engine_buffer)
|
173 |
+
|
174 |
+
self.debug_mode = debug_mode
|
175 |
+
|
176 |
+
self.inputs = {}
|
177 |
+
self.outputs = {}
|
178 |
+
self.buffer_allocated = False
|
179 |
+
|
180 |
+
expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"]
|
181 |
+
|
182 |
+
found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)]
|
183 |
+
if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names):
|
184 |
+
logger.error(
|
185 |
+
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
|
186 |
+
)
|
187 |
+
logger.error(
|
188 |
+
f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
|
189 |
+
)
|
190 |
+
logger.error(f"Expected tensor names: {expected_tensor_names}")
|
191 |
+
logger.error(f"Found tensor names: {found_tensor_names}")
|
192 |
+
raise RuntimeError("Tensor names in engine are not the same as expected.")
|
193 |
+
if self.debug_mode:
|
194 |
+
self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names))
|
195 |
+
|
196 |
+
self.max_mel_len = 4096
|
197 |
+
self.text_embedding = TextEmbedding(
|
198 |
+
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
|
199 |
+
).to(self.device)
|
200 |
+
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
|
201 |
+
|
202 |
+
self.target_audio_sample_rate = 24000
|
203 |
+
self.target_rms = 0.15 # target rms for audio
|
204 |
+
self.n_fft = 1024
|
205 |
+
self.win_length = 1024
|
206 |
+
self.hop_length = 256
|
207 |
+
self.n_mel_channels = 100
|
208 |
+
# self.max_mel_len = 3000
|
209 |
+
self.head_dim = 64
|
210 |
+
self.base_rescale_factor = 1.0
|
211 |
+
self.interpolation_factor = 1.0
|
212 |
+
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
|
213 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
|
214 |
+
freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor
|
215 |
+
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
|
216 |
+
self.rope_cos = self.freqs.cos().half()
|
217 |
+
self.rope_sin = self.freqs.sin().half()
|
218 |
+
self.nfe_steps = 16
|
219 |
+
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
|
220 |
+
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
|
221 |
+
delta_t = torch.diff(time_step)
|
222 |
+
# WAR: hard coding 256 here
|
223 |
+
tmp_dim = 256
|
224 |
+
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
|
225 |
+
half_dim = tmp_dim // 2
|
226 |
+
emb_factor = math.log(10000) / (half_dim - 1)
|
227 |
+
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
|
228 |
+
for i in range(self.nfe_steps):
|
229 |
+
emb = time_step[i] * emb_factor
|
230 |
+
time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
231 |
+
self.time_expand = time_expand.to(self.device)
|
232 |
+
self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device)
|
233 |
+
|
234 |
+
def _tensor_dtype(self, name):
|
235 |
+
# return torch dtype given tensor name for convenience
|
236 |
+
dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
|
237 |
+
return dtype
|
238 |
+
|
239 |
+
def _setup(self, batch_size, seq_len):
|
240 |
+
for i in range(self.session.engine.num_io_tensors):
|
241 |
+
name = self.session.engine.get_tensor_name(i)
|
242 |
+
if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
243 |
+
shape = list(self.session.engine.get_tensor_shape(name))
|
244 |
+
shape[0] = batch_size
|
245 |
+
shape[1] = seq_len
|
246 |
+
self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device)
|
247 |
+
|
248 |
+
self.buffer_allocated = True
|
249 |
+
|
250 |
+
def cuda_stream_guard(func):
|
251 |
+
"""Sync external stream and set current stream to the one bound to the session. Reset on exit."""
|
252 |
+
|
253 |
+
@wraps(func)
|
254 |
+
def wrapper(self, *args, **kwargs):
|
255 |
+
external_stream = torch.cuda.current_stream()
|
256 |
+
if external_stream != self.stream:
|
257 |
+
external_stream.synchronize()
|
258 |
+
torch.cuda.set_stream(self.stream)
|
259 |
+
ret = func(self, *args, **kwargs)
|
260 |
+
if external_stream != self.stream:
|
261 |
+
self.stream.synchronize()
|
262 |
+
torch.cuda.set_stream(external_stream)
|
263 |
+
return ret
|
264 |
+
|
265 |
+
return wrapper
|
266 |
+
|
267 |
+
@cuda_stream_guard
|
268 |
+
def forward(
|
269 |
+
self,
|
270 |
+
noise: torch.Tensor,
|
271 |
+
cond: torch.Tensor,
|
272 |
+
time_expand: torch.Tensor,
|
273 |
+
rope_cos: torch.Tensor,
|
274 |
+
rope_sin: torch.Tensor,
|
275 |
+
input_lengths: torch.Tensor,
|
276 |
+
delta_t: torch.Tensor,
|
277 |
+
use_perf: bool = False,
|
278 |
+
):
|
279 |
+
if use_perf:
|
280 |
+
torch.cuda.nvtx.range_push("flow matching")
|
281 |
+
cfg_strength = 2.0
|
282 |
+
batch_size = noise.shape[0]
|
283 |
+
half_batch = batch_size // 2
|
284 |
+
noise_half = noise[:half_batch] # Store the initial half of noise
|
285 |
+
|
286 |
+
input_type = str_dtype_to_torch(self.dtype)
|
287 |
+
|
288 |
+
# Keep a copy of the initial tensors
|
289 |
+
cond = cond.to(input_type)
|
290 |
+
rope_cos = rope_cos.to(input_type)
|
291 |
+
rope_sin = rope_sin.to(input_type)
|
292 |
+
input_lengths = input_lengths.to(str_dtype_to_torch("int32"))
|
293 |
+
|
294 |
+
# Instead of iteratively updating noise within a single model context,
|
295 |
+
# we'll do a single forward pass for each iteration with fresh context setup
|
296 |
+
for i in range(self.nfe_steps):
|
297 |
+
# Re-setup the buffers for clean execution
|
298 |
+
self._setup(batch_size, noise.shape[1])
|
299 |
+
if not self.buffer_allocated:
|
300 |
+
raise RuntimeError("Buffer not allocated, please call setup first!")
|
301 |
+
|
302 |
+
# Re-create combined noises for this iteration
|
303 |
+
current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type)
|
304 |
+
|
305 |
+
# Get time step for this iteration
|
306 |
+
current_time = time_expand[:, i].to(input_type)
|
307 |
+
|
308 |
+
# Create fresh input dictionary for this iteration
|
309 |
+
current_inputs = {
|
310 |
+
"noise": current_noise,
|
311 |
+
"cond": cond,
|
312 |
+
"time": current_time,
|
313 |
+
"rope_cos": rope_cos,
|
314 |
+
"rope_sin": rope_sin,
|
315 |
+
"input_lengths": input_lengths,
|
316 |
+
}
|
317 |
+
|
318 |
+
# Update inputs and set shapes
|
319 |
+
self.inputs.clear() # Clear previous inputs
|
320 |
+
self.inputs.update(**current_inputs)
|
321 |
+
self.session.set_shapes(self.inputs)
|
322 |
+
|
323 |
+
if use_perf:
|
324 |
+
torch.cuda.nvtx.range_push(f"execute {i}")
|
325 |
+
ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream)
|
326 |
+
assert ok, "Failed to execute model"
|
327 |
+
# self.session.context.execute_async_v3(self.stream.cuda_stream)
|
328 |
+
if use_perf:
|
329 |
+
torch.cuda.nvtx.range_pop()
|
330 |
+
# Process results
|
331 |
+
t_scale = delta_t[i].unsqueeze(0).to(input_type)
|
332 |
+
|
333 |
+
# Extract predictions
|
334 |
+
pred_cond = self.outputs["denoised"][:half_batch]
|
335 |
+
pred_uncond = self.outputs["denoised"][half_batch:]
|
336 |
+
|
337 |
+
# Apply classifier-free guidance with safeguards
|
338 |
+
guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength
|
339 |
+
# Calculate update for noise
|
340 |
+
noise_half = noise_half + guidance * t_scale
|
341 |
+
if use_perf:
|
342 |
+
torch.cuda.nvtx.range_pop()
|
343 |
+
return noise_half
|
344 |
+
|
345 |
+
def sample(
|
346 |
+
self,
|
347 |
+
text_pad_sequence: torch.Tensor,
|
348 |
+
ref_mel_batch: torch.Tensor,
|
349 |
+
ref_mel_len_batch: torch.Tensor,
|
350 |
+
estimated_reference_target_mel_len: List[int],
|
351 |
+
remove_input_padding: bool = False,
|
352 |
+
use_perf: bool = False,
|
353 |
+
):
|
354 |
+
if use_perf:
|
355 |
+
torch.cuda.nvtx.range_push("text embedding")
|
356 |
+
batch = text_pad_sequence.shape[0]
|
357 |
+
max_seq_len = ref_mel_batch.shape[1]
|
358 |
+
|
359 |
+
text_pad_sequence_drop = torch.cat(
|
360 |
+
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
|
361 |
+
)
|
362 |
+
|
363 |
+
text_embedding_drop_list = []
|
364 |
+
for i in range(batch + 1):
|
365 |
+
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
|
366 |
+
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
|
367 |
+
|
368 |
+
text_embedding = text_embedding_drop_condition[:-1]
|
369 |
+
# text_embedding_drop B,T,C batch should be the same
|
370 |
+
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
|
371 |
+
|
372 |
+
noise = torch.randn_like(ref_mel_batch).to(self.device)
|
373 |
+
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
374 |
+
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
375 |
+
|
376 |
+
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
|
377 |
+
cat_mel_text_drop = torch.cat(
|
378 |
+
(
|
379 |
+
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
|
380 |
+
text_embedding_drop,
|
381 |
+
),
|
382 |
+
dim=-1,
|
383 |
+
)
|
384 |
+
|
385 |
+
time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
|
386 |
+
|
387 |
+
# Convert estimated_reference_target_mel_len to tensor
|
388 |
+
input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32)
|
389 |
+
|
390 |
+
# combine above along the batch dimension
|
391 |
+
inputs = {
|
392 |
+
"noise": torch.cat((noise, noise), dim=0).contiguous(),
|
393 |
+
"cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(),
|
394 |
+
"time_expand": time_expand,
|
395 |
+
"rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
|
396 |
+
"rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
|
397 |
+
"input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(),
|
398 |
+
"delta_t": self.delta_t,
|
399 |
+
}
|
400 |
+
if use_perf and remove_input_padding:
|
401 |
+
torch.cuda.nvtx.range_push("remove input padding")
|
402 |
+
if remove_input_padding:
|
403 |
+
max_seq_len = inputs["cond"].shape[1]
|
404 |
+
inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"])
|
405 |
+
inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"])
|
406 |
+
# for time_expand, convert from B,D to B,T,D by repeat
|
407 |
+
inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
|
408 |
+
inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"])
|
409 |
+
inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"])
|
410 |
+
inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"])
|
411 |
+
if use_perf and remove_input_padding:
|
412 |
+
torch.cuda.nvtx.range_pop()
|
413 |
+
for key in inputs:
|
414 |
+
inputs[key] = inputs[key].to(self.device)
|
415 |
+
if use_perf:
|
416 |
+
torch.cuda.nvtx.range_pop()
|
417 |
+
start_time = time.time()
|
418 |
+
denoised = self.forward(**inputs, use_perf=use_perf)
|
419 |
+
cost_time = time.time() - start_time
|
420 |
+
if use_perf and remove_input_padding:
|
421 |
+
torch.cuda.nvtx.range_push("remove input padding output")
|
422 |
+
if remove_input_padding:
|
423 |
+
denoised_list = []
|
424 |
+
start_idx = 0
|
425 |
+
for i in range(batch):
|
426 |
+
denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]])
|
427 |
+
start_idx += inputs["input_lengths"][i]
|
428 |
+
if use_perf and remove_input_padding:
|
429 |
+
torch.cuda.nvtx.range_pop()
|
430 |
+
return denoised_list, cost_time
|
431 |
+
return denoised, cost_time
|
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
import json
|
27 |
+
import torch
|
28 |
+
from torch.nn.utils.rnn import pad_sequence
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torch.utils.dlpack import from_dlpack, to_dlpack
|
31 |
+
import torchaudio
|
32 |
+
import jieba
|
33 |
+
import triton_python_backend_utils as pb_utils
|
34 |
+
from pypinyin import Style, lazy_pinyin
|
35 |
+
import os
|
36 |
+
from f5_tts_trtllm import F5TTS
|
37 |
+
|
38 |
+
|
39 |
+
def get_tokenizer(vocab_file_path: str):
|
40 |
+
"""
|
41 |
+
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
42 |
+
- "char" for char-wise tokenizer, need .txt vocab_file
|
43 |
+
- "byte" for utf-8 tokenizer
|
44 |
+
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
45 |
+
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
46 |
+
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
47 |
+
- if use "byte", set to 256 (unicode byte range)
|
48 |
+
"""
|
49 |
+
with open(vocab_file_path, "r", encoding="utf-8") as f:
|
50 |
+
vocab_char_map = {}
|
51 |
+
for i, char in enumerate(f):
|
52 |
+
vocab_char_map[char[:-1]] = i
|
53 |
+
vocab_size = len(vocab_char_map)
|
54 |
+
return vocab_char_map, vocab_size
|
55 |
+
|
56 |
+
|
57 |
+
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
58 |
+
final_reference_target_texts_list = []
|
59 |
+
custom_trans = str.maketrans(
|
60 |
+
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
61 |
+
) # add custom trans here, to address oov
|
62 |
+
|
63 |
+
def is_chinese(c):
|
64 |
+
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
65 |
+
|
66 |
+
for text in reference_target_texts_list:
|
67 |
+
char_list = []
|
68 |
+
text = text.translate(custom_trans)
|
69 |
+
for seg in jieba.cut(text):
|
70 |
+
seg_byte_len = len(bytes(seg, "UTF-8"))
|
71 |
+
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
72 |
+
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
73 |
+
char_list.append(" ")
|
74 |
+
char_list.extend(seg)
|
75 |
+
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
76 |
+
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
77 |
+
for i, c in enumerate(seg):
|
78 |
+
if is_chinese(c):
|
79 |
+
char_list.append(" ")
|
80 |
+
char_list.append(seg_[i])
|
81 |
+
else: # if mixed characters, alphabets and symbols
|
82 |
+
for c in seg:
|
83 |
+
if ord(c) < 256:
|
84 |
+
char_list.extend(c)
|
85 |
+
elif is_chinese(c):
|
86 |
+
char_list.append(" ")
|
87 |
+
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
88 |
+
else:
|
89 |
+
char_list.append(c)
|
90 |
+
final_reference_target_texts_list.append(char_list)
|
91 |
+
|
92 |
+
return final_reference_target_texts_list
|
93 |
+
|
94 |
+
|
95 |
+
def list_str_to_idx(
|
96 |
+
text: list[str] | list[list[str]],
|
97 |
+
vocab_char_map: dict[str, int], # {char: idx}
|
98 |
+
padding_value=-1,
|
99 |
+
): # noqa: F722
|
100 |
+
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
101 |
+
return list_idx_tensors
|
102 |
+
|
103 |
+
|
104 |
+
class TritonPythonModel:
|
105 |
+
def initialize(self, args):
|
106 |
+
self.use_perf = True
|
107 |
+
self.device = torch.device("cuda")
|
108 |
+
self.target_audio_sample_rate = 24000
|
109 |
+
self.target_rms = 0.15 # target rms for audio
|
110 |
+
self.n_fft = 1024
|
111 |
+
self.win_length = 1024
|
112 |
+
self.hop_length = 256
|
113 |
+
self.n_mel_channels = 100
|
114 |
+
self.max_mel_len = 3000
|
115 |
+
self.head_dim = 64
|
116 |
+
|
117 |
+
parameters = json.loads(args["model_config"])["parameters"]
|
118 |
+
for key, value in parameters.items():
|
119 |
+
parameters[key] = value["string_value"]
|
120 |
+
|
121 |
+
self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
|
122 |
+
self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
|
123 |
+
self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate)
|
124 |
+
|
125 |
+
self.tllm_model_dir = parameters["tllm_model_dir"]
|
126 |
+
config_file = os.path.join(self.tllm_model_dir, "config.json")
|
127 |
+
with open(config_file) as f:
|
128 |
+
config = json.load(f)
|
129 |
+
self.model = F5TTS(
|
130 |
+
config,
|
131 |
+
debug_mode=False,
|
132 |
+
tllm_model_dir=self.tllm_model_dir,
|
133 |
+
model_path=parameters["model_path"],
|
134 |
+
vocab_size=self.vocab_size,
|
135 |
+
)
|
136 |
+
|
137 |
+
self.vocoder = parameters["vocoder"]
|
138 |
+
assert self.vocoder in ["vocos", "bigvgan"]
|
139 |
+
if self.vocoder == "vocos":
|
140 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
141 |
+
sample_rate=self.target_audio_sample_rate,
|
142 |
+
n_fft=self.n_fft,
|
143 |
+
win_length=self.win_length,
|
144 |
+
hop_length=self.hop_length,
|
145 |
+
n_mels=self.n_mel_channels,
|
146 |
+
power=1,
|
147 |
+
center=True,
|
148 |
+
normalized=False,
|
149 |
+
norm=None,
|
150 |
+
).to(self.device)
|
151 |
+
self.compute_mel_fn = self.get_vocos_mel_spectrogram
|
152 |
+
elif self.vocoder == "bigvgan":
|
153 |
+
self.compute_mel_fn = self.get_bigvgan_mel_spectrogram
|
154 |
+
|
155 |
+
def get_vocos_mel_spectrogram(self, waveform):
|
156 |
+
mel = self.mel_stft(waveform)
|
157 |
+
mel = mel.clamp(min=1e-5).log()
|
158 |
+
return mel.transpose(1, 2)
|
159 |
+
|
160 |
+
def forward_vocoder(self, mel):
|
161 |
+
mel = mel.to(torch.float32).contiguous().cpu()
|
162 |
+
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
|
163 |
+
|
164 |
+
inference_request = pb_utils.InferenceRequest(
|
165 |
+
model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0]
|
166 |
+
)
|
167 |
+
inference_response = inference_request.exec()
|
168 |
+
if inference_response.has_error():
|
169 |
+
raise pb_utils.TritonModelException(inference_response.error().message())
|
170 |
+
else:
|
171 |
+
waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform")
|
172 |
+
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
173 |
+
|
174 |
+
return waveform
|
175 |
+
|
176 |
+
def execute(self, requests):
|
177 |
+
(
|
178 |
+
reference_text_list,
|
179 |
+
target_text_list,
|
180 |
+
reference_target_texts_list,
|
181 |
+
estimated_reference_target_mel_len,
|
182 |
+
reference_mel_len,
|
183 |
+
) = [], [], [], [], []
|
184 |
+
mel_features_list = []
|
185 |
+
if self.use_perf:
|
186 |
+
torch.cuda.nvtx.range_push("preprocess")
|
187 |
+
for request in requests:
|
188 |
+
wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
189 |
+
wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
190 |
+
|
191 |
+
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
192 |
+
reference_text = reference_text[0][0].decode("utf-8")
|
193 |
+
reference_text_list.append(reference_text)
|
194 |
+
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
195 |
+
target_text = target_text[0][0].decode("utf-8")
|
196 |
+
target_text_list.append(target_text)
|
197 |
+
|
198 |
+
text = reference_text + target_text
|
199 |
+
reference_target_texts_list.append(text)
|
200 |
+
|
201 |
+
wav = from_dlpack(wav_tensor.to_dlpack())
|
202 |
+
wav_len = from_dlpack(wav_lens.to_dlpack())
|
203 |
+
wav_len = wav_len.squeeze()
|
204 |
+
assert wav.shape[0] == 1, "Only support batch size 1 for now."
|
205 |
+
wav = wav[:, :wav_len]
|
206 |
+
|
207 |
+
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
|
208 |
+
if ref_rms < self.target_rms:
|
209 |
+
wav = wav * self.target_rms / ref_rms
|
210 |
+
if self.reference_sample_rate != self.target_audio_sample_rate:
|
211 |
+
wav = self.resampler(wav)
|
212 |
+
wav = wav.to(self.device)
|
213 |
+
if self.use_perf:
|
214 |
+
torch.cuda.nvtx.range_push("compute_mel")
|
215 |
+
mel_features = self.compute_mel_fn(wav)
|
216 |
+
if self.use_perf:
|
217 |
+
torch.cuda.nvtx.range_pop()
|
218 |
+
mel_features_list.append(mel_features)
|
219 |
+
|
220 |
+
reference_mel_len.append(mel_features.shape[1])
|
221 |
+
estimated_reference_target_mel_len.append(
|
222 |
+
int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text)))
|
223 |
+
)
|
224 |
+
|
225 |
+
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
226 |
+
|
227 |
+
batch = len(requests)
|
228 |
+
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
|
229 |
+
for i, mel in enumerate(mel_features_list):
|
230 |
+
mel_features[i, : mel.shape[1], :] = mel
|
231 |
+
|
232 |
+
reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
|
233 |
+
|
234 |
+
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
235 |
+
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
|
236 |
+
|
237 |
+
for i, item in enumerate(text_pad_sequence):
|
238 |
+
text_pad_sequence[i] = F.pad(
|
239 |
+
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
240 |
+
)
|
241 |
+
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
242 |
+
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
|
243 |
+
text_pad_sequence = F.pad(
|
244 |
+
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
245 |
+
)
|
246 |
+
if self.use_perf:
|
247 |
+
torch.cuda.nvtx.range_pop()
|
248 |
+
|
249 |
+
denoised, cost_time = self.model.sample(
|
250 |
+
text_pad_sequence,
|
251 |
+
mel_features,
|
252 |
+
reference_mel_len_tensor,
|
253 |
+
estimated_reference_target_mel_len,
|
254 |
+
remove_input_padding=False,
|
255 |
+
use_perf=self.use_perf,
|
256 |
+
)
|
257 |
+
if self.use_perf:
|
258 |
+
torch.cuda.nvtx.range_push("vocoder")
|
259 |
+
|
260 |
+
responses = []
|
261 |
+
for i in range(batch):
|
262 |
+
ref_me_len = reference_mel_len[i]
|
263 |
+
estimated_mel_len = estimated_reference_target_mel_len[i]
|
264 |
+
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
265 |
+
audio = self.forward_vocoder(denoised_one_item)
|
266 |
+
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
267 |
+
if rms < self.target_rms:
|
268 |
+
audio = audio * self.target_rms / rms
|
269 |
+
|
270 |
+
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
271 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
|
272 |
+
responses.append(inference_response)
|
273 |
+
if self.use_perf:
|
274 |
+
torch.cuda.nvtx.range_pop()
|
275 |
+
return responses
|
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "f5_tts"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: 4
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: 1000
|
20 |
+
}
|
21 |
+
parameters [
|
22 |
+
{
|
23 |
+
key: "vocab_file"
|
24 |
+
value: { string_value: "${vocab}"}
|
25 |
+
},
|
26 |
+
{
|
27 |
+
key: "model_path",
|
28 |
+
value: {string_value:"${model}"}
|
29 |
+
},
|
30 |
+
{
|
31 |
+
key: "tllm_model_dir",
|
32 |
+
value: {string_value:"${trtllm}"}
|
33 |
+
},
|
34 |
+
{
|
35 |
+
key: "reference_audio_sample_rate",
|
36 |
+
value: {string_value:"16000"}
|
37 |
+
},
|
38 |
+
{
|
39 |
+
key: "vocoder",
|
40 |
+
value: {string_value:"${vocoder}"}
|
41 |
+
}
|
42 |
+
]
|
43 |
+
|
44 |
+
input [
|
45 |
+
{
|
46 |
+
name: "reference_wav"
|
47 |
+
data_type: TYPE_FP32
|
48 |
+
dims: [-1]
|
49 |
+
optional: True
|
50 |
+
},
|
51 |
+
{
|
52 |
+
name: "reference_wav_len"
|
53 |
+
data_type: TYPE_INT32
|
54 |
+
dims: [1]
|
55 |
+
optional: True
|
56 |
+
},
|
57 |
+
{
|
58 |
+
name: "reference_text"
|
59 |
+
data_type: TYPE_STRING
|
60 |
+
dims: [1]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
name: "target_text"
|
64 |
+
data_type: TYPE_STRING
|
65 |
+
dims: [1]
|
66 |
+
}
|
67 |
+
]
|
68 |
+
output [
|
69 |
+
{
|
70 |
+
name: "waveform"
|
71 |
+
data_type: TYPE_FP32
|
72 |
+
dims: [ -1 ]
|
73 |
+
}
|
74 |
+
]
|
75 |
+
|
76 |
+
instance_group [
|
77 |
+
{
|
78 |
+
count: 1
|
79 |
+
kind: KIND_GPU
|
80 |
+
}
|
81 |
+
]
|
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep
ADDED
File without changes
|
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "vocoder"
|
2 |
+
backend: "tensorrt"
|
3 |
+
default_model_filename: "vocoder.plan"
|
4 |
+
max_batch_size: 4
|
5 |
+
|
6 |
+
input [
|
7 |
+
{
|
8 |
+
name: "mel"
|
9 |
+
data_type: TYPE_FP32
|
10 |
+
dims: [ 100, -1 ]
|
11 |
+
}
|
12 |
+
]
|
13 |
+
|
14 |
+
output [
|
15 |
+
{
|
16 |
+
name: "waveform"
|
17 |
+
data_type: TYPE_FP32
|
18 |
+
dims: [ -1 ]
|
19 |
+
}
|
20 |
+
]
|
21 |
+
|
22 |
+
dynamic_batching {
|
23 |
+
preferred_batch_size: [1, 2, 4]
|
24 |
+
max_queue_delay_microseconds: 1
|
25 |
+
}
|
26 |
+
|
27 |
+
instance_group [
|
28 |
+
{
|
29 |
+
count: 1
|
30 |
+
kind: KIND_GPU
|
31 |
+
}
|
32 |
+
]
|
src/f5_tts/runtime/triton_trtllm/patch/__init__.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from .baichuan.model import BaichuanForCausalLM
|
16 |
+
from .bert.model import (BertForQuestionAnswering,
|
17 |
+
BertForSequenceClassification, BertModel,
|
18 |
+
RobertaForQuestionAnswering,
|
19 |
+
RobertaForSequenceClassification, RobertaModel)
|
20 |
+
from .bloom.model import BloomForCausalLM, BloomModel
|
21 |
+
from .chatglm.config import ChatGLMConfig
|
22 |
+
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
|
23 |
+
from .cogvlm.config import CogVLMConfig
|
24 |
+
from .cogvlm.model import CogVLMForCausalLM
|
25 |
+
from .commandr.model import CohereForCausalLM
|
26 |
+
from .dbrx.config import DbrxConfig
|
27 |
+
from .dbrx.model import DbrxForCausalLM
|
28 |
+
from .deepseek_v1.model import DeepseekForCausalLM
|
29 |
+
from .deepseek_v2.model import DeepseekV2ForCausalLM
|
30 |
+
from .dit.model import DiT
|
31 |
+
from .eagle.model import EagleForCausalLM
|
32 |
+
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
|
33 |
+
from .falcon.config import FalconConfig
|
34 |
+
from .falcon.model import FalconForCausalLM, FalconModel
|
35 |
+
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
|
36 |
+
from .gemma.model import GemmaForCausalLM
|
37 |
+
from .gpt.config import GPTConfig
|
38 |
+
from .gpt.model import GPTForCausalLM, GPTModel
|
39 |
+
from .gptj.config import GPTJConfig
|
40 |
+
from .gptj.model import GPTJForCausalLM, GPTJModel
|
41 |
+
from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel
|
42 |
+
from .grok.model import GrokForCausalLM
|
43 |
+
from .llama.config import LLaMAConfig
|
44 |
+
from .llama.model import LLaMAForCausalLM, LLaMAModel
|
45 |
+
from .mamba.model import MambaForCausalLM
|
46 |
+
from .medusa.config import MedusaConfig
|
47 |
+
from .medusa.model import MedusaForCausalLm
|
48 |
+
from .mllama.model import MLLaMAModel
|
49 |
+
from .modeling_utils import (PretrainedConfig, PretrainedModel,
|
50 |
+
SpeculativeDecodingMode)
|
51 |
+
from .mpt.model import MPTForCausalLM, MPTModel
|
52 |
+
from .nemotron_nas.model import DeciLMForCausalLM
|
53 |
+
from .opt.model import OPTForCausalLM, OPTModel
|
54 |
+
from .phi3.model import Phi3ForCausalLM, Phi3Model
|
55 |
+
from .phi.model import PhiForCausalLM, PhiModel
|
56 |
+
from .qwen.model import QWenForCausalLM
|
57 |
+
from .recurrentgemma.model import RecurrentGemmaForCausalLM
|
58 |
+
from .redrafter.model import ReDrafterForCausalLM
|
59 |
+
from .f5tts.model import F5TTS
|
60 |
+
|
61 |
+
__all__ = [
|
62 |
+
'BertModel',
|
63 |
+
'BertForQuestionAnswering',
|
64 |
+
'BertForSequenceClassification',
|
65 |
+
'RobertaModel',
|
66 |
+
'RobertaForQuestionAnswering',
|
67 |
+
'RobertaForSequenceClassification',
|
68 |
+
'BloomModel',
|
69 |
+
'BloomForCausalLM',
|
70 |
+
'DiT',
|
71 |
+
'DeepseekForCausalLM',
|
72 |
+
'FalconConfig',
|
73 |
+
'DeepseekV2ForCausalLM',
|
74 |
+
'FalconForCausalLM',
|
75 |
+
'FalconModel',
|
76 |
+
'GPTConfig',
|
77 |
+
'GPTModel',
|
78 |
+
'GPTForCausalLM',
|
79 |
+
'OPTForCausalLM',
|
80 |
+
'OPTModel',
|
81 |
+
'LLaMAConfig',
|
82 |
+
'LLaMAForCausalLM',
|
83 |
+
'LLaMAModel',
|
84 |
+
'MedusaConfig',
|
85 |
+
'MedusaForCausalLm',
|
86 |
+
'ReDrafterForCausalLM',
|
87 |
+
'GPTJConfig',
|
88 |
+
'GPTJModel',
|
89 |
+
'GPTJForCausalLM',
|
90 |
+
'GPTNeoXModel',
|
91 |
+
'GPTNeoXForCausalLM',
|
92 |
+
'PhiModel',
|
93 |
+
'PhiConfig',
|
94 |
+
'Phi3Model',
|
95 |
+
'Phi3Config',
|
96 |
+
'PhiForCausalLM',
|
97 |
+
'Phi3ForCausalLM',
|
98 |
+
'ChatGLMConfig',
|
99 |
+
'ChatGLMForCausalLM',
|
100 |
+
'ChatGLMModel',
|
101 |
+
'BaichuanForCausalLM',
|
102 |
+
'QWenConfig'
|
103 |
+
'QWenForCausalLM',
|
104 |
+
'QWenModel',
|
105 |
+
'EncoderModel',
|
106 |
+
'DecoderModel',
|
107 |
+
'PretrainedConfig',
|
108 |
+
'PretrainedModel',
|
109 |
+
'WhisperEncoder',
|
110 |
+
'MambaForCausalLM',
|
111 |
+
'MambaConfig',
|
112 |
+
'MPTForCausalLM',
|
113 |
+
'MPTModel',
|
114 |
+
'SkyworkForCausalLM',
|
115 |
+
'GemmaConfig',
|
116 |
+
'GemmaForCausalLM',
|
117 |
+
'DbrxConfig',
|
118 |
+
'DbrxForCausalLM',
|
119 |
+
'RecurrentGemmaForCausalLM',
|
120 |
+
'CogVLMConfig',
|
121 |
+
'CogVLMForCausalLM',
|
122 |
+
'EagleForCausalLM',
|
123 |
+
'SpeculativeDecodingMode',
|
124 |
+
'CohereForCausalLM',
|
125 |
+
'MLLaMAModel',
|
126 |
+
'F5TTS',
|
127 |
+
]
|
128 |
+
|
129 |
+
MODEL_MAP = {
|
130 |
+
'GPT2LMHeadModel': GPTForCausalLM,
|
131 |
+
'GPT2LMHeadCustomModel': GPTForCausalLM,
|
132 |
+
'GPTBigCodeForCausalLM': GPTForCausalLM,
|
133 |
+
'Starcoder2ForCausalLM': GPTForCausalLM,
|
134 |
+
'FuyuForCausalLM': GPTForCausalLM,
|
135 |
+
'Kosmos2ForConditionalGeneration': GPTForCausalLM,
|
136 |
+
'JAISLMHeadModel': GPTForCausalLM,
|
137 |
+
'GPTForCausalLM': GPTForCausalLM,
|
138 |
+
'NemotronForCausalLM': GPTForCausalLM,
|
139 |
+
'OPTForCausalLM': OPTForCausalLM,
|
140 |
+
'BloomForCausalLM': BloomForCausalLM,
|
141 |
+
'RWForCausalLM': FalconForCausalLM,
|
142 |
+
'FalconForCausalLM': FalconForCausalLM,
|
143 |
+
'PhiForCausalLM': PhiForCausalLM,
|
144 |
+
'Phi3ForCausalLM': Phi3ForCausalLM,
|
145 |
+
'Phi3VForCausalLM': Phi3ForCausalLM,
|
146 |
+
'Phi3SmallForCausalLM': Phi3ForCausalLM,
|
147 |
+
'PhiMoEForCausalLM': Phi3ForCausalLM,
|
148 |
+
'MambaForCausalLM': MambaForCausalLM,
|
149 |
+
'GPTNeoXForCausalLM': GPTNeoXForCausalLM,
|
150 |
+
'GPTJForCausalLM': GPTJForCausalLM,
|
151 |
+
'MPTForCausalLM': MPTForCausalLM,
|
152 |
+
'GLMModel': ChatGLMForCausalLM,
|
153 |
+
'ChatGLMModel': ChatGLMForCausalLM,
|
154 |
+
'ChatGLMForCausalLM': ChatGLMForCausalLM,
|
155 |
+
'LlamaForCausalLM': LLaMAForCausalLM,
|
156 |
+
'ExaoneForCausalLM': LLaMAForCausalLM,
|
157 |
+
'MistralForCausalLM': LLaMAForCausalLM,
|
158 |
+
'MixtralForCausalLM': LLaMAForCausalLM,
|
159 |
+
'ArcticForCausalLM': LLaMAForCausalLM,
|
160 |
+
'Grok1ModelForCausalLM': GrokForCausalLM,
|
161 |
+
'InternLMForCausalLM': LLaMAForCausalLM,
|
162 |
+
'InternLM2ForCausalLM': LLaMAForCausalLM,
|
163 |
+
'MedusaForCausalLM': MedusaForCausalLm,
|
164 |
+
'ReDrafterForCausalLM': ReDrafterForCausalLM,
|
165 |
+
'BaichuanForCausalLM': BaichuanForCausalLM,
|
166 |
+
'BaiChuanForCausalLM': BaichuanForCausalLM,
|
167 |
+
'SkyworkForCausalLM': LLaMAForCausalLM,
|
168 |
+
GEMMA_ARCHITECTURE: GemmaForCausalLM,
|
169 |
+
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
|
170 |
+
'QWenLMHeadModel': QWenForCausalLM,
|
171 |
+
'QWenForCausalLM': QWenForCausalLM,
|
172 |
+
'Qwen2ForCausalLM': QWenForCausalLM,
|
173 |
+
'Qwen2MoeForCausalLM': QWenForCausalLM,
|
174 |
+
'Qwen2ForSequenceClassification': QWenForCausalLM,
|
175 |
+
'Qwen2VLForConditionalGeneration': QWenForCausalLM,
|
176 |
+
'WhisperEncoder': WhisperEncoder,
|
177 |
+
'EncoderModel': EncoderModel,
|
178 |
+
'DecoderModel': DecoderModel,
|
179 |
+
'DbrxForCausalLM': DbrxForCausalLM,
|
180 |
+
'RecurrentGemmaForCausalLM': RecurrentGemmaForCausalLM,
|
181 |
+
'CogVLMForCausalLM': CogVLMForCausalLM,
|
182 |
+
'DiT': DiT,
|
183 |
+
'DeepseekForCausalLM': DeepseekForCausalLM,
|
184 |
+
'DeciLMForCausalLM': DeciLMForCausalLM,
|
185 |
+
'DeepseekV2ForCausalLM': DeepseekV2ForCausalLM,
|
186 |
+
'EagleForCausalLM': EagleForCausalLM,
|
187 |
+
'CohereForCausalLM': CohereForCausalLM,
|
188 |
+
'MllamaForConditionalGeneration': MLLaMAModel,
|
189 |
+
'BertForQuestionAnswering': BertForQuestionAnswering,
|
190 |
+
'BertForSequenceClassification': BertForSequenceClassification,
|
191 |
+
'BertModel': BertModel,
|
192 |
+
'RobertaModel': RobertaModel,
|
193 |
+
'RobertaForQuestionAnswering': RobertaForQuestionAnswering,
|
194 |
+
'RobertaForSequenceClassification': RobertaForSequenceClassification,
|
195 |
+
'F5TTS': F5TTS
|
196 |
+
}
|
src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
import tensorrt as trt
|
6 |
+
from collections import OrderedDict
|
7 |
+
from ..._utils import str_dtype_to_trt
|
8 |
+
from ...plugin import current_all_reduce_helper
|
9 |
+
from ..modeling_utils import PretrainedConfig, PretrainedModel
|
10 |
+
from ...functional import Tensor, concat
|
11 |
+
from ...module import Module, ModuleList
|
12 |
+
from tensorrt_llm._common import default_net
|
13 |
+
from ...layers import Linear
|
14 |
+
|
15 |
+
from .modules import (
|
16 |
+
TimestepEmbedding,
|
17 |
+
ConvPositionEmbedding,
|
18 |
+
DiTBlock,
|
19 |
+
AdaLayerNormZero_Final,
|
20 |
+
)
|
21 |
+
|
22 |
+
current_file_path = os.path.abspath(__file__)
|
23 |
+
parent_dir = os.path.dirname(current_file_path)
|
24 |
+
sys.path.append(parent_dir)
|
25 |
+
|
26 |
+
|
27 |
+
class InputEmbedding(Module):
|
28 |
+
def __init__(self, mel_dim, text_dim, out_dim):
|
29 |
+
super().__init__()
|
30 |
+
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
|
31 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
32 |
+
|
33 |
+
def forward(self, x, cond):
|
34 |
+
x = self.proj(concat([x, cond], dim=-1))
|
35 |
+
return self.conv_pos_embed(x) + x
|
36 |
+
|
37 |
+
|
38 |
+
class F5TTS(PretrainedModel):
|
39 |
+
def __init__(self, config: PretrainedConfig):
|
40 |
+
super().__init__(config)
|
41 |
+
self.dtype = str_dtype_to_trt(config.dtype)
|
42 |
+
|
43 |
+
self.time_embed = TimestepEmbedding(config.hidden_size)
|
44 |
+
self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
|
45 |
+
|
46 |
+
self.dim = config.hidden_size
|
47 |
+
self.depth = config.num_hidden_layers
|
48 |
+
self.transformer_blocks = ModuleList(
|
49 |
+
[
|
50 |
+
DiTBlock(
|
51 |
+
dim=self.dim,
|
52 |
+
heads=config.num_attention_heads,
|
53 |
+
dim_head=config.dim_head,
|
54 |
+
ff_mult=config.ff_mult,
|
55 |
+
dropout=config.dropout,
|
56 |
+
)
|
57 |
+
for _ in range(self.depth)
|
58 |
+
]
|
59 |
+
)
|
60 |
+
|
61 |
+
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
|
62 |
+
self.proj_out = Linear(config.hidden_size, config.mel_dim)
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
noise, # nosied input audio
|
67 |
+
cond, # masked cond audio
|
68 |
+
time, # time step
|
69 |
+
rope_cos,
|
70 |
+
rope_sin,
|
71 |
+
input_lengths,
|
72 |
+
scale=1.0,
|
73 |
+
):
|
74 |
+
t = self.time_embed(time)
|
75 |
+
x = self.input_embed(noise, cond)
|
76 |
+
for block in self.transformer_blocks:
|
77 |
+
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
|
78 |
+
denoise = self.proj_out(self.norm_out(x, t))
|
79 |
+
denoise.mark_output("denoised", self.dtype)
|
80 |
+
return denoise
|
81 |
+
|
82 |
+
def prepare_inputs(self, **kwargs):
|
83 |
+
max_batch_size = kwargs["max_batch_size"]
|
84 |
+
batch_size_range = [2, 2, max_batch_size]
|
85 |
+
mel_size = 100
|
86 |
+
max_seq_len = 3000
|
87 |
+
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
|
88 |
+
hidden_size = 512
|
89 |
+
concat_feature_dim = mel_size + hidden_size
|
90 |
+
freq_embed_dim = 256
|
91 |
+
head_dim = 64
|
92 |
+
mapping = self.config.mapping
|
93 |
+
if mapping.tp_size > 1:
|
94 |
+
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
|
95 |
+
if default_net().plugin_config.remove_input_padding:
|
96 |
+
noise = Tensor(
|
97 |
+
name="noise",
|
98 |
+
dtype=self.dtype,
|
99 |
+
shape=[-1, mel_size],
|
100 |
+
dim_range=OrderedDict(
|
101 |
+
[
|
102 |
+
("num_frames", [num_frames_range]),
|
103 |
+
("n_mels", [mel_size]),
|
104 |
+
]
|
105 |
+
),
|
106 |
+
)
|
107 |
+
cond = Tensor(
|
108 |
+
name="cond",
|
109 |
+
dtype=self.dtype,
|
110 |
+
shape=[-1, concat_feature_dim],
|
111 |
+
dim_range=OrderedDict(
|
112 |
+
[
|
113 |
+
("num_frames", [num_frames_range]),
|
114 |
+
("embeded_length", [concat_feature_dim]),
|
115 |
+
]
|
116 |
+
),
|
117 |
+
)
|
118 |
+
time = Tensor(
|
119 |
+
name="time",
|
120 |
+
dtype=self.dtype,
|
121 |
+
shape=[-1, freq_embed_dim],
|
122 |
+
dim_range=OrderedDict(
|
123 |
+
[
|
124 |
+
("num_frames", [num_frames_range]),
|
125 |
+
("freq_dim", [freq_embed_dim]),
|
126 |
+
]
|
127 |
+
),
|
128 |
+
)
|
129 |
+
rope_cos = Tensor(
|
130 |
+
name="rope_cos",
|
131 |
+
dtype=self.dtype,
|
132 |
+
shape=[-1, head_dim],
|
133 |
+
dim_range=OrderedDict(
|
134 |
+
[
|
135 |
+
("num_frames", [num_frames_range]),
|
136 |
+
("head_dim", [head_dim]),
|
137 |
+
]
|
138 |
+
),
|
139 |
+
)
|
140 |
+
rope_sin = Tensor(
|
141 |
+
name="rope_sin",
|
142 |
+
dtype=self.dtype,
|
143 |
+
shape=[-1, head_dim],
|
144 |
+
dim_range=OrderedDict(
|
145 |
+
[
|
146 |
+
("num_frames", [num_frames_range]),
|
147 |
+
("head_dim", [head_dim]),
|
148 |
+
]
|
149 |
+
),
|
150 |
+
)
|
151 |
+
|
152 |
+
else:
|
153 |
+
noise = Tensor(
|
154 |
+
name="noise",
|
155 |
+
dtype=self.dtype,
|
156 |
+
shape=[-1, -1, mel_size],
|
157 |
+
dim_range=OrderedDict(
|
158 |
+
[
|
159 |
+
("batch_size", [batch_size_range]),
|
160 |
+
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
|
161 |
+
("n_mels", [mel_size]),
|
162 |
+
]
|
163 |
+
),
|
164 |
+
)
|
165 |
+
cond = Tensor(
|
166 |
+
name="cond",
|
167 |
+
dtype=self.dtype,
|
168 |
+
shape=[-1, -1, concat_feature_dim],
|
169 |
+
dim_range=OrderedDict(
|
170 |
+
[
|
171 |
+
("batch_size", [batch_size_range]),
|
172 |
+
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
|
173 |
+
("embeded_length", [concat_feature_dim]),
|
174 |
+
]
|
175 |
+
),
|
176 |
+
)
|
177 |
+
time = Tensor(
|
178 |
+
name="time",
|
179 |
+
dtype=self.dtype,
|
180 |
+
shape=[-1, freq_embed_dim],
|
181 |
+
dim_range=OrderedDict(
|
182 |
+
[
|
183 |
+
("batch_size", [batch_size_range]),
|
184 |
+
("freq_dim", [freq_embed_dim]),
|
185 |
+
]
|
186 |
+
),
|
187 |
+
)
|
188 |
+
rope_cos = Tensor(
|
189 |
+
name="rope_cos",
|
190 |
+
dtype=self.dtype,
|
191 |
+
shape=[-1, -1, head_dim],
|
192 |
+
dim_range=OrderedDict(
|
193 |
+
[
|
194 |
+
("batch_size", [batch_size_range]),
|
195 |
+
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
|
196 |
+
("head_dim", [head_dim]),
|
197 |
+
]
|
198 |
+
),
|
199 |
+
)
|
200 |
+
rope_sin = Tensor(
|
201 |
+
name="rope_sin",
|
202 |
+
dtype=self.dtype,
|
203 |
+
shape=[-1, -1, head_dim],
|
204 |
+
dim_range=OrderedDict(
|
205 |
+
[
|
206 |
+
("batch_size", [batch_size_range]),
|
207 |
+
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
|
208 |
+
("head_dim", [head_dim]),
|
209 |
+
]
|
210 |
+
),
|
211 |
+
)
|
212 |
+
input_lengths = Tensor(
|
213 |
+
name="input_lengths",
|
214 |
+
dtype=trt.int32,
|
215 |
+
shape=[-1],
|
216 |
+
dim_range=OrderedDict([("batch_size", [batch_size_range])]),
|
217 |
+
)
|
218 |
+
return {
|
219 |
+
"noise": noise,
|
220 |
+
"cond": cond,
|
221 |
+
"time": time,
|
222 |
+
"rope_cos": rope_cos,
|
223 |
+
"rope_sin": rope_sin,
|
224 |
+
"input_lengths": input_lengths,
|
225 |
+
}
|
src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from tensorrt_llm._common import default_net
|
11 |
+
from ..._utils import trt_dtype_to_np, str_dtype_to_trt
|
12 |
+
from ...functional import (
|
13 |
+
Tensor,
|
14 |
+
chunk,
|
15 |
+
concat,
|
16 |
+
constant,
|
17 |
+
expand,
|
18 |
+
shape,
|
19 |
+
silu,
|
20 |
+
slice,
|
21 |
+
permute,
|
22 |
+
expand_mask,
|
23 |
+
expand_dims_like,
|
24 |
+
unsqueeze,
|
25 |
+
matmul,
|
26 |
+
softmax,
|
27 |
+
squeeze,
|
28 |
+
cast,
|
29 |
+
gelu,
|
30 |
+
)
|
31 |
+
from ...functional import expand_dims, view, bert_attention
|
32 |
+
from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
|
33 |
+
from ...module import Module
|
34 |
+
|
35 |
+
|
36 |
+
class FeedForward(Module):
|
37 |
+
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
|
38 |
+
super().__init__()
|
39 |
+
inner_dim = int(dim * mult)
|
40 |
+
dim_out = dim_out if dim_out is not None else dim
|
41 |
+
|
42 |
+
self.project_in = Linear(dim, inner_dim)
|
43 |
+
self.ff = Linear(inner_dim, dim_out)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
return self.ff(gelu(self.project_in(x)))
|
47 |
+
|
48 |
+
|
49 |
+
class AdaLayerNormZero(Module):
|
50 |
+
def __init__(self, dim):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self.linear = Linear(dim, dim * 6)
|
54 |
+
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
55 |
+
|
56 |
+
def forward(self, x, emb=None):
|
57 |
+
emb = self.linear(silu(emb))
|
58 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
|
59 |
+
x = self.norm(x)
|
60 |
+
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
|
61 |
+
if default_net().plugin_config.remove_input_padding:
|
62 |
+
x = x * (ones + scale_msa) + shift_msa
|
63 |
+
else:
|
64 |
+
x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
|
65 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
66 |
+
|
67 |
+
|
68 |
+
class AdaLayerNormZero_Final(Module):
|
69 |
+
def __init__(self, dim):
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
self.linear = Linear(dim, dim * 2)
|
73 |
+
|
74 |
+
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
75 |
+
|
76 |
+
def forward(self, x, emb):
|
77 |
+
emb = self.linear(silu(emb))
|
78 |
+
scale, shift = chunk(emb, 2, dim=1)
|
79 |
+
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
|
80 |
+
if default_net().plugin_config.remove_input_padding:
|
81 |
+
x = self.norm(x) * (ones + scale) + shift
|
82 |
+
else:
|
83 |
+
x = self.norm(x) * unsqueeze((ones + scale), 1)
|
84 |
+
x = x + unsqueeze(shift, 1)
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
class ConvPositionEmbedding(Module):
|
89 |
+
def __init__(self, dim, kernel_size=31, groups=16):
|
90 |
+
super().__init__()
|
91 |
+
assert kernel_size % 2 != 0
|
92 |
+
self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
|
93 |
+
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
|
94 |
+
self.mish = Mish()
|
95 |
+
|
96 |
+
def forward(self, x, mask=None): # noqa: F722
|
97 |
+
if default_net().plugin_config.remove_input_padding:
|
98 |
+
x = unsqueeze(x, 0)
|
99 |
+
x = permute(x, [0, 2, 1])
|
100 |
+
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
|
101 |
+
out = permute(x, [0, 2, 1])
|
102 |
+
if default_net().plugin_config.remove_input_padding:
|
103 |
+
out = squeeze(out, 0)
|
104 |
+
return out
|
105 |
+
|
106 |
+
|
107 |
+
class Attention(Module):
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
processor: AttnProcessor,
|
111 |
+
dim: int,
|
112 |
+
heads: int = 16,
|
113 |
+
dim_head: int = 64,
|
114 |
+
dropout: float = 0.0,
|
115 |
+
context_dim: Optional[int] = None, # if not None -> joint attention
|
116 |
+
context_pre_only=None,
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
|
120 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
121 |
+
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
122 |
+
|
123 |
+
self.processor = processor
|
124 |
+
|
125 |
+
self.dim = dim # hidden_size
|
126 |
+
self.heads = heads
|
127 |
+
self.inner_dim = dim_head * heads
|
128 |
+
self.dropout = dropout
|
129 |
+
self.attention_head_size = dim_head
|
130 |
+
self.context_dim = context_dim
|
131 |
+
self.context_pre_only = context_pre_only
|
132 |
+
self.tp_size = 1
|
133 |
+
self.num_attention_heads = heads // self.tp_size
|
134 |
+
self.num_attention_kv_heads = heads // self.tp_size # 8
|
135 |
+
self.dtype = str_dtype_to_trt("float32")
|
136 |
+
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
|
137 |
+
self.to_q = ColumnLinear(
|
138 |
+
dim,
|
139 |
+
self.tp_size * self.num_attention_heads * self.attention_head_size,
|
140 |
+
bias=True,
|
141 |
+
dtype=self.dtype,
|
142 |
+
tp_group=None,
|
143 |
+
tp_size=self.tp_size,
|
144 |
+
)
|
145 |
+
self.to_k = ColumnLinear(
|
146 |
+
dim,
|
147 |
+
self.tp_size * self.num_attention_heads * self.attention_head_size,
|
148 |
+
bias=True,
|
149 |
+
dtype=self.dtype,
|
150 |
+
tp_group=None,
|
151 |
+
tp_size=self.tp_size,
|
152 |
+
)
|
153 |
+
self.to_v = ColumnLinear(
|
154 |
+
dim,
|
155 |
+
self.tp_size * self.num_attention_heads * self.attention_head_size,
|
156 |
+
bias=True,
|
157 |
+
dtype=self.dtype,
|
158 |
+
tp_group=None,
|
159 |
+
tp_size=self.tp_size,
|
160 |
+
)
|
161 |
+
|
162 |
+
if self.context_dim is not None:
|
163 |
+
self.to_k_c = Linear(context_dim, self.inner_dim)
|
164 |
+
self.to_v_c = Linear(context_dim, self.inner_dim)
|
165 |
+
if self.context_pre_only is not None:
|
166 |
+
self.to_q_c = Linear(context_dim, self.inner_dim)
|
167 |
+
|
168 |
+
self.to_out = RowLinear(
|
169 |
+
self.tp_size * self.num_attention_heads * self.attention_head_size,
|
170 |
+
dim,
|
171 |
+
bias=True,
|
172 |
+
dtype=self.dtype,
|
173 |
+
tp_group=None,
|
174 |
+
tp_size=self.tp_size,
|
175 |
+
)
|
176 |
+
|
177 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
178 |
+
self.to_out_c = Linear(self.inner_dim, dim)
|
179 |
+
|
180 |
+
def forward(
|
181 |
+
self,
|
182 |
+
x, # noised input x
|
183 |
+
rope_cos,
|
184 |
+
rope_sin,
|
185 |
+
input_lengths,
|
186 |
+
c=None, # context c
|
187 |
+
scale=1.0,
|
188 |
+
rope=None,
|
189 |
+
c_rope=None, # rotary position embedding for c
|
190 |
+
) -> torch.Tensor:
|
191 |
+
if c is not None:
|
192 |
+
return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
|
193 |
+
else:
|
194 |
+
return self.processor(
|
195 |
+
self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
|
200 |
+
shape_tensor = concat(
|
201 |
+
[shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
|
202 |
+
)
|
203 |
+
if default_net().plugin_config.remove_input_padding:
|
204 |
+
assert tensor.ndim() == 2
|
205 |
+
x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
|
206 |
+
x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
|
207 |
+
x1 = expand_dims(x1, 2)
|
208 |
+
x2 = expand_dims(x2, 2)
|
209 |
+
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
|
210 |
+
x2 = zero - x2
|
211 |
+
x = concat([x2, x1], 2)
|
212 |
+
out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
|
213 |
+
else:
|
214 |
+
assert tensor.ndim() == 3
|
215 |
+
|
216 |
+
x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
|
217 |
+
x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
|
218 |
+
x1 = expand_dims(x1, 3)
|
219 |
+
x2 = expand_dims(x2, 3)
|
220 |
+
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
|
221 |
+
x2 = zero - x2
|
222 |
+
x = concat([x2, x1], 3)
|
223 |
+
out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
|
224 |
+
|
225 |
+
return out
|
226 |
+
|
227 |
+
|
228 |
+
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
|
229 |
+
if default_net().plugin_config.remove_input_padding:
|
230 |
+
rot_dim = shape(rope_cos, -1) # 64
|
231 |
+
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
|
232 |
+
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
|
233 |
+
end_dim = shape(x, -1) - shape(rope_cos, -1)
|
234 |
+
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
|
235 |
+
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
|
236 |
+
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
237 |
+
else:
|
238 |
+
rot_dim = shape(rope_cos, 2) # 64
|
239 |
+
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
|
240 |
+
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
|
241 |
+
end_dim = shape(x, 2) - shape(rope_cos, 2)
|
242 |
+
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
|
243 |
+
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
|
244 |
+
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
245 |
+
return out
|
246 |
+
|
247 |
+
|
248 |
+
class AttnProcessor:
|
249 |
+
def __init__(self):
|
250 |
+
pass
|
251 |
+
|
252 |
+
def __call__(
|
253 |
+
self,
|
254 |
+
attn,
|
255 |
+
x, # noised input x
|
256 |
+
rope_cos,
|
257 |
+
rope_sin,
|
258 |
+
input_lengths,
|
259 |
+
scale=1.0,
|
260 |
+
rope=None,
|
261 |
+
) -> torch.FloatTensor:
|
262 |
+
query = attn.to_q(x)
|
263 |
+
key = attn.to_k(x)
|
264 |
+
value = attn.to_v(x)
|
265 |
+
# k,v,q all (2,1226,1024)
|
266 |
+
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
|
267 |
+
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
|
268 |
+
|
269 |
+
# attention
|
270 |
+
inner_dim = key.shape[-1]
|
271 |
+
norm_factor = math.sqrt(attn.attention_head_size)
|
272 |
+
q_scaling = 1.0 / norm_factor
|
273 |
+
mask = None
|
274 |
+
if not default_net().plugin_config.remove_input_padding:
|
275 |
+
N = shape(x, 1)
|
276 |
+
B = shape(x, 0)
|
277 |
+
seq_len_2d = concat([1, N])
|
278 |
+
max_position_embeddings = 4096
|
279 |
+
# create position ids
|
280 |
+
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
|
281 |
+
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
|
282 |
+
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
|
283 |
+
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
|
284 |
+
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
|
285 |
+
mask = tmp_position_ids < tmp_input_lengths # BxL
|
286 |
+
mask = mask.cast("int32")
|
287 |
+
|
288 |
+
if default_net().plugin_config.bert_attention_plugin:
|
289 |
+
qkv = concat([query, key, value], dim=-1)
|
290 |
+
# TRT plugin mode
|
291 |
+
assert input_lengths is not None
|
292 |
+
if default_net().plugin_config.remove_input_padding:
|
293 |
+
qkv = qkv.view(concat([-1, 3 * inner_dim]))
|
294 |
+
max_input_length = constant(
|
295 |
+
np.zeros(
|
296 |
+
[
|
297 |
+
2048,
|
298 |
+
],
|
299 |
+
dtype=np.int32,
|
300 |
+
)
|
301 |
+
)
|
302 |
+
else:
|
303 |
+
max_input_length = None
|
304 |
+
context = bert_attention(
|
305 |
+
qkv,
|
306 |
+
input_lengths,
|
307 |
+
attn.num_attention_heads,
|
308 |
+
attn.attention_head_size,
|
309 |
+
q_scaling=q_scaling,
|
310 |
+
max_input_length=max_input_length,
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
assert not default_net().plugin_config.remove_input_padding
|
314 |
+
|
315 |
+
def transpose_for_scores(x):
|
316 |
+
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
|
317 |
+
|
318 |
+
y = x.view(new_x_shape)
|
319 |
+
y = y.transpose(1, 2)
|
320 |
+
return y
|
321 |
+
|
322 |
+
def transpose_for_scores_k(x):
|
323 |
+
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
|
324 |
+
|
325 |
+
y = x.view(new_x_shape)
|
326 |
+
y = y.permute([0, 2, 3, 1])
|
327 |
+
return y
|
328 |
+
|
329 |
+
query = transpose_for_scores(query)
|
330 |
+
key = transpose_for_scores_k(key)
|
331 |
+
value = transpose_for_scores(value)
|
332 |
+
|
333 |
+
attention_scores = matmul(query, key, use_fp32_acc=False)
|
334 |
+
|
335 |
+
if mask is not None:
|
336 |
+
attention_mask = expand_mask(mask, shape(query, 2))
|
337 |
+
attention_mask = cast(attention_mask, attention_scores.dtype)
|
338 |
+
attention_scores = attention_scores + attention_mask
|
339 |
+
|
340 |
+
attention_probs = softmax(attention_scores, dim=-1)
|
341 |
+
|
342 |
+
context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
|
343 |
+
context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
|
344 |
+
context = attn.to_out(context)
|
345 |
+
if mask is not None:
|
346 |
+
mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
|
347 |
+
mask = expand_dims_like(mask, context)
|
348 |
+
mask = cast(mask, context.dtype)
|
349 |
+
context = context * mask
|
350 |
+
return context
|
351 |
+
|
352 |
+
|
353 |
+
# DiT Block
|
354 |
+
class DiTBlock(Module):
|
355 |
+
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
|
356 |
+
super().__init__()
|
357 |
+
|
358 |
+
self.attn_norm = AdaLayerNormZero(dim)
|
359 |
+
self.attn = Attention(
|
360 |
+
processor=AttnProcessor(),
|
361 |
+
dim=dim,
|
362 |
+
heads=heads,
|
363 |
+
dim_head=dim_head,
|
364 |
+
dropout=dropout,
|
365 |
+
)
|
366 |
+
|
367 |
+
self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
368 |
+
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
|
369 |
+
|
370 |
+
def forward(
|
371 |
+
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
|
372 |
+
): # x: noised input, t: time embedding
|
373 |
+
# pre-norm & modulation for attention input
|
374 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
375 |
+
# attention
|
376 |
+
# norm ----> (2,1226,1024)
|
377 |
+
attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
|
378 |
+
|
379 |
+
# process attention output for input x
|
380 |
+
if default_net().plugin_config.remove_input_padding:
|
381 |
+
x = x + gate_msa * attn_output
|
382 |
+
else:
|
383 |
+
x = x + unsqueeze(gate_msa, 1) * attn_output
|
384 |
+
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
|
385 |
+
if default_net().plugin_config.remove_input_padding:
|
386 |
+
norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
|
387 |
+
else:
|
388 |
+
norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
|
389 |
+
# norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
|
390 |
+
ff_output = self.ff(norm)
|
391 |
+
if default_net().plugin_config.remove_input_padding:
|
392 |
+
x = x + gate_mlp * ff_output
|
393 |
+
else:
|
394 |
+
x = x + unsqueeze(gate_mlp, 1) * ff_output
|
395 |
+
|
396 |
+
return x
|
397 |
+
|
398 |
+
|
399 |
+
class TimestepEmbedding(Module):
|
400 |
+
def __init__(self, dim, freq_embed_dim=256, dtype=None):
|
401 |
+
super().__init__()
|
402 |
+
# self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
403 |
+
self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
|
404 |
+
self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)
|
405 |
+
|
406 |
+
def forward(self, timestep):
|
407 |
+
t_freq = self.mlp1(timestep)
|
408 |
+
t_freq = silu(t_freq)
|
409 |
+
t_emb = self.mlp2(t_freq)
|
410 |
+
return t_emb
|
src/f5_tts/runtime/triton_trtllm/run.sh
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
stage=$1
|
2 |
+
stop_stage=$2
|
3 |
+
model=$3 # F5TTS_Base
|
4 |
+
if [ -z "$model" ]; then
|
5 |
+
echo "Model is none"
|
6 |
+
exit 1
|
7 |
+
fi
|
8 |
+
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
9 |
+
export CUDA_VISIBLE_DEVICES=0
|
10 |
+
|
11 |
+
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
|
12 |
+
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
|
13 |
+
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
|
14 |
+
|
15 |
+
vocoder_trt_engine_path=vocos_vocoder.plan
|
16 |
+
model_repo=./model_repo
|
17 |
+
|
18 |
+
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
19 |
+
echo "Downloading f5 tts from huggingface"
|
20 |
+
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
|
21 |
+
|
22 |
+
fi
|
23 |
+
|
24 |
+
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
25 |
+
echo "Converting checkpoint"
|
26 |
+
python3 ./scripts/convert_checkpoint.py \
|
27 |
+
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
|
28 |
+
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
|
29 |
+
python_package_path=/usr/local/lib/python3.12/dist-packages
|
30 |
+
cp -r patch/* $python_package_path/tensorrt_llm/models
|
31 |
+
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
|
32 |
+
--max_batch_size 8 \
|
33 |
+
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
|
34 |
+
fi
|
35 |
+
|
36 |
+
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
37 |
+
echo "Exporting vocos vocoder"
|
38 |
+
onnx_vocoder_path=vocos_vocoder.onnx
|
39 |
+
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
|
40 |
+
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
|
41 |
+
fi
|
42 |
+
|
43 |
+
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
44 |
+
echo "Building triton server"
|
45 |
+
rm -r $model_repo
|
46 |
+
cp -r ./model_repo_f5_tts $model_repo
|
47 |
+
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
|
48 |
+
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
|
49 |
+
fi
|
50 |
+
|
51 |
+
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
52 |
+
echo "Starting triton server"
|
53 |
+
tritonserver --model-repository=$model_repo
|
54 |
+
fi
|
55 |
+
|
56 |
+
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
57 |
+
echo "Testing triton server"
|
58 |
+
num_task=1
|
59 |
+
log_dir=./log_concurrent_tasks_${num_task}
|
60 |
+
rm -r $log_dir
|
61 |
+
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
|
62 |
+
fi
|
63 |
+
|
64 |
+
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
65 |
+
echo "Testing http client"
|
66 |
+
audio=../../infer/examples/basic/basic_ref_en.wav
|
67 |
+
reference_text="Some call me nature, others call me mother nature."
|
68 |
+
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
69 |
+
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
|
70 |
+
fi
|
src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py
|
2 |
+
|
3 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
# MIT License
|
18 |
+
|
19 |
+
# Copyright (c) 2020 Shimin Zhang
|
20 |
+
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
|
39 |
+
import torch as th
|
40 |
+
import torch.nn.functional as F
|
41 |
+
from scipy.signal import check_COLA, get_window
|
42 |
+
|
43 |
+
support_clp_op = None
|
44 |
+
if th.__version__ >= "1.7.0":
|
45 |
+
from torch.fft import rfft as fft
|
46 |
+
|
47 |
+
support_clp_op = True
|
48 |
+
else:
|
49 |
+
from torch import rfft as fft
|
50 |
+
|
51 |
+
|
52 |
+
class STFT(th.nn.Module):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
win_len=1024,
|
56 |
+
win_hop=512,
|
57 |
+
fft_len=1024,
|
58 |
+
enframe_mode="continue",
|
59 |
+
win_type="hann",
|
60 |
+
win_sqrt=False,
|
61 |
+
pad_center=True,
|
62 |
+
):
|
63 |
+
"""
|
64 |
+
Implement of STFT using 1D convolution and 1D transpose convolutions.
|
65 |
+
Implement of framing the signal in 2 ways, `break` and `continue`.
|
66 |
+
`break` method is a kaldi-like framing.
|
67 |
+
`continue` method is a librosa-like framing.
|
68 |
+
|
69 |
+
More information about `perfect reconstruction`:
|
70 |
+
1. https://ww2.mathworks.cn/help/signal/ref/stft.html
|
71 |
+
2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
|
72 |
+
|
73 |
+
Args:
|
74 |
+
win_len (int): Number of points in one frame. Defaults to 1024.
|
75 |
+
win_hop (int): Number of framing stride. Defaults to 512.
|
76 |
+
fft_len (int): Number of DFT points. Defaults to 1024.
|
77 |
+
enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
|
78 |
+
win_type (str, optional): The type of window to create. Defaults to 'hann'.
|
79 |
+
win_sqrt (bool, optional): using square root window. Defaults to True.
|
80 |
+
pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
|
81 |
+
"""
|
82 |
+
super(STFT, self).__init__()
|
83 |
+
assert enframe_mode in ["break", "continue"]
|
84 |
+
assert fft_len >= win_len
|
85 |
+
self.win_len = win_len
|
86 |
+
self.win_hop = win_hop
|
87 |
+
self.fft_len = fft_len
|
88 |
+
self.mode = enframe_mode
|
89 |
+
self.win_type = win_type
|
90 |
+
self.win_sqrt = win_sqrt
|
91 |
+
self.pad_center = pad_center
|
92 |
+
self.pad_amount = self.fft_len // 2
|
93 |
+
|
94 |
+
en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
|
95 |
+
self.register_buffer("en_k", en_k)
|
96 |
+
self.register_buffer("fft_k", fft_k)
|
97 |
+
self.register_buffer("ifft_k", ifft_k)
|
98 |
+
self.register_buffer("ola_k", ola_k)
|
99 |
+
|
100 |
+
def __init_kernel__(self):
|
101 |
+
"""
|
102 |
+
Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
|
103 |
+
** enframe_kernel: Using conv1d layer and identity matrix.
|
104 |
+
** fft_kernel: Using linear layer for matrix multiplication. In fact,
|
105 |
+
enframe_kernel and fft_kernel can be combined, But for the sake of
|
106 |
+
readability, I took the two apart.
|
107 |
+
** ifft_kernel, pinv of fft_kernel.
|
108 |
+
** overlap-add kernel, just like enframe_kernel, but transposed.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
tuple: four kernels.
|
112 |
+
"""
|
113 |
+
enframed_kernel = th.eye(self.fft_len)[:, None, :]
|
114 |
+
if support_clp_op:
|
115 |
+
tmp = fft(th.eye(self.fft_len))
|
116 |
+
fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
|
117 |
+
else:
|
118 |
+
fft_kernel = fft(th.eye(self.fft_len), 1)
|
119 |
+
if self.mode == "break":
|
120 |
+
enframed_kernel = th.eye(self.win_len)[:, None, :]
|
121 |
+
fft_kernel = fft_kernel[: self.win_len]
|
122 |
+
fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
|
123 |
+
ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
|
124 |
+
window = get_window(self.win_type, self.win_len)
|
125 |
+
|
126 |
+
self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
|
127 |
+
window = th.FloatTensor(window)
|
128 |
+
if self.mode == "continue":
|
129 |
+
left_pad = (self.fft_len - self.win_len) // 2
|
130 |
+
right_pad = left_pad + (self.fft_len - self.win_len) % 2
|
131 |
+
window = F.pad(window, (left_pad, right_pad))
|
132 |
+
if self.win_sqrt:
|
133 |
+
self.padded_window = window
|
134 |
+
window = th.sqrt(window)
|
135 |
+
else:
|
136 |
+
self.padded_window = window**2
|
137 |
+
|
138 |
+
fft_kernel = fft_kernel.T * window
|
139 |
+
ifft_kernel = ifft_kernel * window
|
140 |
+
ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :]
|
141 |
+
if self.mode == "continue":
|
142 |
+
ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len]
|
143 |
+
return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel
|
144 |
+
|
145 |
+
def is_perfect(self):
|
146 |
+
"""
|
147 |
+
Whether the parameters win_len, win_hop and win_sqrt
|
148 |
+
obey constants overlap-add(COLA)
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
bool: Return true if parameters obey COLA.
|
152 |
+
"""
|
153 |
+
return self.perfect_reconstruct and self.pad_center
|
154 |
+
|
155 |
+
def transform(self, inputs, return_type="complex"):
|
156 |
+
"""Take input data (audio) to STFT domain.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
|
160 |
+
return_type (str, optional): return (mag, phase) when `magphase`,
|
161 |
+
return (real, imag) when `realimag` and complex(real, imag) when `complex`.
|
162 |
+
Defaults to 'complex'.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
tuple: (mag, phase) when `magphase`, return (real, imag) when
|
166 |
+
`realimag`. Defaults to 'complex', each elements with shape
|
167 |
+
[num_batch, num_frequencies, num_frames]
|
168 |
+
"""
|
169 |
+
assert return_type in ["magphase", "realimag", "complex"]
|
170 |
+
if inputs.dim() == 2:
|
171 |
+
inputs = th.unsqueeze(inputs, 1)
|
172 |
+
self.num_samples = inputs.size(-1)
|
173 |
+
if self.pad_center:
|
174 |
+
inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect")
|
175 |
+
enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
|
176 |
+
outputs = th.transpose(enframe_inputs, 1, 2)
|
177 |
+
outputs = F.linear(outputs, self.fft_k)
|
178 |
+
outputs = th.transpose(outputs, 1, 2)
|
179 |
+
dim = self.fft_len // 2 + 1
|
180 |
+
real = outputs[:, :dim, :]
|
181 |
+
imag = outputs[:, dim:, :]
|
182 |
+
if return_type == "realimag":
|
183 |
+
return real, imag
|
184 |
+
elif return_type == "complex":
|
185 |
+
assert support_clp_op
|
186 |
+
return th.complex(real, imag)
|
187 |
+
else:
|
188 |
+
mags = th.sqrt(real**2 + imag**2)
|
189 |
+
phase = th.atan2(imag, real)
|
190 |
+
return mags, phase
|
191 |
+
|
192 |
+
def inverse(self, input1, input2=None, input_type="magphase"):
|
193 |
+
"""Call the inverse STFT (iSTFT), given tensors produced
|
194 |
+
by the `transform` function.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
input1 (tensors): Magnitude/Real-part of STFT with shape
|
198 |
+
[num_batch, num_frequencies, num_frames]
|
199 |
+
input2 (tensors): Phase/Imag-part of STFT with shape
|
200 |
+
[num_batch, num_frequencies, num_frames]
|
201 |
+
input_type (str, optional): Mathematical meaning of input tensor's.
|
202 |
+
Defaults to 'magphase'.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
tensors: Reconstructed audio given magnitude and phase. Of
|
206 |
+
shape [num_batch, num_samples]
|
207 |
+
"""
|
208 |
+
assert input_type in ["magphase", "realimag"]
|
209 |
+
if input_type == "realimag":
|
210 |
+
real, imag = None, None
|
211 |
+
if support_clp_op and th.is_complex(input1):
|
212 |
+
real, imag = input1.real, input1.imag
|
213 |
+
else:
|
214 |
+
real, imag = input1, input2
|
215 |
+
else:
|
216 |
+
real = input1 * th.cos(input2)
|
217 |
+
imag = input1 * th.sin(input2)
|
218 |
+
inputs = th.cat([real, imag], dim=1)
|
219 |
+
outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
|
220 |
+
t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
|
221 |
+
t = t.to(inputs.device)
|
222 |
+
coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
|
223 |
+
|
224 |
+
num_frames = input1.size(-1)
|
225 |
+
num_samples = num_frames * self.win_hop
|
226 |
+
|
227 |
+
rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples
|
228 |
+
|
229 |
+
outputs = outputs[..., rm_start:rm_end]
|
230 |
+
coff = coff[..., rm_start:rm_end]
|
231 |
+
coffidx = th.where(coff > 1e-8)
|
232 |
+
outputs[coffidx] = outputs[coffidx] / (coff[coffidx])
|
233 |
+
return outputs.squeeze(dim=1)
|
234 |
+
|
235 |
+
def forward(self, inputs):
|
236 |
+
"""Take input data (audio) to STFT domain and then back to audio.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
tensor: Reconstructed audio given magnitude and phase.
|
243 |
+
Of shape [num_batch, num_samples]
|
244 |
+
"""
|
245 |
+
mag, phase = self.transform(inputs)
|
246 |
+
rec_wav = self.inverse(mag, phase)
|
247 |
+
return rec_wav
|
src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
import traceback
|
7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8 |
+
|
9 |
+
import safetensors.torch
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from tensorrt_llm import str_dtype_to_torch
|
13 |
+
from tensorrt_llm.mapping import Mapping
|
14 |
+
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
|
15 |
+
|
16 |
+
|
17 |
+
def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
18 |
+
split_v = split(v, tensor_parallel, rank, dim=1)
|
19 |
+
return split_v.contiguous()
|
20 |
+
|
21 |
+
|
22 |
+
def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
23 |
+
split_v = split(v, tensor_parallel, rank, dim=0)
|
24 |
+
return split_v.contiguous()
|
25 |
+
|
26 |
+
|
27 |
+
FACEBOOK_DIT_NAME_MAPPING = {
|
28 |
+
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
|
29 |
+
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
|
30 |
+
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
|
31 |
+
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
|
32 |
+
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
|
33 |
+
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
|
34 |
+
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
|
35 |
+
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
|
36 |
+
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
|
37 |
+
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
|
38 |
+
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
|
39 |
+
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
|
40 |
+
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
|
41 |
+
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
|
42 |
+
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
|
43 |
+
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
|
44 |
+
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
|
45 |
+
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
|
46 |
+
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
|
47 |
+
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
|
48 |
+
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
|
49 |
+
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
|
50 |
+
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
|
51 |
+
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
|
52 |
+
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
|
53 |
+
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
|
54 |
+
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
|
55 |
+
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
|
56 |
+
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
|
57 |
+
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
|
58 |
+
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
|
59 |
+
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
|
60 |
+
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
|
61 |
+
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
|
62 |
+
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
|
63 |
+
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
|
64 |
+
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
|
65 |
+
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
|
66 |
+
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
|
67 |
+
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
|
68 |
+
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
|
69 |
+
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
|
70 |
+
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
|
71 |
+
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
|
72 |
+
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
|
73 |
+
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
|
74 |
+
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
|
75 |
+
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
|
76 |
+
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
|
77 |
+
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
|
78 |
+
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
|
79 |
+
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
|
80 |
+
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
|
81 |
+
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
|
82 |
+
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
|
83 |
+
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
|
84 |
+
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
|
85 |
+
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
|
86 |
+
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
|
87 |
+
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
|
88 |
+
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
|
89 |
+
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
|
90 |
+
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
|
91 |
+
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
|
92 |
+
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
|
93 |
+
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
|
94 |
+
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
|
95 |
+
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
|
96 |
+
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
|
97 |
+
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
|
98 |
+
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
|
99 |
+
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
|
100 |
+
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
|
101 |
+
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
|
102 |
+
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
|
103 |
+
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
|
104 |
+
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
|
105 |
+
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
|
106 |
+
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
|
107 |
+
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
|
108 |
+
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
|
109 |
+
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
|
110 |
+
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
|
111 |
+
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
|
112 |
+
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
|
113 |
+
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
|
114 |
+
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
|
115 |
+
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
|
116 |
+
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
|
117 |
+
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
|
118 |
+
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
|
119 |
+
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
|
120 |
+
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
|
121 |
+
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
|
122 |
+
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
|
123 |
+
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
|
124 |
+
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
|
125 |
+
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
|
126 |
+
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
|
127 |
+
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
|
128 |
+
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
|
129 |
+
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
|
130 |
+
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
|
131 |
+
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
|
132 |
+
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
|
133 |
+
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
|
134 |
+
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
|
135 |
+
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
|
136 |
+
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
|
137 |
+
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
|
138 |
+
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
|
139 |
+
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
|
140 |
+
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
|
141 |
+
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
|
142 |
+
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
|
143 |
+
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
|
144 |
+
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
|
145 |
+
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
|
146 |
+
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
|
147 |
+
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
|
148 |
+
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
|
149 |
+
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
|
150 |
+
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
|
151 |
+
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
|
152 |
+
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
|
153 |
+
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
|
154 |
+
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
|
155 |
+
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
|
156 |
+
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
|
157 |
+
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
|
158 |
+
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
|
159 |
+
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
|
160 |
+
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
|
161 |
+
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
|
162 |
+
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
|
163 |
+
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
|
164 |
+
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
|
165 |
+
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
|
166 |
+
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
|
167 |
+
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
|
168 |
+
}
|
169 |
+
|
170 |
+
|
171 |
+
def parse_arguments():
|
172 |
+
parser = argparse.ArgumentParser()
|
173 |
+
parser.add_argument(
|
174 |
+
"--model_name",
|
175 |
+
type=str,
|
176 |
+
default="F5TTS_Base",
|
177 |
+
choices=[
|
178 |
+
"F5TTS_Base",
|
179 |
+
],
|
180 |
+
) # TODO: support F5TTS_v1_Base
|
181 |
+
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
|
182 |
+
parser.add_argument(
|
183 |
+
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
|
184 |
+
)
|
185 |
+
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
186 |
+
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
187 |
+
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
188 |
+
parser.add_argument("--cfg_scale", type=float, default=4.0)
|
189 |
+
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
|
190 |
+
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
|
191 |
+
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
|
192 |
+
parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
|
193 |
+
parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
|
194 |
+
parser.add_argument(
|
195 |
+
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
|
196 |
+
)
|
197 |
+
args = parser.parse_args()
|
198 |
+
return args
|
199 |
+
|
200 |
+
|
201 |
+
def convert_timm_dit(args, mapping, dtype="float32"):
|
202 |
+
weights = {}
|
203 |
+
tik = time.time()
|
204 |
+
torch_dtype = str_dtype_to_torch(dtype)
|
205 |
+
tensor_parallel = mapping.tp_size
|
206 |
+
|
207 |
+
model_params = dict(torch.load(args.timm_ckpt))
|
208 |
+
model_params = {
|
209 |
+
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
|
210 |
+
}
|
211 |
+
prefix = "ema_model.transformer."
|
212 |
+
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
|
213 |
+
|
214 |
+
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
215 |
+
|
216 |
+
def get_trtllm_name(timm_name):
|
217 |
+
for k, v in timm_to_trtllm_name.items():
|
218 |
+
m = re.match(k, timm_name)
|
219 |
+
if m is not None:
|
220 |
+
if "*" in v:
|
221 |
+
v = v.replace("*", m.groups()[0])
|
222 |
+
return v
|
223 |
+
return timm_name
|
224 |
+
|
225 |
+
weights = dict()
|
226 |
+
for name, param in model_params.items():
|
227 |
+
if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight":
|
228 |
+
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1)
|
229 |
+
else:
|
230 |
+
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
|
231 |
+
|
232 |
+
assert len(weights) == len(model_params)
|
233 |
+
|
234 |
+
# new_prefix = 'f5_transformer.'
|
235 |
+
new_prefix = ""
|
236 |
+
weights = {new_prefix + key: value for key, value in weights.items()}
|
237 |
+
import math
|
238 |
+
|
239 |
+
scale_factor = math.pow(64, -0.25)
|
240 |
+
for k, v in weights.items():
|
241 |
+
if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
|
242 |
+
weights[k] *= scale_factor
|
243 |
+
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
244 |
+
|
245 |
+
elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
|
246 |
+
weights[k] *= scale_factor
|
247 |
+
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
248 |
+
|
249 |
+
elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
|
250 |
+
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
251 |
+
weights[k] *= scale_factor
|
252 |
+
|
253 |
+
elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
|
254 |
+
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
255 |
+
weights[k] *= scale_factor
|
256 |
+
|
257 |
+
elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
|
258 |
+
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
259 |
+
|
260 |
+
elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
|
261 |
+
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
262 |
+
|
263 |
+
elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
|
264 |
+
weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
|
265 |
+
|
266 |
+
tok = time.time()
|
267 |
+
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
|
268 |
+
print(f"Weights loaded. Total time: {t}")
|
269 |
+
return weights
|
270 |
+
|
271 |
+
|
272 |
+
def save_config(args):
|
273 |
+
if not os.path.exists(args.output_dir):
|
274 |
+
os.makedirs(args.output_dir)
|
275 |
+
config = {
|
276 |
+
"architecture": "F5TTS",
|
277 |
+
"dtype": args.dtype,
|
278 |
+
"hidden_size": 1024,
|
279 |
+
"num_hidden_layers": 22,
|
280 |
+
"num_attention_heads": 16,
|
281 |
+
"dim_head": 64,
|
282 |
+
"dropout": 0.1,
|
283 |
+
"ff_mult": 2,
|
284 |
+
"mel_dim": 100,
|
285 |
+
"text_num_embeds": 256,
|
286 |
+
"text_dim": 512,
|
287 |
+
"conv_layers": 4,
|
288 |
+
"long_skip_connection": False,
|
289 |
+
"mapping": {
|
290 |
+
"world_size": args.cp_size * args.tp_size * args.pp_size,
|
291 |
+
"cp_size": args.cp_size,
|
292 |
+
"tp_size": args.tp_size,
|
293 |
+
"pp_size": args.pp_size,
|
294 |
+
},
|
295 |
+
}
|
296 |
+
if args.fp8_linear:
|
297 |
+
config["quantization"] = {
|
298 |
+
"quant_algo": "FP8",
|
299 |
+
# TODO: add support for exclude modules.
|
300 |
+
# 'exclude_modules': "*final_layer*",
|
301 |
+
}
|
302 |
+
|
303 |
+
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
|
304 |
+
json.dump(config, f, indent=4)
|
305 |
+
|
306 |
+
|
307 |
+
def covert_and_save(args, rank):
|
308 |
+
if rank == 0:
|
309 |
+
save_config(args)
|
310 |
+
|
311 |
+
mapping = Mapping(
|
312 |
+
world_size=args.cp_size * args.tp_size * args.pp_size,
|
313 |
+
rank=rank,
|
314 |
+
cp_size=args.cp_size,
|
315 |
+
tp_size=args.tp_size,
|
316 |
+
pp_size=args.pp_size,
|
317 |
+
)
|
318 |
+
|
319 |
+
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
|
320 |
+
|
321 |
+
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
|
322 |
+
|
323 |
+
|
324 |
+
def execute(workers, func, args):
|
325 |
+
if workers == 1:
|
326 |
+
for rank, f in enumerate(func):
|
327 |
+
f(args, rank)
|
328 |
+
else:
|
329 |
+
with ThreadPoolExecutor(max_workers=workers) as p:
|
330 |
+
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
|
331 |
+
exceptions = []
|
332 |
+
for future in as_completed(futures):
|
333 |
+
try:
|
334 |
+
future.result()
|
335 |
+
except Exception as e:
|
336 |
+
traceback.print_exc()
|
337 |
+
exceptions.append(e)
|
338 |
+
assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log."
|
339 |
+
|
340 |
+
|
341 |
+
def main():
|
342 |
+
args = parse_arguments()
|
343 |
+
world_size = args.cp_size * args.tp_size * args.pp_size
|
344 |
+
|
345 |
+
assert args.pp_size == 1, "PP is not supported yet."
|
346 |
+
|
347 |
+
tik = time.time()
|
348 |
+
if args.timm_ckpt is None:
|
349 |
+
return
|
350 |
+
print("start execute")
|
351 |
+
execute(args.workers, [covert_and_save] * world_size, args)
|
352 |
+
|
353 |
+
tok = time.time()
|
354 |
+
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
|
355 |
+
print(f"Total time of converting checkpoints: {t}")
|
356 |
+
|
357 |
+
|
358 |
+
if __name__ == "__main__":
|
359 |
+
main()
|
src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from huggingface_hub import hf_hub_download
|
18 |
+
|
19 |
+
from conv_stft import STFT
|
20 |
+
from vocos import Vocos
|
21 |
+
import argparse
|
22 |
+
|
23 |
+
opset_version = 17
|
24 |
+
|
25 |
+
|
26 |
+
def get_args():
|
27 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
28 |
+
parser.add_argument(
|
29 |
+
"--vocoder",
|
30 |
+
type=str,
|
31 |
+
default="vocos",
|
32 |
+
choices=["vocos", "bigvgan"],
|
33 |
+
help="Vocoder to export",
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--output-path",
|
37 |
+
type=str,
|
38 |
+
default="./vocos_vocoder.onnx",
|
39 |
+
help="Output path",
|
40 |
+
)
|
41 |
+
return parser.parse_args()
|
42 |
+
|
43 |
+
|
44 |
+
class ISTFTHead(nn.Module):
|
45 |
+
def __init__(self, n_fft: int, hop_length: int):
|
46 |
+
super().__init__()
|
47 |
+
self.out = None
|
48 |
+
self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft)
|
49 |
+
|
50 |
+
def forward(self, x: torch.Tensor):
|
51 |
+
x = self.out(x).transpose(1, 2)
|
52 |
+
mag, p = x.chunk(2, dim=1)
|
53 |
+
mag = torch.exp(mag)
|
54 |
+
mag = torch.clip(mag, max=1e2)
|
55 |
+
real = mag * torch.cos(p)
|
56 |
+
imag = mag * torch.sin(p)
|
57 |
+
audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag")
|
58 |
+
return audio
|
59 |
+
|
60 |
+
|
61 |
+
class VocosVocoder(nn.Module):
|
62 |
+
def __init__(self, vocos_vocoder):
|
63 |
+
super(VocosVocoder, self).__init__()
|
64 |
+
self.vocos_vocoder = vocos_vocoder
|
65 |
+
istft_head_out = self.vocos_vocoder.head.out
|
66 |
+
n_fft = self.vocos_vocoder.head.istft.n_fft
|
67 |
+
hop_length = self.vocos_vocoder.head.istft.hop_length
|
68 |
+
istft_head_for_export = ISTFTHead(n_fft, hop_length)
|
69 |
+
istft_head_for_export.out = istft_head_out
|
70 |
+
self.vocos_vocoder.head = istft_head_for_export
|
71 |
+
|
72 |
+
def forward(self, mel):
|
73 |
+
waveform = self.vocos_vocoder.decode(mel)
|
74 |
+
return waveform
|
75 |
+
|
76 |
+
|
77 |
+
def export_VocosVocoder(vocos_vocoder, output_path, verbose):
|
78 |
+
vocos_vocoder = VocosVocoder(vocos_vocoder).cuda()
|
79 |
+
vocos_vocoder.eval()
|
80 |
+
|
81 |
+
dummy_batch_size = 8
|
82 |
+
dummy_input_length = 500
|
83 |
+
|
84 |
+
dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda()
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
dummy_waveform = vocos_vocoder(mel=dummy_mel)
|
88 |
+
print(dummy_waveform.shape)
|
89 |
+
|
90 |
+
dummy_input = dummy_mel
|
91 |
+
|
92 |
+
torch.onnx.export(
|
93 |
+
vocos_vocoder,
|
94 |
+
dummy_input,
|
95 |
+
output_path,
|
96 |
+
opset_version=opset_version,
|
97 |
+
do_constant_folding=True,
|
98 |
+
input_names=["mel"],
|
99 |
+
output_names=["waveform"],
|
100 |
+
dynamic_axes={
|
101 |
+
"mel": {0: "batch_size", 2: "input_length"},
|
102 |
+
"waveform": {0: "batch_size", 1: "output_length"},
|
103 |
+
},
|
104 |
+
verbose=verbose,
|
105 |
+
)
|
106 |
+
|
107 |
+
print("Exported to {}".format(output_path))
|
108 |
+
|
109 |
+
|
110 |
+
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None):
|
111 |
+
if vocoder_name == "vocos":
|
112 |
+
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
113 |
+
if is_local:
|
114 |
+
print(f"Load vocos from local path {local_path}")
|
115 |
+
config_path = f"{local_path}/config.yaml"
|
116 |
+
model_path = f"{local_path}/pytorch_model.bin"
|
117 |
+
else:
|
118 |
+
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
119 |
+
repo_id = "charactr/vocos-mel-24khz"
|
120 |
+
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
121 |
+
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
122 |
+
vocoder = Vocos.from_hparams(config_path)
|
123 |
+
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
124 |
+
vocoder.load_state_dict(state_dict)
|
125 |
+
vocoder = vocoder.eval().to(device)
|
126 |
+
elif vocoder_name == "bigvgan":
|
127 |
+
raise NotImplementedError("BigVGAN is not supported yet")
|
128 |
+
vocoder.remove_weight_norm()
|
129 |
+
vocoder = vocoder.eval().to(device)
|
130 |
+
return vocoder
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
args = get_args()
|
135 |
+
vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None)
|
136 |
+
if args.vocoder == "vocos":
|
137 |
+
export_VocosVocoder(vocoder, args.output_path, verbose=False)
|
src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
|
17 |
+
|
18 |
+
ONNX_PATH=$1
|
19 |
+
ENGINE_PATH=$2
|
20 |
+
echo "ONNX_PATH: $ONNX_PATH"
|
21 |
+
echo "ENGINE_PATH: $ENGINE_PATH"
|
22 |
+
PRECISION="fp32"
|
23 |
+
|
24 |
+
|
25 |
+
MIN_BATCH_SIZE=1
|
26 |
+
OPT_BATCH_SIZE=1
|
27 |
+
MAX_BATCH_SIZE=8
|
28 |
+
|
29 |
+
MIN_INPUT_LENGTH=1
|
30 |
+
OPT_INPUT_LENGTH=1000
|
31 |
+
MAX_INPUT_LENGTH=3000
|
32 |
+
|
33 |
+
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
|
34 |
+
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
|
35 |
+
MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}"
|
36 |
+
|
37 |
+
${TRTEXEC} \
|
38 |
+
--minShapes="mel:${MEL_MIN_SHAPE}" \
|
39 |
+
--optShapes="mel:${MEL_OPT_SHAPE}" \
|
40 |
+
--maxShapes="mel:${MEL_MAX_SHAPE}" \
|
41 |
+
--onnx=${ONNX_PATH} \
|
42 |
+
--saveEngine=${ENGINE_PATH}
|
43 |
+
|
src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from string import Template
|
4 |
+
|
5 |
+
|
6 |
+
def main(file_path, substitutions, in_place, participant_ids):
|
7 |
+
with open(file_path) as f:
|
8 |
+
pbtxt = Template(f.read())
|
9 |
+
|
10 |
+
sub_dict = {"max_queue_size": 0}
|
11 |
+
sub_dict["participant_ids"] = participant_ids
|
12 |
+
for sub in substitutions.split(","):
|
13 |
+
key, value = sub.split(":")
|
14 |
+
sub_dict[key] = value
|
15 |
+
|
16 |
+
pbtxt = pbtxt.safe_substitute(sub_dict)
|
17 |
+
|
18 |
+
if in_place:
|
19 |
+
with open(file_path, "w") as f:
|
20 |
+
f.write(pbtxt)
|
21 |
+
else:
|
22 |
+
print(pbtxt)
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
parser = ArgumentParser()
|
27 |
+
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
28 |
+
parser.add_argument(
|
29 |
+
"substitutions",
|
30 |
+
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
|
31 |
+
)
|
32 |
+
parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place")
|
33 |
+
parser.add_argument("--participant_ids", help="Participant IDs for the model", default="")
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
main(**vars(args))
|