mrfakename commited on
Commit
1674828
·
verified ·
1 Parent(s): 8f2de78

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/runtime/triton_trtllm/Dockerfile.server ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ FROM nvcr.io/nvidia/tritonserver:24.12-py3
2
+ RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
3
+ WORKDIR /workspace
src/f5_tts/runtime/triton_trtllm/README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Triton Inference Serving Best Practice for F5 TTS
2
+
3
+ ### Quick Start
4
+ Directly launch the service using docker compose.
5
+ ```sh
6
+ # TODO: support F5TTS_v1_Base
7
+ MODEL=F5TTS_Base docker compose up
8
+ ```
9
+
10
+ ### Build Image
11
+ Build the docker image from scratch.
12
+ ```sh
13
+ docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
14
+ ```
15
+
16
+ ### Create Docker Container
17
+ ```sh
18
+ your_mount_dir=/mnt:/mnt
19
+ docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
20
+ ```
21
+
22
+ ### Export Models to TensorRT-LLM and Launch Server
23
+ Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
24
+
25
+ ```sh
26
+ bash run.sh 0 4 F5TTS_Base
27
+ ```
28
+ ### HTTP Client
29
+ ```sh
30
+ python3 client_http.py
31
+ ```
32
+ ### Benchmark using Dataset
33
+ ```sh
34
+ num_task=2
35
+ python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
36
+ ```
37
+
38
+ ### Benchmark Results
39
+ Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
40
+
41
+ | Model | Concurrency | Avg Latency | RTF |
42
+ |-------|-------------|-----------------|--|
43
+ | F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
44
+
45
+ ### Credits
46
+ 1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
src/f5_tts/runtime/triton_trtllm/client_grpc.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
3
+ # 2023 Nvidia (authors: Yuekai Zhang)
4
+ # 2023 Recurrent.ai (authors: Songtao Shi)
5
+ # See LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ This script supports to load dataset from huggingface and sends it to the server
20
+ for decoding, in parallel.
21
+
22
+ Usage:
23
+ num_task=2
24
+
25
+ # For offline F5-TTS
26
+ python3 client_grpc.py \
27
+ --server-addr localhost \
28
+ --model-name f5_tts \
29
+ --num-tasks $num_task \
30
+ --huggingface-dataset yuekai/seed_tts \
31
+ --split-name test_zh \
32
+ --log-dir ./log_concurrent_tasks_${num_task}
33
+
34
+ # For offline Spark-TTS-0.5B
35
+ python3 client_grpc.py \
36
+ --server-addr localhost \
37
+ --model-name spark_tts \
38
+ --num-tasks $num_task \
39
+ --huggingface-dataset yuekai/seed_tts \
40
+ --split-name wenetspeech4tts \
41
+ --log-dir ./log_concurrent_tasks_${num_task}
42
+ """
43
+
44
+ import argparse
45
+ import asyncio
46
+ import json
47
+
48
+ import os
49
+ import time
50
+ import types
51
+ from pathlib import Path
52
+
53
+ import numpy as np
54
+ import soundfile as sf
55
+ import tritonclient
56
+ import tritonclient.grpc.aio as grpcclient
57
+ from tritonclient.utils import np_to_triton_dtype
58
+
59
+
60
+ def write_triton_stats(stats, summary_file):
61
+ with open(summary_file, "w") as summary_f:
62
+ model_stats = stats["model_stats"]
63
+ # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
64
+ summary_f.write(
65
+ "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
66
+ )
67
+ summary_f.write("To learn more about the log, please refer to: \n")
68
+ summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
69
+ summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
70
+ summary_f.write(
71
+ "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
72
+ )
73
+ summary_f.write(
74
+ "However, there is a trade-off between the increased queue time and the increased batch size. \n"
75
+ )
76
+ summary_f.write(
77
+ "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
78
+ )
79
+ summary_f.write(
80
+ "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
81
+ )
82
+ for model_state in model_stats:
83
+ if "last_inference" not in model_state:
84
+ continue
85
+ summary_f.write(f"model name is {model_state['name']} \n")
86
+ model_inference_stats = model_state["inference_stats"]
87
+ total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
88
+ total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
89
+ total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
90
+ total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
91
+ summary_f.write(
92
+ f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
93
+ )
94
+ model_batch_stats = model_state["batch_stats"]
95
+ for batch in model_batch_stats:
96
+ batch_size = int(batch["batch_size"])
97
+ compute_input = batch["compute_input"]
98
+ compute_output = batch["compute_output"]
99
+ compute_infer = batch["compute_infer"]
100
+ batch_count = int(compute_infer["count"])
101
+ assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
102
+ compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
103
+ compute_input_time_ms = int(compute_input["ns"]) / 1e6
104
+ compute_output_time_ms = int(compute_output["ns"]) / 1e6
105
+ summary_f.write(
106
+ f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
107
+ )
108
+ summary_f.write(
109
+ f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
110
+ )
111
+ summary_f.write(
112
+ f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
113
+ )
114
+
115
+
116
+ def get_args():
117
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
118
+
119
+ parser.add_argument(
120
+ "--server-addr",
121
+ type=str,
122
+ default="localhost",
123
+ help="Address of the server",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--server-port",
128
+ type=int,
129
+ default=8001,
130
+ help="Grpc port of the triton server, default is 8001",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--reference-audio",
135
+ type=str,
136
+ default=None,
137
+ help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
138
+ )
139
+
140
+ parser.add_argument(
141
+ "--reference-text",
142
+ type=str,
143
+ default="",
144
+ help="",
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--target-text",
149
+ type=str,
150
+ default="",
151
+ help="",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--huggingface-dataset",
156
+ type=str,
157
+ default="yuekai/seed_tts",
158
+ help="dataset name in huggingface dataset hub",
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--split-name",
163
+ type=str,
164
+ default="wenetspeech4tts",
165
+ choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
166
+ help="dataset split name, default is 'test'",
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--manifest-path",
171
+ type=str,
172
+ default=None,
173
+ help="Path to the manifest dir which includes wav.scp trans.txt files.",
174
+ )
175
+
176
+ parser.add_argument(
177
+ "--model-name",
178
+ type=str,
179
+ default="f5_tts",
180
+ choices=["f5_tts", "spark_tts"],
181
+ help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
182
+ )
183
+
184
+ parser.add_argument(
185
+ "--num-tasks",
186
+ type=int,
187
+ default=1,
188
+ help="Number of concurrent tasks for sending",
189
+ )
190
+
191
+ parser.add_argument(
192
+ "--log-interval",
193
+ type=int,
194
+ default=5,
195
+ help="Controls how frequently we print the log.",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--compute-wer",
200
+ action="store_true",
201
+ default=False,
202
+ help="""True to compute WER.
203
+ """,
204
+ )
205
+
206
+ parser.add_argument(
207
+ "--log-dir",
208
+ type=str,
209
+ required=False,
210
+ default="./tmp",
211
+ help="log directory",
212
+ )
213
+
214
+ parser.add_argument(
215
+ "--batch-size",
216
+ type=int,
217
+ default=1,
218
+ help="Inference batch_size per request for offline mode.",
219
+ )
220
+
221
+ return parser.parse_args()
222
+
223
+
224
+ def load_audio(wav_path, target_sample_rate=16000):
225
+ assert target_sample_rate == 16000, "hard coding in server"
226
+ if isinstance(wav_path, dict):
227
+ waveform = wav_path["array"]
228
+ sample_rate = wav_path["sampling_rate"]
229
+ else:
230
+ waveform, sample_rate = sf.read(wav_path)
231
+ if sample_rate != target_sample_rate:
232
+ from scipy.signal import resample
233
+
234
+ num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
235
+ waveform = resample(waveform, num_samples)
236
+ return waveform, target_sample_rate
237
+
238
+
239
+ async def send(
240
+ manifest_item_list: list,
241
+ name: str,
242
+ triton_client: tritonclient.grpc.aio.InferenceServerClient,
243
+ protocol_client: types.ModuleType,
244
+ log_interval: int,
245
+ model_name: str,
246
+ padding_duration: int = None,
247
+ audio_save_dir: str = "./",
248
+ save_sample_rate: int = 16000,
249
+ ):
250
+ total_duration = 0.0
251
+ latency_data = []
252
+ task_id = int(name[5:])
253
+
254
+ print(f"manifest_item_list: {manifest_item_list}")
255
+ for i, item in enumerate(manifest_item_list):
256
+ if i % log_interval == 0:
257
+ print(f"{name}: {i}/{len(manifest_item_list)}")
258
+ waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
259
+ duration = len(waveform) / sample_rate
260
+ lengths = np.array([[len(waveform)]], dtype=np.int32)
261
+
262
+ reference_text, target_text = item["reference_text"], item["target_text"]
263
+
264
+ estimated_target_duration = duration / len(reference_text) * len(target_text)
265
+
266
+ if padding_duration:
267
+ # padding to nearset 10 seconds
268
+ samples = np.zeros(
269
+ (
270
+ 1,
271
+ padding_duration
272
+ * sample_rate
273
+ * ((int(estimated_target_duration + duration) // padding_duration) + 1),
274
+ ),
275
+ dtype=np.float32,
276
+ )
277
+
278
+ samples[0, : len(waveform)] = waveform
279
+ else:
280
+ samples = waveform
281
+
282
+ samples = samples.reshape(1, -1).astype(np.float32)
283
+
284
+ inputs = [
285
+ protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
286
+ protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)),
287
+ protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
288
+ protocol_client.InferInput("target_text", [1, 1], "BYTES"),
289
+ ]
290
+ inputs[0].set_data_from_numpy(samples)
291
+ inputs[1].set_data_from_numpy(lengths)
292
+
293
+ input_data_numpy = np.array([reference_text], dtype=object)
294
+ input_data_numpy = input_data_numpy.reshape((1, 1))
295
+ inputs[2].set_data_from_numpy(input_data_numpy)
296
+
297
+ input_data_numpy = np.array([target_text], dtype=object)
298
+ input_data_numpy = input_data_numpy.reshape((1, 1))
299
+ inputs[3].set_data_from_numpy(input_data_numpy)
300
+
301
+ outputs = [protocol_client.InferRequestedOutput("waveform")]
302
+
303
+ sequence_id = 100000000 + i + task_id * 10
304
+ start = time.time()
305
+ response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
306
+
307
+ audio = response.as_numpy("waveform").reshape(-1)
308
+
309
+ end = time.time() - start
310
+
311
+ audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
312
+ sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
313
+
314
+ latency_data.append((end, estimated_target_duration))
315
+ total_duration += estimated_target_duration
316
+
317
+ return total_duration, latency_data
318
+
319
+
320
+ def load_manifests(manifest_path):
321
+ with open(manifest_path, "r") as f:
322
+ manifest_list = []
323
+ for line in f:
324
+ assert len(line.strip().split("|")) == 4
325
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
326
+ utt = Path(utt).stem
327
+ # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
328
+ if not os.path.isabs(prompt_wav):
329
+ prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
330
+ manifest_list.append(
331
+ {
332
+ "audio_filepath": prompt_wav,
333
+ "reference_text": prompt_text,
334
+ "target_text": gt_text,
335
+ "target_audio_path": utt,
336
+ }
337
+ )
338
+ return manifest_list
339
+
340
+
341
+ def split_data(data, k):
342
+ n = len(data)
343
+ if n < k:
344
+ print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
345
+ k = n
346
+
347
+ quotient = n // k
348
+ remainder = n % k
349
+
350
+ result = []
351
+ start = 0
352
+ for i in range(k):
353
+ if i < remainder:
354
+ end = start + quotient + 1
355
+ else:
356
+ end = start + quotient
357
+
358
+ result.append(data[start:end])
359
+ start = end
360
+
361
+ return result
362
+
363
+
364
+ async def main():
365
+ args = get_args()
366
+ url = f"{args.server_addr}:{args.server_port}"
367
+
368
+ triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
369
+ protocol_client = grpcclient
370
+
371
+ if args.reference_audio:
372
+ args.num_tasks = 1
373
+ args.log_interval = 1
374
+ manifest_item_list = [
375
+ {
376
+ "reference_text": args.reference_text,
377
+ "target_text": args.target_text,
378
+ "audio_filepath": args.reference_audio,
379
+ "target_audio_path": "test",
380
+ }
381
+ ]
382
+ elif args.huggingface_dataset:
383
+ import datasets
384
+
385
+ dataset = datasets.load_dataset(
386
+ args.huggingface_dataset,
387
+ split=args.split_name,
388
+ trust_remote_code=True,
389
+ )
390
+ manifest_item_list = []
391
+ for i in range(len(dataset)):
392
+ manifest_item_list.append(
393
+ {
394
+ "audio_filepath": dataset[i]["prompt_audio"],
395
+ "reference_text": dataset[i]["prompt_text"],
396
+ "target_audio_path": dataset[i]["id"],
397
+ "target_text": dataset[i]["target_text"],
398
+ }
399
+ )
400
+ else:
401
+ manifest_item_list = load_manifests(args.manifest_path)
402
+
403
+ args.num_tasks = min(args.num_tasks, len(manifest_item_list))
404
+ manifest_item_list = split_data(manifest_item_list, args.num_tasks)
405
+
406
+ os.makedirs(args.log_dir, exist_ok=True)
407
+ tasks = []
408
+ start_time = time.time()
409
+ for i in range(args.num_tasks):
410
+ task = asyncio.create_task(
411
+ send(
412
+ manifest_item_list[i],
413
+ name=f"task-{i}",
414
+ triton_client=triton_client,
415
+ protocol_client=protocol_client,
416
+ log_interval=args.log_interval,
417
+ model_name=args.model_name,
418
+ audio_save_dir=args.log_dir,
419
+ padding_duration=1,
420
+ save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
421
+ )
422
+ )
423
+ tasks.append(task)
424
+
425
+ ans_list = await asyncio.gather(*tasks)
426
+
427
+ end_time = time.time()
428
+ elapsed = end_time - start_time
429
+
430
+ total_duration = 0.0
431
+ latency_data = []
432
+ for ans in ans_list:
433
+ total_duration += ans[0]
434
+ latency_data += ans[1]
435
+
436
+ rtf = elapsed / total_duration
437
+
438
+ s = f"RTF: {rtf:.4f}\n"
439
+ s += f"total_duration: {total_duration:.3f} seconds\n"
440
+ s += f"({total_duration / 3600:.2f} hours)\n"
441
+ s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
442
+
443
+ latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
444
+ latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
445
+ latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
446
+ s += f"latency_variance: {latency_variance:.2f}\n"
447
+ s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
448
+ s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
449
+ s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
450
+ s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
451
+ s += f"average_latency_ms: {latency_ms:.2f}\n"
452
+
453
+ print(s)
454
+ if args.manifest_path:
455
+ name = Path(args.manifest_path).stem
456
+ elif args.split_name:
457
+ name = args.split_name
458
+ with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
459
+ f.write(s)
460
+
461
+ stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
462
+ write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
463
+
464
+ metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
465
+ with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
466
+ json.dump(metadata, f, indent=4)
467
+
468
+
469
+ if __name__ == "__main__":
470
+ asyncio.run(main())
src/f5_tts/runtime/triton_trtllm/client_http.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import requests
27
+ import soundfile as sf
28
+ import numpy as np
29
+ import argparse
30
+
31
+
32
+ def get_args():
33
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
34
+
35
+ parser.add_argument(
36
+ "--server-url",
37
+ type=str,
38
+ default="localhost:8000",
39
+ help="Address of the server",
40
+ )
41
+
42
+ parser.add_argument(
43
+ "--reference-audio",
44
+ type=str,
45
+ default="../../infer/examples/basic/basic_ref_en.wav",
46
+ help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
47
+ )
48
+
49
+ parser.add_argument(
50
+ "--reference-text",
51
+ type=str,
52
+ default="Some call me nature, others call me mother nature.",
53
+ help="",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--target-text",
58
+ type=str,
59
+ default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
60
+ help="",
61
+ )
62
+
63
+ parser.add_argument(
64
+ "--model-name",
65
+ type=str,
66
+ default="f5_tts",
67
+ choices=["f5_tts", "spark_tts"],
68
+ help="triton model_repo module name to request",
69
+ )
70
+
71
+ parser.add_argument(
72
+ "--output-audio",
73
+ type=str,
74
+ default="output.wav",
75
+ help="Path to save the output audio",
76
+ )
77
+ return parser.parse_args()
78
+
79
+
80
+ def prepare_request(
81
+ samples,
82
+ reference_text,
83
+ target_text,
84
+ sample_rate=16000,
85
+ audio_save_dir: str = "./",
86
+ ):
87
+ assert len(samples.shape) == 1, "samples should be 1D"
88
+ lengths = np.array([[len(samples)]], dtype=np.int32)
89
+ samples = samples.reshape(1, -1).astype(np.float32)
90
+
91
+ data = {
92
+ "inputs": [
93
+ {"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
94
+ {
95
+ "name": "reference_wav_len",
96
+ "shape": lengths.shape,
97
+ "datatype": "INT32",
98
+ "data": lengths.tolist(),
99
+ },
100
+ {"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
101
+ {"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
102
+ ]
103
+ }
104
+
105
+ return data
106
+
107
+
108
+ def load_audio(wav_path, target_sample_rate=16000):
109
+ assert target_sample_rate == 16000, "hard coding in server"
110
+ if isinstance(wav_path, dict):
111
+ samples = wav_path["array"]
112
+ sample_rate = wav_path["sampling_rate"]
113
+ else:
114
+ samples, sample_rate = sf.read(wav_path)
115
+ if sample_rate != target_sample_rate:
116
+ from scipy.signal import resample
117
+
118
+ num_samples = int(len(samples) * (target_sample_rate / sample_rate))
119
+ samples = resample(samples, num_samples)
120
+ return samples, target_sample_rate
121
+
122
+
123
+ if __name__ == "__main__":
124
+ args = get_args()
125
+ server_url = args.server_url
126
+ if not server_url.startswith(("http://", "https://")):
127
+ server_url = f"http://{server_url}"
128
+
129
+ url = f"{server_url}/v2/models/{args.model_name}/infer"
130
+ samples, sr = load_audio(args.reference_audio)
131
+ assert sr == 16000, "sample rate hardcoded in server"
132
+
133
+ samples = np.array(samples, dtype=np.float32)
134
+ data = prepare_request(samples, args.reference_text, args.target_text)
135
+
136
+ rsp = requests.post(
137
+ url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
138
+ )
139
+ result = rsp.json()
140
+ audio = result["outputs"][0]["data"]
141
+ audio = np.array(audio, dtype=np.float32)
142
+ sf.write(args.output_audio, audio, 24000, "PCM_16")
src/f5_tts/runtime/triton_trtllm/docker-compose.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ tts:
3
+ image: soar97/triton-f5-tts:24.12
4
+ shm_size: '1gb'
5
+ ports:
6
+ - "8000:8000"
7
+ - "8001:8001"
8
+ - "8002:8002"
9
+ environment:
10
+ - PYTHONIOENCODING=utf-8
11
+ - MODEL_ID=${MODEL_ID}
12
+ deploy:
13
+ resources:
14
+ reservations:
15
+ devices:
16
+ - driver: nvidia
17
+ device_ids: ['0']
18
+ capabilities: [gpu]
19
+ command: >
20
+ /bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL"
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorrt as trt
2
+ import os
3
+ import math
4
+ import time
5
+ from typing import List, Optional
6
+ from functools import wraps
7
+
8
+ import tensorrt_llm
9
+ from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
10
+ from tensorrt_llm.logger import logger
11
+ from tensorrt_llm.runtime.session import Session
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+
18
+ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
19
+ # Audio tensor case: batch, seq_len, feature_len
20
+ # position_ids case: batch, seq_len
21
+ assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
22
+
23
+ # Initialize a list to collect valid sequences
24
+ valid_sequences = []
25
+
26
+ for i in range(input_tensor.shape[0]):
27
+ valid_length = input_tensor_lengths[i]
28
+ valid_sequences.append(input_tensor[i, :valid_length])
29
+
30
+ # Concatenate all valid sequences along the batch dimension
31
+ output_tensor = torch.cat(valid_sequences, dim=0).contiguous()
32
+ return output_tensor
33
+
34
+
35
+ class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
37
+ super().__init__()
38
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
40
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
41
+
42
+ def forward(self, text):
43
+ # only keep tensors with value not -1
44
+ text_mask = text != -1
45
+ text_pad_cut_off_index = text_mask.sum(dim=1).max()
46
+
47
+ text = text[:, :text_pad_cut_off_index]
48
+ text = self.text_embed(text)
49
+ text = text + self.freqs_cis[: text.shape[1], :]
50
+ for block in self.text_blocks:
51
+ text = block(text)
52
+ # padding text to the original length
53
+ # text shape: B,seq_len,C
54
+ # pad at the second dimension
55
+ text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
56
+ return text
57
+
58
+
59
+ class GRN(nn.Module):
60
+ def __init__(self, dim):
61
+ super().__init__()
62
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
63
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
64
+
65
+ def forward(self, x):
66
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
67
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
68
+ return self.gamma * (x * Nx) + self.beta + x
69
+
70
+
71
+ class ConvNeXtV2Block(nn.Module):
72
+ def __init__(
73
+ self,
74
+ dim: int,
75
+ intermediate_dim: int,
76
+ dilation: int = 1,
77
+ ):
78
+ super().__init__()
79
+ padding = (dilation * (7 - 1)) // 2
80
+ self.dwconv = nn.Conv1d(
81
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
82
+ ) # depthwise conv
83
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
84
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
85
+ self.act = nn.GELU()
86
+ self.grn = GRN(intermediate_dim)
87
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ residual = x
91
+ x = x.transpose(1, 2) # b n d -> b d n
92
+ x = self.dwconv(x)
93
+ x = x.transpose(1, 2) # b d n -> b n d
94
+ x = self.norm(x)
95
+ x = self.pwconv1(x)
96
+ x = self.act(x)
97
+ x = self.grn(x)
98
+ x = self.pwconv2(x)
99
+ return residual + x
100
+
101
+
102
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
103
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
104
+ # has some connection to NTK literature
105
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
106
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
107
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
108
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
109
+ t = torch.arange(end, device=freqs.device) # type: ignore
110
+ freqs = torch.outer(t, freqs).float() # type: ignore
111
+ freqs_cos = torch.cos(freqs) # real part
112
+ freqs_sin = torch.sin(freqs) # imaginary part
113
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
114
+
115
+
116
+ def load_checkpoint(ckpt_path, use_ema=True):
117
+ checkpoint = torch.load(ckpt_path, weights_only=True)
118
+ if use_ema:
119
+ checkpoint["model_state_dict"] = {
120
+ k.replace("ema_model.", ""): v
121
+ for k, v in checkpoint["ema_model_state_dict"].items()
122
+ if k not in ["initted", "step"]
123
+ }
124
+ dict_state = checkpoint["model_state_dict"]
125
+ text_embed_dict = {}
126
+ for key in dict_state.keys():
127
+ # transformer.text_embed.text_embed.weight -> text_embed.weight
128
+ if "text_embed" in key:
129
+ text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
130
+ return text_embed_dict
131
+
132
+
133
+ class F5TTS(object):
134
+ def __init__(
135
+ self,
136
+ config,
137
+ debug_mode=True,
138
+ stream: Optional[torch.cuda.Stream] = None,
139
+ tllm_model_dir: Optional[str] = None,
140
+ model_path: Optional[str] = None,
141
+ vocab_size: Optional[int] = None,
142
+ ):
143
+ self.dtype = config["pretrained_config"]["dtype"]
144
+
145
+ rank = tensorrt_llm.mpi_rank()
146
+ world_size = config["pretrained_config"]["mapping"]["world_size"]
147
+ cp_size = config["pretrained_config"]["mapping"]["cp_size"]
148
+ tp_size = config["pretrained_config"]["mapping"]["tp_size"]
149
+ pp_size = config["pretrained_config"]["mapping"]["pp_size"]
150
+ assert pp_size == 1
151
+ self.mapping = tensorrt_llm.Mapping(
152
+ world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1
153
+ )
154
+
155
+ local_rank = rank % self.mapping.gpus_per_node
156
+ self.device = torch.device(f"cuda:{local_rank}")
157
+
158
+ torch.cuda.set_device(self.device)
159
+
160
+ self.stream = stream
161
+ if self.stream is None:
162
+ self.stream = torch.cuda.Stream(self.device)
163
+ torch.cuda.set_stream(self.stream)
164
+
165
+ engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine")
166
+ logger.info(f"Loading engine from {engine_file}")
167
+ with open(engine_file, "rb") as f:
168
+ engine_buffer = f.read()
169
+
170
+ assert engine_buffer is not None
171
+
172
+ self.session = Session.from_serialized_engine(engine_buffer)
173
+
174
+ self.debug_mode = debug_mode
175
+
176
+ self.inputs = {}
177
+ self.outputs = {}
178
+ self.buffer_allocated = False
179
+
180
+ expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"]
181
+
182
+ found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)]
183
+ if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names):
184
+ logger.error(
185
+ f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
186
+ )
187
+ logger.error(
188
+ f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
189
+ )
190
+ logger.error(f"Expected tensor names: {expected_tensor_names}")
191
+ logger.error(f"Found tensor names: {found_tensor_names}")
192
+ raise RuntimeError("Tensor names in engine are not the same as expected.")
193
+ if self.debug_mode:
194
+ self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names))
195
+
196
+ self.max_mel_len = 4096
197
+ self.text_embedding = TextEmbedding(
198
+ text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
199
+ ).to(self.device)
200
+ self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
201
+
202
+ self.target_audio_sample_rate = 24000
203
+ self.target_rms = 0.15 # target rms for audio
204
+ self.n_fft = 1024
205
+ self.win_length = 1024
206
+ self.hop_length = 256
207
+ self.n_mel_channels = 100
208
+ # self.max_mel_len = 3000
209
+ self.head_dim = 64
210
+ self.base_rescale_factor = 1.0
211
+ self.interpolation_factor = 1.0
212
+ base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
213
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
214
+ freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor
215
+ self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
216
+ self.rope_cos = self.freqs.cos().half()
217
+ self.rope_sin = self.freqs.sin().half()
218
+ self.nfe_steps = 16
219
+ t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
220
+ time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
221
+ delta_t = torch.diff(time_step)
222
+ # WAR: hard coding 256 here
223
+ tmp_dim = 256
224
+ time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
225
+ half_dim = tmp_dim // 2
226
+ emb_factor = math.log(10000) / (half_dim - 1)
227
+ emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
228
+ for i in range(self.nfe_steps):
229
+ emb = time_step[i] * emb_factor
230
+ time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
231
+ self.time_expand = time_expand.to(self.device)
232
+ self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device)
233
+
234
+ def _tensor_dtype(self, name):
235
+ # return torch dtype given tensor name for convenience
236
+ dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
237
+ return dtype
238
+
239
+ def _setup(self, batch_size, seq_len):
240
+ for i in range(self.session.engine.num_io_tensors):
241
+ name = self.session.engine.get_tensor_name(i)
242
+ if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
243
+ shape = list(self.session.engine.get_tensor_shape(name))
244
+ shape[0] = batch_size
245
+ shape[1] = seq_len
246
+ self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device)
247
+
248
+ self.buffer_allocated = True
249
+
250
+ def cuda_stream_guard(func):
251
+ """Sync external stream and set current stream to the one bound to the session. Reset on exit."""
252
+
253
+ @wraps(func)
254
+ def wrapper(self, *args, **kwargs):
255
+ external_stream = torch.cuda.current_stream()
256
+ if external_stream != self.stream:
257
+ external_stream.synchronize()
258
+ torch.cuda.set_stream(self.stream)
259
+ ret = func(self, *args, **kwargs)
260
+ if external_stream != self.stream:
261
+ self.stream.synchronize()
262
+ torch.cuda.set_stream(external_stream)
263
+ return ret
264
+
265
+ return wrapper
266
+
267
+ @cuda_stream_guard
268
+ def forward(
269
+ self,
270
+ noise: torch.Tensor,
271
+ cond: torch.Tensor,
272
+ time_expand: torch.Tensor,
273
+ rope_cos: torch.Tensor,
274
+ rope_sin: torch.Tensor,
275
+ input_lengths: torch.Tensor,
276
+ delta_t: torch.Tensor,
277
+ use_perf: bool = False,
278
+ ):
279
+ if use_perf:
280
+ torch.cuda.nvtx.range_push("flow matching")
281
+ cfg_strength = 2.0
282
+ batch_size = noise.shape[0]
283
+ half_batch = batch_size // 2
284
+ noise_half = noise[:half_batch] # Store the initial half of noise
285
+
286
+ input_type = str_dtype_to_torch(self.dtype)
287
+
288
+ # Keep a copy of the initial tensors
289
+ cond = cond.to(input_type)
290
+ rope_cos = rope_cos.to(input_type)
291
+ rope_sin = rope_sin.to(input_type)
292
+ input_lengths = input_lengths.to(str_dtype_to_torch("int32"))
293
+
294
+ # Instead of iteratively updating noise within a single model context,
295
+ # we'll do a single forward pass for each iteration with fresh context setup
296
+ for i in range(self.nfe_steps):
297
+ # Re-setup the buffers for clean execution
298
+ self._setup(batch_size, noise.shape[1])
299
+ if not self.buffer_allocated:
300
+ raise RuntimeError("Buffer not allocated, please call setup first!")
301
+
302
+ # Re-create combined noises for this iteration
303
+ current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type)
304
+
305
+ # Get time step for this iteration
306
+ current_time = time_expand[:, i].to(input_type)
307
+
308
+ # Create fresh input dictionary for this iteration
309
+ current_inputs = {
310
+ "noise": current_noise,
311
+ "cond": cond,
312
+ "time": current_time,
313
+ "rope_cos": rope_cos,
314
+ "rope_sin": rope_sin,
315
+ "input_lengths": input_lengths,
316
+ }
317
+
318
+ # Update inputs and set shapes
319
+ self.inputs.clear() # Clear previous inputs
320
+ self.inputs.update(**current_inputs)
321
+ self.session.set_shapes(self.inputs)
322
+
323
+ if use_perf:
324
+ torch.cuda.nvtx.range_push(f"execute {i}")
325
+ ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream)
326
+ assert ok, "Failed to execute model"
327
+ # self.session.context.execute_async_v3(self.stream.cuda_stream)
328
+ if use_perf:
329
+ torch.cuda.nvtx.range_pop()
330
+ # Process results
331
+ t_scale = delta_t[i].unsqueeze(0).to(input_type)
332
+
333
+ # Extract predictions
334
+ pred_cond = self.outputs["denoised"][:half_batch]
335
+ pred_uncond = self.outputs["denoised"][half_batch:]
336
+
337
+ # Apply classifier-free guidance with safeguards
338
+ guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength
339
+ # Calculate update for noise
340
+ noise_half = noise_half + guidance * t_scale
341
+ if use_perf:
342
+ torch.cuda.nvtx.range_pop()
343
+ return noise_half
344
+
345
+ def sample(
346
+ self,
347
+ text_pad_sequence: torch.Tensor,
348
+ ref_mel_batch: torch.Tensor,
349
+ ref_mel_len_batch: torch.Tensor,
350
+ estimated_reference_target_mel_len: List[int],
351
+ remove_input_padding: bool = False,
352
+ use_perf: bool = False,
353
+ ):
354
+ if use_perf:
355
+ torch.cuda.nvtx.range_push("text embedding")
356
+ batch = text_pad_sequence.shape[0]
357
+ max_seq_len = ref_mel_batch.shape[1]
358
+
359
+ text_pad_sequence_drop = torch.cat(
360
+ (text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
361
+ )
362
+
363
+ text_embedding_drop_list = []
364
+ for i in range(batch + 1):
365
+ text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
366
+ text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
367
+
368
+ text_embedding = text_embedding_drop_condition[:-1]
369
+ # text_embedding_drop B,T,C batch should be the same
370
+ text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
371
+
372
+ noise = torch.randn_like(ref_mel_batch).to(self.device)
373
+ rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
374
+ rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
375
+
376
+ cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
377
+ cat_mel_text_drop = torch.cat(
378
+ (
379
+ torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
380
+ text_embedding_drop,
381
+ ),
382
+ dim=-1,
383
+ )
384
+
385
+ time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
386
+
387
+ # Convert estimated_reference_target_mel_len to tensor
388
+ input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32)
389
+
390
+ # combine above along the batch dimension
391
+ inputs = {
392
+ "noise": torch.cat((noise, noise), dim=0).contiguous(),
393
+ "cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(),
394
+ "time_expand": time_expand,
395
+ "rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
396
+ "rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
397
+ "input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(),
398
+ "delta_t": self.delta_t,
399
+ }
400
+ if use_perf and remove_input_padding:
401
+ torch.cuda.nvtx.range_push("remove input padding")
402
+ if remove_input_padding:
403
+ max_seq_len = inputs["cond"].shape[1]
404
+ inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"])
405
+ inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"])
406
+ # for time_expand, convert from B,D to B,T,D by repeat
407
+ inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
408
+ inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"])
409
+ inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"])
410
+ inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"])
411
+ if use_perf and remove_input_padding:
412
+ torch.cuda.nvtx.range_pop()
413
+ for key in inputs:
414
+ inputs[key] = inputs[key].to(self.device)
415
+ if use_perf:
416
+ torch.cuda.nvtx.range_pop()
417
+ start_time = time.time()
418
+ denoised = self.forward(**inputs, use_perf=use_perf)
419
+ cost_time = time.time() - start_time
420
+ if use_perf and remove_input_padding:
421
+ torch.cuda.nvtx.range_push("remove input padding output")
422
+ if remove_input_padding:
423
+ denoised_list = []
424
+ start_idx = 0
425
+ for i in range(batch):
426
+ denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]])
427
+ start_idx += inputs["input_lengths"][i]
428
+ if use_perf and remove_input_padding:
429
+ torch.cuda.nvtx.range_pop()
430
+ return denoised_list, cost_time
431
+ return denoised, cost_time
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import json
27
+ import torch
28
+ from torch.nn.utils.rnn import pad_sequence
29
+ import torch.nn.functional as F
30
+ from torch.utils.dlpack import from_dlpack, to_dlpack
31
+ import torchaudio
32
+ import jieba
33
+ import triton_python_backend_utils as pb_utils
34
+ from pypinyin import Style, lazy_pinyin
35
+ import os
36
+ from f5_tts_trtllm import F5TTS
37
+
38
+
39
+ def get_tokenizer(vocab_file_path: str):
40
+ """
41
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
42
+ - "char" for char-wise tokenizer, need .txt vocab_file
43
+ - "byte" for utf-8 tokenizer
44
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
45
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
46
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
47
+ - if use "byte", set to 256 (unicode byte range)
48
+ """
49
+ with open(vocab_file_path, "r", encoding="utf-8") as f:
50
+ vocab_char_map = {}
51
+ for i, char in enumerate(f):
52
+ vocab_char_map[char[:-1]] = i
53
+ vocab_size = len(vocab_char_map)
54
+ return vocab_char_map, vocab_size
55
+
56
+
57
+ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
58
+ final_reference_target_texts_list = []
59
+ custom_trans = str.maketrans(
60
+ {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
61
+ ) # add custom trans here, to address oov
62
+
63
+ def is_chinese(c):
64
+ return "\u3100" <= c <= "\u9fff" # common chinese characters
65
+
66
+ for text in reference_target_texts_list:
67
+ char_list = []
68
+ text = text.translate(custom_trans)
69
+ for seg in jieba.cut(text):
70
+ seg_byte_len = len(bytes(seg, "UTF-8"))
71
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
72
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
73
+ char_list.append(" ")
74
+ char_list.extend(seg)
75
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
76
+ seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
77
+ for i, c in enumerate(seg):
78
+ if is_chinese(c):
79
+ char_list.append(" ")
80
+ char_list.append(seg_[i])
81
+ else: # if mixed characters, alphabets and symbols
82
+ for c in seg:
83
+ if ord(c) < 256:
84
+ char_list.extend(c)
85
+ elif is_chinese(c):
86
+ char_list.append(" ")
87
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
88
+ else:
89
+ char_list.append(c)
90
+ final_reference_target_texts_list.append(char_list)
91
+
92
+ return final_reference_target_texts_list
93
+
94
+
95
+ def list_str_to_idx(
96
+ text: list[str] | list[list[str]],
97
+ vocab_char_map: dict[str, int], # {char: idx}
98
+ padding_value=-1,
99
+ ): # noqa: F722
100
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
101
+ return list_idx_tensors
102
+
103
+
104
+ class TritonPythonModel:
105
+ def initialize(self, args):
106
+ self.use_perf = True
107
+ self.device = torch.device("cuda")
108
+ self.target_audio_sample_rate = 24000
109
+ self.target_rms = 0.15 # target rms for audio
110
+ self.n_fft = 1024
111
+ self.win_length = 1024
112
+ self.hop_length = 256
113
+ self.n_mel_channels = 100
114
+ self.max_mel_len = 3000
115
+ self.head_dim = 64
116
+
117
+ parameters = json.loads(args["model_config"])["parameters"]
118
+ for key, value in parameters.items():
119
+ parameters[key] = value["string_value"]
120
+
121
+ self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
122
+ self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
123
+ self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate)
124
+
125
+ self.tllm_model_dir = parameters["tllm_model_dir"]
126
+ config_file = os.path.join(self.tllm_model_dir, "config.json")
127
+ with open(config_file) as f:
128
+ config = json.load(f)
129
+ self.model = F5TTS(
130
+ config,
131
+ debug_mode=False,
132
+ tllm_model_dir=self.tllm_model_dir,
133
+ model_path=parameters["model_path"],
134
+ vocab_size=self.vocab_size,
135
+ )
136
+
137
+ self.vocoder = parameters["vocoder"]
138
+ assert self.vocoder in ["vocos", "bigvgan"]
139
+ if self.vocoder == "vocos":
140
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
141
+ sample_rate=self.target_audio_sample_rate,
142
+ n_fft=self.n_fft,
143
+ win_length=self.win_length,
144
+ hop_length=self.hop_length,
145
+ n_mels=self.n_mel_channels,
146
+ power=1,
147
+ center=True,
148
+ normalized=False,
149
+ norm=None,
150
+ ).to(self.device)
151
+ self.compute_mel_fn = self.get_vocos_mel_spectrogram
152
+ elif self.vocoder == "bigvgan":
153
+ self.compute_mel_fn = self.get_bigvgan_mel_spectrogram
154
+
155
+ def get_vocos_mel_spectrogram(self, waveform):
156
+ mel = self.mel_stft(waveform)
157
+ mel = mel.clamp(min=1e-5).log()
158
+ return mel.transpose(1, 2)
159
+
160
+ def forward_vocoder(self, mel):
161
+ mel = mel.to(torch.float32).contiguous().cpu()
162
+ input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
163
+
164
+ inference_request = pb_utils.InferenceRequest(
165
+ model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0]
166
+ )
167
+ inference_response = inference_request.exec()
168
+ if inference_response.has_error():
169
+ raise pb_utils.TritonModelException(inference_response.error().message())
170
+ else:
171
+ waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform")
172
+ waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
173
+
174
+ return waveform
175
+
176
+ def execute(self, requests):
177
+ (
178
+ reference_text_list,
179
+ target_text_list,
180
+ reference_target_texts_list,
181
+ estimated_reference_target_mel_len,
182
+ reference_mel_len,
183
+ ) = [], [], [], [], []
184
+ mel_features_list = []
185
+ if self.use_perf:
186
+ torch.cuda.nvtx.range_push("preprocess")
187
+ for request in requests:
188
+ wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
189
+ wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
190
+
191
+ reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
192
+ reference_text = reference_text[0][0].decode("utf-8")
193
+ reference_text_list.append(reference_text)
194
+ target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
195
+ target_text = target_text[0][0].decode("utf-8")
196
+ target_text_list.append(target_text)
197
+
198
+ text = reference_text + target_text
199
+ reference_target_texts_list.append(text)
200
+
201
+ wav = from_dlpack(wav_tensor.to_dlpack())
202
+ wav_len = from_dlpack(wav_lens.to_dlpack())
203
+ wav_len = wav_len.squeeze()
204
+ assert wav.shape[0] == 1, "Only support batch size 1 for now."
205
+ wav = wav[:, :wav_len]
206
+
207
+ ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
208
+ if ref_rms < self.target_rms:
209
+ wav = wav * self.target_rms / ref_rms
210
+ if self.reference_sample_rate != self.target_audio_sample_rate:
211
+ wav = self.resampler(wav)
212
+ wav = wav.to(self.device)
213
+ if self.use_perf:
214
+ torch.cuda.nvtx.range_push("compute_mel")
215
+ mel_features = self.compute_mel_fn(wav)
216
+ if self.use_perf:
217
+ torch.cuda.nvtx.range_pop()
218
+ mel_features_list.append(mel_features)
219
+
220
+ reference_mel_len.append(mel_features.shape[1])
221
+ estimated_reference_target_mel_len.append(
222
+ int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text)))
223
+ )
224
+
225
+ max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
226
+
227
+ batch = len(requests)
228
+ mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
229
+ for i, mel in enumerate(mel_features_list):
230
+ mel_features[i, : mel.shape[1], :] = mel
231
+
232
+ reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
233
+
234
+ pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
235
+ text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
236
+
237
+ for i, item in enumerate(text_pad_sequence):
238
+ text_pad_sequence[i] = F.pad(
239
+ item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
240
+ )
241
+ text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
242
+ text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
243
+ text_pad_sequence = F.pad(
244
+ text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
245
+ )
246
+ if self.use_perf:
247
+ torch.cuda.nvtx.range_pop()
248
+
249
+ denoised, cost_time = self.model.sample(
250
+ text_pad_sequence,
251
+ mel_features,
252
+ reference_mel_len_tensor,
253
+ estimated_reference_target_mel_len,
254
+ remove_input_padding=False,
255
+ use_perf=self.use_perf,
256
+ )
257
+ if self.use_perf:
258
+ torch.cuda.nvtx.range_push("vocoder")
259
+
260
+ responses = []
261
+ for i in range(batch):
262
+ ref_me_len = reference_mel_len[i]
263
+ estimated_mel_len = estimated_reference_target_mel_len[i]
264
+ denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
265
+ audio = self.forward_vocoder(denoised_one_item)
266
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
267
+ if rms < self.target_rms:
268
+ audio = audio * self.target_rms / rms
269
+
270
+ audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
271
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
272
+ responses.append(inference_response)
273
+ if self.use_perf:
274
+ torch.cuda.nvtx.range_pop()
275
+ return responses
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "f5_tts"
16
+ backend: "python"
17
+ max_batch_size: 4
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: 1000
20
+ }
21
+ parameters [
22
+ {
23
+ key: "vocab_file"
24
+ value: { string_value: "${vocab}"}
25
+ },
26
+ {
27
+ key: "model_path",
28
+ value: {string_value:"${model}"}
29
+ },
30
+ {
31
+ key: "tllm_model_dir",
32
+ value: {string_value:"${trtllm}"}
33
+ },
34
+ {
35
+ key: "reference_audio_sample_rate",
36
+ value: {string_value:"16000"}
37
+ },
38
+ {
39
+ key: "vocoder",
40
+ value: {string_value:"${vocoder}"}
41
+ }
42
+ ]
43
+
44
+ input [
45
+ {
46
+ name: "reference_wav"
47
+ data_type: TYPE_FP32
48
+ dims: [-1]
49
+ optional: True
50
+ },
51
+ {
52
+ name: "reference_wav_len"
53
+ data_type: TYPE_INT32
54
+ dims: [1]
55
+ optional: True
56
+ },
57
+ {
58
+ name: "reference_text"
59
+ data_type: TYPE_STRING
60
+ dims: [1]
61
+ },
62
+ {
63
+ name: "target_text"
64
+ data_type: TYPE_STRING
65
+ dims: [1]
66
+ }
67
+ ]
68
+ output [
69
+ {
70
+ name: "waveform"
71
+ data_type: TYPE_FP32
72
+ dims: [ -1 ]
73
+ }
74
+ ]
75
+
76
+ instance_group [
77
+ {
78
+ count: 1
79
+ kind: KIND_GPU
80
+ }
81
+ ]
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep ADDED
File without changes
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "vocoder"
2
+ backend: "tensorrt"
3
+ default_model_filename: "vocoder.plan"
4
+ max_batch_size: 4
5
+
6
+ input [
7
+ {
8
+ name: "mel"
9
+ data_type: TYPE_FP32
10
+ dims: [ 100, -1 ]
11
+ }
12
+ ]
13
+
14
+ output [
15
+ {
16
+ name: "waveform"
17
+ data_type: TYPE_FP32
18
+ dims: [ -1 ]
19
+ }
20
+ ]
21
+
22
+ dynamic_batching {
23
+ preferred_batch_size: [1, 2, 4]
24
+ max_queue_delay_microseconds: 1
25
+ }
26
+
27
+ instance_group [
28
+ {
29
+ count: 1
30
+ kind: KIND_GPU
31
+ }
32
+ ]
src/f5_tts/runtime/triton_trtllm/patch/__init__.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from .baichuan.model import BaichuanForCausalLM
16
+ from .bert.model import (BertForQuestionAnswering,
17
+ BertForSequenceClassification, BertModel,
18
+ RobertaForQuestionAnswering,
19
+ RobertaForSequenceClassification, RobertaModel)
20
+ from .bloom.model import BloomForCausalLM, BloomModel
21
+ from .chatglm.config import ChatGLMConfig
22
+ from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
23
+ from .cogvlm.config import CogVLMConfig
24
+ from .cogvlm.model import CogVLMForCausalLM
25
+ from .commandr.model import CohereForCausalLM
26
+ from .dbrx.config import DbrxConfig
27
+ from .dbrx.model import DbrxForCausalLM
28
+ from .deepseek_v1.model import DeepseekForCausalLM
29
+ from .deepseek_v2.model import DeepseekV2ForCausalLM
30
+ from .dit.model import DiT
31
+ from .eagle.model import EagleForCausalLM
32
+ from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
33
+ from .falcon.config import FalconConfig
34
+ from .falcon.model import FalconForCausalLM, FalconModel
35
+ from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
36
+ from .gemma.model import GemmaForCausalLM
37
+ from .gpt.config import GPTConfig
38
+ from .gpt.model import GPTForCausalLM, GPTModel
39
+ from .gptj.config import GPTJConfig
40
+ from .gptj.model import GPTJForCausalLM, GPTJModel
41
+ from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel
42
+ from .grok.model import GrokForCausalLM
43
+ from .llama.config import LLaMAConfig
44
+ from .llama.model import LLaMAForCausalLM, LLaMAModel
45
+ from .mamba.model import MambaForCausalLM
46
+ from .medusa.config import MedusaConfig
47
+ from .medusa.model import MedusaForCausalLm
48
+ from .mllama.model import MLLaMAModel
49
+ from .modeling_utils import (PretrainedConfig, PretrainedModel,
50
+ SpeculativeDecodingMode)
51
+ from .mpt.model import MPTForCausalLM, MPTModel
52
+ from .nemotron_nas.model import DeciLMForCausalLM
53
+ from .opt.model import OPTForCausalLM, OPTModel
54
+ from .phi3.model import Phi3ForCausalLM, Phi3Model
55
+ from .phi.model import PhiForCausalLM, PhiModel
56
+ from .qwen.model import QWenForCausalLM
57
+ from .recurrentgemma.model import RecurrentGemmaForCausalLM
58
+ from .redrafter.model import ReDrafterForCausalLM
59
+ from .f5tts.model import F5TTS
60
+
61
+ __all__ = [
62
+ 'BertModel',
63
+ 'BertForQuestionAnswering',
64
+ 'BertForSequenceClassification',
65
+ 'RobertaModel',
66
+ 'RobertaForQuestionAnswering',
67
+ 'RobertaForSequenceClassification',
68
+ 'BloomModel',
69
+ 'BloomForCausalLM',
70
+ 'DiT',
71
+ 'DeepseekForCausalLM',
72
+ 'FalconConfig',
73
+ 'DeepseekV2ForCausalLM',
74
+ 'FalconForCausalLM',
75
+ 'FalconModel',
76
+ 'GPTConfig',
77
+ 'GPTModel',
78
+ 'GPTForCausalLM',
79
+ 'OPTForCausalLM',
80
+ 'OPTModel',
81
+ 'LLaMAConfig',
82
+ 'LLaMAForCausalLM',
83
+ 'LLaMAModel',
84
+ 'MedusaConfig',
85
+ 'MedusaForCausalLm',
86
+ 'ReDrafterForCausalLM',
87
+ 'GPTJConfig',
88
+ 'GPTJModel',
89
+ 'GPTJForCausalLM',
90
+ 'GPTNeoXModel',
91
+ 'GPTNeoXForCausalLM',
92
+ 'PhiModel',
93
+ 'PhiConfig',
94
+ 'Phi3Model',
95
+ 'Phi3Config',
96
+ 'PhiForCausalLM',
97
+ 'Phi3ForCausalLM',
98
+ 'ChatGLMConfig',
99
+ 'ChatGLMForCausalLM',
100
+ 'ChatGLMModel',
101
+ 'BaichuanForCausalLM',
102
+ 'QWenConfig'
103
+ 'QWenForCausalLM',
104
+ 'QWenModel',
105
+ 'EncoderModel',
106
+ 'DecoderModel',
107
+ 'PretrainedConfig',
108
+ 'PretrainedModel',
109
+ 'WhisperEncoder',
110
+ 'MambaForCausalLM',
111
+ 'MambaConfig',
112
+ 'MPTForCausalLM',
113
+ 'MPTModel',
114
+ 'SkyworkForCausalLM',
115
+ 'GemmaConfig',
116
+ 'GemmaForCausalLM',
117
+ 'DbrxConfig',
118
+ 'DbrxForCausalLM',
119
+ 'RecurrentGemmaForCausalLM',
120
+ 'CogVLMConfig',
121
+ 'CogVLMForCausalLM',
122
+ 'EagleForCausalLM',
123
+ 'SpeculativeDecodingMode',
124
+ 'CohereForCausalLM',
125
+ 'MLLaMAModel',
126
+ 'F5TTS',
127
+ ]
128
+
129
+ MODEL_MAP = {
130
+ 'GPT2LMHeadModel': GPTForCausalLM,
131
+ 'GPT2LMHeadCustomModel': GPTForCausalLM,
132
+ 'GPTBigCodeForCausalLM': GPTForCausalLM,
133
+ 'Starcoder2ForCausalLM': GPTForCausalLM,
134
+ 'FuyuForCausalLM': GPTForCausalLM,
135
+ 'Kosmos2ForConditionalGeneration': GPTForCausalLM,
136
+ 'JAISLMHeadModel': GPTForCausalLM,
137
+ 'GPTForCausalLM': GPTForCausalLM,
138
+ 'NemotronForCausalLM': GPTForCausalLM,
139
+ 'OPTForCausalLM': OPTForCausalLM,
140
+ 'BloomForCausalLM': BloomForCausalLM,
141
+ 'RWForCausalLM': FalconForCausalLM,
142
+ 'FalconForCausalLM': FalconForCausalLM,
143
+ 'PhiForCausalLM': PhiForCausalLM,
144
+ 'Phi3ForCausalLM': Phi3ForCausalLM,
145
+ 'Phi3VForCausalLM': Phi3ForCausalLM,
146
+ 'Phi3SmallForCausalLM': Phi3ForCausalLM,
147
+ 'PhiMoEForCausalLM': Phi3ForCausalLM,
148
+ 'MambaForCausalLM': MambaForCausalLM,
149
+ 'GPTNeoXForCausalLM': GPTNeoXForCausalLM,
150
+ 'GPTJForCausalLM': GPTJForCausalLM,
151
+ 'MPTForCausalLM': MPTForCausalLM,
152
+ 'GLMModel': ChatGLMForCausalLM,
153
+ 'ChatGLMModel': ChatGLMForCausalLM,
154
+ 'ChatGLMForCausalLM': ChatGLMForCausalLM,
155
+ 'LlamaForCausalLM': LLaMAForCausalLM,
156
+ 'ExaoneForCausalLM': LLaMAForCausalLM,
157
+ 'MistralForCausalLM': LLaMAForCausalLM,
158
+ 'MixtralForCausalLM': LLaMAForCausalLM,
159
+ 'ArcticForCausalLM': LLaMAForCausalLM,
160
+ 'Grok1ModelForCausalLM': GrokForCausalLM,
161
+ 'InternLMForCausalLM': LLaMAForCausalLM,
162
+ 'InternLM2ForCausalLM': LLaMAForCausalLM,
163
+ 'MedusaForCausalLM': MedusaForCausalLm,
164
+ 'ReDrafterForCausalLM': ReDrafterForCausalLM,
165
+ 'BaichuanForCausalLM': BaichuanForCausalLM,
166
+ 'BaiChuanForCausalLM': BaichuanForCausalLM,
167
+ 'SkyworkForCausalLM': LLaMAForCausalLM,
168
+ GEMMA_ARCHITECTURE: GemmaForCausalLM,
169
+ GEMMA2_ARCHITECTURE: GemmaForCausalLM,
170
+ 'QWenLMHeadModel': QWenForCausalLM,
171
+ 'QWenForCausalLM': QWenForCausalLM,
172
+ 'Qwen2ForCausalLM': QWenForCausalLM,
173
+ 'Qwen2MoeForCausalLM': QWenForCausalLM,
174
+ 'Qwen2ForSequenceClassification': QWenForCausalLM,
175
+ 'Qwen2VLForConditionalGeneration': QWenForCausalLM,
176
+ 'WhisperEncoder': WhisperEncoder,
177
+ 'EncoderModel': EncoderModel,
178
+ 'DecoderModel': DecoderModel,
179
+ 'DbrxForCausalLM': DbrxForCausalLM,
180
+ 'RecurrentGemmaForCausalLM': RecurrentGemmaForCausalLM,
181
+ 'CogVLMForCausalLM': CogVLMForCausalLM,
182
+ 'DiT': DiT,
183
+ 'DeepseekForCausalLM': DeepseekForCausalLM,
184
+ 'DeciLMForCausalLM': DeciLMForCausalLM,
185
+ 'DeepseekV2ForCausalLM': DeepseekV2ForCausalLM,
186
+ 'EagleForCausalLM': EagleForCausalLM,
187
+ 'CohereForCausalLM': CohereForCausalLM,
188
+ 'MllamaForConditionalGeneration': MLLaMAModel,
189
+ 'BertForQuestionAnswering': BertForQuestionAnswering,
190
+ 'BertForSequenceClassification': BertForSequenceClassification,
191
+ 'BertModel': BertModel,
192
+ 'RobertaModel': RobertaModel,
193
+ 'RobertaForQuestionAnswering': RobertaForQuestionAnswering,
194
+ 'RobertaForSequenceClassification': RobertaForSequenceClassification,
195
+ 'F5TTS': F5TTS
196
+ }
src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import sys
3
+ import os
4
+
5
+ import tensorrt as trt
6
+ from collections import OrderedDict
7
+ from ..._utils import str_dtype_to_trt
8
+ from ...plugin import current_all_reduce_helper
9
+ from ..modeling_utils import PretrainedConfig, PretrainedModel
10
+ from ...functional import Tensor, concat
11
+ from ...module import Module, ModuleList
12
+ from tensorrt_llm._common import default_net
13
+ from ...layers import Linear
14
+
15
+ from .modules import (
16
+ TimestepEmbedding,
17
+ ConvPositionEmbedding,
18
+ DiTBlock,
19
+ AdaLayerNormZero_Final,
20
+ )
21
+
22
+ current_file_path = os.path.abspath(__file__)
23
+ parent_dir = os.path.dirname(current_file_path)
24
+ sys.path.append(parent_dir)
25
+
26
+
27
+ class InputEmbedding(Module):
28
+ def __init__(self, mel_dim, text_dim, out_dim):
29
+ super().__init__()
30
+ self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
31
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
32
+
33
+ def forward(self, x, cond):
34
+ x = self.proj(concat([x, cond], dim=-1))
35
+ return self.conv_pos_embed(x) + x
36
+
37
+
38
+ class F5TTS(PretrainedModel):
39
+ def __init__(self, config: PretrainedConfig):
40
+ super().__init__(config)
41
+ self.dtype = str_dtype_to_trt(config.dtype)
42
+
43
+ self.time_embed = TimestepEmbedding(config.hidden_size)
44
+ self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
45
+
46
+ self.dim = config.hidden_size
47
+ self.depth = config.num_hidden_layers
48
+ self.transformer_blocks = ModuleList(
49
+ [
50
+ DiTBlock(
51
+ dim=self.dim,
52
+ heads=config.num_attention_heads,
53
+ dim_head=config.dim_head,
54
+ ff_mult=config.ff_mult,
55
+ dropout=config.dropout,
56
+ )
57
+ for _ in range(self.depth)
58
+ ]
59
+ )
60
+
61
+ self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
62
+ self.proj_out = Linear(config.hidden_size, config.mel_dim)
63
+
64
+ def forward(
65
+ self,
66
+ noise, # nosied input audio
67
+ cond, # masked cond audio
68
+ time, # time step
69
+ rope_cos,
70
+ rope_sin,
71
+ input_lengths,
72
+ scale=1.0,
73
+ ):
74
+ t = self.time_embed(time)
75
+ x = self.input_embed(noise, cond)
76
+ for block in self.transformer_blocks:
77
+ x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
78
+ denoise = self.proj_out(self.norm_out(x, t))
79
+ denoise.mark_output("denoised", self.dtype)
80
+ return denoise
81
+
82
+ def prepare_inputs(self, **kwargs):
83
+ max_batch_size = kwargs["max_batch_size"]
84
+ batch_size_range = [2, 2, max_batch_size]
85
+ mel_size = 100
86
+ max_seq_len = 3000
87
+ num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
88
+ hidden_size = 512
89
+ concat_feature_dim = mel_size + hidden_size
90
+ freq_embed_dim = 256
91
+ head_dim = 64
92
+ mapping = self.config.mapping
93
+ if mapping.tp_size > 1:
94
+ current_all_reduce_helper().set_workspace_tensor(mapping, 1)
95
+ if default_net().plugin_config.remove_input_padding:
96
+ noise = Tensor(
97
+ name="noise",
98
+ dtype=self.dtype,
99
+ shape=[-1, mel_size],
100
+ dim_range=OrderedDict(
101
+ [
102
+ ("num_frames", [num_frames_range]),
103
+ ("n_mels", [mel_size]),
104
+ ]
105
+ ),
106
+ )
107
+ cond = Tensor(
108
+ name="cond",
109
+ dtype=self.dtype,
110
+ shape=[-1, concat_feature_dim],
111
+ dim_range=OrderedDict(
112
+ [
113
+ ("num_frames", [num_frames_range]),
114
+ ("embeded_length", [concat_feature_dim]),
115
+ ]
116
+ ),
117
+ )
118
+ time = Tensor(
119
+ name="time",
120
+ dtype=self.dtype,
121
+ shape=[-1, freq_embed_dim],
122
+ dim_range=OrderedDict(
123
+ [
124
+ ("num_frames", [num_frames_range]),
125
+ ("freq_dim", [freq_embed_dim]),
126
+ ]
127
+ ),
128
+ )
129
+ rope_cos = Tensor(
130
+ name="rope_cos",
131
+ dtype=self.dtype,
132
+ shape=[-1, head_dim],
133
+ dim_range=OrderedDict(
134
+ [
135
+ ("num_frames", [num_frames_range]),
136
+ ("head_dim", [head_dim]),
137
+ ]
138
+ ),
139
+ )
140
+ rope_sin = Tensor(
141
+ name="rope_sin",
142
+ dtype=self.dtype,
143
+ shape=[-1, head_dim],
144
+ dim_range=OrderedDict(
145
+ [
146
+ ("num_frames", [num_frames_range]),
147
+ ("head_dim", [head_dim]),
148
+ ]
149
+ ),
150
+ )
151
+
152
+ else:
153
+ noise = Tensor(
154
+ name="noise",
155
+ dtype=self.dtype,
156
+ shape=[-1, -1, mel_size],
157
+ dim_range=OrderedDict(
158
+ [
159
+ ("batch_size", [batch_size_range]),
160
+ ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
161
+ ("n_mels", [mel_size]),
162
+ ]
163
+ ),
164
+ )
165
+ cond = Tensor(
166
+ name="cond",
167
+ dtype=self.dtype,
168
+ shape=[-1, -1, concat_feature_dim],
169
+ dim_range=OrderedDict(
170
+ [
171
+ ("batch_size", [batch_size_range]),
172
+ ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
173
+ ("embeded_length", [concat_feature_dim]),
174
+ ]
175
+ ),
176
+ )
177
+ time = Tensor(
178
+ name="time",
179
+ dtype=self.dtype,
180
+ shape=[-1, freq_embed_dim],
181
+ dim_range=OrderedDict(
182
+ [
183
+ ("batch_size", [batch_size_range]),
184
+ ("freq_dim", [freq_embed_dim]),
185
+ ]
186
+ ),
187
+ )
188
+ rope_cos = Tensor(
189
+ name="rope_cos",
190
+ dtype=self.dtype,
191
+ shape=[-1, -1, head_dim],
192
+ dim_range=OrderedDict(
193
+ [
194
+ ("batch_size", [batch_size_range]),
195
+ ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
196
+ ("head_dim", [head_dim]),
197
+ ]
198
+ ),
199
+ )
200
+ rope_sin = Tensor(
201
+ name="rope_sin",
202
+ dtype=self.dtype,
203
+ shape=[-1, -1, head_dim],
204
+ dim_range=OrderedDict(
205
+ [
206
+ ("batch_size", [batch_size_range]),
207
+ ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
208
+ ("head_dim", [head_dim]),
209
+ ]
210
+ ),
211
+ )
212
+ input_lengths = Tensor(
213
+ name="input_lengths",
214
+ dtype=trt.int32,
215
+ shape=[-1],
216
+ dim_range=OrderedDict([("batch_size", [batch_size_range])]),
217
+ )
218
+ return {
219
+ "noise": noise,
220
+ "cond": cond,
221
+ "time": time,
222
+ "rope_cos": rope_cos,
223
+ "rope_sin": rope_sin,
224
+ "input_lengths": input_lengths,
225
+ }
src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ import numpy as np
10
+ from tensorrt_llm._common import default_net
11
+ from ..._utils import trt_dtype_to_np, str_dtype_to_trt
12
+ from ...functional import (
13
+ Tensor,
14
+ chunk,
15
+ concat,
16
+ constant,
17
+ expand,
18
+ shape,
19
+ silu,
20
+ slice,
21
+ permute,
22
+ expand_mask,
23
+ expand_dims_like,
24
+ unsqueeze,
25
+ matmul,
26
+ softmax,
27
+ squeeze,
28
+ cast,
29
+ gelu,
30
+ )
31
+ from ...functional import expand_dims, view, bert_attention
32
+ from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
33
+ from ...module import Module
34
+
35
+
36
+ class FeedForward(Module):
37
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
38
+ super().__init__()
39
+ inner_dim = int(dim * mult)
40
+ dim_out = dim_out if dim_out is not None else dim
41
+
42
+ self.project_in = Linear(dim, inner_dim)
43
+ self.ff = Linear(inner_dim, dim_out)
44
+
45
+ def forward(self, x):
46
+ return self.ff(gelu(self.project_in(x)))
47
+
48
+
49
+ class AdaLayerNormZero(Module):
50
+ def __init__(self, dim):
51
+ super().__init__()
52
+
53
+ self.linear = Linear(dim, dim * 6)
54
+ self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
55
+
56
+ def forward(self, x, emb=None):
57
+ emb = self.linear(silu(emb))
58
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
59
+ x = self.norm(x)
60
+ ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
61
+ if default_net().plugin_config.remove_input_padding:
62
+ x = x * (ones + scale_msa) + shift_msa
63
+ else:
64
+ x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
65
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
66
+
67
+
68
+ class AdaLayerNormZero_Final(Module):
69
+ def __init__(self, dim):
70
+ super().__init__()
71
+
72
+ self.linear = Linear(dim, dim * 2)
73
+
74
+ self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
75
+
76
+ def forward(self, x, emb):
77
+ emb = self.linear(silu(emb))
78
+ scale, shift = chunk(emb, 2, dim=1)
79
+ ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
80
+ if default_net().plugin_config.remove_input_padding:
81
+ x = self.norm(x) * (ones + scale) + shift
82
+ else:
83
+ x = self.norm(x) * unsqueeze((ones + scale), 1)
84
+ x = x + unsqueeze(shift, 1)
85
+ return x
86
+
87
+
88
+ class ConvPositionEmbedding(Module):
89
+ def __init__(self, dim, kernel_size=31, groups=16):
90
+ super().__init__()
91
+ assert kernel_size % 2 != 0
92
+ self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
93
+ self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
94
+ self.mish = Mish()
95
+
96
+ def forward(self, x, mask=None): # noqa: F722
97
+ if default_net().plugin_config.remove_input_padding:
98
+ x = unsqueeze(x, 0)
99
+ x = permute(x, [0, 2, 1])
100
+ x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
101
+ out = permute(x, [0, 2, 1])
102
+ if default_net().plugin_config.remove_input_padding:
103
+ out = squeeze(out, 0)
104
+ return out
105
+
106
+
107
+ class Attention(Module):
108
+ def __init__(
109
+ self,
110
+ processor: AttnProcessor,
111
+ dim: int,
112
+ heads: int = 16,
113
+ dim_head: int = 64,
114
+ dropout: float = 0.0,
115
+ context_dim: Optional[int] = None, # if not None -> joint attention
116
+ context_pre_only=None,
117
+ ):
118
+ super().__init__()
119
+
120
+ if not hasattr(F, "scaled_dot_product_attention"):
121
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
122
+
123
+ self.processor = processor
124
+
125
+ self.dim = dim # hidden_size
126
+ self.heads = heads
127
+ self.inner_dim = dim_head * heads
128
+ self.dropout = dropout
129
+ self.attention_head_size = dim_head
130
+ self.context_dim = context_dim
131
+ self.context_pre_only = context_pre_only
132
+ self.tp_size = 1
133
+ self.num_attention_heads = heads // self.tp_size
134
+ self.num_attention_kv_heads = heads // self.tp_size # 8
135
+ self.dtype = str_dtype_to_trt("float32")
136
+ self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
137
+ self.to_q = ColumnLinear(
138
+ dim,
139
+ self.tp_size * self.num_attention_heads * self.attention_head_size,
140
+ bias=True,
141
+ dtype=self.dtype,
142
+ tp_group=None,
143
+ tp_size=self.tp_size,
144
+ )
145
+ self.to_k = ColumnLinear(
146
+ dim,
147
+ self.tp_size * self.num_attention_heads * self.attention_head_size,
148
+ bias=True,
149
+ dtype=self.dtype,
150
+ tp_group=None,
151
+ tp_size=self.tp_size,
152
+ )
153
+ self.to_v = ColumnLinear(
154
+ dim,
155
+ self.tp_size * self.num_attention_heads * self.attention_head_size,
156
+ bias=True,
157
+ dtype=self.dtype,
158
+ tp_group=None,
159
+ tp_size=self.tp_size,
160
+ )
161
+
162
+ if self.context_dim is not None:
163
+ self.to_k_c = Linear(context_dim, self.inner_dim)
164
+ self.to_v_c = Linear(context_dim, self.inner_dim)
165
+ if self.context_pre_only is not None:
166
+ self.to_q_c = Linear(context_dim, self.inner_dim)
167
+
168
+ self.to_out = RowLinear(
169
+ self.tp_size * self.num_attention_heads * self.attention_head_size,
170
+ dim,
171
+ bias=True,
172
+ dtype=self.dtype,
173
+ tp_group=None,
174
+ tp_size=self.tp_size,
175
+ )
176
+
177
+ if self.context_pre_only is not None and not self.context_pre_only:
178
+ self.to_out_c = Linear(self.inner_dim, dim)
179
+
180
+ def forward(
181
+ self,
182
+ x, # noised input x
183
+ rope_cos,
184
+ rope_sin,
185
+ input_lengths,
186
+ c=None, # context c
187
+ scale=1.0,
188
+ rope=None,
189
+ c_rope=None, # rotary position embedding for c
190
+ ) -> torch.Tensor:
191
+ if c is not None:
192
+ return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
193
+ else:
194
+ return self.processor(
195
+ self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
196
+ )
197
+
198
+
199
+ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
200
+ shape_tensor = concat(
201
+ [shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
202
+ )
203
+ if default_net().plugin_config.remove_input_padding:
204
+ assert tensor.ndim() == 2
205
+ x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
206
+ x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
207
+ x1 = expand_dims(x1, 2)
208
+ x2 = expand_dims(x2, 2)
209
+ zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
210
+ x2 = zero - x2
211
+ x = concat([x2, x1], 2)
212
+ out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
213
+ else:
214
+ assert tensor.ndim() == 3
215
+
216
+ x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
217
+ x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
218
+ x1 = expand_dims(x1, 3)
219
+ x2 = expand_dims(x2, 3)
220
+ zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
221
+ x2 = zero - x2
222
+ x = concat([x2, x1], 3)
223
+ out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
224
+
225
+ return out
226
+
227
+
228
+ def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
229
+ if default_net().plugin_config.remove_input_padding:
230
+ rot_dim = shape(rope_cos, -1) # 64
231
+ new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
232
+ x_ = slice(x, [0, 0], new_t_shape, [1, 1])
233
+ end_dim = shape(x, -1) - shape(rope_cos, -1)
234
+ new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
235
+ x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
236
+ out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
237
+ else:
238
+ rot_dim = shape(rope_cos, 2) # 64
239
+ new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
240
+ x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
241
+ end_dim = shape(x, 2) - shape(rope_cos, 2)
242
+ new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
243
+ x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
244
+ out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
245
+ return out
246
+
247
+
248
+ class AttnProcessor:
249
+ def __init__(self):
250
+ pass
251
+
252
+ def __call__(
253
+ self,
254
+ attn,
255
+ x, # noised input x
256
+ rope_cos,
257
+ rope_sin,
258
+ input_lengths,
259
+ scale=1.0,
260
+ rope=None,
261
+ ) -> torch.FloatTensor:
262
+ query = attn.to_q(x)
263
+ key = attn.to_k(x)
264
+ value = attn.to_v(x)
265
+ # k,v,q all (2,1226,1024)
266
+ query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
267
+ key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
268
+
269
+ # attention
270
+ inner_dim = key.shape[-1]
271
+ norm_factor = math.sqrt(attn.attention_head_size)
272
+ q_scaling = 1.0 / norm_factor
273
+ mask = None
274
+ if not default_net().plugin_config.remove_input_padding:
275
+ N = shape(x, 1)
276
+ B = shape(x, 0)
277
+ seq_len_2d = concat([1, N])
278
+ max_position_embeddings = 4096
279
+ # create position ids
280
+ position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
281
+ tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
282
+ tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
283
+ tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
284
+ tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
285
+ mask = tmp_position_ids < tmp_input_lengths # BxL
286
+ mask = mask.cast("int32")
287
+
288
+ if default_net().plugin_config.bert_attention_plugin:
289
+ qkv = concat([query, key, value], dim=-1)
290
+ # TRT plugin mode
291
+ assert input_lengths is not None
292
+ if default_net().plugin_config.remove_input_padding:
293
+ qkv = qkv.view(concat([-1, 3 * inner_dim]))
294
+ max_input_length = constant(
295
+ np.zeros(
296
+ [
297
+ 2048,
298
+ ],
299
+ dtype=np.int32,
300
+ )
301
+ )
302
+ else:
303
+ max_input_length = None
304
+ context = bert_attention(
305
+ qkv,
306
+ input_lengths,
307
+ attn.num_attention_heads,
308
+ attn.attention_head_size,
309
+ q_scaling=q_scaling,
310
+ max_input_length=max_input_length,
311
+ )
312
+ else:
313
+ assert not default_net().plugin_config.remove_input_padding
314
+
315
+ def transpose_for_scores(x):
316
+ new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
317
+
318
+ y = x.view(new_x_shape)
319
+ y = y.transpose(1, 2)
320
+ return y
321
+
322
+ def transpose_for_scores_k(x):
323
+ new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
324
+
325
+ y = x.view(new_x_shape)
326
+ y = y.permute([0, 2, 3, 1])
327
+ return y
328
+
329
+ query = transpose_for_scores(query)
330
+ key = transpose_for_scores_k(key)
331
+ value = transpose_for_scores(value)
332
+
333
+ attention_scores = matmul(query, key, use_fp32_acc=False)
334
+
335
+ if mask is not None:
336
+ attention_mask = expand_mask(mask, shape(query, 2))
337
+ attention_mask = cast(attention_mask, attention_scores.dtype)
338
+ attention_scores = attention_scores + attention_mask
339
+
340
+ attention_probs = softmax(attention_scores, dim=-1)
341
+
342
+ context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
343
+ context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
344
+ context = attn.to_out(context)
345
+ if mask is not None:
346
+ mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
347
+ mask = expand_dims_like(mask, context)
348
+ mask = cast(mask, context.dtype)
349
+ context = context * mask
350
+ return context
351
+
352
+
353
+ # DiT Block
354
+ class DiTBlock(Module):
355
+ def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
356
+ super().__init__()
357
+
358
+ self.attn_norm = AdaLayerNormZero(dim)
359
+ self.attn = Attention(
360
+ processor=AttnProcessor(),
361
+ dim=dim,
362
+ heads=heads,
363
+ dim_head=dim_head,
364
+ dropout=dropout,
365
+ )
366
+
367
+ self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
368
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
369
+
370
+ def forward(
371
+ self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
372
+ ): # x: noised input, t: time embedding
373
+ # pre-norm & modulation for attention input
374
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
375
+ # attention
376
+ # norm ----> (2,1226,1024)
377
+ attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
378
+
379
+ # process attention output for input x
380
+ if default_net().plugin_config.remove_input_padding:
381
+ x = x + gate_msa * attn_output
382
+ else:
383
+ x = x + unsqueeze(gate_msa, 1) * attn_output
384
+ ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
385
+ if default_net().plugin_config.remove_input_padding:
386
+ norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
387
+ else:
388
+ norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
389
+ # norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
390
+ ff_output = self.ff(norm)
391
+ if default_net().plugin_config.remove_input_padding:
392
+ x = x + gate_mlp * ff_output
393
+ else:
394
+ x = x + unsqueeze(gate_mlp, 1) * ff_output
395
+
396
+ return x
397
+
398
+
399
+ class TimestepEmbedding(Module):
400
+ def __init__(self, dim, freq_embed_dim=256, dtype=None):
401
+ super().__init__()
402
+ # self.time_embed = SinusPositionEmbedding(freq_embed_dim)
403
+ self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
404
+ self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)
405
+
406
+ def forward(self, timestep):
407
+ t_freq = self.mlp1(timestep)
408
+ t_freq = silu(t_freq)
409
+ t_emb = self.mlp2(t_freq)
410
+ return t_emb
src/f5_tts/runtime/triton_trtllm/run.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ stage=$1
2
+ stop_stage=$2
3
+ model=$3 # F5TTS_Base
4
+ if [ -z "$model" ]; then
5
+ echo "Model is none"
6
+ exit 1
7
+ fi
8
+ echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
9
+ export CUDA_VISIBLE_DEVICES=0
10
+
11
+ F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
12
+ F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
13
+ F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
14
+
15
+ vocoder_trt_engine_path=vocos_vocoder.plan
16
+ model_repo=./model_repo
17
+
18
+ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
19
+ echo "Downloading f5 tts from huggingface"
20
+ huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
21
+
22
+ fi
23
+
24
+ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
25
+ echo "Converting checkpoint"
26
+ python3 ./scripts/convert_checkpoint.py \
27
+ --timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
28
+ --output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
29
+ python_package_path=/usr/local/lib/python3.12/dist-packages
30
+ cp -r patch/* $python_package_path/tensorrt_llm/models
31
+ trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
32
+ --max_batch_size 8 \
33
+ --output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
34
+ fi
35
+
36
+ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
37
+ echo "Exporting vocos vocoder"
38
+ onnx_vocoder_path=vocos_vocoder.onnx
39
+ python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
40
+ bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
41
+ fi
42
+
43
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
44
+ echo "Building triton server"
45
+ rm -r $model_repo
46
+ cp -r ./model_repo_f5_tts $model_repo
47
+ python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
48
+ cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
49
+ fi
50
+
51
+ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
52
+ echo "Starting triton server"
53
+ tritonserver --model-repository=$model_repo
54
+ fi
55
+
56
+ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
57
+ echo "Testing triton server"
58
+ num_task=1
59
+ log_dir=./log_concurrent_tasks_${num_task}
60
+ rm -r $log_dir
61
+ python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
62
+ fi
63
+
64
+ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
65
+ echo "Testing http client"
66
+ audio=../../infer/examples/basic/basic_ref_en.wav
67
+ reference_text="Some call me nature, others call me mother nature."
68
+ target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
69
+ python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
70
+ fi
src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py
2
+
3
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # MIT License
18
+
19
+ # Copyright (c) 2020 Shimin Zhang
20
+
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ import torch as th
40
+ import torch.nn.functional as F
41
+ from scipy.signal import check_COLA, get_window
42
+
43
+ support_clp_op = None
44
+ if th.__version__ >= "1.7.0":
45
+ from torch.fft import rfft as fft
46
+
47
+ support_clp_op = True
48
+ else:
49
+ from torch import rfft as fft
50
+
51
+
52
+ class STFT(th.nn.Module):
53
+ def __init__(
54
+ self,
55
+ win_len=1024,
56
+ win_hop=512,
57
+ fft_len=1024,
58
+ enframe_mode="continue",
59
+ win_type="hann",
60
+ win_sqrt=False,
61
+ pad_center=True,
62
+ ):
63
+ """
64
+ Implement of STFT using 1D convolution and 1D transpose convolutions.
65
+ Implement of framing the signal in 2 ways, `break` and `continue`.
66
+ `break` method is a kaldi-like framing.
67
+ `continue` method is a librosa-like framing.
68
+
69
+ More information about `perfect reconstruction`:
70
+ 1. https://ww2.mathworks.cn/help/signal/ref/stft.html
71
+ 2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
72
+
73
+ Args:
74
+ win_len (int): Number of points in one frame. Defaults to 1024.
75
+ win_hop (int): Number of framing stride. Defaults to 512.
76
+ fft_len (int): Number of DFT points. Defaults to 1024.
77
+ enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
78
+ win_type (str, optional): The type of window to create. Defaults to 'hann'.
79
+ win_sqrt (bool, optional): using square root window. Defaults to True.
80
+ pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
81
+ """
82
+ super(STFT, self).__init__()
83
+ assert enframe_mode in ["break", "continue"]
84
+ assert fft_len >= win_len
85
+ self.win_len = win_len
86
+ self.win_hop = win_hop
87
+ self.fft_len = fft_len
88
+ self.mode = enframe_mode
89
+ self.win_type = win_type
90
+ self.win_sqrt = win_sqrt
91
+ self.pad_center = pad_center
92
+ self.pad_amount = self.fft_len // 2
93
+
94
+ en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
95
+ self.register_buffer("en_k", en_k)
96
+ self.register_buffer("fft_k", fft_k)
97
+ self.register_buffer("ifft_k", ifft_k)
98
+ self.register_buffer("ola_k", ola_k)
99
+
100
+ def __init_kernel__(self):
101
+ """
102
+ Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
103
+ ** enframe_kernel: Using conv1d layer and identity matrix.
104
+ ** fft_kernel: Using linear layer for matrix multiplication. In fact,
105
+ enframe_kernel and fft_kernel can be combined, But for the sake of
106
+ readability, I took the two apart.
107
+ ** ifft_kernel, pinv of fft_kernel.
108
+ ** overlap-add kernel, just like enframe_kernel, but transposed.
109
+
110
+ Returns:
111
+ tuple: four kernels.
112
+ """
113
+ enframed_kernel = th.eye(self.fft_len)[:, None, :]
114
+ if support_clp_op:
115
+ tmp = fft(th.eye(self.fft_len))
116
+ fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
117
+ else:
118
+ fft_kernel = fft(th.eye(self.fft_len), 1)
119
+ if self.mode == "break":
120
+ enframed_kernel = th.eye(self.win_len)[:, None, :]
121
+ fft_kernel = fft_kernel[: self.win_len]
122
+ fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
123
+ ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
124
+ window = get_window(self.win_type, self.win_len)
125
+
126
+ self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
127
+ window = th.FloatTensor(window)
128
+ if self.mode == "continue":
129
+ left_pad = (self.fft_len - self.win_len) // 2
130
+ right_pad = left_pad + (self.fft_len - self.win_len) % 2
131
+ window = F.pad(window, (left_pad, right_pad))
132
+ if self.win_sqrt:
133
+ self.padded_window = window
134
+ window = th.sqrt(window)
135
+ else:
136
+ self.padded_window = window**2
137
+
138
+ fft_kernel = fft_kernel.T * window
139
+ ifft_kernel = ifft_kernel * window
140
+ ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :]
141
+ if self.mode == "continue":
142
+ ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len]
143
+ return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel
144
+
145
+ def is_perfect(self):
146
+ """
147
+ Whether the parameters win_len, win_hop and win_sqrt
148
+ obey constants overlap-add(COLA)
149
+
150
+ Returns:
151
+ bool: Return true if parameters obey COLA.
152
+ """
153
+ return self.perfect_reconstruct and self.pad_center
154
+
155
+ def transform(self, inputs, return_type="complex"):
156
+ """Take input data (audio) to STFT domain.
157
+
158
+ Args:
159
+ inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
160
+ return_type (str, optional): return (mag, phase) when `magphase`,
161
+ return (real, imag) when `realimag` and complex(real, imag) when `complex`.
162
+ Defaults to 'complex'.
163
+
164
+ Returns:
165
+ tuple: (mag, phase) when `magphase`, return (real, imag) when
166
+ `realimag`. Defaults to 'complex', each elements with shape
167
+ [num_batch, num_frequencies, num_frames]
168
+ """
169
+ assert return_type in ["magphase", "realimag", "complex"]
170
+ if inputs.dim() == 2:
171
+ inputs = th.unsqueeze(inputs, 1)
172
+ self.num_samples = inputs.size(-1)
173
+ if self.pad_center:
174
+ inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect")
175
+ enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
176
+ outputs = th.transpose(enframe_inputs, 1, 2)
177
+ outputs = F.linear(outputs, self.fft_k)
178
+ outputs = th.transpose(outputs, 1, 2)
179
+ dim = self.fft_len // 2 + 1
180
+ real = outputs[:, :dim, :]
181
+ imag = outputs[:, dim:, :]
182
+ if return_type == "realimag":
183
+ return real, imag
184
+ elif return_type == "complex":
185
+ assert support_clp_op
186
+ return th.complex(real, imag)
187
+ else:
188
+ mags = th.sqrt(real**2 + imag**2)
189
+ phase = th.atan2(imag, real)
190
+ return mags, phase
191
+
192
+ def inverse(self, input1, input2=None, input_type="magphase"):
193
+ """Call the inverse STFT (iSTFT), given tensors produced
194
+ by the `transform` function.
195
+
196
+ Args:
197
+ input1 (tensors): Magnitude/Real-part of STFT with shape
198
+ [num_batch, num_frequencies, num_frames]
199
+ input2 (tensors): Phase/Imag-part of STFT with shape
200
+ [num_batch, num_frequencies, num_frames]
201
+ input_type (str, optional): Mathematical meaning of input tensor's.
202
+ Defaults to 'magphase'.
203
+
204
+ Returns:
205
+ tensors: Reconstructed audio given magnitude and phase. Of
206
+ shape [num_batch, num_samples]
207
+ """
208
+ assert input_type in ["magphase", "realimag"]
209
+ if input_type == "realimag":
210
+ real, imag = None, None
211
+ if support_clp_op and th.is_complex(input1):
212
+ real, imag = input1.real, input1.imag
213
+ else:
214
+ real, imag = input1, input2
215
+ else:
216
+ real = input1 * th.cos(input2)
217
+ imag = input1 * th.sin(input2)
218
+ inputs = th.cat([real, imag], dim=1)
219
+ outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
220
+ t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
221
+ t = t.to(inputs.device)
222
+ coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
223
+
224
+ num_frames = input1.size(-1)
225
+ num_samples = num_frames * self.win_hop
226
+
227
+ rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples
228
+
229
+ outputs = outputs[..., rm_start:rm_end]
230
+ coff = coff[..., rm_start:rm_end]
231
+ coffidx = th.where(coff > 1e-8)
232
+ outputs[coffidx] = outputs[coffidx] / (coff[coffidx])
233
+ return outputs.squeeze(dim=1)
234
+
235
+ def forward(self, inputs):
236
+ """Take input data (audio) to STFT domain and then back to audio.
237
+
238
+ Args:
239
+ inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]
240
+
241
+ Returns:
242
+ tensor: Reconstructed audio given magnitude and phase.
243
+ Of shape [num_batch, num_samples]
244
+ """
245
+ mag, phase = self.transform(inputs)
246
+ rec_wav = self.inverse(mag, phase)
247
+ return rec_wav
src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+ import traceback
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+
9
+ import safetensors.torch
10
+ import torch
11
+
12
+ from tensorrt_llm import str_dtype_to_torch
13
+ from tensorrt_llm.mapping import Mapping
14
+ from tensorrt_llm.models.convert_utils import split, split_matrix_tp
15
+
16
+
17
+ def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank):
18
+ split_v = split(v, tensor_parallel, rank, dim=1)
19
+ return split_v.contiguous()
20
+
21
+
22
+ def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
23
+ split_v = split(v, tensor_parallel, rank, dim=0)
24
+ return split_v.contiguous()
25
+
26
+
27
+ FACEBOOK_DIT_NAME_MAPPING = {
28
+ "^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
29
+ "^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
30
+ "^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
31
+ "^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
32
+ "^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
33
+ "^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
34
+ "^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
35
+ "^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
36
+ "^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
37
+ "^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
38
+ "^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
39
+ "^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
40
+ "^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
41
+ "^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
42
+ "^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
43
+ "^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
44
+ "^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
45
+ "^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
46
+ "^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
47
+ "^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
48
+ "^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
49
+ "^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
50
+ "^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
51
+ "^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
52
+ "^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
53
+ "^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
54
+ "^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
55
+ "^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
56
+ "^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
57
+ "^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
58
+ "^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
59
+ "^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
60
+ "^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
61
+ "^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
62
+ "^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
63
+ "^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
64
+ "^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
65
+ "^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
66
+ "^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
67
+ "^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
68
+ "^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
69
+ "^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
70
+ "^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
71
+ "^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
72
+ "^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
73
+ "^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
74
+ "^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
75
+ "^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
76
+ "^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
77
+ "^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
78
+ "^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
79
+ "^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
80
+ "^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
81
+ "^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
82
+ "^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
83
+ "^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
84
+ "^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
85
+ "^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
86
+ "^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
87
+ "^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
88
+ "^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
89
+ "^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
90
+ "^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
91
+ "^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
92
+ "^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
93
+ "^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
94
+ "^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
95
+ "^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
96
+ "^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
97
+ "^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
98
+ "^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
99
+ "^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
100
+ "^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
101
+ "^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
102
+ "^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
103
+ "^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
104
+ "^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
105
+ "^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
106
+ "^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
107
+ "^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
108
+ "^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
109
+ "^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
110
+ "^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
111
+ "^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
112
+ "^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
113
+ "^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
114
+ "^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
115
+ "^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
116
+ "^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
117
+ "^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
118
+ "^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
119
+ "^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
120
+ "^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
121
+ "^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
122
+ "^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
123
+ "^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
124
+ "^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
125
+ "^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
126
+ "^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
127
+ "^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
128
+ "^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
129
+ "^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
130
+ "^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
131
+ "^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
132
+ "^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
133
+ "^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
134
+ "^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
135
+ "^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
136
+ "^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
137
+ "^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
138
+ "^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
139
+ "^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
140
+ "^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
141
+ "^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
142
+ "^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
143
+ "^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
144
+ "^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
145
+ "^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
146
+ "^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
147
+ "^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
148
+ "^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
149
+ "^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
150
+ "^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
151
+ "^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
152
+ "^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
153
+ "^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
154
+ "^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
155
+ "^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
156
+ "^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
157
+ "^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
158
+ "^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
159
+ "^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
160
+ "^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
161
+ "^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
162
+ "^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
163
+ "^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
164
+ "^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
165
+ "^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
166
+ "^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
167
+ "^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
168
+ }
169
+
170
+
171
+ def parse_arguments():
172
+ parser = argparse.ArgumentParser()
173
+ parser.add_argument(
174
+ "--model_name",
175
+ type=str,
176
+ default="F5TTS_Base",
177
+ choices=[
178
+ "F5TTS_Base",
179
+ ],
180
+ ) # TODO: support F5TTS_v1_Base
181
+ parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
182
+ parser.add_argument(
183
+ "--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
184
+ )
185
+ parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
186
+ parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
187
+ parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
188
+ parser.add_argument("--cfg_scale", type=float, default=4.0)
189
+ parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
190
+ parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
191
+ parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
192
+ parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
193
+ parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
194
+ parser.add_argument(
195
+ "--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
196
+ )
197
+ args = parser.parse_args()
198
+ return args
199
+
200
+
201
+ def convert_timm_dit(args, mapping, dtype="float32"):
202
+ weights = {}
203
+ tik = time.time()
204
+ torch_dtype = str_dtype_to_torch(dtype)
205
+ tensor_parallel = mapping.tp_size
206
+
207
+ model_params = dict(torch.load(args.timm_ckpt))
208
+ model_params = {
209
+ k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
210
+ }
211
+ prefix = "ema_model.transformer."
212
+ model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
213
+
214
+ timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
215
+
216
+ def get_trtllm_name(timm_name):
217
+ for k, v in timm_to_trtllm_name.items():
218
+ m = re.match(k, timm_name)
219
+ if m is not None:
220
+ if "*" in v:
221
+ v = v.replace("*", m.groups()[0])
222
+ return v
223
+ return timm_name
224
+
225
+ weights = dict()
226
+ for name, param in model_params.items():
227
+ if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight":
228
+ weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1)
229
+ else:
230
+ weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
231
+
232
+ assert len(weights) == len(model_params)
233
+
234
+ # new_prefix = 'f5_transformer.'
235
+ new_prefix = ""
236
+ weights = {new_prefix + key: value for key, value in weights.items()}
237
+ import math
238
+
239
+ scale_factor = math.pow(64, -0.25)
240
+ for k, v in weights.items():
241
+ if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
242
+ weights[k] *= scale_factor
243
+ weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
244
+
245
+ elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
246
+ weights[k] *= scale_factor
247
+ weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
248
+
249
+ elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
250
+ weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
251
+ weights[k] *= scale_factor
252
+
253
+ elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
254
+ weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
255
+ weights[k] *= scale_factor
256
+
257
+ elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
258
+ weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
259
+
260
+ elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
261
+ weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
262
+
263
+ elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
264
+ weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
265
+
266
+ tok = time.time()
267
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
268
+ print(f"Weights loaded. Total time: {t}")
269
+ return weights
270
+
271
+
272
+ def save_config(args):
273
+ if not os.path.exists(args.output_dir):
274
+ os.makedirs(args.output_dir)
275
+ config = {
276
+ "architecture": "F5TTS",
277
+ "dtype": args.dtype,
278
+ "hidden_size": 1024,
279
+ "num_hidden_layers": 22,
280
+ "num_attention_heads": 16,
281
+ "dim_head": 64,
282
+ "dropout": 0.1,
283
+ "ff_mult": 2,
284
+ "mel_dim": 100,
285
+ "text_num_embeds": 256,
286
+ "text_dim": 512,
287
+ "conv_layers": 4,
288
+ "long_skip_connection": False,
289
+ "mapping": {
290
+ "world_size": args.cp_size * args.tp_size * args.pp_size,
291
+ "cp_size": args.cp_size,
292
+ "tp_size": args.tp_size,
293
+ "pp_size": args.pp_size,
294
+ },
295
+ }
296
+ if args.fp8_linear:
297
+ config["quantization"] = {
298
+ "quant_algo": "FP8",
299
+ # TODO: add support for exclude modules.
300
+ # 'exclude_modules': "*final_layer*",
301
+ }
302
+
303
+ with open(os.path.join(args.output_dir, "config.json"), "w") as f:
304
+ json.dump(config, f, indent=4)
305
+
306
+
307
+ def covert_and_save(args, rank):
308
+ if rank == 0:
309
+ save_config(args)
310
+
311
+ mapping = Mapping(
312
+ world_size=args.cp_size * args.tp_size * args.pp_size,
313
+ rank=rank,
314
+ cp_size=args.cp_size,
315
+ tp_size=args.tp_size,
316
+ pp_size=args.pp_size,
317
+ )
318
+
319
+ weights = convert_timm_dit(args, mapping, dtype=args.dtype)
320
+
321
+ safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
322
+
323
+
324
+ def execute(workers, func, args):
325
+ if workers == 1:
326
+ for rank, f in enumerate(func):
327
+ f(args, rank)
328
+ else:
329
+ with ThreadPoolExecutor(max_workers=workers) as p:
330
+ futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
331
+ exceptions = []
332
+ for future in as_completed(futures):
333
+ try:
334
+ future.result()
335
+ except Exception as e:
336
+ traceback.print_exc()
337
+ exceptions.append(e)
338
+ assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log."
339
+
340
+
341
+ def main():
342
+ args = parse_arguments()
343
+ world_size = args.cp_size * args.tp_size * args.pp_size
344
+
345
+ assert args.pp_size == 1, "PP is not supported yet."
346
+
347
+ tik = time.time()
348
+ if args.timm_ckpt is None:
349
+ return
350
+ print("start execute")
351
+ execute(args.workers, [covert_and_save] * world_size, args)
352
+
353
+ tok = time.time()
354
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
355
+ print(f"Total time of converting checkpoints: {t}")
356
+
357
+
358
+ if __name__ == "__main__":
359
+ main()
src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ from conv_stft import STFT
20
+ from vocos import Vocos
21
+ import argparse
22
+
23
+ opset_version = 17
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28
+ parser.add_argument(
29
+ "--vocoder",
30
+ type=str,
31
+ default="vocos",
32
+ choices=["vocos", "bigvgan"],
33
+ help="Vocoder to export",
34
+ )
35
+ parser.add_argument(
36
+ "--output-path",
37
+ type=str,
38
+ default="./vocos_vocoder.onnx",
39
+ help="Output path",
40
+ )
41
+ return parser.parse_args()
42
+
43
+
44
+ class ISTFTHead(nn.Module):
45
+ def __init__(self, n_fft: int, hop_length: int):
46
+ super().__init__()
47
+ self.out = None
48
+ self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft)
49
+
50
+ def forward(self, x: torch.Tensor):
51
+ x = self.out(x).transpose(1, 2)
52
+ mag, p = x.chunk(2, dim=1)
53
+ mag = torch.exp(mag)
54
+ mag = torch.clip(mag, max=1e2)
55
+ real = mag * torch.cos(p)
56
+ imag = mag * torch.sin(p)
57
+ audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag")
58
+ return audio
59
+
60
+
61
+ class VocosVocoder(nn.Module):
62
+ def __init__(self, vocos_vocoder):
63
+ super(VocosVocoder, self).__init__()
64
+ self.vocos_vocoder = vocos_vocoder
65
+ istft_head_out = self.vocos_vocoder.head.out
66
+ n_fft = self.vocos_vocoder.head.istft.n_fft
67
+ hop_length = self.vocos_vocoder.head.istft.hop_length
68
+ istft_head_for_export = ISTFTHead(n_fft, hop_length)
69
+ istft_head_for_export.out = istft_head_out
70
+ self.vocos_vocoder.head = istft_head_for_export
71
+
72
+ def forward(self, mel):
73
+ waveform = self.vocos_vocoder.decode(mel)
74
+ return waveform
75
+
76
+
77
+ def export_VocosVocoder(vocos_vocoder, output_path, verbose):
78
+ vocos_vocoder = VocosVocoder(vocos_vocoder).cuda()
79
+ vocos_vocoder.eval()
80
+
81
+ dummy_batch_size = 8
82
+ dummy_input_length = 500
83
+
84
+ dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda()
85
+
86
+ with torch.no_grad():
87
+ dummy_waveform = vocos_vocoder(mel=dummy_mel)
88
+ print(dummy_waveform.shape)
89
+
90
+ dummy_input = dummy_mel
91
+
92
+ torch.onnx.export(
93
+ vocos_vocoder,
94
+ dummy_input,
95
+ output_path,
96
+ opset_version=opset_version,
97
+ do_constant_folding=True,
98
+ input_names=["mel"],
99
+ output_names=["waveform"],
100
+ dynamic_axes={
101
+ "mel": {0: "batch_size", 2: "input_length"},
102
+ "waveform": {0: "batch_size", 1: "output_length"},
103
+ },
104
+ verbose=verbose,
105
+ )
106
+
107
+ print("Exported to {}".format(output_path))
108
+
109
+
110
+ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None):
111
+ if vocoder_name == "vocos":
112
+ # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
113
+ if is_local:
114
+ print(f"Load vocos from local path {local_path}")
115
+ config_path = f"{local_path}/config.yaml"
116
+ model_path = f"{local_path}/pytorch_model.bin"
117
+ else:
118
+ print("Download Vocos from huggingface charactr/vocos-mel-24khz")
119
+ repo_id = "charactr/vocos-mel-24khz"
120
+ config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
121
+ model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
122
+ vocoder = Vocos.from_hparams(config_path)
123
+ state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
124
+ vocoder.load_state_dict(state_dict)
125
+ vocoder = vocoder.eval().to(device)
126
+ elif vocoder_name == "bigvgan":
127
+ raise NotImplementedError("BigVGAN is not supported yet")
128
+ vocoder.remove_weight_norm()
129
+ vocoder = vocoder.eval().to(device)
130
+ return vocoder
131
+
132
+
133
+ if __name__ == "__main__":
134
+ args = get_args()
135
+ vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None)
136
+ if args.vocoder == "vocos":
137
+ export_VocosVocoder(vocoder, args.output_path, verbose=False)
src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ TRTEXEC="/usr/src/tensorrt/bin/trtexec"
17
+
18
+ ONNX_PATH=$1
19
+ ENGINE_PATH=$2
20
+ echo "ONNX_PATH: $ONNX_PATH"
21
+ echo "ENGINE_PATH: $ENGINE_PATH"
22
+ PRECISION="fp32"
23
+
24
+
25
+ MIN_BATCH_SIZE=1
26
+ OPT_BATCH_SIZE=1
27
+ MAX_BATCH_SIZE=8
28
+
29
+ MIN_INPUT_LENGTH=1
30
+ OPT_INPUT_LENGTH=1000
31
+ MAX_INPUT_LENGTH=3000
32
+
33
+ MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
34
+ MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
35
+ MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}"
36
+
37
+ ${TRTEXEC} \
38
+ --minShapes="mel:${MEL_MIN_SHAPE}" \
39
+ --optShapes="mel:${MEL_OPT_SHAPE}" \
40
+ --maxShapes="mel:${MEL_MAX_SHAPE}" \
41
+ --onnx=${ONNX_PATH} \
42
+ --saveEngine=${ENGINE_PATH}
43
+
src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ from argparse import ArgumentParser
3
+ from string import Template
4
+
5
+
6
+ def main(file_path, substitutions, in_place, participant_ids):
7
+ with open(file_path) as f:
8
+ pbtxt = Template(f.read())
9
+
10
+ sub_dict = {"max_queue_size": 0}
11
+ sub_dict["participant_ids"] = participant_ids
12
+ for sub in substitutions.split(","):
13
+ key, value = sub.split(":")
14
+ sub_dict[key] = value
15
+
16
+ pbtxt = pbtxt.safe_substitute(sub_dict)
17
+
18
+ if in_place:
19
+ with open(file_path, "w") as f:
20
+ f.write(pbtxt)
21
+ else:
22
+ print(pbtxt)
23
+
24
+
25
+ if __name__ == "__main__":
26
+ parser = ArgumentParser()
27
+ parser.add_argument("file_path", help="path of the .pbtxt to modify")
28
+ parser.add_argument(
29
+ "substitutions",
30
+ help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
31
+ )
32
+ parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place")
33
+ parser.add_argument("--participant_ids", help="Participant IDs for the model", default="")
34
+ args = parser.parse_args()
35
+
36
+ main(**vars(args))