mrfakename commited on
Commit
0bebc31
·
verified ·
1 Parent(s): 0abf49a

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (41) hide show
  1. .pre-commit-config.yaml +5 -2
  2. app.py +166 -97
  3. pyproject.toml +1 -1
  4. ruff.toml +1 -1
  5. src/f5_tts/api.py +2 -2
  6. src/f5_tts/eval/ecapa_tdnn.py +1 -0
  7. src/f5_tts/eval/eval_infer_batch.py +2 -0
  8. src/f5_tts/eval/eval_librispeech_test_clean.py +4 -5
  9. src/f5_tts/eval/eval_seedtts_testset.py +4 -5
  10. src/f5_tts/infer/infer_cli.py +7 -7
  11. src/f5_tts/infer/speech_edit.py +3 -1
  12. src/f5_tts/infer/utils_infer.py +4 -4
  13. src/f5_tts/model/__init__.py +2 -4
  14. src/f5_tts/model/backbones/dit.py +4 -5
  15. src/f5_tts/model/backbones/mmdit.py +3 -4
  16. src/f5_tts/model/backbones/unett.py +6 -6
  17. src/f5_tts/model/trainer.py +1 -0
  18. src/f5_tts/model/utils.py +2 -3
  19. src/f5_tts/runtime/triton_trtllm/benchmark.py +12 -11
  20. src/f5_tts/runtime/triton_trtllm/client_grpc.py +0 -1
  21. src/f5_tts/runtime/triton_trtllm/client_http.py +3 -2
  22. src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py +6 -7
  23. src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py +6 -5
  24. src/f5_tts/runtime/triton_trtllm/patch/__init__.py +3 -2
  25. src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +9 -12
  26. src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +14 -12
  27. src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py +1 -0
  28. src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py +0 -1
  29. src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py +4 -3
  30. src/f5_tts/scripts/count_params_gflops.py +5 -4
  31. src/f5_tts/socket_client.py +5 -3
  32. src/f5_tts/socket_server.py +5 -4
  33. src/f5_tts/train/datasets/prepare_csv_wavs.py +7 -8
  34. src/f5_tts/train/datasets/prepare_emilia.py +3 -5
  35. src/f5_tts/train/datasets/prepare_emilia_v2.py +6 -6
  36. src/f5_tts/train/datasets/prepare_libritts.py +3 -1
  37. src/f5_tts/train/datasets/prepare_ljspeech.py +3 -1
  38. src/f5_tts/train/datasets/prepare_wenetspeech4tts.py +2 -1
  39. src/f5_tts/train/finetune_cli.py +2 -2
  40. src/f5_tts/train/finetune_gradio.py +5 -5
  41. src/f5_tts/train/train.py +1 -0
.pre-commit-config.yaml CHANGED
@@ -3,11 +3,14 @@ repos:
3
  # Ruff version.
4
  rev: v0.11.2
5
  hooks:
6
- # Run the linter.
7
  - id: ruff
 
8
  args: [--fix]
9
- # Run the formatter.
10
  - id: ruff-format
 
 
 
 
11
  - repo: https://github.com/pre-commit/pre-commit-hooks
12
  rev: v5.0.0
13
  hooks:
 
3
  # Ruff version.
4
  rev: v0.11.2
5
  hooks:
 
6
  - id: ruff
7
+ name: ruff linter
8
  args: [--fix]
 
9
  - id: ruff-format
10
+ name: ruff formatter
11
+ - id: ruff
12
+ name: ruff sorter
13
+ args: [--select, I, --fix]
14
  - repo: https://github.com/pre-commit/pre-commit-hooks
15
  rev: v5.0.0
16
  hooks:
app.py CHANGED
@@ -6,6 +6,7 @@ import json
6
  import re
7
  import tempfile
8
  from collections import OrderedDict
 
9
  from importlib.resources import files
10
 
11
  import click
@@ -17,6 +18,7 @@ import torchaudio
17
  from cached_path import cached_path
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
19
 
 
20
  try:
21
  import spaces
22
 
@@ -32,15 +34,15 @@ def gpu_decorator(func):
32
  return func
33
 
34
 
35
- from f5_tts.model import DiT, UNetT
36
  from f5_tts.infer.utils_infer import (
37
- load_vocoder,
38
  load_model,
 
39
  preprocess_ref_audio_text,
40
- infer_process,
41
  remove_silence_for_generated_wav,
42
  save_spectrogram,
43
  )
 
44
 
45
 
46
  DEFAULT_TTS_MODEL = "F5-TTS_v1"
@@ -122,6 +124,7 @@ def load_text_from_file(file):
122
  return gr.update(value=text)
123
 
124
 
 
125
  @gpu_decorator
126
  def infer(
127
  ref_audio_orig,
@@ -140,7 +143,11 @@ def infer(
140
  return gr.update(), gr.update(), ref_text
141
 
142
  # Set inference seed
 
 
 
143
  torch.manual_seed(seed)
 
144
 
145
  if not gen_text.strip():
146
  gr.Warning("Please enter text to generate or upload a text file.")
@@ -191,7 +198,7 @@ def infer(
191
  spectrogram_path = tmp_spectrogram.name
192
  save_spectrogram(combined_spectrogram, spectrogram_path)
193
 
194
- return (final_sample_rate, final_wave), spectrogram_path, ref_text
195
 
196
 
197
  with gr.Blocks() as app_credits:
@@ -277,27 +284,21 @@ with gr.Blocks() as app_tts:
277
  nfe_slider,
278
  speed_slider,
279
  ):
280
- # Determine the seed to use
281
  if randomize_seed:
282
- seed = np.random.randint(0, 2**31 - 1)
283
- else:
284
- seed = seed_input
285
- if seed < 0 or seed > 2**31 - 1:
286
- gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
287
- seed = np.random.randint(0, 2**31 - 1)
288
 
289
- audio_out, spectrogram_path, ref_text_out = infer(
290
  ref_audio_input,
291
  ref_text_input,
292
  gen_text_input,
293
  tts_model_choice,
294
  remove_silence,
295
- seed=seed,
296
  cross_fade_duration=cross_fade_duration_slider,
297
  nfe_step=nfe_slider,
298
  speed=speed_slider,
299
  )
300
- return audio_out, spectrogram_path, ref_text_out, seed
301
 
302
  gen_text_file.upload(
303
  load_text_from_file,
@@ -329,26 +330,34 @@ with gr.Blocks() as app_tts:
329
 
330
 
331
  def parse_speechtypes_text(gen_text):
332
- # Pattern to find {speechtype}
333
- pattern = r"\{(.*?)\}"
334
 
335
  # Split the text by the pattern
336
  tokens = re.split(pattern, gen_text)
337
 
338
  segments = []
339
 
340
- current_style = "Regular"
 
 
 
 
341
 
342
  for i in range(len(tokens)):
343
  if i % 2 == 0:
344
  # This is text
345
  text = tokens[i].strip()
346
  if text:
347
- segments.append({"style": current_style, "text": text})
 
348
  else:
349
- # This is style
350
- style = tokens[i].strip()
351
- current_style = style
 
 
 
352
 
353
  return segments
354
 
@@ -366,41 +375,48 @@ with gr.Blocks() as app_multistyle:
366
  with gr.Row():
367
  gr.Markdown(
368
  """
369
- **Example Input:**
370
- {Regular} Hello, I'd like to order a sandwich please.
371
- {Surprised} What do you mean you're out of bread?
372
- {Sad} I really wanted a sandwich though...
373
- {Angry} You know what, darn you and your little shop!
374
- {Whisper} I'll just go back home and cry now.
375
  {Shouting} Why me?!
376
  """
377
  )
378
 
379
  gr.Markdown(
380
  """
381
- **Example Input 2:**
382
- {Speaker1_Happy} Hello, I'd like to order a sandwich please.
383
- {Speaker2_Regular} Sorry, we're out of bread.
384
- {Speaker1_Sad} I really wanted a sandwich though...
385
- {Speaker2_Whisper} I'll give you the last one I was hiding.
386
  """
387
  )
388
 
389
  gr.Markdown(
390
- "Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
391
  )
392
 
393
  # Regular speech type (mandatory)
394
- with gr.Row() as regular_row:
395
  with gr.Column(scale=1, min_width=160):
396
  regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
397
  regular_insert = gr.Button("Insert Label", variant="secondary")
398
  with gr.Column(scale=3):
399
  regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
400
  with gr.Column(scale=3):
401
- regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=8, scale=3)
402
- with gr.Column(scale=1):
403
- regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
 
 
 
 
 
 
 
404
 
405
  # Regular speech type (max 100)
406
  max_speech_types = 100
@@ -409,32 +425,54 @@ with gr.Blocks() as app_multistyle:
409
  speech_type_audios = [regular_audio]
410
  speech_type_ref_texts = [regular_ref_text]
411
  speech_type_ref_text_files = [regular_ref_text_file]
 
 
412
  speech_type_delete_btns = [None]
413
  speech_type_insert_btns = [regular_insert]
414
 
415
  # Additional speech types (99 more)
416
  for i in range(max_speech_types - 1):
417
- with gr.Row(visible=False) as row:
418
  with gr.Column(scale=1, min_width=160):
419
  name_input = gr.Textbox(label="Speech Type Name")
420
- delete_btn = gr.Button("Delete Type", variant="secondary")
421
  insert_btn = gr.Button("Insert Label", variant="secondary")
 
422
  with gr.Column(scale=3):
423
  audio_input = gr.Audio(label="Reference Audio", type="filepath")
424
  with gr.Column(scale=3):
425
- ref_text_input = gr.Textbox(label="Reference Text", lines=8, scale=3)
426
- with gr.Column(scale=1):
427
- ref_text_file_input = gr.File(
428
- label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
429
- )
 
 
 
 
 
430
  speech_type_rows.append(row)
431
  speech_type_names.append(name_input)
432
  speech_type_audios.append(audio_input)
433
  speech_type_ref_texts.append(ref_text_input)
434
  speech_type_ref_text_files.append(ref_text_file_input)
 
 
435
  speech_type_delete_btns.append(delete_btn)
436
  speech_type_insert_btns.append(insert_btn)
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  # Button to add speech type
439
  add_speech_type_btn = gr.Button("Add Speech Type")
440
 
@@ -470,18 +508,6 @@ with gr.Blocks() as app_multistyle:
470
  speech_type_ref_text_files[i],
471
  ],
472
  )
