File size: 2,577 Bytes
9df6489 b9f4652 27a9215 9df6489 27a9215 5472d87 9df6489 5472d87 9df6489 b9f4652 9df6489 27a9215 9df6489 27a9215 9df6489 27a9215 9df6489 27a9215 9df6489 27a9215 9df6489 27a9215 9df6489 27a9215 9df6489 27a9215 c44256f 27a9215 5472d87 82cefa3 9df6489 27a9215 9df6489 27a9215 9df6489 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# training script.
import os
from importlib.resources import files
import hydra
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(cfg):
tokenizer = cfg.model.tokenizer
mel_spec_type = cfg.model.mel_spec.mel_spec_type
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
# set text tokenizer
if tokenizer != "custom":
tokenizer_path = cfg.datasets.name
else:
tokenizer_path = cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
if "F5TTS" in cfg.model.name:
model_cls = DiT
elif "E2TTS" in cfg.model.name:
model_cls = UNetT
wandb_resume_id = None
model = CFM(
transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
# init trainer
trainer = Trainer(
model,
epochs=cfg.optim.epochs,
learning_rate=cfg.optim.learning_rate,
num_warmup_updates=cfg.optim.num_warmup_updates,
save_per_updates=cfg.ckpts.save_per_updates,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
batch_size=cfg.datasets.batch_size_per_gpu,
batch_size_type=cfg.datasets.batch_size_type,
max_samples=cfg.datasets.max_samples,
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
max_grad_norm=cfg.optim.max_grad_norm,
logger=cfg.ckpts.logger,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_steps=cfg.ckpts.last_per_steps,
log_samples=True,
bnb_optimizer=cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path,
)
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
trainer.train(
train_dataset,
num_workers=cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()
|