Coqui-Xtts-Demo / gradio_app.py
Jimmy Vu
Update split_sentence
0210dff
import os
import sys
import time
import hashlib
import site
import subprocess
import gradio as gr
import torch
import torchaudio
import numpy as np
from underthesea import sent_tokenize
from df.enhance import enhance, init_df, load_audio, save_audio
from huggingface_hub import snapshot_download
from langdetect import detect
from utils.vietnamese_normalization import normalize_vietnamese_text
from utils.logger import setup_logger
from utils.sentence import split_sentence, merge_sentences
import warnings
warnings.filterwarnings("ignore")
logger = setup_logger(__file__)
df_model, df_state = None, None
APP_DIR = os.path.dirname(os.path.abspath(__file__))
checkpoint_dir=f"{APP_DIR}/cache"
temp_dir=f"{APP_DIR}/cache/temp/"
sample_audio_dir=f"{APP_DIR}/cache/audio_samples/"
enhance_audio_dir=f"{APP_DIR}/cache/audio_enhances/"
speakers_dir=f"{APP_DIR}/cache/speakers/"
for d in [checkpoint_dir, temp_dir, sample_audio_dir, enhance_audio_dir]:
os.makedirs(d, exist_ok=True)
language_dict = {'English': 'en', 'Español (Spanish)': 'es', 'Français (French)': 'fr',
'Deutsch (German)': 'de', 'Italiano (Italian)': 'it', 'Português (Portuguese)': 'pt',
'Polski (Polish)': 'pl', 'Türkçe (Turkish)': 'tr', 'Русский (Russian)': 'ru',
'Nederlands (Dutch)': 'nl', 'Čeština (Czech)': 'cs', 'العربية (Arabic)': 'ar', '中文 (Chinese)': 'zh-cn',
'Magyar nyelv (Hungarian)': 'hu', '한국어 (Korean)': 'ko', '日本語 (Japanese)': 'ja',
'Tiếng Việt (Vietnamese)': 'vi', 'Auto': 'auto'}
default_language = 'Auto'
language_codes = [v for _, v in language_dict.items()]
def lang_detect(text):
try:
lang = detect(text)
if lang == 'zh-tw':
return 'zh-cn'
return lang if lang in language_codes else 'en'
except:
return 'en'
input_text_max_length = 3000
use_deepspeed = False
try:
import spaces
except ImportError:
from utils import spaces
xtts_model = None
def load_model():
global xtts_model
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
repo_id = "jimmyvu/xtts"
snapshot_download(repo_id=repo_id,
local_dir=checkpoint_dir,
allow_patterns=["*.safetensors", "*.wav", "*.json"],
ignore_patterns="*.pth")
config = XttsConfig()
config.load_json(os.path.join(checkpoint_dir, "config.json"))
xtts_model = Xtts.init_from_config(config)
logger.info("Loading model...")
xtts_model.load_safetensors_checkpoint(
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
)
if torch.cuda.is_available():
xtts_model.cuda()
logger.info(f"Successfully loaded model from {checkpoint_dir}")
load_model()
def download_unidic():
site_package_path = site.getsitepackages()[0]
unidic_path = os.path.join(site_package_path, "unidic", "dicdir")
if not os.path.exists(unidic_path):
logger.info("Downloading unidic...")
subprocess.call([sys.executable, "-m", "unidic", "download"])
download_unidic()
default_speaker_reference_audio = os.path.join(sample_audio_dir, 'harvard.wav')
default_speaker_id = "Aaron Dreschner"
def validate_input(input_text, language):
log_messages = ""
if len(input_text) > input_text_max_length:
gr.Warning("Text is too long! Please provide a shorter text.")
log_messages += "Text is too long! Please provide a shorter text.\n"
return log_messages
language_code = language_dict.get(language, 'en')
logger.info(f"Language [{language}], code: [{language_code}]")
lang = lang_detect(input_text) if language == 'Auto' else language_code
if (lang not in ['ja', 'kr', 'zh-cn'] and len(input_text.split()) < 2) or \
(lang in ['ja', 'kr', 'zh-cn'] and len(input_text) < 2):
gr.Warning("Text is too short! Please provide a longer text.")
log_messages += "Text is too short! Please provide a longer text.\n"
return log_messages
@spaces.GPU
def synthesize_speech(input_text, speaker_id, temperature=0.3, top_p=0.85, top_k=50, repetition_penalty=10.0, language='Auto'):
"""Process text and generate audio."""
global xtts_model
log_messages = validate_input(input_text, language)
if log_messages:
return None, log_messages
start = time.time()
logger.info(f"Start processing text: {input_text[:30]}... [length: {len(input_text)}]")
# inference
wav_array, num_of_tokens = inference(input_text=input_text,
language=language,
speaker_id=speaker_id,
gpt_cond_latent=None,
speaker_embedding=None,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=float(repetition_penalty))
end = time.time()
processing_time = end - start
tokens_per_second = num_of_tokens/processing_time
logger.info(f"End processing text: {input_text[:30]}")
message = f"💡 {tokens_per_second:.1f} tok/s • {num_of_tokens} tokens • in {processing_time:.2f} seconds"
logger.info(message)
log_messages += message
return (24000, wav_array), log_messages
@spaces.GPU
def generate_speech(input_text, speaker_reference_audio, enhance_speech, temperature=0.3, top_p=0.85, top_k=50, repetition_penalty=10.0, language='Auto'):
"""Process text and generate audio."""
global df_model, df_state, xtts_model
log_messages = validate_input(input_text, language)
if log_messages:
return None, log_messages
if not speaker_reference_audio:
gr.Warning("Please provide at least one reference audio!")
log_messages += "Please provide at least one reference audio!\n"
return None, log_messages
start = time.time()
logger.info(f"Start processing text: {input_text[:30]}... [length: {len(input_text)}]")
if enhance_speech:
logger.info("Enhancing reference audio...")
_, audio_file = os.path.split(speaker_reference_audio)
enhanced_audio_path = os.path.join(enhance_audio_dir, f"{audio_file}.enh.wav")
if not os.path.exists(enhanced_audio_path):
if not df_model:
df_model, df_state, _ = init_df()
audio, _ = load_audio(speaker_reference_audio, sr=df_state.sr())
# denoise audio
enhanced_audio = enhance(df_model, df_state, audio)
# save enhanced audio
save_audio(enhanced_audio_path, enhanced_audio, sr=df_state.sr())
speaker_reference_audio = enhanced_audio_path
gpt_cond_latent, speaker_embedding = xtts_model.get_conditioning_latents(
audio_path=speaker_reference_audio,
gpt_cond_len=xtts_model.config.gpt_cond_len,
max_ref_length=xtts_model.config.max_ref_len,
sound_norm_refs=xtts_model.config.sound_norm_refs,
)
# inference
wav_array, num_of_tokens = inference(input_text=input_text,
language=language,
speaker_id=None,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=float(repetition_penalty))
end = time.time()
processing_time = end - start
tokens_per_second = num_of_tokens/processing_time
logger.info(f"End processing text: {input_text[:30]}")
message = f"💡 {tokens_per_second:.1f} tok/s • {num_of_tokens} tokens • in {processing_time:.2f} seconds"
logger.info(message)
log_messages += message
return (24000, wav_array), log_messages
def inference(input_text, language, speaker_id=None, gpt_cond_latent=None, speaker_embedding=None, temperature=0.3, top_p=0.85, top_k=50, repetition_penalty=10.0):
language_code = lang_detect(input_text) if language == 'Auto' else language_dict.get(language, 'en')
# Split text by sentence
if language_code in ["ja", "zh-cn"]:
sentences = input_text.split("。")
else:
sentences = sent_tokenize(input_text)
# merge short sentences to next/prev ones
sentences = merge_sentences(sentences)
# set dynamic length penalty from -1.0 to 1,0 based on text length
max_text_length = 180
dynamic_length_penalty = lambda text_length: (2 * (min(max_text_length, text_length) / max_text_length)) - 1
if speaker_id is not None:
gpt_cond_latent, speaker_embedding = xtts_model.speaker_manager.speakers[speaker_id].values()
# inference
out_wavs = []
num_of_tokens = 0
for sentence in sentences:
if len(sentence.strip()) == 0:
continue
lang = lang_detect(sentence) if language == 'Auto' else language_code
if lang == 'vi':
sentence = normalize_vietnamese_text(sentence)
text_tokens = torch.IntTensor(xtts_model.tokenizer.encode(sentence, lang=lang)).unsqueeze(0).to(xtts_model.device)
num_of_tokens += text_tokens.shape[-1]
txts = split_sentence(sentence, max_text_length=max_text_length)
for txt in txts:
logger.info(f"[{lang}] {txt}")
try:
out = xtts_model.inference(
text=txt,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
length_penalty=dynamic_length_penalty(len(sentence)),
enable_text_splitting=False,
)
out_wavs.append(out["wav"])
except Exception as e:
logger.error(f"Error processing text: {e}")
return np.concatenate(out_wavs), num_of_tokens
def build_gradio_ui():
"""Builds and launches the Gradio UI."""
default_prompt = ("Hi, I am a multilingual text-to-speech AI model.\n"
"Bonjour, je suis un modèle d'IA de synthèse vocale multilingue.\n"
"Hallo, ich bin ein mehrsprachiges Text-zu-Sprache KI-Modell.\n"
"Ciao, sono un modello di intelligenza artificiale di sintesi vocale multilingue.\n"
"Привет, я многоязычная модель искусственного интеллекта, преобразующая текст в речь.\n"
"Xin chào, tôi là một mô hình AI chuyển đổi văn bản thành giọng nói đa ngôn ngữ.\n")
with gr.Blocks(title="Coqui XTTS Demo", theme='jimmyvu/small_and_pretty') as ui:
gr.Markdown(
"""
# 🐸 Coqui-XTTS Text-to-Speech Demo
Convert text to speech with advanced voice cloning and enhancement.
Support 17 languages, \u2605 **Vietnamese** \u2605 newly added.
"""
)
with gr.Tab("Built-in Voice"):
with gr.Row():
with gr.Column():
input_text = gr.Text(label="Enter Text Here",
placeholder="Write the text you want to synthesize...",
value=default_prompt,
lines=5,
max_length=input_text_max_length)
speaker_id = gr.Dropdown(label="Speaker", choices=[k for k in xtts_model.speaker_manager.speakers.keys()], value=default_speaker_id)
language = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language)
synthesize_button = gr.Button("Generate Speech")
with gr.Column():
audio_output = gr.Audio(label="Generated Audio")
log_output = gr.Text(label="Log Output")
with gr.Tab("Reference Voice"):
with gr.Row():
with gr.Column():
input_text_generate = gr.Text(label="Enter Text Here",
placeholder="Write the text you want to synthesize...",
lines=5,
max_length=input_text_max_length)
speaker_reference_audio = gr.Audio(
label="Speaker reference audio:",
type="filepath",
editable=False,
min_length=3,
max_length=300,
value=default_speaker_reference_audio
)
enhance_speech = gr.Checkbox(label="Enhance Reference Audio", value=False)
language_generate = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language)
generate_button = gr.Button("Generate Speech")
with gr.Column():
audio_output_generate = gr.Audio(label="Generated Audio")
log_output_generate = gr.Text(label="Log Output")
with gr.Tab("Clone Your Voice"):
with gr.Row():
with gr.Column():
input_text_mic = gr.Text(label="Enter Text Here",
placeholder="Write the text you want to synthesize...",
lines=5,
max_length=input_text_max_length)
mic_ref_audio = gr.Audio(label="Record Reference Audio", sources=["microphone"])
enhance_speech_mic = gr.Checkbox(label="Enhance Reference Audio", value=True)
language_mic = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language)
generate_button_mic = gr.Button("Generate Speech")
with gr.Column():
audio_output_mic = gr.Audio(label="Generated Audio")
log_output_mic = gr.Text(label="Log Output")
def process_mic_and_generate(input_text_mic, mic_ref_audio, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic):
if mic_ref_audio:
data = str(time.time()).encode("utf-8")
hash = hashlib.sha1(data).hexdigest()[:10]
output_path = os.path.join(temp_dir, (f"mic_{hash}.wav"))
torch_audio = torch.from_numpy(mic_ref_audio[1].astype(float))
try:
torchaudio.save(output_path, torch_audio.unsqueeze(0), mic_ref_audio[0])
return generate_speech(input_text_mic, output_path, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic)
except Exception as e:
logger.error(f"Error saving audio file: {e}")
return None, f"Error saving audio file: {e}"
else:
return None, "Please record an audio!"
with gr.Tab("Advanced Settings"):
with gr.Row():
with gr.Column():
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.3, step=0.05)
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=50.0, value=9.5, step=1.0)
with gr.Column():
top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.85, step=0.05)
top_k = gr.Slider(label="Top K", minimum=0, maximum=100, value=50, step=5)
synthesize_button.click(
synthesize_speech,
inputs=[input_text, speaker_id, temperature, top_p, top_k, repetition_penalty, language],
outputs=[audio_output, log_output],
)
generate_button.click(
generate_speech,
inputs=[input_text_generate, speaker_reference_audio, enhance_speech, temperature, top_p, top_k, repetition_penalty, language_generate],
outputs=[audio_output_generate, log_output_generate],
)
generate_button_mic.click(
process_mic_and_generate,
inputs=[input_text_mic, mic_ref_audio, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic],
outputs=[audio_output_mic, log_output_mic],
)
return ui
if __name__ == "__main__":
ui = build_gradio_ui()
ui.launch(debug=False)