473
- speech_type_ref_text_files[i].upload(
474
- load_text_from_file,
475
- inputs=[speech_type_ref_text_files[i]],
476
- outputs=[speech_type_ref_texts[i]],
477
- )
478
-
479
- # Update regular speech type ref text file
480
- regular_ref_text_file.upload(
481
- load_text_from_file,
482
- inputs=[regular_ref_text_file],
483
- outputs=[regular_ref_text],
484
- )
485
 
486
  # Text input for the prompt
487
  with gr.Row():
@@ -495,10 +521,17 @@ with gr.Blocks() as app_multistyle:
495
  gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
496
 
497
  def make_insert_speech_type_fn(index):
498
- def insert_speech_type_fn(current_text, speech_type_name):
499
  current_text = current_text or ""
500
- speech_type_name = speech_type_name or "None"
501
- updated_text = current_text + f"{{{speech_type_name}}} "
 
 
 
 
 
 
 
502
  return updated_text
503
 
504
  return insert_speech_type_fn
@@ -507,16 +540,24 @@ with gr.Blocks() as app_multistyle:
507
  insert_fn = make_insert_speech_type_fn(i)
508
  insert_btn.click(
509
  insert_fn,
510
- inputs=[gen_text_input_multistyle, speech_type_names[i]],
511
  outputs=gen_text_input_multistyle,
512
  )
513
 
514
- with gr.Accordion("Advanced Settings", open=False):
515
- remove_silence_multistyle = gr.Checkbox(
516
- label="Remove Silences",
517
- info="Turn on to automatically detect and crop long silences.",
518
- value=True,
519
- )
 
 
 
 
 
 
 
 
520
 
521
  # Generate button
522
  generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
@@ -524,6 +565,24 @@ with gr.Blocks() as app_multistyle:
524
  # Output audio
