diff --git a/f5_tts/api.py b/f5_tts/api.py index 7c73c87dae0f79ddb53acd5101c21b525d119899..d9ca38e8bce90200521e8f4ec3c198dd69e7196c 100644 --- a/f5_tts/api.py +++ b/f5_tts/api.py @@ -5,43 +5,43 @@ from importlib.resources import files import soundfile as sf import tqdm from cached_path import cached_path -from omegaconf import OmegaConf from f5_tts.infer.utils_infer import ( + hop_length, + infer_process, load_model, load_vocoder, - transcribe, preprocess_ref_audio_text, - infer_process, remove_silence_for_generated_wav, save_spectrogram, + transcribe, + target_sample_rate, ) -from f5_tts.model import DiT, UNetT # noqa: F401. used for config +from f5_tts.model import DiT, UNetT from f5_tts.model.utils import seed_everything class F5TTS: def __init__( self, - model="F5TTS_v1_Base", + model_type="F5-TTS", ckpt_file="", vocab_file="", ode_method="euler", use_ema=True, - vocoder_local_path=None, + vocoder_name="vocos", + local_path=None, device=None, hf_cache_dir=None, ): - model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) - model_cls = globals()[model_cfg.model.backbone] - model_arc = model_cfg.model.arch - - self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type - self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate - - self.ode_method = ode_method - self.use_ema = use_ema - + # Initialize parameters + self.final_wave = None + self.target_sample_rate = target_sample_rate + self.hop_length = hop_length + self.seed = -1 + self.mel_spec_type = vocoder_name + + # Set device if device is not None: self.device = device else: @@ -58,31 +58,39 @@ class F5TTS: ) # Load models - self.vocoder = load_vocoder( - self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir + self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir) + self.load_ema_model( + model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir ) - repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" - - # override for previous models - if model == "F5TTS_Base": - if self.mel_spec_type == "vocos": - ckpt_step = 1200000 - elif self.mel_spec_type == "bigvgan": - model = "F5TTS_Base_bigvgan" - ckpt_type = "pt" - elif model == "E2TTS_Base": - repo_name = "E2-TTS" - ckpt_step = 1200000 + def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None): + self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir) + + def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None): + if model_type == "F5-TTS": + if not ckpt_file: + if mel_spec_type == "vocos": + ckpt_file = str( + cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) + ) + elif mel_spec_type == "bigvgan": + ckpt_file = str( + cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir) + ) + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) + model_cls = DiT + elif model_type == "E2-TTS": + if not ckpt_file: + ckpt_file = str( + cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) + ) + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) + model_cls = UNetT else: - raise ValueError(f"Unknown model type: {model}") + raise ValueError(f"Unknown model type: {model_type}") - if not ckpt_file: - ckpt_file = str( - cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir) - ) self.ema_model = load_model( - model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device + model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device ) def transcribe(self, ref_audio, language=None): @@ -94,8 +102,8 @@ class F5TTS: if remove_silence: remove_silence_for_generated_wav(file_wave) - def export_spectrogram(self, spec, file_spec): - save_spectrogram(spec, file_spec) + def export_spectrogram(self, spect, file_spect): + save_spectrogram(spect, file_spect) def infer( self, @@ -113,16 +121,17 @@ class F5TTS: fix_duration=None, remove_silence=False, file_wave=None, - file_spec=None, - seed=None, + file_spect=None, + seed=-1, ): - if seed is None: - self.seed = random.randint(0, sys.maxsize) - seed_everything(self.seed) + if seed == -1: + seed = random.randint(0, sys.maxsize) + seed_everything(seed) + self.seed = seed ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device) - wav, sr, spec = infer_process( + wav, sr, spect = infer_process( ref_file, ref_text, gen_text, @@ -144,22 +153,22 @@ class F5TTS: if file_wave is not None: self.export_wav(wav, file_wave, remove_silence) - if file_spec is not None: - self.export_spectrogram(spec, file_spec) + if file_spect is not None: + self.export_spectrogram(spect, file_spect) - return wav, sr, spec + return wav, sr, spect if __name__ == "__main__": f5tts = F5TTS() - wav, sr, spec = f5tts.infer( + wav, sr, spect = f5tts.infer( ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), ref_text="some call me nature, others call me mother nature.", gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), - file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")), - seed=None, + file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")), + seed=-1, # random seed = -1 ) print("seed :", f5tts.seed) diff --git a/f5_tts/configs/E2TTS_Base.yaml b/f5_tts/configs/E2TTS_Base_train.yaml similarity index 73% rename from f5_tts/configs/E2TTS_Base.yaml rename to f5_tts/configs/E2TTS_Base_train.yaml index ee701829414864454d42be86260a33722eccdf38..da23b05dfdb87f88b4403a847ad3ab285f4a222e 100644 --- a/f5_tts/configs/E2TTS_Base.yaml +++ b/f5_tts/configs/E2TTS_Base_train.yaml @@ -1,16 +1,16 @@ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - + datasets: name: Emilia_ZH_EN # dataset name batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 - batch_size_type: frame # frame | sample + batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: - epochs: 11 + epochs: 15 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps @@ -20,29 +20,25 @@ optim: model: name: E2TTS_Base tokenizer: pinyin - tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) - backbone: UNetT + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) arch: dim: 1024 depth: 24 heads: 16 ff_mult: 4 - text_mask_padding: False - pe_attn_head: 1 mel_spec: target_sample_rate: 24000 n_mel_channels: 100 hop_length: 256 win_length: 1024 n_fft: 1024 - mel_spec_type: vocos # vocos | bigvgan + mel_spec_type: vocos # 'vocos' or 'bigvgan' vocoder: is_local: False # use local offline ckpt or not - local_path: null # local vocoder path + local_path: None # local vocoder path ckpts: - logger: wandb # wandb | tensorboard | null - log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples + logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates diff --git a/f5_tts/configs/E2TTS_Small.yaml b/f5_tts/configs/E2TTS_Small_train.yaml similarity index 72% rename from f5_tts/configs/E2TTS_Small.yaml rename to f5_tts/configs/E2TTS_Small_train.yaml index cbb1f44e281ca9fc937eb8097af1bc3618c88d77..b2d1a6cb299f67c1a42a1148a264e1ccb16be2db 100644 --- a/f5_tts/configs/E2TTS_Small.yaml +++ b/f5_tts/configs/E2TTS_Small_train.yaml @@ -1,16 +1,16 @@ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - + datasets: name: Emilia_ZH_EN batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 - batch_size_type: frame # frame | sample + batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: - epochs: 11 + epochs: 15 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps @@ -20,29 +20,25 @@ optim: model: name: E2TTS_Small tokenizer: pinyin - tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) - backbone: UNetT + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) arch: dim: 768 depth: 20 heads: 12 ff_mult: 4 - text_mask_padding: False - pe_attn_head: 1 mel_spec: target_sample_rate: 24000 n_mel_channels: 100 hop_length: 256 win_length: 1024 n_fft: 1024 - mel_spec_type: vocos # vocos | bigvgan + mel_spec_type: vocos # 'vocos' or 'bigvgan' vocoder: is_local: False # use local offline ckpt or not - local_path: null # local vocoder path + local_path: None # local vocoder path ckpts: - logger: wandb # wandb | tensorboard | null - log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples + logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates diff --git a/f5_tts/configs/F5TTS_Base.yaml b/f5_tts/configs/F5TTS_Base_train.yaml similarity index 64% rename from f5_tts/configs/F5TTS_Base.yaml rename to f5_tts/configs/F5TTS_Base_train.yaml index 7043cb4a90206a741f509fa383bff575e903c3b9..24d811ea30cfae2a7ff95ff06d61daae1d62ba68 100644 --- a/f5_tts/configs/F5TTS_Base.yaml +++ b/f5_tts/configs/F5TTS_Base_train.yaml @@ -1,16 +1,16 @@ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - + datasets: - name: your_training_dataset # dataset name - batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 - batch_size_type: frame # frame | sample + name: vn_1000h # dataset name + batch_size_per_gpu: 2000 # 8 GPUs, 8 * 38400 = 307200 + batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: - epochs: 11 + epochs: 200 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps @@ -20,17 +20,14 @@ optim: model: name: F5TTS_Base # model name tokenizer: char # tokenizer type - tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) - backbone: DiT + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) arch: dim: 1024 depth: 22 heads: 16 ff_mult: 2 text_dim: 512 - text_mask_padding: False conv_layers: 4 - pe_attn_head: 1 checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 @@ -38,15 +35,14 @@ model: hop_length: 256 win_length: 1024 n_fft: 1024 - mel_spec_type: vocos # vocos | bigvgan + mel_spec_type: vocos # 'vocos' or 'bigvgan' vocoder: - is_local: False # use local offline ckpt or not - local_path: null # local vocoder path + is_local: True # use local offline ckpt or not + local_path: /mnt/i/Project/F5-TTS/ckpts/vocos # local vocoder path ckpts: - logger: tensorboard # wandb | tensorboard | null - log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples - save_per_updates: 50000 # save checkpoint per updates + logger: tensorboard # wandb | tensorboard | None + save_per_updates: 30000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/f5_tts/configs/F5TTS_Small.yaml b/f5_tts/configs/F5TTS_Small_train.yaml similarity index 75% rename from f5_tts/configs/F5TTS_Small.yaml rename to f5_tts/configs/F5TTS_Small_train.yaml index faae390337d076b18e4a35c1af4ac48d92524952..790be06f0f1c647c3d5d53f44be2259ef1bb7571 100644 --- a/f5_tts/configs/F5TTS_Small.yaml +++ b/f5_tts/configs/F5TTS_Small_train.yaml @@ -1,16 +1,16 @@ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - + datasets: name: Emilia_ZH_EN batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 - batch_size_type: frame # frame | sample + batch_size_type: frame # "frame" or "sample" max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: - epochs: 11 + epochs: 15 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps @@ -20,17 +20,14 @@ optim: model: name: F5TTS_Small tokenizer: pinyin - tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) - backbone: DiT + tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) arch: dim: 768 depth: 18 heads: 12 ff_mult: 2 text_dim: 512 - text_mask_padding: False conv_layers: 4 - pe_attn_head: 1 checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 @@ -38,14 +35,13 @@ model: hop_length: 256 win_length: 1024 n_fft: 1024 - mel_spec_type: vocos # vocos | bigvgan + mel_spec_type: vocos # 'vocos' or 'bigvgan' vocoder: is_local: False # use local offline ckpt or not - local_path: null # local vocoder path + local_path: None # local vocoder path ckpts: - logger: wandb # wandb | tensorboard | null - log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples + logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates diff --git a/f5_tts/configs/F5TTS_v1_Base.yaml b/f5_tts/configs/F5TTS_v1_Base.yaml deleted file mode 100644 index c7717facb114c0c1fc598e29a5c589485249d9d1..0000000000000000000000000000000000000000 --- a/f5_tts/configs/F5TTS_v1_Base.yaml +++ /dev/null @@ -1,53 +0,0 @@ -hydra: - run: - dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - -datasets: - name: Emilia_ZH_EN # dataset name - batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 - batch_size_type: frame # frame | sample - max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models - num_workers: 16 - -optim: - epochs: 11 - learning_rate: 7.5e-5 - num_warmup_updates: 20000 # warmup updates - grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps - max_grad_norm: 1.0 # gradient clipping - bnb_optimizer: False # use bnb 8bit AdamW optimizer or not - -model: - name: F5TTS_v1_Base # model name - tokenizer: pinyin # tokenizer type - tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) - backbone: DiT - arch: - dim: 1024 - depth: 22 - heads: 16 - ff_mult: 2 - text_dim: 512 - text_mask_padding: True - qk_norm: null # null | rms_norm - conv_layers: 4 - pe_attn_head: null - checkpoint_activations: False # recompute activations and save memory for extra compute - mel_spec: - target_sample_rate: 24000 - n_mel_channels: 100 - hop_length: 256 - win_length: 1024 - n_fft: 1024 - mel_spec_type: vocos # vocos | bigvgan - vocoder: - is_local: False # use local offline ckpt or not - local_path: null # local vocoder path - -ckpts: - logger: wandb # wandb | tensorboard | null - log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples - save_per_updates: 50000 # save checkpoint per updates - keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints - last_per_updates: 5000 # save last checkpoint per updates - save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/f5_tts/eval/eval_infer_batch.py b/f5_tts/eval/eval_infer_batch.py index e779ff0703c1d6febef273fc1bae6e2ec2c2b266..785880ccd14564b1615b0dca66ed93e66fff2a1f 100644 --- a/f5_tts/eval/eval_infer_batch.py +++ b/f5_tts/eval/eval_infer_batch.py @@ -10,7 +10,6 @@ from importlib.resources import files import torch import torchaudio from accelerate import Accelerator -from omegaconf import OmegaConf from tqdm import tqdm from f5_tts.eval.utils_eval import ( @@ -19,26 +18,36 @@ from f5_tts.eval.utils_eval import ( get_seedtts_testset_metainfo, ) from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder -from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config +from f5_tts.model import CFM, DiT, UNetT from f5_tts.model.utils import get_tokenizer accelerator = Accelerator() device = f"cuda:{accelerator.process_index}" -use_ema = True -target_rms = 0.1 +# --------------------- Dataset Settings -------------------- # +target_sample_rate = 24000 +n_mel_channels = 100 +hop_length = 256 +win_length = 1024 +n_fft = 1024 +target_rms = 0.1 rel_path = str(files("f5_tts").joinpath("../../")) def main(): + # ---------------------- infer setting ---------------------- # + parser = argparse.ArgumentParser(description="batch inference") parser.add_argument("-s", "--seed", default=None, type=int) + parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN") parser.add_argument("-n", "--expname", required=True) - parser.add_argument("-c", "--ckptstep", default=1250000, type=int) + parser.add_argument("-c", "--ckptstep", default=1200000, type=int) + parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"]) + parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]) parser.add_argument("-nfe", "--nfestep", default=32, type=int) parser.add_argument("-o", "--odemethod", default="euler") @@ -49,8 +58,12 @@ def main(): args = parser.parse_args() seed = args.seed + dataset_name = args.dataset exp_name = args.expname ckpt_step = args.ckptstep + ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt" + mel_spec_type = args.mel_spec_type + tokenizer = args.tokenizer nfe_step = args.nfestep ode_method = args.odemethod @@ -64,19 +77,13 @@ def main(): use_truth_duration = False no_ref_audio = False - model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) - model_cls = globals()[model_cfg.model.backbone] - model_arc = model_cfg.model.arch + if exp_name == "F5TTS_Base": + model_cls = DiT + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) - dataset_name = model_cfg.datasets.name - tokenizer = model_cfg.model.tokenizer - - mel_spec_type = model_cfg.model.mel_spec.mel_spec_type - target_sample_rate = model_cfg.model.mel_spec.target_sample_rate - n_mel_channels = model_cfg.model.mel_spec.n_mel_channels - hop_length = model_cfg.model.mel_spec.hop_length - win_length = model_cfg.model.mel_spec.win_length - n_fft = model_cfg.model.mel_spec.n_fft + elif exp_name == "E2TTS_Base": + model_cls = UNetT + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if testset == "ls_pc_test_clean": metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" @@ -104,6 +111,8 @@ def main(): # -------------------------------------------------# + use_ema = True + prompts_all = get_inference_prompt( metainfo, speed=speed, @@ -130,7 +139,7 @@ def main(): # Model model = CFM( - transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, @@ -145,10 +154,6 @@ def main(): vocab_char_map=vocab_char_map, ).to(device) - ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt" - if not os.path.exists(ckpt_path): - print("Loading from self-organized training checkpoints rather than released pretrained.") - ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt" dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) diff --git a/f5_tts/eval/eval_infer_batch.sh b/f5_tts/eval/eval_infer_batch.sh index a5b4f631eaa7df754a5a422d42657d2cd18acfea..47361e3ce6d7d2b0ea5305236e9b89580297428c 100644 --- a/f5_tts/eval/eval_infer_batch.sh +++ b/f5_tts/eval/eval_infer_batch.sh @@ -1,18 +1,13 @@ #!/bin/bash # e.g. F5-TTS, 16 NFE -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16 -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16 -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16 # e.g. Vanilla E2 TTS, 32 NFE -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0 -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0 -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 - -# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh -python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 -python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 -python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0 # etc. diff --git a/f5_tts/eval/eval_librispeech_test_clean.py b/f5_tts/eval/eval_librispeech_test_clean.py index 0b403689b02c250e0882e946f029f93ec8739292..f1722869ccfe270043a53a53319c165279ca7dc4 100644 --- a/f5_tts/eval/eval_librispeech_test_clean.py +++ b/f5_tts/eval/eval_librispeech_test_clean.py @@ -53,37 +53,43 @@ def main(): asr_ckpt_dir = "" # auto download to cache dir wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" - # -------------------------------------------------------------------------- - - full_results = [] - metrics = [] + # --------------------------- WER --------------------------- if eval_task == "wer": + wer_results = [] + wers = [] + with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) for r in results: - full_results.extend(r) - elif eval_task == "sim": + wer_results.extend(r) + + wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl" + with open(wer_result_path, "w") as f: + for line in wer_results: + wers.append(line["wer"]) + json_line = json.dumps(line, ensure_ascii=False) + f.write(json_line + "\n") + + wer = round(np.mean(wers) * 100, 3) + print(f"\nTotal {len(wers)} samples") + print(f"WER : {wer}%") + print(f"Results have been saved to {wer_result_path}") + + # --------------------------- SIM --------------------------- + + if eval_task == "sim": + sims = [] with mp.Pool(processes=len(gpus)) as pool: args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_sim, args) for r in results: - full_results.extend(r) - else: - raise ValueError(f"Unknown metric type: {eval_task}") - - result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl" - with open(result_path, "w") as f: - for line in full_results: - metrics.append(line[eval_task]) - f.write(json.dumps(line, ensure_ascii=False) + "\n") - metric = round(np.mean(metrics), 5) - f.write(f"\n{eval_task.upper()}: {metric}\n") - - print(f"\nTotal {len(metrics)} samples") - print(f"{eval_task.upper()}: {metric}") - print(f"{eval_task.upper()} results saved to {result_path}") + sims.extend(r) + + sim = round(sum(sims) / len(sims), 3) + print(f"\nTotal {len(sims)} samples") + print(f"SIM : {sim}") if __name__ == "__main__": diff --git a/f5_tts/eval/eval_seedtts_testset.py b/f5_tts/eval/eval_seedtts_testset.py index 0bb68eeab3c018388aff6bf225ce0efa9d7acae1..95a5f44a2459eefd5f03e910682b13688f586e8e 100644 --- a/f5_tts/eval/eval_seedtts_testset.py +++ b/f5_tts/eval/eval_seedtts_testset.py @@ -52,37 +52,43 @@ def main(): asr_ckpt_dir = "" # auto download to cache dir wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" - # -------------------------------------------------------------------------- - - full_results = [] - metrics = [] + # --------------------------- WER --------------------------- if eval_task == "wer": + wer_results = [] + wers = [] + with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) for r in results: - full_results.extend(r) - elif eval_task == "sim": + wer_results.extend(r) + + wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl" + with open(wer_result_path, "w") as f: + for line in wer_results: + wers.append(line["wer"]) + json_line = json.dumps(line, ensure_ascii=False) + f.write(json_line + "\n") + + wer = round(np.mean(wers) * 100, 3) + print(f"\nTotal {len(wers)} samples") + print(f"WER : {wer}%") + print(f"Results have been saved to {wer_result_path}") + + # --------------------------- SIM --------------------------- + + if eval_task == "sim": + sims = [] with mp.Pool(processes=len(gpus)) as pool: args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_sim, args) for r in results: - full_results.extend(r) - else: - raise ValueError(f"Unknown metric type: {eval_task}") - - result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl" - with open(result_path, "w") as f: - for line in full_results: - metrics.append(line[eval_task]) - f.write(json.dumps(line, ensure_ascii=False) + "\n") - metric = round(np.mean(metrics), 5) - f.write(f"\n{eval_task.upper()}: {metric}\n") - - print(f"\nTotal {len(metrics)} samples") - print(f"{eval_task.upper()}: {metric}") - print(f"{eval_task.upper()} results saved to {result_path}") + sims.extend(r) + + sim = round(sum(sims) / len(sims), 3) + print(f"\nTotal {len(sims)} samples") + print(f"SIM : {sim}") if __name__ == "__main__": diff --git a/f5_tts/eval/eval_utmos.py b/f5_tts/eval/eval_utmos.py index b6166e8ab073a6134b23936e15e440332991bab2..c4e944998477ff135cf84d662ea94656f2e3ecd6 100644 --- a/f5_tts/eval/eval_utmos.py +++ b/f5_tts/eval/eval_utmos.py @@ -19,23 +19,25 @@ def main(): predictor = predictor.to(device) audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}")) + utmos_results = {} utmos_score = 0 - utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl" + for audio_path in tqdm(audio_paths, desc="Processing"): + wav_name = audio_path.stem + wav, sr = librosa.load(audio_path, sr=None, mono=True) + wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0) + score = predictor(wav_tensor, sr) + utmos_results[str(wav_name)] = score.item() + utmos_score += score.item() + + avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0 + print(f"UTMOS: {avg_score}") + + utmos_result_path = Path(args.audio_dir) / "utmos_results.json" with open(utmos_result_path, "w", encoding="utf-8") as f: - for audio_path in tqdm(audio_paths, desc="Processing"): - wav, sr = librosa.load(audio_path, sr=None, mono=True) - wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0) - score = predictor(wav_tensor, sr) - line = {} - line["wav"], line["utmos"] = str(audio_path.stem), score.item() - utmos_score += score.item() - f.write(json.dumps(line, ensure_ascii=False) + "\n") - avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0 - f.write(f"\nUTMOS: {avg_score:.4f}\n") - - print(f"UTMOS: {avg_score:.4f}") - print(f"UTMOS results saved to {utmos_result_path}") + json.dump(utmos_results, f, ensure_ascii=False, indent=4) + + print(f"Results have been saved to {utmos_result_path}") if __name__ == "__main__": diff --git a/f5_tts/eval/utils_eval.py b/f5_tts/eval/utils_eval.py index d8407adb82e765632fb939b3afc4f63a750869b7..7c0a8a8d8573a3286ba3d88d32de46ca6aa58f4e 100644 --- a/f5_tts/eval/utils_eval.py +++ b/f5_tts/eval/utils_eval.py @@ -389,10 +389,10 @@ def run_sim(args): model = model.cuda(device) model.eval() - sim_results = [] - for gen_wav, prompt_wav, truth in tqdm(test_set): - wav1, sr1 = torchaudio.load(gen_wav) - wav2, sr2 = torchaudio.load(prompt_wav) + sims = [] + for wav1, wav2, truth in tqdm(test_set): + wav1, sr1 = torchaudio.load(wav1) + wav2, sr2 = torchaudio.load(wav2) resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000) resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000) @@ -408,11 +408,6 @@ def run_sim(args): sim = F.cosine_similarity(emb1, emb2)[0].item() # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).") - sim_results.append( - { - "wav": Path(gen_wav).stem, - "sim": sim, - } - ) + sims.append(sim) - return sim_results + return sims diff --git a/f5_tts/infer/README.md b/f5_tts/infer/README.md index afcc1fc4bf5790f5ba3e4471b575c9924f8d2a63..8d194423b2d48da3c5569036a8854109a73e79de 100644 --- a/f5_tts/infer/README.md +++ b/f5_tts/infer/README.md @@ -23,24 +23,12 @@ Currently supported features: - Basic TTS with Chunk Inference - Multi-Style / Multi-Speaker Generation - Voice Chat powered by Qwen2.5-3B-Instruct -- [Custom inference with more language support](src/f5_tts/infer/SHARED.md) The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference. The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat. -More flags options: - -```bash -# Automatically launch the interface in the default web browser -f5-tts_infer-gradio --inbrowser - -# Set the root path of the application, if it's not served from the root ("/") of the domain -# For example, if the application is served at "https://example.com/myapp" -f5-tts_infer-gradio --root_path "/myapp" -``` - -Could also be used as a component for larger application: +Could also be used as a component for larger application. ```python import gradio as gr from f5_tts.infer.infer_gradio import app @@ -68,16 +56,17 @@ Basically you can inference with flags: ```bash # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage) f5-tts_infer-cli \ ---model F5TTS_v1_Base \ +--model "F5-TTS" \ --ref_audio "ref_audio.wav" \ ---ref_text "The content, subtitle or transcription of reference audio." \ ---gen_text "Some text you want TTS model generate for you." +--ref_text "hình ảnh cực đoan trong em_vi của sơn tùng mờ thành phố bị khán giả chỉ trích" \ +--gen_text "tôi yêu em đến nay chừng có thể, ngọn lửa tình chưa hẳn đã tàn phai." \ +--vocoder_name vocos \ +--load_vocoder_from_local \ +--ckpt_file ckpts/F5TTS_Base_vocos_char_vnTTS/model_last.pt -# Use BigVGAN as vocoder. Currently only support F5TTS_Base. -f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local - -# Use custom path checkpoint, e.g. -f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors +# Choose Vocoder +f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file +f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file # More instructions f5-tts_infer-cli --help @@ -92,8 +81,8 @@ f5-tts_infer-cli -c custom.toml For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`: ```toml -# F5TTS_v1_Base | E2TTS_Base -model = "F5TTS_v1_Base" +# F5-TTS | E2-TTS +model = "F5-TTS" ref_audio = "infer/examples/basic/basic_ref_en.wav" # If an empty "", transcribes the reference audio automatically. ref_text = "Some call me nature, others call me mother nature." @@ -107,8 +96,8 @@ output_dir = "tests" You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`. ```toml -# F5TTS_v1_Base | E2TTS_Base -model = "F5TTS_v1_Base" +# F5-TTS | E2-TTS +model = "F5-TTS" ref_audio = "infer/examples/multi/main.flac" # If an empty "", transcribes the reference audio automatically. ref_text = "" @@ -128,27 +117,83 @@ ref_text = "" ``` You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`. -## Socket Real-time Service +## Speech Editing -Real-time voice output with chunk stream: +To test speech editing capabilities, use the following command: ```bash -# Start socket server -python src/f5_tts/socket_server.py - -# If PyAudio not installed -sudo apt-get install portaudio19-dev -pip install pyaudio - -# Communicate with socket client -python src/f5_tts/socket_client.py +python src/f5_tts/infer/speech_edit.py ``` -## Speech Editing - -To test speech editing capabilities, use the following command: +## Socket Realtime Client +To communicate with socket server you need to run ```bash -python src/f5_tts/infer/speech_edit.py +python src/f5_tts/socket_server.py ``` +
+Then create client to communicate + +``` python +import socket +import numpy as np +import asyncio +import pyaudio + +async def listen_to_voice(text, server_ip='localhost', server_port=9999): + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.connect((server_ip, server_port)) + + async def play_audio_stream(): + buffer = b'' + p = pyaudio.PyAudio() + stream = p.open(format=pyaudio.paFloat32, + channels=1, + rate=24000, # Ensure this matches the server's sampling rate + output=True, + frames_per_buffer=2048) + + try: + while True: + chunk = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 1024) + if not chunk: # End of stream + break + if b"END_OF_AUDIO" in chunk: + buffer += chunk.replace(b"END_OF_AUDIO", b"") + if buffer: + audio_array = np.frombuffer(buffer, dtype=np.float32).copy() # Make a writable copy + stream.write(audio_array.tobytes()) + break + buffer += chunk + if len(buffer) >= 4096: + audio_array = np.frombuffer(buffer[:4096], dtype=np.float32).copy() # Make a writable copy + stream.write(audio_array.tobytes()) + buffer = buffer[4096:] + finally: + stream.stop_stream() + stream.close() + p.terminate() + + try: + # Send only the text to the server + await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, text.encode('utf-8')) + await play_audio_stream() + print("Audio playback finished.") + + except Exception as e: + print(f"Error in listen_to_voice: {e}") + + finally: + client_socket.close() + +# Example usage: Replace this with your actual server IP and port +async def main(): + await listen_to_voice("my name is jenny..", server_ip='localhost', server_port=9998) + +# Run the main async function +asyncio.run(main()) +``` + +
+ diff --git a/f5_tts/infer/SHARED.md b/f5_tts/infer/SHARED.md index 79d7f56e22d64fcf2161efeac8a5ab7cdc007ea2..400548f7692ca9070913da333b8b88f7f56f3feb 100644 --- a/f5_tts/infer/SHARED.md +++ b/f5_tts/infer/SHARED.md @@ -16,7 +16,7 @@ ### Supported Languages - [Multilingual](#multilingual) - - [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts) + - [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts) - [English](#english) - [Finnish](#finnish) - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen) @@ -37,17 +37,7 @@ ## Multilingual -#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS -|Model|🤗Hugging Face|Data (Hours)|Model License| -|:---:|:------------:|:-----------:|:-------------:| -|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0| - -```bash -Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors -Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} -``` - +#### F5-TTS Base @ zh & en @ F5-TTS |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0| @@ -55,7 +45,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} ``` *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...* @@ -74,7 +64,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} ``` @@ -88,7 +78,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} ``` - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french). @@ -106,7 +96,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt -Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} +Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} ``` - Authors: SPRING Lab, Indian Institute of Technology, Madras @@ -123,7 +113,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "t ```bash Model: hf://alien79/F5-TTS-italian/model_159600.safetensors Vocab: hf://alien79/F5-TTS-italian/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} ``` - Trained by [Mithril Man](https://github.com/MithrilMan) @@ -141,7 +131,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} ``` @@ -158,7 +148,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} ``` - Finetuned by [HotDro4illa](https://github.com/HotDro4illa) - Any improvements are welcome diff --git a/f5_tts/infer/__pycache__/infer_cli.cpython-310.pyc b/f5_tts/infer/__pycache__/infer_cli.cpython-310.pyc index bc51877c581cf39343122c71364d0a5d8ba6755e..e3b0b96d4963757af2b44d4ba8c8eef5865785b8 100644 Binary files a/f5_tts/infer/__pycache__/infer_cli.cpython-310.pyc and b/f5_tts/infer/__pycache__/infer_cli.cpython-310.pyc differ diff --git a/f5_tts/infer/__pycache__/utils_infer.cpython-310.pyc b/f5_tts/infer/__pycache__/utils_infer.cpython-310.pyc index 5d2259d61845102088f368d2f126d3f8c9ba58c4..738578b59068737b992fb45b2bd5e2d284d23bff 100644 Binary files a/f5_tts/infer/__pycache__/utils_infer.cpython-310.pyc and b/f5_tts/infer/__pycache__/utils_infer.cpython-310.pyc differ diff --git a/f5_tts/infer/examples/basic/basic.toml b/f5_tts/infer/examples/basic/basic.toml index bc3ebb4e3f96368f5d5e7d5ad347a7c852f6f1cb..61a1272cd186d0afe54fb17399e84091bfae0575 100644 --- a/f5_tts/infer/examples/basic/basic.toml +++ b/f5_tts/infer/examples/basic/basic.toml @@ -1,5 +1,5 @@ -# F5TTS_v1_Base | E2TTS_Base -model = "F5TTS_v1_Base" +# F5-TTS | E2-TTS +model = "F5-TTS" ref_audio = "infer/examples/basic/basic_ref_en.wav" # If an empty "", transcribes the reference audio automatically. ref_text = "Some call me nature, others call me mother nature." @@ -8,4 +8,4 @@ gen_text = "I don't really care what you call me. I've been a silent spectator, gen_file = "" remove_silence = false output_dir = "tests" -output_file = "infer_cli_basic.wav" +output_file = "infer_cli_basic.wav" \ No newline at end of file diff --git a/f5_tts/infer/examples/multi/story.toml b/f5_tts/infer/examples/multi/story.toml index f073c26d7335b5ad4ce33b41e18bfa46c87914c8..10ba3fc8eb16531002baa1bef1638577118b829e 100644 --- a/f5_tts/infer/examples/multi/story.toml +++ b/f5_tts/infer/examples/multi/story.toml @@ -1,5 +1,5 @@ -# F5TTS_v1_Base | E2TTS_Base -model = "F5TTS_v1_Base" +# F5-TTS | E2-TTS +model = "F5-TTS" ref_audio = "infer/examples/multi/main.flac" # If an empty "", transcribes the reference audio automatically. ref_text = "" diff --git a/f5_tts/infer/infer_cli.py b/f5_tts/infer/infer_cli.py index 3510a325007c2fb30350458fa7fce31dfe7008e1..33f9622abcbdf9914463b16d0a0ef58e792a1a7d 100644 --- a/f5_tts/infer/infer_cli.py +++ b/f5_tts/infer/infer_cli.py @@ -27,7 +27,7 @@ from f5_tts.infer.utils_infer import ( preprocess_ref_audio_text, remove_silence_for_generated_wav, ) -from f5_tts.model import DiT, UNetT # noqa: F401. used for config +from f5_tts.model import DiT, UNetT parser = argparse.ArgumentParser( @@ -50,8 +50,7 @@ parser.add_argument( "-m", "--model", type=str, - default="F5TTS_Base", - help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.", + help="The model name: F5-TTS | E2-TTS", ) parser.add_argument( "-mc", @@ -173,7 +172,8 @@ config = tomli.load(open(args.config, "rb")) # command-line interface parameters -model = args.model or config.get("model", "F5TTS_v1_Base") +model = args.model or config.get("model", "F5-TTS") +model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml"))) ckpt_file = args.ckpt_file or config.get("ckpt_file", "") vocab_file = args.vocab_file or config.get("vocab_file", "") @@ -236,7 +236,7 @@ if save_chunk: # load vocoder if vocoder_name == "vocos": - vocoder_local_path = "../checkpoints/vocos-mel-24khz" + vocoder_local_path = "ckpts/vocos" elif vocoder_name == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" @@ -245,32 +245,37 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc # load TTS model -model_cfg = OmegaConf.load( - args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) -).model -model_cls = globals()[model_cfg.backbone] - -repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" - -if model != "F5TTS_Base": - assert vocoder_name == model_cfg.mel_spec.mel_spec_type - -# override for previous models -if model == "F5TTS_Base": - if vocoder_name == "vocos": +if model == "F5-TTS": + model_cls = DiT + model_cfg = OmegaConf.load(model_cfg).model.arch + if not ckpt_file: # path not specified, download from repo + if vocoder_name == "vocos": + repo_name = "F5-TTS" + exp_name = "F5TTS_Base" + ckpt_step = 1200000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) + # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path + # ckpt_file = f"ckpts/{exp_name}/model_last.pt" # .pt | .safetensors; local path + elif vocoder_name == "bigvgan": + repo_name = "F5-TTS" + exp_name = "F5TTS_Base_bigvgan" + ckpt_step = 1250000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")) + +elif model == "E2-TTS": + assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet" + assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet" + model_cls = UNetT + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) + if not ckpt_file: # path not specified, download from repo + repo_name = "E2-TTS" + exp_name = "E2TTS_Base" ckpt_step = 1200000 - elif vocoder_name == "bigvgan": - model = "F5TTS_Base_bigvgan" - ckpt_type = "pt" -elif model == "E2TTS_Base": - repo_name = "E2-TTS" - ckpt_step = 1200000 - -if not ckpt_file: - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) + # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path print(f"Using {model}...") -ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) +ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) # inference process diff --git a/f5_tts/infer/infer_gradio.py b/f5_tts/infer/infer_gradio.py index 72202f6b7a6b6233a9d86fde3760be604513c0e4..1fb1bfb93252fd4809217977a24e88e02015b902 100644 --- a/f5_tts/infer/infer_gradio.py +++ b/f5_tts/infer/infer_gradio.py @@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import ( ) -DEFAULT_TTS_MODEL = "F5-TTS_v1" +DEFAULT_TTS_MODEL = "F5-TTS" tts_model_choice = DEFAULT_TTS_MODEL DEFAULT_TTS_MODEL_CFG = [ - "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors", - "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt", + "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", + "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), ] @@ -56,15 +56,13 @@ DEFAULT_TTS_MODEL_CFG = [ vocoder = load_vocoder() -def load_f5tts(): - ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0])) - F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2]) +def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))): + F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) return load_model(DiT, F5TTS_model_cfg, ckpt_path) -def load_e2tts(): - ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) - E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1) +def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))): + E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) return load_model(UNetT, E2TTS_model_cfg, ckpt_path) @@ -75,7 +73,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None): if vocab_path.startswith("hf://"): vocab_path = str(cached_path(vocab_path)) if model_cfg is None: - model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2]) + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path) @@ -132,7 +130,7 @@ def infer( ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) - if model == DEFAULT_TTS_MODEL: + if model == "F5-TTS": ema_model = F5TTS_ema_model elif model == "E2-TTS": global E2TTS_ema_model @@ -764,7 +762,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip """ ) - last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt") + last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt") def load_last_used_custom(): try: @@ -823,30 +821,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip custom_model_cfg = gr.Dropdown( choices=[ DEFAULT_TTS_MODEL_CFG[2], - json.dumps( - dict( - dim=1024, - depth=22, - heads=16, - ff_mult=2, - text_dim=512, - text_mask_padding=False, - conv_layers=4, - pe_attn_head=1, - ) - ), - json.dumps( - dict( - dim=768, - depth=18, - heads=12, - ff_mult=2, - text_dim=512, - text_mask_padding=False, - conv_layers=4, - pe_attn_head=1, - ) - ), + json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)), ], value=load_last_used_custom()[2], allow_custom_value=True, @@ -900,24 +875,10 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip type=str, help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".', ) -@click.option( - "--inbrowser", - "-i", - is_flag=True, - default=False, - help="Automatically launch the interface in the default web browser", -) -def main(port, host, share, api, root_path, inbrowser): +def main(port, host, share, api, root_path): global app print("Starting app...") - app.queue(api_open=api).launch( - server_name=host, - server_port=port, - share=share, - show_api=api, - root_path=root_path, - inbrowser=inbrowser, - ) + app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api, root_path=root_path) if __name__ == "__main__": diff --git a/f5_tts/infer/speech_edit.py b/f5_tts/infer/speech_edit.py index d8d073eadaa2bf14ceba43c04acc7453eb662ee8..40ab554d74790e8fdca98b570ad665c0ac64a1da 100644 --- a/f5_tts/infer/speech_edit.py +++ b/f5_tts/infer/speech_edit.py @@ -1,16 +1,13 @@ import os -os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility - -from importlib.resources import files +os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility import torch import torch.nn.functional as F import torchaudio -from omegaconf import OmegaConf from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram -from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config +from f5_tts.model import CFM, DiT, UNetT from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer device = ( @@ -24,40 +21,44 @@ device = ( ) +# --------------------- Dataset Settings -------------------- # + +target_sample_rate = 24000 +n_mel_channels = 100 +hop_length = 256 +win_length = 1024 +n_fft = 1024 +mel_spec_type = "vocos" # 'vocos' or 'bigvgan' +target_rms = 0.1 + +tokenizer = "pinyin" +dataset_name = "Emilia_ZH_EN" + + # ---------------------- infer setting ---------------------- # seed = None # int | None -exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base -ckpt_step = 1250000 +exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base +ckpt_step = 1200000 nfe_step = 32 # 16, 32 cfg_strength = 2.0 ode_method = "euler" # euler | midpoint sway_sampling_coef = -1.0 speed = 1.0 -target_rms = 0.1 - - -model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) -model_cls = globals()[model_cfg.model.backbone] -model_arc = model_cfg.model.arch -dataset_name = model_cfg.datasets.name -tokenizer = model_cfg.model.tokenizer +if exp_name == "F5TTS_Base": + model_cls = DiT + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) -mel_spec_type = model_cfg.model.mel_spec.mel_spec_type -target_sample_rate = model_cfg.model.mel_spec.target_sample_rate -n_mel_channels = model_cfg.model.mel_spec.n_mel_channels -hop_length = model_cfg.model.mel_spec.hop_length -win_length = model_cfg.model.mel_spec.win_length -n_fft = model_cfg.model.mel_spec.n_fft +elif exp_name == "E2TTS_Base": + model_cls = UNetT + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - -ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors" +ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors" output_dir = "tests" - # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment] # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git # [write the origin_text into a file, e.g. tests/test_edit.txt] @@ -66,7 +67,7 @@ output_dir = "tests" # [--language "zho" for Chinese, "eng" for English] # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"] -audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")) +audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav" origin_text = "Some call me nature, others call me mother nature." target_text = "Some call me optimist, others call me realist." parts_to_edit = [ @@ -105,7 +106,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) # Model model = CFM( - transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, diff --git a/f5_tts/infer/utils_infer.py b/f5_tts/infer/utils_infer.py index b2fd72719dbbb511480216538e4999a8735095ba..63b6d61bd6da137a28d46995dbdfd25e561b9109 100644 --- a/f5_tts/infer/utils_infer.py +++ b/f5_tts/infer/utils_infer.py @@ -2,9 +2,8 @@ # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format import os import sys -from concurrent.futures import ThreadPoolExecutor -os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility +os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/") import hashlib @@ -110,8 +109,13 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev repo_id = "charactr/vocos-mel-24khz" config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") + # print("Download Vocos from huggingface charactr/vocos-mel-24khz") + # repo_id = "charactr/vocos-mel-24khz" + # config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") + # model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") vocoder = Vocos.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) + # print(state_dict) from vocos.feature_extractors import EncodecFeatures if isinstance(vocoder.feature_extractor, EncodecFeatures): @@ -301,19 +305,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: - if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: + if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: show_info("Audio is over 15s, clipping short. (1)") break non_silent_wave += non_silent_seg # 2. try to find short silence for clipping if 1. failed - if len(non_silent_wave) > 12000: + if len(non_silent_wave) > 15000: non_silent_segs = silence.split_on_silence( aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: - if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: + if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: show_info("Audio is over 15s, clipping short. (2)") break non_silent_wave += non_silent_seg @@ -321,8 +325,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in aseg = non_silent_wave # 3. if no proper silence found for clipping - if len(aseg) > 12000: - aseg = aseg[:12000] + if len(aseg) > 15000: + aseg = aseg[:15000] show_info("Audio is over 15s, clipping short. (3)") aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50) @@ -383,31 +387,29 @@ def infer_process( ): # Split the input text into batches audio, sr = torchaudio.load(ref_audio) - max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)) + max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr)) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) for i, gen_text in enumerate(gen_text_batches): print(f"gen_text {i}", gen_text) print("\n") show_info(f"Generating audio in {len(gen_text_batches)} batches...") - return next( - infer_batch_process( - (audio, sr), - ref_text, - gen_text_batches, - model_obj, - vocoder, - mel_spec_type=mel_spec_type, - progress=progress, - target_rms=target_rms, - cross_fade_duration=cross_fade_duration, - nfe_step=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - speed=speed, - fix_duration=fix_duration, - device=device, - ) + return infer_batch_process( + (audio, sr), + ref_text, + gen_text_batches, + model_obj, + vocoder, + mel_spec_type=mel_spec_type, + progress=progress, + target_rms=target_rms, + cross_fade_duration=cross_fade_duration, + nfe_step=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + speed=speed, + fix_duration=fix_duration, + device=device, ) @@ -430,8 +432,6 @@ def infer_batch_process( speed=1, fix_duration=None, device=None, - streaming=False, - chunk_size=2048, ): audio, sr = ref_audio if audio.shape[0] > 1: @@ -450,12 +450,7 @@ def infer_batch_process( if len(ref_text[-1].encode("utf-8")) == 1: ref_text = ref_text + " " - - def process_batch(gen_text): - local_speed = speed - if len(gen_text.encode("utf-8")) < 10: - local_speed = 0.3 - + for i, gen_text in enumerate(progress.tqdm(gen_text_batches)): # Prepare the text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) @@ -467,7 +462,7 @@ def infer_batch_process( # Calculate duration ref_text_len = len(ref_text.encode("utf-8")) gen_text_len = len(gen_text.encode("utf-8")) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) + duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) # inference with torch.inference_mode(): @@ -479,88 +474,64 @@ def infer_batch_process( cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, ) - del _ - generated = generated.to(torch.float32) # generated mel spectrogram + generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] - generated = generated.permute(0, 2, 1) + generated_mel_spec = generated.permute(0, 2, 1) if mel_spec_type == "vocos": - generated_wave = vocoder.decode(generated) + generated_wave = vocoder.decode(generated_mel_spec) elif mel_spec_type == "bigvgan": - generated_wave = vocoder(generated) + generated_wave = vocoder(generated_mel_spec) if rms < target_rms: generated_wave = generated_wave * rms / target_rms # wav -> numpy generated_wave = generated_wave.squeeze().cpu().numpy() - if streaming: - for j in range(0, len(generated_wave), chunk_size): - yield generated_wave[j : j + chunk_size], target_sample_rate - else: - generated_cpu = generated[0].cpu().numpy() - del generated - yield generated_wave, generated_cpu - - if streaming: - for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: - for chunk in process_batch(gen_text): - yield chunk + generated_waves.append(generated_wave) + spectrograms.append(generated_mel_spec[0].cpu().numpy()) + + # Combine all generated waves with cross-fading + if cross_fade_duration <= 0: + # Simply concatenate + final_wave = np.concatenate(generated_waves) else: - with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches] - for future in progress.tqdm(futures) if progress is not None else futures: - result = future.result() - if result: - generated_wave, generated_mel_spec = next(result) - generated_waves.append(generated_wave) - spectrograms.append(generated_mel_spec) - - if generated_waves: - if cross_fade_duration <= 0: - # Simply concatenate - final_wave = np.concatenate(generated_waves) - else: - # Combine all generated waves with cross-fading - final_wave = generated_waves[0] - for i in range(1, len(generated_waves)): - prev_wave = final_wave - next_wave = generated_waves[i] - - # Calculate cross-fade samples, ensuring it does not exceed wave lengths - cross_fade_samples = int(cross_fade_duration * target_sample_rate) - cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) - - if cross_fade_samples <= 0: - # No overlap possible, concatenate - final_wave = np.concatenate([prev_wave, next_wave]) - continue - - # Overlapping parts - prev_overlap = prev_wave[-cross_fade_samples:] - next_overlap = next_wave[:cross_fade_samples] - - # Fade out and fade in - fade_out = np.linspace(1, 0, cross_fade_samples) - fade_in = np.linspace(0, 1, cross_fade_samples) - - # Cross-faded overlap - cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in - - # Combine - new_wave = np.concatenate( - [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] - ) - - final_wave = new_wave - - # Create a combined spectrogram - combined_spectrogram = np.concatenate(spectrograms, axis=1) - - yield final_wave, target_sample_rate, combined_spectrogram + final_wave = generated_waves[0] + for i in range(1, len(generated_waves)): + prev_wave = final_wave + next_wave = generated_waves[i] + + # Calculate cross-fade samples, ensuring it does not exceed wave lengths + cross_fade_samples = int(cross_fade_duration * target_sample_rate) + cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) + + if cross_fade_samples <= 0: + # No overlap possible, concatenate + final_wave = np.concatenate([prev_wave, next_wave]) + continue + + # Overlapping parts + prev_overlap = prev_wave[-cross_fade_samples:] + next_overlap = next_wave[:cross_fade_samples] + + # Fade out and fade in + fade_out = np.linspace(1, 0, cross_fade_samples) + fade_in = np.linspace(0, 1, cross_fade_samples) + + # Cross-faded overlap + cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in + + # Combine + new_wave = np.concatenate( + [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] + ) - else: - yield None, target_sample_rate, None + final_wave = new_wave + + # Create a combined spectrogram + combined_spectrogram = np.concatenate(spectrograms, axis=1) + + return final_wave, target_sample_rate, combined_spectrogram # remove silence from generated wav diff --git a/f5_tts/model/__pycache__/__init__.cpython-310.pyc b/f5_tts/model/__pycache__/__init__.cpython-310.pyc index 359ae39b5df55f28ad2965f896129b589e8066aa..2da2a47046288e90e75e870e5443b839b6dcd0de 100644 Binary files a/f5_tts/model/__pycache__/__init__.cpython-310.pyc and b/f5_tts/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/f5_tts/model/__pycache__/cfm.cpython-310.pyc b/f5_tts/model/__pycache__/cfm.cpython-310.pyc index 18757153295524251e4ff8644fee9042ade65741..12a68b10fa0741ca5ce6e299e5d162ed427615da 100644 Binary files a/f5_tts/model/__pycache__/cfm.cpython-310.pyc and b/f5_tts/model/__pycache__/cfm.cpython-310.pyc differ diff --git a/f5_tts/model/__pycache__/dataset.cpython-310.pyc b/f5_tts/model/__pycache__/dataset.cpython-310.pyc index 9f6fd9a6da4f0c27f25e864b8a017a11dbc1015f..88a37db43a31672e67275d0d3466b70269f030dd 100644 Binary files a/f5_tts/model/__pycache__/dataset.cpython-310.pyc and b/f5_tts/model/__pycache__/dataset.cpython-310.pyc differ diff --git a/f5_tts/model/__pycache__/modules.cpython-310.pyc b/f5_tts/model/__pycache__/modules.cpython-310.pyc index bc2f694f1beb894e990fabd7985049a227687149..9140cf01b44f2cf74b55bffb184b80ed3d1e6add 100644 Binary files a/f5_tts/model/__pycache__/modules.cpython-310.pyc and b/f5_tts/model/__pycache__/modules.cpython-310.pyc differ diff --git a/f5_tts/model/__pycache__/trainer.cpython-310.pyc b/f5_tts/model/__pycache__/trainer.cpython-310.pyc index 6a0aa87041dcaad61c2357b253adc1b043a7ae75..991aa0bcaa7b718000e6536442ab0ceb9e760e07 100644 Binary files a/f5_tts/model/__pycache__/trainer.cpython-310.pyc and b/f5_tts/model/__pycache__/trainer.cpython-310.pyc differ diff --git a/f5_tts/model/__pycache__/utils.cpython-310.pyc b/f5_tts/model/__pycache__/utils.cpython-310.pyc index fb1272d9e04f5cb861b9bdf4b9062d344b620005..1f775f5618f68237083dc4d0ae93cdace30b1e97 100644 Binary files a/f5_tts/model/__pycache__/utils.cpython-310.pyc and b/f5_tts/model/__pycache__/utils.cpython-310.pyc differ diff --git a/f5_tts/model/backbones/README.md b/f5_tts/model/backbones/README.md index 09bd4da5b51d3349b0136cec601d5d3ae9ed92f0..155671e16fbf128a243ece9033cefd47b957af88 100644 --- a/f5_tts/model/backbones/README.md +++ b/f5_tts/model/backbones/README.md @@ -4,7 +4,7 @@ ### unett.py - flat unet transformer - structure same as in e2-tts & voicebox paper except using rotary pos emb -- possible abs pos emb & convnextv2 blocks for embedded text before concat +- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat ### dit.py - adaln-zero dit @@ -14,7 +14,7 @@ - possible long skip connection (first layer to last layer) ### mmdit.py -- stable diffusion 3 block structure +- sd3 structure - timestep as condition - left stream: text embedded and applied a abs pos emb - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett diff --git a/f5_tts/model/backbones/__pycache__/dit.cpython-310.pyc b/f5_tts/model/backbones/__pycache__/dit.cpython-310.pyc index b5b85894806c493a7186a308e52ae6a9e15090d1..155311aa152d8d8cb6aed71073fc5aff93431714 100644 Binary files a/f5_tts/model/backbones/__pycache__/dit.cpython-310.pyc and b/f5_tts/model/backbones/__pycache__/dit.cpython-310.pyc differ diff --git a/f5_tts/model/backbones/__pycache__/mmdit.cpython-310.pyc b/f5_tts/model/backbones/__pycache__/mmdit.cpython-310.pyc index b2cc7a60bec3d7e34ff41240b5bcea01e41db879..532da527d82af94d5cfdcd0467bb4da46ed4e0aa 100644 Binary files a/f5_tts/model/backbones/__pycache__/mmdit.cpython-310.pyc and b/f5_tts/model/backbones/__pycache__/mmdit.cpython-310.pyc differ diff --git a/f5_tts/model/backbones/__pycache__/unett.cpython-310.pyc b/f5_tts/model/backbones/__pycache__/unett.cpython-310.pyc index 0d828ff8c804ec441c69e55c12c51fc982a2e431..e35d44f77964f62a4d59369485a3b55681eb5a6a 100644 Binary files a/f5_tts/model/backbones/__pycache__/unett.cpython-310.pyc and b/f5_tts/model/backbones/__pycache__/unett.cpython-310.pyc differ diff --git a/f5_tts/model/backbones/dit.py b/f5_tts/model/backbones/dit.py index c4625285acb0d2d6e1d9b8a154edce45a654df7d..1ecd10e425041d924ec7361d5cd3b2750da11e09 100644 --- a/f5_tts/model/backbones/dit.py +++ b/f5_tts/model/backbones/dit.py @@ -20,7 +20,7 @@ from f5_tts.model.modules import ( ConvNeXtV2Block, ConvPositionEmbedding, DiTBlock, - AdaLayerNorm_Final, + AdaLayerNormZero_Final, precompute_freqs_cis, get_pos_embed_indices, ) @@ -30,12 +30,10 @@ from f5_tts.model.modules import ( class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token - self.mask_padding = mask_padding # mask filler and batch padding tokens or not - if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio @@ -51,8 +49,6 @@ class TextEmbedding(nn.Module): text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) - if self.mask_padding: - text_mask = text == 0 if drop_text: # cfg for text text = torch.zeros_like(text) @@ -68,13 +64,7 @@ class TextEmbedding(nn.Module): text = text + text_pos_embed # convnextv2 blocks - if self.mask_padding: - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) - for block in self.text_blocks: - text = block(text) - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) - else: - text = self.text_blocks(text) + text = self.text_blocks(text) return text @@ -113,10 +103,7 @@ class DiT(nn.Module): mel_dim=100, text_num_embeds=256, text_dim=None, - text_mask_padding=True, - qk_norm=None, conv_layers=0, - pe_attn_head=None, long_skip_connection=False, checkpoint_activations=False, ): @@ -125,10 +112,7 @@ class DiT(nn.Module): self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim - self.text_embed = TextEmbedding( - text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers - ) - self.text_cond, self.text_uncond = None, None # text cache + self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -137,40 +121,15 @@ class DiT(nn.Module): self.depth = depth self.transformer_blocks = nn.ModuleList( - [ - DiTBlock( - dim=dim, - heads=heads, - dim_head=dim_head, - ff_mult=ff_mult, - dropout=dropout, - qk_norm=qk_norm, - pe_attn_head=pe_attn_head, - ) - for _ in range(depth) - ] + [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)] ) self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None - self.norm_out = AdaLayerNorm_Final(dim) # final modulation + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) self.checkpoint_activations = checkpoint_activations - self.initialize_weights() - - def initialize_weights(self): - # Zero-out AdaLN layers in DiT blocks: - for block in self.transformer_blocks: - nn.init.constant_(block.attn_norm.linear.weight, 0) - nn.init.constant_(block.attn_norm.linear.bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.norm_out.linear.weight, 0) - nn.init.constant_(self.norm_out.linear.bias, 0) - nn.init.constant_(self.proj_out.weight, 0) - nn.init.constant_(self.proj_out.bias, 0) - def ckpt_wrapper(self, module): # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py def ckpt_forward(*inputs): @@ -179,9 +138,6 @@ class DiT(nn.Module): return ckpt_forward - def clear_cache(self): - self.text_cond, self.text_uncond = None, None - def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -191,25 +147,14 @@ class DiT(nn.Module): drop_audio_cond, # cfg for cond audio drop_text, # cfg for text mask: bool["b n"] | None = None, # noqa: F722 - cache=False, ): batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: time = time.repeat(batch) - # t: conditioning time, text: text, x: noised audio + cond audio + text + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) - if cache: - if drop_text: - if self.text_uncond is None: - self.text_uncond = self.text_embed(text, seq_len, drop_text=True) - text_embed = self.text_uncond - else: - if self.text_cond is None: - self.text_cond = self.text_embed(text, seq_len, drop_text=False) - text_embed = self.text_cond - else: - text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) rope = self.rotary_embed.forward_from_seq_len(seq_len) diff --git a/f5_tts/model/backbones/mmdit.py b/f5_tts/model/backbones/mmdit.py index d150555430886d64768af5b4a808701be288d01e..64c7ef18e1195631f3917af95ca7c8ac12462bf8 100644 --- a/f5_tts/model/backbones/mmdit.py +++ b/f5_tts/model/backbones/mmdit.py @@ -18,7 +18,7 @@ from f5_tts.model.modules import ( TimestepEmbedding, ConvPositionEmbedding, MMDiTBlock, - AdaLayerNorm_Final, + AdaLayerNormZero_Final, precompute_freqs_cis, get_pos_embed_indices, ) @@ -28,24 +28,18 @@ from f5_tts.model.modules import ( class TextEmbedding(nn.Module): - def __init__(self, out_dim, text_num_embeds, mask_padding=True): + def __init__(self, out_dim, text_num_embeds): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token - self.mask_padding = mask_padding # mask filler and batch padding tokens or not - self.precompute_max_pos = 1024 self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722 - text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() - if self.mask_padding: - text_mask = text == 0 - - if drop_text: # cfg for text + text = text + 1 + if drop_text: text = torch.zeros_like(text) - - text = self.text_embed(text) # b nt -> b nt d + text = self.text_embed(text) # sinus pos emb batch_start = torch.zeros((text.shape[0],), dtype=torch.long) @@ -55,9 +49,6 @@ class TextEmbedding(nn.Module): text = text + text_pos_embed - if self.mask_padding: - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) - return text @@ -92,16 +83,13 @@ class MMDiT(nn.Module): dim_head=64, dropout=0.1, ff_mult=4, - mel_dim=100, text_num_embeds=256, - text_mask_padding=True, - qk_norm=None, + mel_dim=100, ): super().__init__() self.time_embed = TimestepEmbedding(dim) - self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding) - self.text_cond, self.text_uncond = None, None # text cache + self.text_embed = TextEmbedding(dim, text_num_embeds) self.audio_embed = AudioEmbedding(mel_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -118,33 +106,13 @@ class MMDiT(nn.Module): dropout=dropout, ff_mult=ff_mult, context_pre_only=i == depth - 1, - qk_norm=qk_norm, ) for i in range(depth) ] ) - self.norm_out = AdaLayerNorm_Final(dim) # final modulation + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) - self.initialize_weights() - - def initialize_weights(self): - # Zero-out AdaLN layers in MMDiT blocks: - for block in self.transformer_blocks: - nn.init.constant_(block.attn_norm_x.linear.weight, 0) - nn.init.constant_(block.attn_norm_x.linear.bias, 0) - nn.init.constant_(block.attn_norm_c.linear.weight, 0) - nn.init.constant_(block.attn_norm_c.linear.bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.norm_out.linear.weight, 0) - nn.init.constant_(self.norm_out.linear.bias, 0) - nn.init.constant_(self.proj_out.weight, 0) - nn.init.constant_(self.proj_out.bias, 0) - - def clear_cache(self): - self.text_cond, self.text_uncond = None, None - def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -154,7 +122,6 @@ class MMDiT(nn.Module): drop_audio_cond, # cfg for cond audio drop_text, # cfg for text mask: bool["b n"] | None = None, # noqa: F722 - cache=False, ): batch = x.shape[0] if time.ndim == 0: @@ -162,17 +129,7 @@ class MMDiT(nn.Module): # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) - if cache: - if drop_text: - if self.text_uncond is None: - self.text_uncond = self.text_embed(text, drop_text=True) - c = self.text_uncond - else: - if self.text_cond is None: - self.text_cond = self.text_embed(text, drop_text=False) - c = self.text_cond - else: - c = self.text_embed(text, drop_text=drop_text) + c = self.text_embed(text, drop_text=drop_text) x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond) seq_len = x.shape[1] diff --git a/f5_tts/model/backbones/unett.py b/f5_tts/model/backbones/unett.py index 11e4d026544089c6664f3f4a5f00433d48e69ceb..acf649a52448e87a34a2af4bc14051caaba74c86 100644 --- a/f5_tts/model/backbones/unett.py +++ b/f5_tts/model/backbones/unett.py @@ -33,12 +33,10 @@ from f5_tts.model.modules import ( class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token - self.mask_padding = mask_padding # mask filler and batch padding tokens or not - if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio @@ -54,8 +52,6 @@ class TextEmbedding(nn.Module): text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) - if self.mask_padding: - text_mask = text == 0 if drop_text: # cfg for text text = torch.zeros_like(text) @@ -71,13 +67,7 @@ class TextEmbedding(nn.Module): text = text + text_pos_embed # convnextv2 blocks - if self.mask_padding: - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) - for block in self.text_blocks: - text = block(text) - text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) - else: - text = self.text_blocks(text) + text = self.text_blocks(text) return text @@ -116,10 +106,7 @@ class UNetT(nn.Module): mel_dim=100, text_num_embeds=256, text_dim=None, - text_mask_padding=True, - qk_norm=None, conv_layers=0, - pe_attn_head=None, skip_connect_type: Literal["add", "concat", "none"] = "concat", ): super().__init__() @@ -128,10 +115,7 @@ class UNetT(nn.Module): self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim - self.text_embed = TextEmbedding( - text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers - ) - self.text_cond, self.text_uncond = None, None # text cache + self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -150,12 +134,11 @@ class UNetT(nn.Module): attn_norm = RMSNorm(dim) attn = Attention( - processor=AttnProcessor(pe_attn_head=pe_attn_head), + processor=AttnProcessor(), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, - qk_norm=qk_norm, ) ff_norm = RMSNorm(dim) @@ -178,9 +161,6 @@ class UNetT(nn.Module): self.norm_out = RMSNorm(dim) self.proj_out = nn.Linear(dim, mel_dim) - def clear_cache(self): - self.text_cond, self.text_uncond = None, None - def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -190,7 +170,6 @@ class UNetT(nn.Module): drop_audio_cond, # cfg for cond audio drop_text, # cfg for text mask: bool["b n"] | None = None, # noqa: F722 - cache=False, ): batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: @@ -198,17 +177,7 @@ class UNetT(nn.Module): # t: conditioning time, c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) - if cache: - if drop_text: - if self.text_uncond is None: - self.text_uncond = self.text_embed(text, seq_len, drop_text=True) - text_embed = self.text_uncond - else: - if self.text_cond is None: - self.text_cond = self.text_embed(text, seq_len, drop_text=False) - text_embed = self.text_cond - else: - text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) # postfix time t to input x, [b n d] -> [b n+1 d] diff --git a/f5_tts/model/cfm.py b/f5_tts/model/cfm.py index ea4b67f846e2f8992e19444ac8275b905e3a50ae..b0cefc0cc008221eb09c99c5ff9112ae7e9e99c1 100644 --- a/f5_tts/model/cfm.py +++ b/f5_tts/model/cfm.py @@ -162,13 +162,13 @@ class CFM(nn.Module): # predict flow pred = self.transformer( - x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True + x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False ) if cfg_strength < 1e-5: return pred null_pred = self.transformer( - x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True + x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True ) return pred + (pred - null_pred) * cfg_strength @@ -195,7 +195,6 @@ class CFM(nn.Module): t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) trajectory = odeint(fn, y0, t, **self.odeint_kwargs) - self.transformer.clear_cache() sampled = trajectory[-1] out = sampled diff --git a/f5_tts/model/dataset.py b/f5_tts/model/dataset.py index b0622aa0c131804ee983f575fa3c07c2d38a68f8..75eedddb513804d64a5b66ebb9c415cb9e72d79e 100644 --- a/f5_tts/model/dataset.py +++ b/f5_tts/model/dataset.py @@ -173,7 +173,7 @@ class DynamicBatchSampler(Sampler[list[int]]): """ def __init__( - self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False + self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False ): self.sampler = sampler self.frames_threshold = frames_threshold @@ -208,15 +208,12 @@ class DynamicBatchSampler(Sampler[list[int]]): batch = [] batch_frames = 0 - if not drop_residual and len(batch) > 0: + if not drop_last and len(batch) > 0: batches.append(batch) del indices self.batches = batches - # Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting - self.drop_last = True - def set_epoch(self, epoch: int) -> None: """Sets the epoch for this sampler.""" self.epoch = epoch @@ -256,7 +253,7 @@ def load_dataset( print("Loading dataset ...") if dataset_type == "CustomDataset": - rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}")) + rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}")) if audio_type == "raw": try: train_dataset = load_from_disk(f"{rel_data_path}/raw") diff --git a/f5_tts/model/modules.py b/f5_tts/model/modules.py index 8e5c3c27a29882a52156f38e0ecc4553f257ac8f..bf67fffb1dabf456d4cc804380d42358fe0ca79f 100644 --- a/f5_tts/model/modules.py +++ b/f5_tts/model/modules.py @@ -269,36 +269,11 @@ class ConvNeXtV2Block(nn.Module): return residual + x -# RMSNorm - - -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - self.native_rms_norm = float(torch.__version__[:3]) >= 2.4 - - def forward(self, x): - if self.native_rms_norm: - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.to(self.weight.dtype) - x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps) - else: - variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.eps) - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.to(self.weight.dtype) - x = x * self.weight - - return x - - -# AdaLayerNorm +# AdaLayerNormZero # return with modulated x for attn input, and params for later mlp modulation -class AdaLayerNorm(nn.Module): +class AdaLayerNormZero(nn.Module): def __init__(self, dim): super().__init__() @@ -315,11 +290,11 @@ class AdaLayerNorm(nn.Module): return x, gate_msa, shift_mlp, scale_mlp, gate_mlp -# AdaLayerNorm for final layer +# AdaLayerNormZero for final layer # return only with modulated x for attn input, cuz no more mlp modulation -class AdaLayerNorm_Final(nn.Module): +class AdaLayerNormZero_Final(nn.Module): def __init__(self, dim): super().__init__() @@ -366,8 +341,7 @@ class Attention(nn.Module): dim_head: int = 64, dropout: float = 0.0, context_dim: Optional[int] = None, # if not None -> joint attention - context_pre_only: bool = False, - qk_norm: Optional[str] = None, + context_pre_only=None, ): super().__init__() @@ -388,32 +362,18 @@ class Attention(nn.Module): self.to_k = nn.Linear(dim, self.inner_dim) self.to_v = nn.Linear(dim, self.inner_dim) - if qk_norm is None: - self.q_norm = None - self.k_norm = None - elif qk_norm == "rms_norm": - self.q_norm = RMSNorm(dim_head, eps=1e-6) - self.k_norm = RMSNorm(dim_head, eps=1e-6) - else: - raise ValueError(f"Unimplemented qk_norm: {qk_norm}") - if self.context_dim is not None: - self.to_q_c = nn.Linear(context_dim, self.inner_dim) self.to_k_c = nn.Linear(context_dim, self.inner_dim) self.to_v_c = nn.Linear(context_dim, self.inner_dim) - if qk_norm is None: - self.c_q_norm = None - self.c_k_norm = None - elif qk_norm == "rms_norm": - self.c_q_norm = RMSNorm(dim_head, eps=1e-6) - self.c_k_norm = RMSNorm(dim_head, eps=1e-6) + if self.context_pre_only is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, dim)) self.to_out.append(nn.Dropout(dropout)) - if self.context_dim is not None and not self.context_pre_only: - self.to_out_c = nn.Linear(self.inner_dim, context_dim) + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, dim) def forward( self, @@ -433,11 +393,8 @@ class Attention(nn.Module): class AttnProcessor: - def __init__( - self, - pe_attn_head: int | None = None, # number of attention head to apply rope, None for all - ): - self.pe_attn_head = pe_attn_head + def __init__(self): + pass def __call__( self, @@ -448,11 +405,19 @@ class AttnProcessor: ) -> torch.FloatTensor: batch_size = x.shape[0] - # `sample` projections + # `sample` projections. query = attn.to_q(x) key = attn.to_k(x) value = attn.to_v(x) + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + # attention inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -460,25 +425,6 @@ class AttnProcessor: key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # qk norm - if attn.q_norm is not None: - query = attn.q_norm(query) - if attn.k_norm is not None: - key = attn.k_norm(key) - - # apply rotary position embedding - if rope is not None: - freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) - - if self.pe_attn_head is not None: - pn = self.pe_attn_head - query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale) - key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale) - else: - query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) - key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) - # mask. e.g. inference got a batch with different target durations, mask out the padding if mask is not None: attn_mask = mask @@ -524,36 +470,16 @@ class JointAttnProcessor: batch_size = c.shape[0] - # `sample` projections + # `sample` projections. query = attn.to_q(x) key = attn.to_k(x) value = attn.to_v(x) - # `context` projections + # `context` projections. c_query = attn.to_q_c(c) c_key = attn.to_k_c(c) c_value = attn.to_v_c(c) - # attention - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # qk norm - if attn.q_norm is not None: - query = attn.q_norm(query) - if attn.k_norm is not None: - key = attn.k_norm(key) - if attn.c_q_norm is not None: - c_query = attn.c_q_norm(c_query) - if attn.c_k_norm is not None: - c_key = attn.c_k_norm(c_key) - # apply rope for context and noised input independently if rope is not None: freqs, xpos_scale = rope @@ -566,10 +492,16 @@ class JointAttnProcessor: c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) - # joint attention - query = torch.cat([query, c_query], dim=2) - key = torch.cat([key, c_key], dim=2) - value = torch.cat([value, c_value], dim=2) + # attention + query = torch.cat([query, c_query], dim=1) + key = torch.cat([key, c_key], dim=1) + value = torch.cat([value, c_value], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # mask. e.g. inference got a batch with different target durations, mask out the padding if mask is not None: @@ -608,17 +540,16 @@ class JointAttnProcessor: class DiTBlock(nn.Module): - def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): super().__init__() - self.attn_norm = AdaLayerNorm(dim) + self.attn_norm = AdaLayerNormZero(dim) self.attn = Attention( - processor=AttnProcessor(pe_attn_head=pe_attn_head), + processor=AttnProcessor(), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, - qk_norm=qk_norm, ) self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) @@ -654,30 +585,26 @@ class MMDiTBlock(nn.Module): context_pre_only: last layer only do prenorm + modulation cuz no more ffn """ - def __init__( - self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None - ): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False): super().__init__() - if context_dim is None: - context_dim = dim + self.context_pre_only = context_pre_only - self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim) - self.attn_norm_x = AdaLayerNorm(dim) + self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + self.attn_norm_x = AdaLayerNormZero(dim) self.attn = Attention( processor=JointAttnProcessor(), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, - context_dim=context_dim, + context_dim=dim, context_pre_only=context_pre_only, - qk_norm=qk_norm, ) if not context_pre_only: - self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6) - self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh") + self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") else: self.ff_norm_c = None self.ff_c = None diff --git a/f5_tts/model/trainer.py b/f5_tts/model/trainer.py index d9ab4a8c76de7ac7c89ea29e43ef2b79f0da1862..26970a32d047bba9fb158053ea2628971da3005b 100644 --- a/f5_tts/model/trainer.py +++ b/f5_tts/model/trainer.py @@ -32,7 +32,7 @@ class Trainer: save_per_updates=1000, keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints checkpoint_path=None, - batch_size_per_gpu=32, + batch_size=32, batch_size_type: str = "sample", max_samples=32, grad_accumulation_steps=1, @@ -40,7 +40,7 @@ class Trainer: noise_scheduler: str | None = None, duration_predictor: torch.nn.Module | None = None, logger: str | None = "wandb", # "wandb" | "tensorboard" | None - wandb_project="test_f5-tts", + wandb_project="test_e2-tts", wandb_run_name="test_run", wandb_resume_id: str = None, log_samples: bool = False, @@ -51,7 +51,6 @@ class Trainer: mel_spec_type: str = "vocos", # "vocos" | "bigvgan" is_local_vocoder: bool = False, # use local path vocoder local_vocoder_path: str = "", # local vocoder path - cfg_dict: dict = dict(), # training config ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) @@ -73,23 +72,21 @@ class Trainer: else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} - if not cfg_dict: - cfg_dict = { + self.accelerator.init_trackers( + project_name=wandb_project, + init_kwargs=init_kwargs, + config={ "epochs": epochs, "learning_rate": learning_rate, "num_warmup_updates": num_warmup_updates, - "batch_size_per_gpu": batch_size_per_gpu, + "batch_size": batch_size, "batch_size_type": batch_size_type, "max_samples": max_samples, "grad_accumulation_steps": grad_accumulation_steps, "max_grad_norm": max_grad_norm, + "gpus": self.accelerator.num_processes, "noise_scheduler": noise_scheduler, - } - cfg_dict["gpus"] = self.accelerator.num_processes - self.accelerator.init_trackers( - project_name=wandb_project, - init_kwargs=init_kwargs, - config=cfg_dict, + }, ) elif self.logger == "tensorboard": @@ -114,9 +111,9 @@ class Trainer: self.save_per_updates = save_per_updates self.keep_last_n_checkpoints = keep_last_n_checkpoints self.last_per_updates = default(last_per_updates, save_per_updates) - self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts") + self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") - self.batch_size_per_gpu = batch_size_per_gpu + self.batch_size = batch_size self.batch_size_type = batch_size_type self.max_samples = max_samples self.grad_accumulation_steps = grad_accumulation_steps @@ -182,7 +179,7 @@ class Trainer: if ( not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) - or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path)) + or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path)) ): return 0 @@ -194,7 +191,7 @@ class Trainer: all_checkpoints = [ f for f in os.listdir(self.checkpoint_path) - if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors")) + if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt") ] # First try to find regular training checkpoints @@ -208,16 +205,8 @@ class Trainer: # If no training checkpoints, use pretrained model latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_")) - if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint - from safetensors.torch import load_file - - checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu") - checkpoint = {"ema_model_state_dict": checkpoint} - elif latest_checkpoint.endswith(".pt"): - # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ - checkpoint = torch.load( - f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu" - ) + # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ + checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") # patch for backward compatibility, 305e3ea for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]: @@ -282,7 +271,7 @@ class Trainer: num_workers=num_workers, pin_memory=True, persistent_workers=True, - batch_size=self.batch_size_per_gpu, + batch_size=self.batch_size, shuffle=True, generator=generator, ) @@ -291,10 +280,10 @@ class Trainer: sampler = SequentialSampler(train_dataset) batch_sampler = DynamicBatchSampler( sampler, - self.batch_size_per_gpu, + self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, # This enables reproducible shuffling - drop_residual=False, + drop_last=False, ) train_dataloader = DataLoader( train_dataset, diff --git a/f5_tts/model/utils.py b/f5_tts/model/utils.py index e8811274dc25245703e78048314a98b69afdcedd..040f965628ffb72f2d7d3c665bdf322e6ce4bed5 100644 --- a/f5_tts/model/utils.py +++ b/f5_tts/model/utils.py @@ -109,7 +109,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): - if use "byte", set to 256 (unicode byte range) """ if tokenizer in ["pinyin", "char"]: - tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}/vocab.txt") + tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") with open(tokenizer_path, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): @@ -133,12 +133,22 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): # convert char to pinyin +jieba.initialize() +print("Word segmentation module jieba initialized.\n") -def convert_char_to_pinyin(text_list, polyphone=True): - if jieba.dt.initialized is False: - jieba.default_logger.setLevel(50) # CRITICAL - jieba.initialize() +# def convert_char_to_pinyin(text_list, polyphone=True): +# final_text_list = [] +# for text in text_list: +# char_list = [char for char in text if char not in "。,、;:?!《》【】—…:;\"()[]{}"] +# final_text_list.append(char_list) +# # print(final_text_list) +# return final_text_list + +# def convert_char_to_pinyin(text_list, polyphone=True): +# final_text_list = [char for char in text_list if char not in "。,、;:?!《》【】—…:;?!\"()[]{}"] +# return final_text_list +def convert_char_to_pinyin(text_list, polyphone=True): final_text_list = [] custom_trans = str.maketrans( {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} @@ -174,13 +184,11 @@ def convert_char_to_pinyin(text_list, polyphone=True): else: char_list.append(c) final_text_list.append(char_list) - + # print(final_text_list) return final_text_list - # filter func for dirty data with many repetitions - def repetition_found(text, length=2, tolerance=10): pattern_count = defaultdict(int) for i in range(len(text) - length + 1): diff --git a/f5_tts/scripts/count_max_epoch.py b/f5_tts/scripts/count_max_epoch.py index fe291e52f636aea1b61eabca6d3279a33e664c94..18d36df332e8b93bd2760e4f8d9f4b2354740286 100644 --- a/f5_tts/scripts/count_max_epoch.py +++ b/f5_tts/scripts/count_max_epoch.py @@ -9,7 +9,7 @@ mel_hop_length = 256 mel_sampling_rate = 24000 # target -wanted_max_updates = 1200000 +wanted_max_updates = 1000000 # train params gpus = 8 diff --git a/f5_tts/socket_client.py b/f5_tts/socket_client.py deleted file mode 100644 index 4cad5e7178eec1758b7d999c64842ee99e410971..0000000000000000000000000000000000000000 --- a/f5_tts/socket_client.py +++ /dev/null @@ -1,61 +0,0 @@ -import socket -import asyncio -import pyaudio -import numpy as np -import logging -import time - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): - client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port))) - - start_time = time.time() - first_chunk_time = None - - async def play_audio_stream(): - nonlocal first_chunk_time - p = pyaudio.PyAudio() - stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048) - - try: - while True: - data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192) - if not data: - break - if data == b"END": - logger.info("End of audio received.") - break - - audio_array = np.frombuffer(data, dtype=np.float32) - stream.write(audio_array.tobytes()) - - if first_chunk_time is None: - first_chunk_time = time.time() - - finally: - stream.stop_stream() - stream.close() - p.terminate() - - logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds") - - try: - data_to_send = f"{text}".encode("utf-8") - await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send) - await play_audio_stream() - - except Exception as e: - logger.error(f"Error in listen_to_F5TTS: {e}") - - finally: - client_socket.close() - - -if __name__ == "__main__": - text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components" - - asyncio.run(listen_to_F5TTS(text_to_send)) diff --git a/f5_tts/socket_server.py b/f5_tts/socket_server.py index 344b1d7ab3a3bcaf5a04080252becc39d2bd6fb9..fc5dab4dfcf9d5af55f4813db32626657b506381 100644 --- a/f5_tts/socket_server.py +++ b/f5_tts/socket_server.py @@ -1,75 +1,21 @@ import argparse import gc -import logging -import numpy as np -import queue import socket import struct -import threading +import torch +import torchaudio import traceback -import wave from importlib.resources import files +from threading import Thread -import torch -import torchaudio -from huggingface_hub import hf_hub_download -from omegaconf import OmegaConf - -from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config -from f5_tts.infer.utils_infer import ( - chunk_text, - preprocess_ref_audio_text, - load_vocoder, - load_model, - infer_batch_process, -) - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class AudioFileWriterThread(threading.Thread): - """Threaded file writer to avoid blocking the TTS streaming process.""" - - def __init__(self, output_file, sampling_rate): - super().__init__() - self.output_file = output_file - self.sampling_rate = sampling_rate - self.queue = queue.Queue() - self.stop_event = threading.Event() - self.audio_data = [] - - def run(self): - """Process queued audio data and write it to a file.""" - logger.info("AudioFileWriterThread started.") - with wave.open(self.output_file, "wb") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(self.sampling_rate) - - while not self.stop_event.is_set() or not self.queue.empty(): - try: - chunk = self.queue.get(timeout=0.1) - if chunk is not None: - chunk = np.int16(chunk * 32767) - self.audio_data.append(chunk) - wf.writeframes(chunk.tobytes()) - except queue.Empty: - continue - - def add_chunk(self, chunk): - """Add a new chunk to the queue.""" - self.queue.put(chunk) - - def stop(self): - """Stop writing and ensure all queued data is written.""" - self.stop_event.set() - self.join() - logger.info("Audio writing completed.") +from cached_path import cached_path + +from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model +from model.backbones.dit import DiT class TTSStreamingProcessor: - def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): + def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): self.device = device or ( "cuda" if torch.cuda.is_available() @@ -79,135 +25,124 @@ class TTSStreamingProcessor: if torch.backends.mps.is_available() else "cpu" ) - model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) - self.model_cls = globals()[model_cfg.model.backbone] - self.model_arc = model_cfg.model.arch - self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type - self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate - - self.model = self.load_ema_model(ckpt_file, vocab_file, dtype) - self.vocoder = self.load_vocoder_model() - - self.update_reference(ref_audio, ref_text) - self._warm_up() - self.file_writer_thread = None - self.first_package = True - def load_ema_model(self, ckpt_file, vocab_file, dtype): - return load_model( - self.model_cls, - self.model_arc, + # Load the model using the provided checkpoint and vocab files + self.model = load_model( + model_cls=DiT, + model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), ckpt_path=ckpt_file, - mel_spec_type=self.mel_spec_type, + mel_spec_type="vocos", # or "bigvgan" depending on vocoder vocab_file=vocab_file, ode_method="euler", use_ema=True, device=self.device, ).to(self.device, dtype=dtype) - def load_vocoder_model(self): - return load_vocoder(vocoder_name=self.mel_spec_type, is_local=False, local_path=None, device=self.device) + # Load the vocoder + self.vocoder = load_vocoder(is_local=False) - def update_reference(self, ref_audio, ref_text): - self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text) - self.audio, self.sr = torchaudio.load(self.ref_audio) + # Set sampling rate for streaming + self.sampling_rate = 24000 # Consistency with client - ref_audio_duration = self.audio.shape[-1] / self.sr - ref_text_byte_len = len(self.ref_text.encode("utf-8")) - self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration)) - self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2) - self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4) + # Set reference audio and text + self.ref_audio = ref_audio + self.ref_text = ref_text + + # Warm up the model + self._warm_up() def _warm_up(self): - logger.info("Warming up the model...") + """Warm up the model with a dummy input to ensure it's ready for real-time processing.""" + print("Warming up the model...") + ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text) + audio, sr = torchaudio.load(ref_audio) gen_text = "Warm-up text for the model." - for _ in infer_batch_process( - (self.audio, self.sr), - self.ref_text, - [gen_text], - self.model, - self.vocoder, - progress=None, - device=self.device, - streaming=True, - ): - pass - logger.info("Warm-up completed.") - - def generate_stream(self, text, conn): - text_batches = chunk_text(text, max_chars=self.max_chars) - if self.first_package: - text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:] - text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:] - self.first_package = False - - audio_stream = infer_batch_process( - (self.audio, self.sr), - self.ref_text, - text_batches, + + # Pass the vocoder as an argument here + infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device) + print("Warm-up completed.") + + def generate_stream(self, text, play_steps_in_s=0.5): + """Generate audio in chunks and yield them in real-time.""" + # Preprocess the reference audio and text + ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text) + + # Load reference audio + audio, sr = torchaudio.load(ref_audio) + + # Run inference for the input text + audio_chunk, final_sample_rate, _ = infer_batch_process( + (audio, sr), + ref_text, + [text], self.model, self.vocoder, - progress=None, - device=self.device, - streaming=True, - chunk_size=2048, + device=self.device, # Pass vocoder here ) - # Reset the file writer thread - if self.file_writer_thread is not None: - self.file_writer_thread.stop() - self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate) - self.file_writer_thread.start() - - for audio_chunk, _ in audio_stream: - if len(audio_chunk) > 0: - logger.info(f"Generated audio chunk of size: {len(audio_chunk)}") + # Break the generated audio into chunks and send them + chunk_size = int(final_sample_rate * play_steps_in_s) - # Send audio chunk via socket - conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk)) + if len(audio_chunk) < chunk_size: + packed_audio = struct.pack(f"{len(audio_chunk)}f", *audio_chunk) + yield packed_audio + return - # Write to file asynchronously - self.file_writer_thread.add_chunk(audio_chunk) + for i in range(0, len(audio_chunk), chunk_size): + chunk = audio_chunk[i : i + chunk_size] - logger.info("Finished sending audio stream.") - conn.sendall(b"END") # Send end signal + # Check if it's the final chunk + if i + chunk_size >= len(audio_chunk): + chunk = audio_chunk[i:] - # Ensure all audio data is written before exiting - self.file_writer_thread.stop() + # Send the chunk if it is not empty + if len(chunk) > 0: + packed_audio = struct.pack(f"{len(chunk)}f", *chunk) + yield packed_audio -def handle_client(conn, processor): +def handle_client(client_socket, processor): try: - with conn: - conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - while True: - data = conn.recv(1024) - if not data: - processor.first_package = True - break - data_str = data.decode("utf-8").strip() - logger.info(f"Received text: {data_str}") - - try: - processor.generate_stream(data_str, conn) - except Exception as inner_e: - logger.error(f"Error during processing: {inner_e}") - traceback.print_exc() - break + while True: + # Receive data from the client + data = client_socket.recv(1024).decode("utf-8") + if not data: + break + + try: + # The client sends the text input + text = data.strip() + + # Generate and stream audio chunks + for audio_chunk in processor.generate_stream(text): + client_socket.sendall(audio_chunk) + + # Send end-of-audio signal + client_socket.sendall(b"END_OF_AUDIO") + + except Exception as inner_e: + print(f"Error during processing: {inner_e}") + traceback.print_exc() # Print the full traceback to diagnose the issue + break + except Exception as e: - logger.error(f"Error handling client: {e}") + print(f"Error handling client: {e}") traceback.print_exc() + finally: + client_socket.close() def start_server(host, port, processor): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind((host, port)) - s.listen() - logger.info(f"Server started on {host}:{port}") - while True: - conn, addr = s.accept() - logger.info(f"Connected by {addr}") - handle_client(conn, processor) + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind((host, port)) + server.listen(5) + print(f"Server listening on {host}:{port}") + + while True: + client_socket, addr = server.accept() + print(f"Accepted connection from {addr}") + client_handler = Thread(target=handle_client, args=(client_socket, processor)) + client_handler.start() if __name__ == "__main__": @@ -216,14 +151,9 @@ if __name__ == "__main__": parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", default=9998) - parser.add_argument( - "--model", - default="F5TTS_v1_Base", - help="The model name, e.g. F5TTS_v1_Base", - ) parser.add_argument( "--ckpt_file", - default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")), + default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")), help="Path to the model checkpoint file", ) parser.add_argument( @@ -251,7 +181,6 @@ if __name__ == "__main__": try: # Initialize the processor with the model and vocoder processor = TTSStreamingProcessor( - model=args.model, ckpt_file=args.ckpt_file, vocab_file=args.vocab_file, ref_audio=args.ref_audio, diff --git a/f5_tts/train/README.md b/f5_tts/train/README.md index 25d2380b1388dbd4344fcc76ed3caef843e3f917..a57577ff258f1d79dd06604f0f6c33753bcda26a 100644 --- a/f5_tts/train/README.md +++ b/f5_tts/train/README.md @@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process. accelerate config # .yaml files are under src/f5_tts/configs directory -accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml +accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml # possible to overwrite accelerate and hydra config -accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200 +accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200 ``` ### 2. Finetuning practice @@ -53,7 +53,7 @@ Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#1 The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results. -### 3. W&B Logging +### 3. Wandb Logging The `wandb/` dir will be created under path you run training/finetuning scripts. @@ -62,7 +62,7 @@ By default, the training script does NOT use logging (assuming you didn't manual To turn on wandb logging, you can either: 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login) -2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows: +2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows: On Mac & Linux: @@ -75,7 +75,7 @@ On Windows: ``` set WANDB_API_KEY= ``` -Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows: +Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows: ``` export WANDB_MODE=offline diff --git a/f5_tts/train/__pycache__/finetune_gradio.cpython-310.pyc b/f5_tts/train/__pycache__/finetune_gradio.cpython-310.pyc index 99749fa9f4e5c0d6410e72267479385a281a7b30..99c30bc05f3e1003c701b30cd712ec6b6bb30d4e 100644 Binary files a/f5_tts/train/__pycache__/finetune_gradio.cpython-310.pyc and b/f5_tts/train/__pycache__/finetune_gradio.cpython-310.pyc differ diff --git a/f5_tts/train/datasets/prepare_csv_wavs.py b/f5_tts/train/datasets/prepare_csv_wavs.py index 14794d6268e959b9246861609574760a1f2f9166..323a143a29ee4279be193276d0df0b8d248f2e9d 100644 --- a/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/f5_tts/train/datasets/prepare_csv_wavs.py @@ -24,7 +24,7 @@ from f5_tts.model.utils import ( ) -PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/your_training_dataset/vocab.txt") +PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt") def is_csv_wavs_format(input_dataset_dir): @@ -224,7 +224,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine voca_out_path = out_dir / "vocab.txt" if is_finetune: file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix() - # shutil.copy2(file_vocab_finetune, voca_out_path) # Không cần copy lại vocab, do đã thực hiện ở bước chuẩn bị dữ liệu + shutil.copy2(file_vocab_finetune, voca_out_path) else: with open(voca_out_path.as_posix(), "w") as f: for vocab in sorted(text_vocab_set): diff --git a/f5_tts/train/datasets/prepare_emilia.py b/f5_tts/train/datasets/prepare_emilia.py index d9b276afa68d671cee69f45cc16d2b12cd0859a4..581ffa6882d632274b312669b27bdae68729504e 100644 --- a/f5_tts/train/datasets/prepare_emilia.py +++ b/f5_tts/train/datasets/prepare_emilia.py @@ -206,14 +206,14 @@ def main(): if __name__ == "__main__": - max_workers = 32 + max_workers = 16 tokenizer = "pinyin" # "pinyin" | "char" polyphone = True - langs = ["ZH", "EN"] - dataset_dir = "/Emilia_Dataset/raw" - dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}" + langs = ["EN"] + dataset_dir = "data/datasetVN" + dataset_name = f"vnTTS_{'_'.join(langs)}_{tokenizer}" save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") diff --git a/f5_tts/train/datasets/prepare_libritts.py b/f5_tts/train/datasets/prepare_libritts.py index 2a35dd97980154500be715b41a41d6acae15361f..ad002a996dd5ac9d6b655d7763d6daaa0fb92ddc 100644 --- a/f5_tts/train/datasets/prepare_libritts.py +++ b/f5_tts/train/datasets/prepare_libritts.py @@ -11,6 +11,11 @@ from tqdm import tqdm import soundfile as sf from datasets.arrow_writer import ArrowWriter +from f5_tts.model.utils import ( + repetition_found, + convert_char_to_pinyin, +) + def deal_with_audio_dir(audio_dir): sub_result, durations = [], [] @@ -18,7 +23,7 @@ def deal_with_audio_dir(audio_dir): audio_lists = list(audio_dir.rglob("*.wav")) for line in audio_lists: - text_path = line.with_suffix(".normalized.txt") + text_path = line.with_suffix(".lab") text = open(text_path, "r").read().strip() duration = sf.info(line).duration if duration < 0.4 or duration > 30: @@ -76,13 +81,13 @@ def main(): if __name__ == "__main__": - max_workers = 36 + max_workers = 16 tokenizer = "char" # "pinyin" | "char" - SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"] - dataset_dir = "/LibriTTS" - dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "") + SUB_SET = ["mc"] + dataset_dir = "data/datasetVN" + dataset_name = f"vnTTS_{'_'.join(SUB_SET)}_{tokenizer}" save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") main() diff --git a/f5_tts/train/datasets/prepare_metadata.py b/f5_tts/train/datasets/prepare_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..7b033ec622437904b72b26bd5827fcaa3521b954 --- /dev/null +++ b/f5_tts/train/datasets/prepare_metadata.py @@ -0,0 +1,12 @@ +import glob +from tqdm import tqdm + +wavs_path = glob.glob("data/datasetVN/mc/mc1/*.wav") + +with open("data/vnTTS__char/metadata.csv", "w", encoding="utf8") as fw: + fw.write("audio_file|text\n") + for wav_path in tqdm(wavs_path): + wav_name = wav_path.split("/")[-1] + with open(wav_path.replace(".wav", ".lab"), "r", encoding="utf8") as fr: + text = fr.readlines()[0].replace("\n", "") + fw.write("wavs/" + wav_name + "|" + text + "\n") \ No newline at end of file diff --git a/f5_tts/train/finetune_cli.py b/f5_tts/train/finetune_cli.py index 3785179d418814b9349ae03b3335d42eb39765a0..463194dd873afdb9f4f877ce06203f4792a1de6c 100644 --- a/f5_tts/train/finetune_cli.py +++ b/f5_tts/train/finetune_cli.py @@ -1,13 +1,12 @@ import argparse import os import shutil -from importlib.resources import files from cached_path import cached_path - from f5_tts.model import CFM, UNetT, DiT, Trainer from f5_tts.model.utils import get_tokenizer from f5_tts.model.dataset import load_dataset +from importlib.resources import files # -------------------------- Dataset Settings --------------------------- # @@ -21,16 +20,21 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan' # -------------------------- Argument Parsing --------------------------- # def parse_args(): + # batch_size_per_gpu = 1000 settting for gpu 8GB + # batch_size_per_gpu = 1600 settting for gpu 12GB + # batch_size_per_gpu = 2000 settting for gpu 16GB + # batch_size_per_gpu = 3200 settting for gpu 24GB + + # num_warmup_updates = 300 for 5000 sample about 10 hours + + # change save_per_updates , last_per_updates change this value what you need , + parser = argparse.ArgumentParser(description="Train CFM Model") parser.add_argument( - "--exp_name", - type=str, - default="F5TTS_v1_Base", - choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], - help="Experiment name", + "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name" ) - parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") + parser.add_argument("--dataset_name", type=str, default="vnTTS_mc", help="Name of the dataset to use") parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training") parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU") parser.add_argument( @@ -39,7 +43,7 @@ def parse_args(): parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch") parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") - parser.add_argument("--epochs", type=int, default=1000, help="Number of training epochs") + parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates") parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates") parser.add_argument( @@ -50,7 +54,7 @@ def parse_args(): ) parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates") parser.add_argument("--finetune", action="store_true", help="Use Finetune") - parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint") + parser.add_argument("--pretrain", type=str, default="/mnt/d/ckpts/vn_tts_mc_vlog/pretrained_model_1200000.pt", help="the path to the checkpoint") parser.add_argument( "--tokenizer", type=str, default="char", choices=["pinyin", "char", "custom"], help="Tokenizer type" ) @@ -84,54 +88,19 @@ def main(): checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}")) # Model parameters based on experiment name - - if args.exp_name == "F5TTS_v1_Base": + if args.exp_name == "F5TTS_Base": wandb_resume_id = None model_cls = DiT - model_cfg = dict( - dim=1024, - depth=22, - heads=16, - ff_mult=2, - text_dim=512, - conv_layers=4, - ) - if args.finetune: - if args.pretrain is None: - ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")) - else: - ckpt_path = args.pretrain - - elif args.exp_name == "F5TTS_Base": - wandb_resume_id = None - model_cls = DiT - model_cfg = dict( - dim=1024, - depth=22, - heads=16, - ff_mult=2, - text_dim=512, - text_mask_padding=False, - conv_layers=4, - pe_attn_head=1, - ) + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) if args.finetune: if args.pretrain is None: ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) else: ckpt_path = args.pretrain - elif args.exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT - model_cfg = dict( - dim=1024, - depth=24, - heads=16, - ff_mult=4, - text_mask_padding=False, - pe_attn_head=1, - ) + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if args.finetune: if args.pretrain is None: ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) @@ -149,10 +118,8 @@ def main(): if not os.path.isfile(file_checkpoint): shutil.copy2(ckpt_path, file_checkpoint) print("copy checkpoint for finetune") - print("Pretrained checkpoint được sử dụng: " + file_checkpoint) # Use the tokenizer and tokenizer_path provided in the command line arguments - tokenizer = args.tokenizer if tokenizer == "custom": if not args.tokenizer_path: @@ -163,8 +130,8 @@ def main(): vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - print("vocab : ", vocab_size) - print("vocoder : ", mel_spec_type) + print("\nvocab : ", vocab_size) + print("\nvocoder : ", mel_spec_type) mel_spec_kwargs = dict( n_fft=n_fft, @@ -189,7 +156,7 @@ def main(): save_per_updates=args.save_per_updates, keep_last_n_checkpoints=args.keep_last_n_checkpoints, checkpoint_path=checkpoint_path, - batch_size_per_gpu=args.batch_size_per_gpu, + batch_size=args.batch_size_per_gpu, batch_size_type=args.batch_size_type, max_samples=args.max_samples, grad_accumulation_steps=args.grad_accumulation_steps, diff --git a/f5_tts/train/finetune_gradio.py b/f5_tts/train/finetune_gradio.py index 578c93104b930a1de97cfdd77b1193c66eceaccd..6c5f465fb46099cd71f5aef78f3d285be88e190b 100644 --- a/f5_tts/train/finetune_gradio.py +++ b/f5_tts/train/finetune_gradio.py @@ -1,36 +1,36 @@ +import threading +import queue +import re + import gc import json -import numpy as np import os import platform import psutil -import queue import random -import re import signal import shutil import subprocess import sys import tempfile -import threading import time from glob import glob -from importlib.resources import files -from scipy.io import wavfile import click import gradio as gr import librosa +import numpy as np import torch import torchaudio -from cached_path import cached_path from datasets import Dataset as Dataset_ from datasets.arrow_writer import ArrowWriter -from safetensors.torch import load_file, save_file - +from safetensors.torch import save_file +from scipy.io import wavfile +from cached_path import cached_path from f5_tts.api import F5TTS from f5_tts.model.utils import convert_char_to_pinyin from f5_tts.infer.utils_infer import transcribe +from importlib.resources import files training_process = None @@ -43,7 +43,7 @@ last_ema = None path_data = str(files("f5_tts").joinpath("../../data")) -path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts")) +path_project_ckpts = "/mnt/d/ckpts" file_train = str(files("f5_tts").joinpath("train/finetune_cli.py")) device = ( @@ -118,16 +118,16 @@ def load_settings(project_name): # Default settings default_settings = { - "exp_name": "F5TTS_v1_Base", - "learning_rate": 1e-5, - "batch_size_per_gpu": 1, - "batch_size_type": "sample", + "exp_name": "F5TTS_Base", + "learning_rate": 1e-05, + "batch_size_per_gpu": 1000, + "batch_size_type": "frame", "max_samples": 64, - "grad_accumulation_steps": 4, + "grad_accumulation_steps": 1, "max_grad_norm": 1, "epochs": 100, - "num_warmup_updates": 100, - "save_per_updates": 500, + "num_warmup_updates": 2, + "save_per_updates": 300, "keep_last_n_checkpoints": -1, "last_per_updates": 100, "finetune": True, @@ -362,18 +362,18 @@ def terminate_process(pid): def start_training( dataset_name="", - exp_name="F5TTS_v1_Base", - learning_rate=1e-5, - batch_size_per_gpu=1, - batch_size_type="sample", + exp_name="F5TTS_Base", + learning_rate=1e-4, + batch_size_per_gpu=400, + batch_size_type="frame", max_samples=64, - grad_accumulation_steps=4, + grad_accumulation_steps=1, max_grad_norm=1.0, - epochs=100, - num_warmup_updates=100, - save_per_updates=500, + epochs=11, + num_warmup_updates=200, + save_per_updates=400, keep_last_n_checkpoints=-1, - last_per_updates=100, + last_per_updates=800, finetune=True, file_checkpoint_train="", tokenizer_type="pinyin", @@ -797,14 +797,14 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): print(f"Error processing {file_audio}: {e}") continue - if duration < 1 or duration > 30: - if duration > 30: - error_files.append([file_audio, "duration > 30 sec"]) + if duration < 1 or duration > 25: + if duration > 25: + error_files.append([file_audio, "duration > 25 sec"]) if duration < 1: error_files.append([file_audio, "duration < 1 sec "]) continue if len(text) < 3: - error_files.append([file_audio, "very short text length 3"]) + error_files.append([file_audio, "very small text len 3"]) continue text = clear_text(text) @@ -871,37 +871,40 @@ def check_user(value): def calculate_train( name_project, - epochs, - learning_rate, - batch_size_per_gpu, batch_size_type, max_samples, + learning_rate, num_warmup_updates, + save_per_updates, + last_per_updates, finetune, ): path_project = os.path.join(path_data, name_project) - file_duration = os.path.join(path_project, "duration.json") - - hop_length = 256 - sampling_rate = 24000 + file_duraction = os.path.join(path_project, "duration.json") - if not os.path.isfile(file_duration): + if not os.path.isfile(file_duraction): return ( - epochs, - learning_rate, - batch_size_per_gpu, + 1000, max_samples, num_warmup_updates, + save_per_updates, + last_per_updates, "project not found !", + learning_rate, ) - with open(file_duration, "r") as file: + with open(file_duraction, "r") as file: data = json.load(file) duration_list = data["duration"] - max_sample_length = max(duration_list) * sampling_rate / hop_length - total_samples = len(duration_list) - total_duration = sum(duration_list) + samples = len(duration_list) + hours = sum(duration_list) / 3600 + + # if torch.cuda.is_available(): + # gpu_properties = torch.cuda.get_device_properties(0) + # total_memory = gpu_properties.total_memory / (1024**3) + # elif torch.backends.mps.is_available(): + # total_memory = psutil.virtual_memory().available / (1024**3) if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() @@ -909,39 +912,64 @@ def calculate_train( for i in range(gpu_count): gpu_properties = torch.cuda.get_device_properties(i) total_memory += gpu_properties.total_memory / (1024**3) # in GB + elif torch.xpu.is_available(): gpu_count = torch.xpu.device_count() total_memory = 0 for i in range(gpu_count): gpu_properties = torch.xpu.get_device_properties(i) total_memory += gpu_properties.total_memory / (1024**3) + elif torch.backends.mps.is_available(): gpu_count = 1 total_memory = psutil.virtual_memory().available / (1024**3) - avg_gpu_memory = total_memory / gpu_count - - # rough estimate of batch size if batch_size_type == "frame": - batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length)) - elif batch_size_type == "sample": - batch_size_per_gpu = int(200 / (total_duration / total_samples)) - - if total_samples < 64: - max_samples = int(total_samples * 0.25) - - num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05)) - - # take 1.2M updates as the maximum - max_updates = 1200000 + batch = int(total_memory * 0.5) + batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch) + batch_size_per_gpu = int(38400 / batch) + else: + batch_size_per_gpu = int(total_memory / 8) + batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu) + batch = batch_size_per_gpu - if batch_size_type == "frame": - mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate - updates_per_epoch = total_duration / mini_batch_duration - elif batch_size_type == "sample": - updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count + if batch_size_per_gpu <= 0: + batch_size_per_gpu = 1 - epochs = int(max_updates / updates_per_epoch) + if samples < 64: + max_samples = int(samples * 0.25) + else: + max_samples = 64 + + num_warmup_updates = int(samples * 0.05) + save_per_updates = int(samples * 0.10) + last_per_updates = int(save_per_updates * 0.25) + + max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples) + num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates) + save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates) + last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates) + if last_per_updates <= 0: + last_per_updates = 2 + + total_hours = hours + mel_hop_length = 256 + mel_sampling_rate = 24000 + + # target + wanted_max_updates = 1000000 + + # train params + gpus = gpu_count + frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200 + grad_accum = 1 + + # intermediate + mini_batch_frames = frames_per_gpu * grad_accum * gpus + mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600 + updates_per_epoch = total_hours / mini_batch_hours + # steps_per_epoch = updates_per_epoch * grad_accum + epochs = wanted_max_updates / updates_per_epoch if finetune: learning_rate = 1e-5 @@ -949,12 +977,14 @@ def calculate_train( learning_rate = 7.5e-5 return ( - epochs, - learning_rate, batch_size_per_gpu, max_samples, num_warmup_updates, - total_samples, + save_per_updates, + last_per_updates, + samples, + learning_rate, + int(epochs), ) @@ -991,11 +1021,7 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - if ckpt_path.endswith(".safetensors"): - ckpt = load_file(ckpt_path, device="cpu") - ckpt = {"ema_model_state_dict": ckpt} - elif ckpt_path.endswith(".pt"): - ckpt = torch.load(ckpt_path, map_location="cpu") + ckpt = torch.load(ckpt_path, map_location="cpu") ema_sd = ckpt.get("ema_model_state_dict", {}) embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight" @@ -1063,11 +1089,9 @@ def vocab_extend(project_name, symbols, model_type): with open(file_vocab_project, "w", encoding="utf-8") as f: f.write("\n".join(vocab)) - if model_type == "F5TTS_v1_Base": - ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")) - elif model_type == "F5TTS_Base": + if model_type == "F5-TTS": ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) - elif model_type == "E2TTS_Base": + else: ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) vocab_size_new = len(miss_symbols) @@ -1077,12 +1101,12 @@ def vocab_extend(project_name, symbols, model_type): os.makedirs(new_ckpt_path, exist_ok=True) # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py - new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path)) + new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt") - size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new) + size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new-1) vocab_new = "\n".join(miss_symbols) - return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}" + return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new-1}\nnew symbols :\n{vocab_new}" def vocab_check(project_name): @@ -1207,21 +1231,21 @@ def infer( vocab_file = os.path.join(path_data, project, "vocab.txt") tts_api = F5TTS( - model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema + model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema ) print("update >> ", device_test, file_checkpoint, use_ema) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: tts_api.infer( - ref_file=ref_audio, - ref_text=ref_text.lower().strip(), gen_text=gen_text.lower().strip(), + ref_text=ref_text.lower().strip(), + ref_file=ref_audio, nfe_step=nfe_step, - speed=speed, - remove_silence=remove_silence, file_wave=f.name, + speed=speed, seed=seed, + remove_silence=remove_silence, ) return f.name, tts_api.device, str(tts_api.seed) @@ -1380,14 +1404,14 @@ def get_audio_select(file_sample): with gr.Blocks() as app: gr.Markdown( """ -# F5 TTS Automatic Finetune +# E2/F5 TTS Automatic Finetune -This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models: +This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models: * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching) * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS) -The pretrained checkpoints support English and Chinese. +The checkpoints support English and Chinese. For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143) """ @@ -1464,9 +1488,7 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder. ```""") - exp_name_extend = gr.Radio( - label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" - ) + exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") with gr.Row(): txt_extend = gr.Textbox( @@ -1535,9 +1557,9 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare] ) - with gr.TabItem("Train Model"): + with gr.TabItem("Train Data"): gr.Markdown("""```plaintext -The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space. +The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed. If you encounter a memory error, try reducing the batch size per GPU to a smaller number. ```""") with gr.Row(): @@ -1551,13 +1573,11 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="") with gr.Row(): - exp_name = gr.Radio( - label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" - ) + exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base") learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5) with gr.Row(): - batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200) + batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000) max_samples = gr.Number(label="Max Samples", value=64) with gr.Row(): @@ -1565,23 +1585,23 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0) with gr.Row(): - epochs = gr.Number(label="Epochs", value=100) - num_warmup_updates = gr.Number(label="Warmup Updates", value=100) + epochs = gr.Number(label="Epochs", value=10) + num_warmup_updates = gr.Number(label="Warmup Updates", value=2) with gr.Row(): - save_per_updates = gr.Number(label="Save per Updates", value=500) + save_per_updates = gr.Number(label="Save per Updates", value=300) keep_last_n_checkpoints = gr.Number( label="Keep Last N Checkpoints", value=-1, step=1, precision=0, - info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints", + info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints", ) last_per_updates = gr.Number(label="Last per Updates", value=100) with gr.Row(): ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer") - mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16") + mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none") cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb") start_button = gr.Button("Start Training") stop_button = gr.Button("Stop Training", interactive=False) @@ -1698,21 +1718,23 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle fn=calculate_train, inputs=[ cm_project, - epochs, - learning_rate, - batch_size_per_gpu, batch_size_type, max_samples, + learning_rate, num_warmup_updates, + save_per_updates, + last_per_updates, ch_finetune, ], outputs=[ - epochs, - learning_rate, batch_size_per_gpu, max_samples, num_warmup_updates, + save_per_updates, + last_per_updates, lb_samples, + learning_rate, + epochs, ], ) @@ -1722,25 +1744,25 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle def setup_load_settings(): output_components = [ - exp_name, - learning_rate, - batch_size_per_gpu, - batch_size_type, - max_samples, - grad_accumulation_steps, - max_grad_norm, - epochs, - num_warmup_updates, - save_per_updates, - keep_last_n_checkpoints, - last_per_updates, - ch_finetune, - file_checkpoint_train, - tokenizer_type, - tokenizer_file, - mixed_precision, - cd_logger, - ch_8bit_adam, + exp_name, # 1 + learning_rate, # 2 + batch_size_per_gpu, # 3 + batch_size_type, # 4 + max_samples, # 5 + grad_accumulation_steps, # 6 + max_grad_norm, # 7 + epochs, # 8 + num_warmup_updates, # 9 + save_per_updates, # 10 + keep_last_n_checkpoints, # 11 + last_per_updates, # 12 + ch_finetune, # 13 + file_checkpoint_train, # 14 + tokenizer_type, # 15 + tokenizer_file, # 16 + mixed_precision, # 17 + cd_logger, # 18 + ch_8bit_adam, # 19 ] return output_components @@ -1762,9 +1784,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle gr.Markdown("""```plaintext SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random ```""") - exp_name = gr.Radio( - label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" - ) + exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False) with gr.Row(): @@ -1818,9 +1838,9 @@ SOS: Check the use_ema setting (True or False) for your model to see what works bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) - with gr.TabItem("Prune Checkpoint"): + with gr.TabItem("Reduce Checkpoint"): gr.Markdown("""```plaintext -Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining. +Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or fine-tuning afterward, but it cannot be used to continue training. ```""") txt_path_checkpoint = gr.Text(label="Path to Checkpoint:") txt_path_checkpoint_small = gr.Text(label="Path to Output:") diff --git a/f5_tts/train/train.py b/f5_tts/train/train.py index 2e191a3707a23a710bc7d98510ee7c3df50ea4ca..ade54be5fd53fecd9b3e7073bbe1b9d579518888 100644 --- a/f5_tts/train/train.py +++ b/f5_tts/train/train.py @@ -4,9 +4,8 @@ import os from importlib.resources import files import hydra -from omegaconf import OmegaConf -from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config +from f5_tts.model import CFM, DiT, Trainer, UNetT from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer @@ -15,13 +14,9 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None) def main(cfg): - model_cls = globals()[cfg.model.backbone] - model_arc = cfg.model.arch tokenizer = cfg.model.tokenizer mel_spec_type = cfg.model.mel_spec.mel_spec_type - exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}" - wandb_resume_id = None # set text tokenizer if tokenizer != "custom": @@ -31,8 +26,14 @@ def main(cfg): vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) # set model + if "F5TTS" in cfg.model.name: + model_cls = DiT + elif "E2TTS" in cfg.model.name: + model_cls = UNetT + wandb_resume_id = None + model = CFM( - transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels), + transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels), mel_spec_kwargs=cfg.model.mel_spec, vocab_char_map=vocab_char_map, ) @@ -44,9 +45,9 @@ def main(cfg): learning_rate=cfg.optim.learning_rate, num_warmup_updates=cfg.optim.num_warmup_updates, save_per_updates=cfg.ckpts.save_per_updates, - keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints, + keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1), checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")), - batch_size_per_gpu=cfg.datasets.batch_size_per_gpu, + batch_size=cfg.datasets.batch_size_per_gpu, batch_size_type=cfg.datasets.batch_size_type, max_samples=cfg.datasets.max_samples, grad_accumulation_steps=cfg.optim.grad_accumulation_steps, @@ -56,12 +57,11 @@ def main(cfg): wandb_run_name=exp_name, wandb_resume_id=wandb_resume_id, last_per_updates=cfg.ckpts.last_per_updates, - log_samples=cfg.ckpts.log_samples, + log_samples=True, bnb_optimizer=cfg.optim.bnb_optimizer, mel_spec_type=mel_spec_type, is_local_vocoder=cfg.model.vocoder.is_local, local_vocoder_path=cfg.model.vocoder.local_path, - cfg_dict=OmegaConf.to_container(cfg, resolve=True), ) train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)