525
  audio_output_multistyle = gr.Audio(label="Synthesized Audio")
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  gen_text_file_multistyle.upload(
528
  load_text_from_file,
529
  inputs=[gen_text_file_multistyle],
@@ -557,44 +616,60 @@ with gr.Blocks() as app_multistyle:
557
 
558
  # For each segment, generate speech
559
  generated_audio_segments = []
560
- current_style = "Regular"
 
561
 
562
  for segment in segments:
563
- style = segment["style"]
 
 
564
  text = segment["text"]
565
 
566
- if style in speech_types:
567
- current_style = style
568
  else:
569
- gr.Warning(f"Type {style} is not available, will use Regular as default.")
570
- current_style = "Regular"
571
 
572
  try:
573
- ref_audio = speech_types[current_style]["audio"]
574
  except KeyError:
575
- gr.Warning(f"Please provide reference audio for type {current_style}.")
576
- return [None] + [speech_types[style]["ref_text"] for style in speech_types]
577
- ref_text = speech_types[current_style].get("ref_text", "")
578
 
579
- # TODO. Attribute each type a unique seed (maybe also speed, pseudo-feature for #730 #813)
580
- seed = np.random.randint(0, 2**31 - 1)
581
 
582
- # Generate speech for this segment
583
- audio_out, _, ref_text_out = infer(
584
- ref_audio, ref_text, text, tts_model_choice, remove_silence, seed, 0, show_info=print
585
- ) # show_info=print no pull to top when generating
 
 
 
 
 
 
 
 
586
  sr, audio_data = audio_out
587
 
588
  generated_audio_segments.append(audio_data)
589
- speech_types[current_style]["ref_text"] = ref_text_out
 
590
 
591
  # Concatenate all audio segments
592
  if generated_audio_segments:
593
  final_audio_data = np.concatenate(generated_audio_segments)
594
- return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
 
 
 
 
595
  else:
596
  gr.Warning("No audio generated.")
597
- return [None] + [speech_types[style]["ref_text"] for style in speech_types]
598
 
599
  generate_multistyle_btn.click(
600
  generate_multistyle_speech,
@@ -607,7 +682,7 @@ with gr.Blocks() as app_multistyle:
607
  + [
608
  remove_silence_multistyle,
609
  ],
610
- outputs=[audio_output_multistyle] + speech_type_ref_texts,
611
  )
612
 
613
  # Validation function to disable Generate button if speech types are missing
@@ -624,7 +699,7 @@ with gr.Blocks() as app_multistyle:
624
 
625
  # Parse the gen_text to get the speech types used
626
  segments = parse_speechtypes_text(gen_text)
627
- speech_types_in_text = set(segment["style"] for segment in segments)
628
 
629
  # Check if all speech types in text are available
630
  missing_speech_types = speech_types_in_text - speech_types_available
@@ -788,27 +863,21 @@ Have a conversation with an AI using your reference voice!
788
  if not last_ai_response or conv_state[-1]["role"] != "assistant":
789
  return None, ref_text, seed_input
790
 
791
- # Determine the seed to use
792
  if randomize_seed:
793
- seed = np.random.randint(0, 2**31 - 1)
794
- else:
795
- seed = seed_input
796
- if seed < 0 or seed > 2**31 - 1:
797
- gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
798
- seed = np.random.randint(0, 2**31 - 1)
799
 
800
- audio_result, _, ref_text_out = infer(
801
  ref_audio,
802
  ref_text,
803
  last_ai_response,
804
  tts_model_choice,
805
  remove_silence,
806
- seed=seed,
807
  cross_fade_duration=0.15,
808
  speed=1.0,
809
  show_info=print, # show_info=print no pull to top when generating
810
  )
811
- return audio_result, ref_text_out, seed
812
 
813
  def clear_conversation():
814
  """Reset the conversation"""
 
6
  import re
7
  import tempfile
8
  from collections import OrderedDict
9
+ from functools import lru_cache
10
  from importlib.resources import files
11
 
12
  import click
 
18
  from cached_path import cached_path
19
  from transformers import AutoModelForCausalLM, AutoTokenizer
20
 
21
+
22
  try:
23
  import spaces
24
 
 
34
  return func
35
 
36
 
 
37
  from f5_tts.infer.utils_infer import (
38
+ infer_process,
39
  load_model,
40
+ load_vocoder,
41
  preprocess_ref_audio_text,
 
42
  remove_silence_for_generated_wav,
43
  save_spectrogram,
44
  )
45
+ from f5_tts.model import DiT, UNetT
46
 
47
 
48
  DEFAULT_TTS_MODEL = "F5-TTS_v1"
 
124
  return gr.update(value=text)
125
 
126
 
127
+ @lru_cache(maxsize=100)
128
  @gpu_decorator
129
  def infer(
130
  ref_audio_orig,
 
143
  return gr.update(), gr.update(), ref_text
144
 
145
  # Set inference seed
146
+ if seed < 0 or seed > 2**31 - 1:
147
+ gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
148
+ seed = np.random.randint(0, 2**31 - 1)
149
  torch.manual_seed(seed)
150
+ used_seed = seed
151
 
152
  if not gen_text.strip():
153
  gr.Warning("Please enter text to generate or upload a text file.")
 
198
  spectrogram_path = tmp_spectrogram.name
199
  save_spectrogram(combined_spectrogram, spectrogram_path)
200
 
201
+ return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
202
 
203
 
204
  with gr.Blocks() as app_credits:
 
284
  nfe_slider,
285
  speed_slider,
286
  ):
 
287
  if randomize_seed:
288
+ seed_input = np.random.randint(0, 2**31 - 1)
 
 
 
 
 
289
 
290
+ audio_out, spectrogram_path, ref_text_out, used_seed = infer(
291
  ref_audio_input,
292
  ref_text_input,
293
  gen_text_input,
294
  tts_model_choice,
295
  remove_silence,
296
+ seed=seed_input,
297
  cross_fade_duration=cross_fade_duration_slider,
298
  nfe_step=nfe_slider,
299
  speed=speed_slider,
300
  )
301
+ return audio_out, spectrogram_path, ref_text_out, used_seed
302
 
303
  gen_text_file.upload(
304
  load_text_from_file,
 
330
 
331
 
332
  def parse_speechtypes_text(gen_text):
333
+ # Pattern to find {str} or {"name": str, "seed": int, "speed": float}
334
+ pattern = r"(\{.*?\})"
335
 
336
  # Split the text by the pattern
337
  tokens = re.split(pattern, gen_text)
338
 
339
  segments = []
340
 
341
+ current_type_dict = {
342
+ "name": "Regular",
343
+ "seed": -1,
344
+ "speed": 1.0,
345
+ }
346
 
347
  for i in range(len(tokens)):
348
  if i % 2 == 0:
349
  # This is text
350
  text = tokens[i].strip()
351
  if text:
352
+ current_type_dict["text"] = text
353
+ segments.append(current_type_dict)
354
  else:
355
+ # This is type
356
+ type_str = tokens[i].strip()
357
+ try: # if type dict
358
+ current_type_dict = json.loads(type_str)
359
+ except json.decoder.JSONDecodeError:
360
+ current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
361
 
362
  return segments
363
 
 
375
  with gr.Row():
376
  gr.Markdown(
377
  """
378
+ **Example Input:** <br>
379
+ {Regular} Hello, I'd like to order a sandwich please. <br>
380
+ {Surprised} What do you mean you're out of bread? <br>
381
+ {Sad} I really wanted a sandwich though... <br>
382
+ {Angry} You know what, darn you and your little shop! <br>
383
+ {Whisper} I'll just go back home and cry now. <br>
384
  {Shouting} Why me?!
385
  """
386
  )
387
 
388
  gr.Markdown(
389
  """
390
+ **Example Input 2:** <br>
391
+ {"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please. <br>
392
+ {"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread. <br>
393
+ {"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though... <br>
394
+ {"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
395
  """
396
  )
397
 
398
  gr.Markdown(
399
+ 'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
400
  )
401
 
402
  # Regular speech type (mandatory)
403
+ with gr.Row(variant="compact") as regular_row:
404
  with gr.Column(scale=1, min_width=160):
405
  regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
406
  regular_insert = gr.Button("Insert Label", variant="secondary")
407
  with gr.Column(scale=3):
408
  regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
409
  with gr.Column(scale=3):
410
+ regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
411
+ with gr.Row():
412
+ regular_seed_slider = gr.Slider(
413
+ show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
414
+ )
415
+ regular_speed_slider = gr.Slider(
416
+ show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
417
+ )
418
+ with gr.Column(scale=1, min_width=160):
419
+ regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
420
 
421
  # Regular speech type (max 100)
422
  max_speech_types = 100
 
425
  speech_type_audios = [regular_audio]
426
  speech_type_ref_texts = [regular_ref_text]
427
  speech_type_ref_text_files = [regular_ref_text_file]
428
+ speech_type_seeds = [regular_seed_slider]
429
+ speech_type_speeds = [regular_speed_slider]
430
  speech_type_delete_btns = [None]
431
  speech_type_insert_btns = [regular_insert]
432
 
433
  # Additional speech types (99 more)
434
  for i in range(max_speech_types - 1):
435
+ with gr.Row(variant="compact", visible=False) as row:
436
  with gr.Column(scale=1, min_width=160):
437
  name_input = gr.Textbox(label="Speech Type Name")
 
438
  insert_btn = gr.Button("Insert Label", variant="secondary")
439
+ delete_btn = gr.Button("Delete Type", variant="stop")
440
  with gr.Column(scale=3):
441
  audio_input = gr.Audio(label="Reference Audio", type="filepath")
442
  with gr.Column(scale=3):
443
+ ref_text_input = gr.Textbox(label="Reference Text", lines=4)
444
+ with gr.Row():
445
+ seed_input = gr.Slider(
446
+ show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
447
+ )
448
+ speed_input = gr.Slider(
449
+ show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
450
+ )
451
+ with gr.Column(scale=1, min_width=160):
452
+ ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
453
  speech_type_rows.append(row)
454
  speech_type_names.append(name_input)
455
  speech_type_audios.append(audio_input)
456
  speech_type_ref_texts.append(ref_text_input)
457
  speech_type_ref_text_files.append(ref_text_file_input)
458
+ speech_type_seeds.append(seed_input)
459
+ speech_type_speeds.append(speed_input)
460
  speech_type_delete_btns.append(delete_btn)
461
  speech_type_insert_btns.append(insert_btn)
462
 
463
+ # Global logic for all speech types
464
+ for i in range(max_speech_types):
465
+ speech_type_audios[i].clear(
466
+ lambda: [None, None],
467
+ None,
468
+ [speech_type_ref_texts[i], speech_type_ref_text_files[i]],
469
+ )
470
+ speech_type_ref_text_files[i].upload(
471
+ load_text_from_file,
472
+ inputs=[speech_type_ref_text_files[i]],
473
+ outputs=[speech_type_ref_texts[i]],
474
+ )
475
+
476
  # Button to add speech type
477
  add_speech_type_btn = gr.Button("Add Speech Type")
478
 
 
508
  speech_type_ref_text_files[i],
509
  ],
510
  )
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
  # Text input for the prompt
513
  with gr.Row():
 
521
  gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
522
 
523
  def make_insert_speech_type_fn(index):
524
+ def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
525
  current_text = current_text or ""
526
+ if not speech_type_name:
527
+ gr.Warning("Please enter speech type name before insert.")
528
+ return current_text
529
+ speech_type_dict = {
530
+ "name": speech_type_name,
531
+ "seed": speech_type_seed,
532
+ "speed": speech_type_speed,
533
+ }
534
+ updated_text = current_text + json.dumps(speech_type_dict) + " "
535
  return updated_text
536
 
537
  return insert_speech_type_fn
 
540
  insert_fn = make_insert_speech_type_fn(i)
541
  insert_btn.click(
542
  insert_fn,
543
+ inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
544
  outputs=gen_text_input_multistyle,
545
  )
546
 
547
+ with gr.Accordion("Advanced Settings", open=True):
548
+ with gr.Row():
549
+ with gr.Column():
550
+ show_cherrypick_multistyle = gr.Checkbox(
551
+ label="Show Cherry-pick Interface",
552
+ info="Turn on to show interface, picking seeds from previous generations.",
553
+ value=False,
554
+ )
555
+ with gr.Column():
556
+ remove_silence_multistyle = gr.Checkbox(
557
+ label="Remove Silences",
558
+ info="Turn on to automatically detect and crop long silences.",
559
+ value=True,
560
+ )
561
 
562
  # Generate button
563
  generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
 
565
  # Output audio
566
  audio_output_multistyle = gr.Audio(label="Synthesized Audio")
567
 
568
+ # Used seed gallery
569
+ cherrypick_interface_multistyle = gr.Textbox(
570
+ label="Cherry-pick Interface",
571
+ lines=10,
572
+ max_lines=40,
573
+ show_copy_button=True,
574
+ interactive=False,
575
+ visible=False,
576
+ )
577
+
578
+ # Logic control to show/hide the cherrypick interface
579
+ show_cherrypick_multistyle.change(
580
+ lambda is_visible: gr.update(visible=is_visible),
581
+ show_cherrypick_multistyle,
582
+ cherrypick_interface_multistyle,
583
+ )
584
+
585
+ # Function to load text to generate from file
586
  gen_text_file_multistyle.upload(
587
  load_text_from_file,
588
  inputs=[gen_text_file_multistyle],
 
616
 
617
  # For each segment, generate speech
618
  generated_audio_segments = []
619
+ current_type_name = "Regular"
620
+ inference_meta_data = ""
621
 
622
  for segment in segments:
623
+ name = segment["name"]
624
+ seed_input = segment["seed"]
625
+ speed = segment["speed"]
626
  text = segment["text"]
627
 
628
+ if name in speech_types:
629
+ current_type_name = name
630
  else:
631
+ gr.Warning(f"Type {name} is not available, will use Regular as default.")
632
+ current_type_name = "Regular"
633
 
634
  try:
635
+ ref_audio = speech_types[current_type_name]["audio"]
636
  except KeyError:
637
+ gr.Warning(f"Please provide reference audio for type {current_type_name}.")
638
+ return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
639
+ ref_text = speech_types[current_type_name].get("ref_text", "")
640
 
641
+ if seed_input == -1:
642
+ seed_input = np.random.randint(0, 2**31 - 1)
643
 
644
+ # Generate or retrieve speech for this segment
645
+ audio_out, _, ref_text_out, used_seed = infer(
646
+ ref_audio,
647
+ ref_text,
648
+ text,
649
+ tts_model_choice,
650
+ remove_silence,
651
+ seed=seed_input,
652
+ cross_fade_duration=0,
653
+ speed=speed,
654
+ show_info=print, # no pull to top when generating
655
+ )
656
  sr, audio_data = audio_out
657
 
658
  generated_audio_segments.append(audio_data)
659
+ speech_types[current_type_name]["ref_text"] = ref_text_out
660
+ inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
661
 
662
  # Concatenate all audio segments
663
  if generated_audio_segments:
664
  final_audio_data = np.concatenate(generated_audio_segments)
665
+ return (
666
+ [(sr, final_audio_data)]
667
+ + [speech_types[name]["ref_text"] for name in speech_types]
668
+ + [inference_meta_data]
669
+ )
670
  else:
671
  gr.Warning("No audio generated.")
672
+ return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
673
 
674
  generate_multistyle_btn.click(
675
  generate_multistyle_speech,
 
682
  + [
683
  remove_silence_multistyle,
684
  ],
685
+ outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
686
  )
687
 
688
  # Validation function to disable Generate button if speech types are missing
 
699
 
700
  # Parse the gen_text to get the speech types used
701
  segments = parse_speechtypes_text(gen_text)
702
+ speech_types_in_text = set(segment["name"] for segment in segments)
703
 
704
  # Check if all speech types in text are available
705
  missing_speech_types = speech_types_in_text - speech_types_available
 
863
  if not last_ai_response or conv_state[-1]["role"] != "assistant":
864
  return None, ref_text, seed_input
865
 
 
866
  if randomize_seed:
867
+ seed_input = np.random.randint(0, 2**31 - 1)
 
 
 
 
 
868
 
869
+ audio_result, _, ref_text_out, used_seed = infer(
870
  ref_audio,
871
  ref_text,
872
  last_ai_response,
873
  tts_model_choice,
874
  remove_silence,
875
+ seed=seed_input,
876
  cross_fade_duration=0.15,
877
  speed=1.0,
878
  show_info=print, # show_info=print no pull to top when generating
879
  )
880
+ return audio_result, ref_text_out, used_seed
881
 
882
  def clear_conversation():
883
  """Reset the conversation"""
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "f5-tts"
7
- version = "1.1.2"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
 
4
 
5
  [project]
6
  name = "f5-tts"
7
+ version = "1.1.3"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
ruff.toml CHANGED
@@ -6,5 +6,5 @@ target-version = "py310"
6
  dummy-variable-rgx = "^_.*$"
7
 
8
  [lint.isort]
9
- force-single-line = true
10
  lines-after-imports = 2
 
6
  dummy-variable-rgx = "^_.*$"
7
 
8
  [lint.isort]
9
+ force-single-line = false
10
  lines-after-imports = 2
src/f5_tts/api.py CHANGED
@@ -9,13 +9,13 @@ from hydra.utils import get_class
9
  from omegaconf import OmegaConf
10
 
11
  from f5_tts.infer.utils_infer import (
 
12
  load_model,
13
  load_vocoder,
14
- transcribe,
15
  preprocess_ref_audio_text,
16
- infer_process,
17
  remove_silence_for_generated_wav,
18
  save_spectrogram,
 
19
  )
20
  from f5_tts.model.utils import seed_everything
21
 
 
9
  from omegaconf import OmegaConf
10
 
11
  from f5_tts.infer.utils_infer import (
12
+ infer_process,
13
  load_model,
14
  load_vocoder,
 
15
  preprocess_ref_audio_text,
 
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
18
+ transcribe,
19
  )
20
  from f5_tts.model.utils import seed_everything
21
 
src/f5_tts/eval/ecapa_tdnn.py CHANGED
@@ -4,6 +4,7 @@
4
  # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
 
6
  import os
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
 
4
  # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
 
6
  import os
7
+
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import sys
3
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import argparse
@@ -23,6 +24,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
23
  from f5_tts.model import CFM
24
  from f5_tts.model.utils import get_tokenizer
25
 
 
26
  accelerator = Accelerator()
27
  device = f"cuda:{accelerator.process_index}"
28
 
 
1
  import os
2
  import sys
3
 
4
+
5
  sys.path.append(os.getcwd())
6
 
7
  import argparse
 
24
  from f5_tts.model import CFM
25
  from f5_tts.model.utils import get_tokenizer
26
 
27
+
28
  accelerator = Accelerator()
29
  device = f"cuda:{accelerator.process_index}"
30
 
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -5,17 +5,16 @@ import json
5
  import os
6
  import sys
7
 
 
8
  sys.path.append(os.getcwd())
9
 
10
  import multiprocessing as mp
11
  from importlib.resources import files
12
 
13
  import numpy as np
14
- from f5_tts.eval.utils_eval import (
15
- get_librispeech_test,
16
- run_asr_wer,
17
- run_sim,
18
- )
19
 
20
  rel_path = str(files("f5_tts").joinpath("../../"))
21
 
 
5
  import os
6
  import sys
7
 
8
+
9
  sys.path.append(os.getcwd())
10
 
11
  import multiprocessing as mp
12
  from importlib.resources import files
13
 
14
  import numpy as np
15
+
16
+ from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
17
+
 
 
18
 
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -5,17 +5,16 @@ import json
5
  import os
6
  import sys
7
 
 
8
  sys.path.append(os.getcwd())
9
 
10
  import multiprocessing as mp
11
  from importlib.resources import files
12
 
13
  import numpy as np
14
- from f5_tts.eval.utils_eval import (
15
- get_seed_tts_test,
16
- run_asr_wer,
17
- run_sim,
18
- )
19
 
20
  rel_path = str(files("f5_tts").joinpath("../../"))
21
 
 
5
  import os
6
  import sys
7
 
8
+
9
  sys.path.append(os.getcwd())
10
 
11
  import multiprocessing as mp
12
  from importlib.resources import files
13
 
14
  import numpy as np
15
+
16
+ from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
17
+
 
 
18
 
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
src/f5_tts/infer/infer_cli.py CHANGED
@@ -14,20 +14,20 @@ from hydra.utils import get_class
14
  from omegaconf import OmegaConf
15
 
16
  from f5_tts.infer.utils_infer import (
17
- mel_spec_type,
18
- target_rms,
19
- cross_fade_duration,
20
- nfe_step,
21
  cfg_strength,
22
- sway_sampling_coef,
23
- speed,
24
- fix_duration,
25
  device,
 
26
  infer_process,
27
  load_model,
28
  load_vocoder,
 
 
29
  preprocess_ref_audio_text,
30
  remove_silence_for_generated_wav,
 
 
 
31
  )
32
 
33
 
 
14
  from omegaconf import OmegaConf
15
 
16
  from f5_tts.infer.utils_infer import (
 
 
 
 
17
  cfg_strength,
18
+ cross_fade_duration,
 
 
19
  device,
20
+ fix_duration,
21
  infer_process,
22
  load_model,
23
  load_vocoder,
24
+ mel_spec_type,
25
+ nfe_step,
26
  preprocess_ref_audio_text,
27
  remove_silence_for_generated_wav,
28
+ speed,
29
+ sway_sampling_coef,
30
+ target_rms,
31
  )
32
 
33
 
src/f5_tts/infer/speech_edit.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
 
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
4
 
5
  from importlib.resources import files
@@ -7,14 +8,15 @@ from importlib.resources import files
7
  import torch
8
  import torch.nn.functional as F
9
  import torchaudio
 
10
  from hydra.utils import get_class
11
  from omegaconf import OmegaConf
12
- from cached_path import cached_path
13
 
14
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
15
  from f5_tts.model import CFM
16
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
17
 
 
18
  device = (
19
  "cuda"
20
  if torch.cuda.is_available()
 
1
  import os
2
 
3
+
4
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
5
 
6
  from importlib.resources import files
 
8
  import torch
9
  import torch.nn.functional as F
10
  import torchaudio
11
+ from cached_path import cached_path
12
  from hydra.utils import get_class
13
  from omegaconf import OmegaConf
 
14
 
15
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
16
  from f5_tts.model import CFM
17
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
18
 
19
+
20
  device = (
21
  "cuda"
22
  if torch.cuda.is_available()
src/f5_tts/infer/utils_infer.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import sys
5
  from concurrent.futures import ThreadPoolExecutor
6
 
 
7
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
8
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
9
 
@@ -14,6 +15,7 @@ from importlib.resources import files
14
 
15
  import matplotlib
16
 
 
17
  matplotlib.use("Agg")
18
 
19
  import matplotlib.pylab as plt
@@ -27,10 +29,8 @@ from transformers import pipeline
27
  from vocos import Vocos
28
 
29
  from f5_tts.model import CFM
30
- from f5_tts.model.utils import (
31
- get_tokenizer,
32
- convert_char_to_pinyin,
33
- )
34
 
35
  _ref_audio_cache = {}
36
 
 
4
  import sys
5
  from concurrent.futures import ThreadPoolExecutor
6
 
7
+
8
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
9
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
10
 
 
15
 
16
  import matplotlib
17
 
18
+
19
  matplotlib.use("Agg")
20
 
21
  import matplotlib.pylab as plt
 
29
  from vocos import Vocos
30
 
31
  from f5_tts.model import CFM
32
+ from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
33
+
 
 
34
 
35
  _ref_audio_cache = {}
36
 
src/f5_tts/model/__init__.py CHANGED
@@ -1,9 +1,7 @@
1
- from f5_tts.model.cfm import CFM
2
-
3
- from f5_tts.model.backbones.unett import UNetT
4
  from f5_tts.model.backbones.dit import DiT
5
  from f5_tts.model.backbones.mmdit import MMDiT
6
-
 
7
  from f5_tts.model.trainer import Trainer
8
 
9
 
 
 
 
 
1
  from f5_tts.model.backbones.dit import DiT
2
  from f5_tts.model.backbones.mmdit import MMDiT
3
+ from f5_tts.model.backbones.unett import UNetT
4
+ from f5_tts.model.cfm import CFM
5
  from f5_tts.model.trainer import Trainer
6
 
7
 
src/f5_tts/model/backbones/dit.py CHANGED
@@ -10,19 +10,18 @@ d - dimension
10
  from __future__ import annotations
11
 
12
  import torch
13
- from torch import nn
14
  import torch.nn.functional as F
15
-
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
  from f5_tts.model.modules import (
19
- TimestepEmbedding,
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
- AdaLayerNorm_Final,
24
- precompute_freqs_cis,
25
  get_pos_embed_indices,
 
26
  )
27
 
28
 
 
10
  from __future__ import annotations
11
 
12
  import torch
 
13
  import torch.nn.functional as F
14
+ from torch import nn
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
  from f5_tts.model.modules import (
18
+ AdaLayerNorm_Final,
19
  ConvNeXtV2Block,
20
  ConvPositionEmbedding,
21
  DiTBlock,
22
+ TimestepEmbedding,
 
23
  get_pos_embed_indices,
24
+ precompute_freqs_cis,
25
  )
26
 
27
 
src/f5_tts/model/backbones/mmdit.py CHANGED
@@ -11,16 +11,15 @@ from __future__ import annotations
11
 
12
  import torch
13
  from torch import nn
14
-
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
  from f5_tts.model.modules import (
18
- TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
- AdaLayerNorm_Final,
22
- precompute_freqs_cis,
23
  get_pos_embed_indices,
 
24
  )
25
 
26
 
 
11
 
12
  import torch
13
  from torch import nn
 
14
  from x_transformers.x_transformers import RotaryEmbedding
15
 
16
  from f5_tts.model.modules import (
17
+ AdaLayerNorm_Final,
18
  ConvPositionEmbedding,
19
  MMDiTBlock,
20
+ TimestepEmbedding,
 
21
  get_pos_embed_indices,
22
+ precompute_freqs_cis,
23
  )
24
 
25
 
src/f5_tts/model/backbones/unett.py CHANGED
@@ -8,24 +8,24 @@ d - dimension
8
  """
9
 
10
  from __future__ import annotations
 
11
  from typing import Literal
12
 
13
  import torch
14
- from torch import nn
15
  import torch.nn.functional as F
16
-
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
  from f5_tts.model.modules import (
21
- TimestepEmbedding,
22
- ConvNeXtV2Block,
23
- ConvPositionEmbedding,
24
  Attention,
25
  AttnProcessor,
 
 
26
  FeedForward,
27
- precompute_freqs_cis,
28
  get_pos_embed_indices,
 
29
  )
30
 
31
 
 
8
  """
9
 
10
  from __future__ import annotations
11
+
12
  from typing import Literal
13
 
14
  import torch
 
15
  import torch.nn.functional as F
16
+ from torch import nn
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
  from f5_tts.model.modules import (
 
 
 
21
  Attention,
22
  AttnProcessor,
23
+ ConvNeXtV2Block,
24
+ ConvPositionEmbedding,
25
  FeedForward,
26
+ TimestepEmbedding,
27
  get_pos_embed_indices,
28
+ precompute_freqs_cis,
29
  )
30
 
31
 
src/f5_tts/model/trainer.py CHANGED
@@ -19,6 +19,7 @@ from f5_tts.model import CFM
19
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
20
  from f5_tts.model.utils import default, exists
21
 
 
22
  # trainer
23
 
24
 
 
19
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
20
  from f5_tts.model.utils import default, exists
21
 
22
+
23
  # trainer
24
 
25
 
src/f5_tts/model/utils.py CHANGED
@@ -5,12 +5,11 @@ import random
5
  from collections import defaultdict
6
  from importlib.resources import files
7
 
 
8
  import torch
 
9
  from torch.nn.utils.rnn import pad_sequence
10
 
11
- import jieba
12
- from pypinyin import lazy_pinyin, Style
13
-
14
 
15
  # seed everything
16
 
 
5
  from collections import defaultdict
6
  from importlib.resources import files
7
 
8
+ import jieba
9
  import torch
10
+ from pypinyin import Style, lazy_pinyin
11
  from torch.nn.utils.rnn import pad_sequence
12
 
 
 
 
13
 
14
  # seed everything
15
 
src/f5_tts/runtime/triton_trtllm/benchmark.py CHANGED
@@ -30,26 +30,27 @@ import argparse
30
  import json
31
  import os
32
  import time
33
- from typing import List, Dict, Union
34
 
 
 
 
35
  import torch
36
  import torch.distributed as dist
37
  import torch.nn.functional as F
38
- from torch.nn.utils.rnn import pad_sequence
39
  import torchaudio
40
- import jieba
41
- from pypinyin import Style, lazy_pinyin
42
  from datasets import load_dataset
43
- import datasets
44
  from huggingface_hub import hf_hub_download
 
 
 
 
 
45
  from torch.utils.data import DataLoader, DistributedSampler
46
  from tqdm import tqdm
47
  from vocos import Vocos
48
- from f5_tts_trtllm import F5TTS
49
- import tensorrt as trt
50
- from tensorrt_llm.runtime.session import Session, TensorInfo
51
- from tensorrt_llm.logger import logger
52
- from tensorrt_llm._utils import trt_dtype_to_torch
53
 
54
  torch.manual_seed(0)
55
 
@@ -381,8 +382,8 @@ def main():
381
  import sys
382
 
383
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
384
- from f5_tts.model import DiT
385
  from f5_tts.infer.utils_infer import load_model
 
386
 
387
  F5TTS_model_cfg = dict(
388
  dim=1024,
 
30
  import json
31
  import os
32
  import time
33
+ from typing import Dict, List, Union
34
 
35
+ import datasets
36
+ import jieba
37
+ import tensorrt as trt
38
  import torch
39
  import torch.distributed as dist
40
  import torch.nn.functional as F
 
41
  import torchaudio
 
 
42
  from datasets import load_dataset
43
+ from f5_tts_trtllm import F5TTS
44
  from huggingface_hub import hf_hub_download
45
+ from pypinyin import Style, lazy_pinyin
46
+ from tensorrt_llm._utils import trt_dtype_to_torch
47
+ from tensorrt_llm.logger import logger
48
+ from tensorrt_llm.runtime.session import Session, TensorInfo
49
+ from torch.nn.utils.rnn import pad_sequence
50
  from torch.utils.data import DataLoader, DistributedSampler
51
  from tqdm import tqdm
52
  from vocos import Vocos
53
+
 
 
 
 
54
 
55
  torch.manual_seed(0)
56
 
 
382
  import sys
383
 
384
  sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
 
385
  from f5_tts.infer.utils_infer import load_model
386
+ from f5_tts.model import DiT
387
 
388
  F5TTS_model_cfg = dict(
389
  dim=1024,
src/f5_tts/runtime/triton_trtllm/client_grpc.py CHANGED
@@ -44,7 +44,6 @@ python3 client_grpc.py \
44
  import argparse
45
  import asyncio
46
  import json
47
-
48
  import os
49
  import time
50
  import types
 
44
  import argparse
45
  import asyncio
46
  import json
 
47
  import os
48
  import time
49
  import types
src/f5_tts/runtime/triton_trtllm/client_http.py CHANGED
@@ -23,10 +23,11 @@
23
  # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
  # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
26
  import requests
27
  import soundfile as sf
28
- import numpy as np
29
- import argparse
30
 
31
 
32
  def get_args():
 
23
  # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
  # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import argparse
27
+
28
+ import numpy as np
29
  import requests
30
  import soundfile as sf
 
 
31
 
32
 
33
  def get_args():
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py CHANGED
@@ -1,18 +1,17 @@
1
- import tensorrt as trt
2
- import os
3
  import math
 
4
  import time
5
- from typing import List, Optional
6
  from functools import wraps
 
7
 
 
8
  import tensorrt_llm
9
- from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
10
- from tensorrt_llm.logger import logger
11
- from tensorrt_llm.runtime.session import Session
12
-
13
  import torch
14
  import torch.nn as nn
15
  import torch.nn.functional as F
 
 
 
16
 
17
 
18
  def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
 
 
 
1
  import math
2
+ import os
3
  import time
 
4
  from functools import wraps
5
+ from typing import List, Optional
6
 
7
+ import tensorrt as trt
8
  import tensorrt_llm
 
 
 
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
13
+ from tensorrt_llm.logger import logger
14
+ from tensorrt_llm.runtime.session import Session
15
 
16
 
17
  def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py CHANGED
@@ -24,16 +24,17 @@
24
  # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
  import json
 
 
 
27
  import torch
28
- from torch.nn.utils.rnn import pad_sequence
29
  import torch.nn.functional as F
30
- from torch.utils.dlpack import from_dlpack, to_dlpack
31
  import torchaudio
32
- import jieba
33
  import triton_python_backend_utils as pb_utils
34
- from pypinyin import Style, lazy_pinyin
35
- import os
36
  from f5_tts_trtllm import F5TTS
 
 
 
37
 
38
 
39
  def get_tokenizer(vocab_file_path: str):
 
24
  # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
  import json
27
+ import os
28
+
29
+ import jieba
30
  import torch
 
31
  import torch.nn.functional as F
 
32
  import torchaudio
 
33
  import triton_python_backend_utils as pb_utils
 
 
34
  from f5_tts_trtllm import F5TTS
35
+ from pypinyin import Style, lazy_pinyin
36
+ from torch.nn.utils.rnn import pad_sequence
37
+ from torch.utils.dlpack import from_dlpack, to_dlpack
38
 
39
 
40
  def get_tokenizer(vocab_file_path: str):
src/f5_tts/runtime/triton_trtllm/patch/__init__.py CHANGED
@@ -34,6 +34,7 @@ from .deepseek_v2.model import DeepseekV2ForCausalLM
34
  from .dit.model import DiT
35
  from .eagle.model import EagleForCausalLM
36
  from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
 
37
  from .falcon.config import FalconConfig
38
  from .falcon.model import FalconForCausalLM, FalconModel
39
  from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
@@ -54,12 +55,12 @@ from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodi
54
  from .mpt.model import MPTForCausalLM, MPTModel
55
  from .nemotron_nas.model import DeciLMForCausalLM
56
  from .opt.model import OPTForCausalLM, OPTModel
57
- from .phi3.model import Phi3ForCausalLM, Phi3Model
58
  from .phi.model import PhiForCausalLM, PhiModel
 
59
  from .qwen.model import QWenForCausalLM
60
  from .recurrentgemma.model import RecurrentGemmaForCausalLM
61
  from .redrafter.model import ReDrafterForCausalLM
62
- from .f5tts.model import F5TTS
63
 
64
  __all__ = [
65
  "BertModel",
 
34
  from .dit.model import DiT
35
  from .eagle.model import EagleForCausalLM
36
  from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
37
+ from .f5tts.model import F5TTS
38
  from .falcon.config import FalconConfig
39
  from .falcon.model import FalconForCausalLM, FalconModel
40
  from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
 
55
  from .mpt.model import MPTForCausalLM, MPTModel
56
  from .nemotron_nas.model import DeciLMForCausalLM
57
  from .opt.model import OPTForCausalLM, OPTModel
 
58
  from .phi.model import PhiForCausalLM, PhiModel
59
+ from .phi3.model import Phi3ForCausalLM, Phi3Model
60
  from .qwen.model import QWenForCausalLM
61
  from .recurrentgemma.model import RecurrentGemmaForCausalLM
62
  from .redrafter.model import ReDrafterForCausalLM
63
+
64
 
65
  __all__ = [
66
  "BertModel",
src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py CHANGED
@@ -1,23 +1,20 @@
1
  from __future__ import annotations
2
- import sys
3
  import os
 
 
4
 
5
  import tensorrt as trt
6
- from collections import OrderedDict
 
7
  from ..._utils import str_dtype_to_trt
8
- from ...plugin import current_all_reduce_helper
9
- from ..modeling_utils import PretrainedConfig, PretrainedModel
10
  from ...functional import Tensor, concat
11
- from ...module import Module, ModuleList
12
- from tensorrt_llm._common import default_net
13
  from ...layers import Linear
 
 
 
 
14
 
15
- from .modules import (
16
- TimestepEmbedding,
17
- ConvPositionEmbedding,
18
- DiTBlock,
19
- AdaLayerNormZero_Final,
20
- )
21
 
22
  current_file_path = os.path.abspath(__file__)
23
  parent_dir = os.path.dirname(current_file_path)
 
1
  from __future__ import annotations
2
+
3
  import os
4
+ import sys
5
+ from collections import OrderedDict
6
 
7
  import tensorrt as trt
8
+ from tensorrt_llm._common import default_net
9
+
10
  from ..._utils import str_dtype_to_trt
 
 
11
  from ...functional import Tensor, concat
 
 
12
  from ...layers import Linear
13
+ from ...module import Module, ModuleList
14
+ from ...plugin import current_all_reduce_helper
15
+ from ..modeling_utils import PretrainedConfig, PretrainedModel
16
+ from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
17
 
 
 
 
 
 
 
18
 
19
  current_file_path = os.path.abspath(__file__)
20
  parent_dir = os.path.dirname(current_file_path)
src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py CHANGED
@@ -3,33 +3,35 @@ from __future__ import annotations
3
  import math
4
  from typing import Optional
5
 
 
6
  import torch
7
  import torch.nn.functional as F
8
-
9
- import numpy as np
10
  from tensorrt_llm._common import default_net
11
- from ..._utils import trt_dtype_to_np, str_dtype_to_trt
 
12
  from ...functional import (
13
  Tensor,
 
 
14
  chunk,
15
  concat,
16
  constant,
17
  expand,
 
 
 
 
 
 
18
  shape,
19
  silu,
20
  slice,
21
- permute,
22
- expand_mask,
23
- expand_dims_like,
24
- unsqueeze,
25
- matmul,
26
  softmax,
27
  squeeze,
28
- cast,
29
- gelu,
30
  )
31
- from ...functional import expand_dims, view, bert_attention
32
- from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
33
  from ...module import Module
34
 
35
 
 
3
  import math
4
  from typing import Optional
5
 
6
+ import numpy as np
7
  import torch
8
  import torch.nn.functional as F
 
 
9
  from tensorrt_llm._common import default_net
10
+
11
+ from ..._utils import str_dtype_to_trt, trt_dtype_to_np
12
  from ...functional import (
13
  Tensor,
14
+ bert_attention,
15
+ cast,
16
  chunk,
17
  concat,
18
  constant,
19
  expand,
20
+ expand_dims,
21
+ expand_dims_like,
22
+ expand_mask,
23
+ gelu,
24
+ matmul,
25
+ permute,
26
  shape,
27
  silu,
28
  slice,
 
 
 
 
 
29
  softmax,
30
  squeeze,
31
+ unsqueeze,
32
+ view,
33
  )
34
+ from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
 
35
  from ...module import Module
36
 
37
 
src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py CHANGED
@@ -40,6 +40,7 @@ import torch as th
40
  import torch.nn.functional as F
41
  from scipy.signal import check_COLA, get_window
42
 
 
43
  support_clp_op = None
44
  if th.__version__ >= "1.7.0":
45
  from torch.fft import rfft as fft
 
40
  import torch.nn.functional as F
41
  from scipy.signal import check_COLA, get_window
42
 
43
+
44
  support_clp_op = None
45
  if th.__version__ >= "1.7.0":
46
  from torch.fft import rfft as fft
src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py CHANGED
@@ -8,7 +8,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
8
 
9
  import safetensors.torch
10
  import torch
11
-
12
  from tensorrt_llm import str_dtype_to_torch
13
  from tensorrt_llm.mapping import Mapping
14
  from tensorrt_llm.models.convert_utils import split, split_matrix_tp
 
8
 
9
  import safetensors.torch
10
  import torch
 
11
  from tensorrt_llm import str_dtype_to_torch
12
  from tensorrt_llm.mapping import Mapping
13
  from tensorrt_llm.models.convert_utils import split, split_matrix_tp
src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py CHANGED
@@ -12,13 +12,14 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import torch
16
  import torch.nn as nn
17
- from huggingface_hub import hf_hub_download
18
-
19
  from conv_stft import STFT
 
20
  from vocos import Vocos
21
- import argparse
22
 
23
  opset_version = 17
24
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import argparse
16
+
17
  import torch
18
  import torch.nn as nn
 
 
19
  from conv_stft import STFT
20
+ from huggingface_hub import hf_hub_download
21
  from vocos import Vocos
22
+
23
 
24
  opset_version = 17
25
 
src/f5_tts/scripts/count_params_gflops.py CHANGED
@@ -1,12 +1,13 @@
1
- import sys
2
  import os
 
3
 
4
- sys.path.append(os.getcwd())
5
 
6
- from f5_tts.model import CFM, DiT
7
 
8
- import torch
9
  import thop
 
 
 
10
 
11
 
12
  """ ~155M """
 
 
1
  import os
2
+ import sys
3
 
 
4
 
5
+ sys.path.append(os.getcwd())
6
 
 
7
  import thop
8
+ import torch
9
+
10
+ from f5_tts.model import CFM, DiT
11
 
12
 
13
  """ ~155M """
src/f5_tts/socket_client.py CHANGED
@@ -1,10 +1,12 @@
1
- import socket
2
  import asyncio
3
- import pyaudio
4
- import numpy as np
5
  import logging
 
6
  import time
7
 
 
 
 
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
 
 
1
  import asyncio
 
 
2
  import logging
3
+ import socket
4
  import time
5
 
6
+ import numpy as np
7
+ import pyaudio
8
+
9
+
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
src/f5_tts/socket_server.py CHANGED
@@ -1,7 +1,6 @@
1
  import argparse
2
  import gc
3
  import logging
4
- import numpy as np
5
  import queue
6
  import socket
7
  import struct
@@ -10,6 +9,7 @@ import traceback
10
  import wave
11
  from importlib.resources import files
12
 
 
13
  import torch
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
@@ -18,12 +18,13 @@ from omegaconf import OmegaConf
18
 
19
  from f5_tts.infer.utils_infer import (
20
  chunk_text,
21
- preprocess_ref_audio_text,
22
- load_vocoder,
23
- load_model,
24
  infer_batch_process,
 
 
 
25
  )
26
 
 
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
29
 
 
1
  import argparse
2
  import gc
3
  import logging
 
4
  import queue
5
  import socket
6
  import struct
 
9
  import wave
10
  from importlib.resources import files
11
 
12
+ import numpy as np
13
  import torch
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
 
18
 
19
  from f5_tts.infer.utils_infer import (
20
  chunk_text,
 
 
 
21
  infer_batch_process,
22
+ load_model,
23
+ load_vocoder,
24
+ preprocess_ref_audio_text,
25
  )
26
 
27
+
28
  logging.basicConfig(level=logging.INFO)
29
  logger = logging.getLogger(__name__)
30
 
src/f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -1,12 +1,13 @@
 
 
1
  import os
2
- import sys
3
  import signal
4
  import subprocess # For invoking ffprobe
5
- import shutil
6
- import concurrent.futures
7
- import multiprocessing
8
  from contextlib import contextmanager
9
 
 
10
  sys.path.append(os.getcwd())
11
 
12
  import argparse
@@ -16,12 +17,10 @@ from importlib.resources import files
16
  from pathlib import Path
17
 
18
  import torchaudio
19
- from tqdm import tqdm
20
  from datasets.arrow_writer import ArrowWriter
 
21
 
22
- from f5_tts.model.utils import (
23
- convert_char_to_pinyin,
24
- )
25
 
26
 
27
  PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
 
1
+ import concurrent.futures
2
+ import multiprocessing
3
  import os
4
+ import shutil
5
  import signal
6
  import subprocess # For invoking ffprobe
7
+ import sys
 
 
8
  from contextlib import contextmanager
9
 
10
+
11
  sys.path.append(os.getcwd())
12
 
13
  import argparse
 
17
  from pathlib import Path
18
 
19
  import torchaudio
 
20
  from datasets.arrow_writer import ArrowWriter
21
+ from tqdm import tqdm
22
 
23
+ from f5_tts.model.utils import convert_char_to_pinyin
 
 
24
 
25
 
26
  PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
src/f5_tts/train/datasets/prepare_emilia.py CHANGED
@@ -7,20 +7,18 @@
7
  import os
8
  import sys
9
 
 
10
  sys.path.append(os.getcwd())
11
 
12
  import json
13
  from concurrent.futures import ProcessPoolExecutor
14
  from importlib.resources import files
15
  from pathlib import Path
16
- from tqdm import tqdm
17
 
18
  from datasets.arrow_writer import ArrowWriter
 
19
 
20
- from f5_tts.model.utils import (
21
- repetition_found,
22
- convert_char_to_pinyin,
23
- )
24
 
25
 
26
  out_zh = {
 
7
  import os
8
  import sys
9
 
10
+
11
  sys.path.append(os.getcwd())
12
 
13
  import json
14
  from concurrent.futures import ProcessPoolExecutor
15
  from importlib.resources import files
16
  from pathlib import Path
 
17
 
18
  from datasets.arrow_writer import ArrowWriter
19
+ from tqdm import tqdm
20
 
21
+ from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
 
 
 
22
 
23
 
24
  out_zh = {
src/f5_tts/train/datasets/prepare_emilia_v2.py CHANGED
@@ -1,17 +1,17 @@
1
  # put in src/f5_tts/train/datasets/prepare_emilia_v2.py
2
  # prepares Emilia dataset with the new format w/ Emilia-YODAS
3
 
4
- import os
5
  import json
 
6
  from concurrent.futures import ProcessPoolExecutor
 
7
  from pathlib import Path
8
- from tqdm import tqdm
9
  from datasets.arrow_writer import ArrowWriter
10
- from importlib.resources import files
 
 
11
 
12
- from f5_tts.model.utils import (
13
- repetition_found,
14
- )
15
 
16
  # Define filters for exclusion
17
  out_en = set()
 
1
  # put in src/f5_tts/train/datasets/prepare_emilia_v2.py
2
  # prepares Emilia dataset with the new format w/ Emilia-YODAS
3
 
 
4
  import json
5
+ import os
6
  from concurrent.futures import ProcessPoolExecutor
7
+ from importlib.resources import files
8
  from pathlib import Path
9
+
10
  from datasets.arrow_writer import ArrowWriter
11
+ from tqdm import tqdm
12
+
13
+ from f5_tts.model.utils import repetition_found
14
 
 
 
 
15
 
16
  # Define filters for exclusion
17
  out_en = set()
src/f5_tts/train/datasets/prepare_libritts.py CHANGED
@@ -1,15 +1,17 @@
1
  import os
2
  import sys
3
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import json
7
  from concurrent.futures import ProcessPoolExecutor
8
  from importlib.resources import files
9
  from pathlib import Path
10
- from tqdm import tqdm
11
  import soundfile as sf
12
  from datasets.arrow_writer import ArrowWriter
 
13
 
14
 
15
  def deal_with_audio_dir(audio_dir):
 
1
  import os
2
  import sys
3
 
4
+
5
  sys.path.append(os.getcwd())
6
 
7
  import json
8
  from concurrent.futures import ProcessPoolExecutor
9
  from importlib.resources import files
10
  from pathlib import Path
11
+
12
  import soundfile as sf
13
  from datasets.arrow_writer import ArrowWriter
14
+ from tqdm import tqdm
15
 
16
 
17
  def deal_with_audio_dir(audio_dir):
src/f5_tts/train/datasets/prepare_ljspeech.py CHANGED
@@ -1,14 +1,16 @@
1
  import os
2
  import sys
3
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import json
7
  from importlib.resources import files
8
  from pathlib import Path
9
- from tqdm import tqdm
10
  import soundfile as sf
11
  from datasets.arrow_writer import ArrowWriter
 
12
 
13
 
14
  def main():
 
1
  import os
2
  import sys
3
 
4
+
5
  sys.path.append(os.getcwd())
6
 
7
  import json
8
  from importlib.resources import files
9
  from pathlib import Path
10
+
11
  import soundfile as sf
12
  from datasets.arrow_writer import ArrowWriter
13
+ from tqdm import tqdm
14
 
15
 
16
  def main():
src/f5_tts/train/datasets/prepare_wenetspeech4tts.py CHANGED
@@ -4,15 +4,16 @@
4
  import os
5
  import sys
6
 
 
7
  sys.path.append(os.getcwd())
8
 
9
  import json
10
  from concurrent.futures import ProcessPoolExecutor
11
  from importlib.resources import files
12
- from tqdm import tqdm
13
 
14
  import torchaudio
15
  from datasets import Dataset
 
16
 
17
  from f5_tts.model.utils import convert_char_to_pinyin
18
 
 
4
  import os
5
  import sys
6
 
7
+
8
  sys.path.append(os.getcwd())
9
 
10
  import json
11
  from concurrent.futures import ProcessPoolExecutor
12
  from importlib.resources import files
 
13
 
14
  import torchaudio
15
  from datasets import Dataset
16
+ from tqdm import tqdm
17
 
18
  from f5_tts.model.utils import convert_char_to_pinyin
19
 
src/f5_tts/train/finetune_cli.py CHANGED
@@ -5,9 +5,9 @@ from importlib.resources import files
5
 
6
  from cached_path import cached_path
7
 
8
- from f5_tts.model import CFM, UNetT, DiT, Trainer
9
- from f5_tts.model.utils import get_tokenizer
10
  from f5_tts.model.dataset import load_dataset
 
11
 
12
 
13
  # -------------------------- Dataset Settings --------------------------- #
 
5
 
6
  from cached_path import cached_path
7
 
8
+ from f5_tts.model import CFM, DiT, Trainer, UNetT
 
9
  from f5_tts.model.dataset import load_dataset
10
+ from f5_tts.model.utils import get_tokenizer
11
 
12
 
13
  # -------------------------- Dataset Settings --------------------------- #
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -1,14 +1,12 @@
1
  import gc
2
  import json
3
- import numpy as np
4
  import os
5
  import platform
6
- import psutil
7
  import queue
8
  import random
9
  import re
10
- import signal
11
  import shutil
 
12
  import subprocess
13
  import sys
14
  import tempfile
@@ -16,21 +14,23 @@ import threading
16
  import time
17
  from glob import glob
18
  from importlib.resources import files
19
- from scipy.io import wavfile
20
 
21
  import click
22
  import gradio as gr
23
  import librosa
 
 
24
  import torch
25
  import torchaudio
26
  from cached_path import cached_path
27
  from datasets import Dataset as Dataset_
28
  from datasets.arrow_writer import ArrowWriter
29
  from safetensors.torch import load_file, save_file
 
30
 
31
  from f5_tts.api import F5TTS
32
- from f5_tts.model.utils import convert_char_to_pinyin
33
  from f5_tts.infer.utils_infer import transcribe
 
34
 
35
 
36
  training_process = None
 
1
  import gc
2
  import json
 
3
  import os
4
  import platform
 
5
  import queue
6
  import random
7
  import re
 
8
  import shutil
9
+ import signal
10
  import subprocess
11
  import sys
12
  import tempfile
 
14
  import time
15
  from glob import glob
16
  from importlib.resources import files
 
17
 
18
  import click
19
  import gradio as gr
20
  import librosa
21
+ import numpy as np
22
+ import psutil
23
  import torch
24
  import torchaudio
25
  from cached_path import cached_path
26
  from datasets import Dataset as Dataset_
27
  from datasets.arrow_writer import ArrowWriter
28
  from safetensors.torch import load_file, save_file
29
+ from scipy.io import wavfile
30
 
31
  from f5_tts.api import F5TTS
 
32
  from f5_tts.infer.utils_infer import transcribe
33
+ from f5_tts.model.utils import convert_char_to_pinyin
34
 
35
 
36
  training_process = None
src/f5_tts/train/train.py CHANGED
@@ -10,6 +10,7 @@ from f5_tts.model import CFM, Trainer
10
  from f5_tts.model.dataset import load_dataset
11
  from f5_tts.model.utils import get_tokenizer
12
 
 
13
  os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
14
 
15
 
 
10
  from f5_tts.model.dataset import load_dataset
11
  from f5_tts.model.utils import get_tokenizer
12
 
13
+
14
  os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
15
 
16