Spaces:
Running
Running
import os | |
import sys | |
import time | |
import hashlib | |
import site | |
import subprocess | |
import gradio as gr | |
import torch | |
import torchaudio | |
import numpy as np | |
from underthesea import sent_tokenize | |
from df.enhance import enhance, init_df, load_audio, save_audio | |
from huggingface_hub import snapshot_download | |
from langdetect import detect | |
from utils.vietnamese_normalization import normalize_vietnamese_text | |
from utils.logger import setup_logger | |
from utils.sentence import split_sentence, merge_sentences | |
import warnings | |
warnings.filterwarnings("ignore") | |
logger = setup_logger(__file__) | |
df_model, df_state = None, None | |
APP_DIR = os.path.dirname(os.path.abspath(__file__)) | |
checkpoint_dir=f"{APP_DIR}/cache" | |
temp_dir=f"{APP_DIR}/cache/temp/" | |
sample_audio_dir=f"{APP_DIR}/cache/audio_samples/" | |
enhance_audio_dir=f"{APP_DIR}/cache/audio_enhances/" | |
speakers_dir=f"{APP_DIR}/cache/speakers/" | |
for d in [checkpoint_dir, temp_dir, sample_audio_dir, enhance_audio_dir]: | |
os.makedirs(d, exist_ok=True) | |
language_dict = {'English': 'en', 'Español (Spanish)': 'es', 'Français (French)': 'fr', | |
'Deutsch (German)': 'de', 'Italiano (Italian)': 'it', 'Português (Portuguese)': 'pt', | |
'Polski (Polish)': 'pl', 'Türkçe (Turkish)': 'tr', 'Русский (Russian)': 'ru', | |
'Nederlands (Dutch)': 'nl', 'Čeština (Czech)': 'cs', 'العربية (Arabic)': 'ar', '中文 (Chinese)': 'zh-cn', | |
'Magyar nyelv (Hungarian)': 'hu', '한국어 (Korean)': 'ko', '日本語 (Japanese)': 'ja', | |
'Tiếng Việt (Vietnamese)': 'vi', 'Auto': 'auto'} | |
default_language = 'Auto' | |
language_codes = [v for _, v in language_dict.items()] | |
def lang_detect(text): | |
try: | |
lang = detect(text) | |
if lang == 'zh-tw': | |
return 'zh-cn' | |
return lang if lang in language_codes else 'en' | |
except: | |
return 'en' | |
input_text_max_length = 3000 | |
use_deepspeed = False | |
try: | |
import spaces | |
except ImportError: | |
from utils import spaces | |
xtts_model = None | |
def load_model(): | |
global xtts_model | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
repo_id = "jimmyvu/xtts" | |
snapshot_download(repo_id=repo_id, | |
local_dir=checkpoint_dir, | |
allow_patterns=["*.safetensors", "*.wav", "*.json"], | |
ignore_patterns="*.pth") | |
config = XttsConfig() | |
config.load_json(os.path.join(checkpoint_dir, "config.json")) | |
xtts_model = Xtts.init_from_config(config) | |
logger.info("Loading model...") | |
xtts_model.load_safetensors_checkpoint( | |
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed | |
) | |
if torch.cuda.is_available(): | |
xtts_model.cuda() | |
logger.info(f"Successfully loaded model from {checkpoint_dir}") | |
load_model() | |
def download_unidic(): | |
site_package_path = site.getsitepackages()[0] | |
unidic_path = os.path.join(site_package_path, "unidic", "dicdir") | |
if not os.path.exists(unidic_path): | |
logger.info("Downloading unidic...") | |
subprocess.call([sys.executable, "-m", "unidic", "download"]) | |
download_unidic() | |
default_speaker_reference_audio = os.path.join(sample_audio_dir, 'harvard.wav') | |
default_speaker_id = "Aaron Dreschner" | |
def validate_input(input_text, language): | |
log_messages = "" | |
if len(input_text) > input_text_max_length: | |
gr.Warning("Text is too long! Please provide a shorter text.") | |
log_messages += "Text is too long! Please provide a shorter text.\n" | |
return log_messages | |
language_code = language_dict.get(language, 'en') | |
logger.info(f"Language [{language}], code: [{language_code}]") | |
lang = lang_detect(input_text) if language == 'Auto' else language_code | |
if (lang not in ['ja', 'kr', 'zh-cn'] and len(input_text.split()) < 2) or \ | |
(lang in ['ja', 'kr', 'zh-cn'] and len(input_text) < 2): | |
gr.Warning("Text is too short! Please provide a longer text.") | |
log_messages += "Text is too short! Please provide a longer text.\n" | |
return log_messages | |
def synthesize_speech(input_text, speaker_id, temperature=0.3, top_p=0.85, top_k=50, repetition_penalty=10.0, language='Auto'): | |
"""Process text and generate audio.""" | |
global xtts_model | |
log_messages = validate_input(input_text, language) | |
if log_messages: | |
return None, log_messages | |
start = time.time() | |
logger.info(f"Start processing text: {input_text[:30]}... [length: {len(input_text)}]") | |
# inference | |
wav_array, num_of_tokens = inference(input_text=input_text, | |
language=language, | |
speaker_id=speaker_id, | |
gpt_cond_latent=None, | |
speaker_embedding=None, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=float(repetition_penalty)) | |
end = time.time() | |
processing_time = end - start | |
tokens_per_second = num_of_tokens/processing_time | |
logger.info(f"End processing text: {input_text[:30]}") | |
message = f"💡 {tokens_per_second:.1f} tok/s • {num_of_tokens} tokens • in {processing_time:.2f} seconds" | |
logger.info(message) | |
log_messages += message | |
return (24000, wav_array), log_messages | |
def generate_speech(input_text, speaker_reference_audio, enhance_speech, temperature=0.3, top_p=0.85, top_k=50, repetition_penalty=10.0, language='Auto'): | |
"""Process text and generate audio.""" | |
global df_model, df_state, xtts_model | |
log_messages = validate_input(input_text, language) | |
if log_messages: | |
return None, log_messages | |
if not speaker_reference_audio: | |
gr.Warning("Please provide at least one reference audio!") | |
log_messages += "Please provide at least one reference audio!\n" | |
return None, log_messages | |
start = time.time() | |
logger.info(f"Start processing text: {input_text[:30]}... [length: {len(input_text)}]") | |
if enhance_speech: | |
logger.info("Enhancing reference audio...") | |
_, audio_file = os.path.split(speaker_reference_audio) | |
enhanced_audio_path = os.path.join(enhance_audio_dir, f"{audio_file}.enh.wav") | |
if not os.path.exists(enhanced_audio_path): | |
if not df_model: | |
df_model, df_state, _ = init_df() | |
audio, _ = load_audio(speaker_reference_audio, sr=df_state.sr()) | |
# denoise audio | |
enhanced_audio = enhance(df_model, df_state, audio) | |
# save enhanced audio | |
save_audio(enhanced_audio_path, enhanced_audio, sr=df_state.sr()) | |
speaker_reference_audio = enhanced_audio_path | |
gpt_cond_latent, speaker_embedding = xtts_model.get_conditioning_latents( | |
audio_path=speaker_reference_audio, | |
gpt_cond_len=xtts_model.config.gpt_cond_len, | |
max_ref_length=xtts_model.config.max_ref_len, | |
sound_norm_refs=xtts_model.config.sound_norm_refs, | |
) | |
# inference | |
wav_array, num_of_tokens = inference(input_text=input_text, | |
language=language, | |
speaker_id=None, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=float(repetition_penalty)) | |
end = time.time() | |
processing_time = end - start | |
tokens_per_second = num_of_tokens/processing_time | |
logger.info(f"End processing text: {input_text[:30]}") | |
message = f"💡 {tokens_per_second:.1f} tok/s • {num_of_tokens} tokens • in {processing_time:.2f} seconds" | |
logger.info(message) | |
log_messages += message | |
return (24000, wav_array), log_messages | |
def inference(input_text, language, speaker_id=None, gpt_cond_latent=None, speaker_embedding=None, temperature=0.3, top_p=0.85, top_k=50, repetition_penalty=10.0): | |
language_code = lang_detect(input_text) if language == 'Auto' else language_dict.get(language, 'en') | |
# Split text by sentence | |
if language_code in ["ja", "zh-cn"]: | |
sentences = input_text.split("。") | |
else: | |
sentences = sent_tokenize(input_text) | |
# merge short sentences to next/prev ones | |
sentences = merge_sentences(sentences) | |
# set dynamic length penalty from -1.0 to 1,0 based on text length | |
max_text_length = 180 | |
dynamic_length_penalty = lambda text_length: (2 * (min(max_text_length, text_length) / max_text_length)) - 1 | |
if speaker_id is not None: | |
gpt_cond_latent, speaker_embedding = xtts_model.speaker_manager.speakers[speaker_id].values() | |
# inference | |
out_wavs = [] | |
num_of_tokens = 0 | |
for sentence in sentences: | |
if len(sentence.strip()) == 0: | |
continue | |
lang = lang_detect(sentence) if language == 'Auto' else language_code | |
if lang == 'vi': | |
sentence = normalize_vietnamese_text(sentence) | |
text_tokens = torch.IntTensor(xtts_model.tokenizer.encode(sentence, lang=lang)).unsqueeze(0).to(xtts_model.device) | |
num_of_tokens += text_tokens.shape[-1] | |
txts = split_sentence(sentence, max_text_length=max_text_length) | |
for txt in txts: | |
logger.info(f"[{lang}] {txt}") | |
try: | |
out = xtts_model.inference( | |
text=txt, | |
language=lang, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
length_penalty=dynamic_length_penalty(len(sentence)), | |
enable_text_splitting=False, | |
) | |
out_wavs.append(out["wav"]) | |
except Exception as e: | |
logger.error(f"Error processing text: {e}") | |
return np.concatenate(out_wavs), num_of_tokens | |
def build_gradio_ui(): | |
"""Builds and launches the Gradio UI.""" | |
default_prompt = ("Hi, I am a multilingual text-to-speech AI model.\n" | |
"Bonjour, je suis un modèle d'IA de synthèse vocale multilingue.\n" | |
"Hallo, ich bin ein mehrsprachiges Text-zu-Sprache KI-Modell.\n" | |
"Ciao, sono un modello di intelligenza artificiale di sintesi vocale multilingue.\n" | |
"Привет, я многоязычная модель искусственного интеллекта, преобразующая текст в речь.\n" | |
"Xin chào, tôi là một mô hình AI chuyển đổi văn bản thành giọng nói đa ngôn ngữ.\n") | |
with gr.Blocks(title="Coqui XTTS Demo", theme='jimmyvu/small_and_pretty') as ui: | |
gr.Markdown( | |
""" | |
# 🐸 Coqui-XTTS Text-to-Speech Demo | |
Convert text to speech with advanced voice cloning and enhancement. | |
Support 17 languages, \u2605 **Vietnamese** \u2605 newly added. | |
""" | |
) | |
with gr.Tab("Built-in Voice"): | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Text(label="Enter Text Here", | |
placeholder="Write the text you want to synthesize...", | |
value=default_prompt, | |
lines=5, | |
max_length=input_text_max_length) | |
speaker_id = gr.Dropdown(label="Speaker", choices=[k for k in xtts_model.speaker_manager.speakers.keys()], value=default_speaker_id) | |
language = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language) | |
synthesize_button = gr.Button("Generate Speech") | |
with gr.Column(): | |
audio_output = gr.Audio(label="Generated Audio") | |
log_output = gr.Text(label="Log Output") | |
with gr.Tab("Reference Voice"): | |
with gr.Row(): | |
with gr.Column(): | |
input_text_generate = gr.Text(label="Enter Text Here", | |
placeholder="Write the text you want to synthesize...", | |
lines=5, | |
max_length=input_text_max_length) | |
speaker_reference_audio = gr.Audio( | |
label="Speaker reference audio:", | |
type="filepath", | |
editable=False, | |
min_length=3, | |
max_length=300, | |
value=default_speaker_reference_audio | |
) | |
enhance_speech = gr.Checkbox(label="Enhance Reference Audio", value=False) | |
language_generate = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language) | |
generate_button = gr.Button("Generate Speech") | |
with gr.Column(): | |
audio_output_generate = gr.Audio(label="Generated Audio") | |
log_output_generate = gr.Text(label="Log Output") | |
with gr.Tab("Clone Your Voice"): | |
with gr.Row(): | |
with gr.Column(): | |
input_text_mic = gr.Text(label="Enter Text Here", | |
placeholder="Write the text you want to synthesize...", | |
lines=5, | |
max_length=input_text_max_length) | |
mic_ref_audio = gr.Audio(label="Record Reference Audio", sources=["microphone"]) | |
enhance_speech_mic = gr.Checkbox(label="Enhance Reference Audio", value=True) | |
language_mic = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language) | |
generate_button_mic = gr.Button("Generate Speech") | |
with gr.Column(): | |
audio_output_mic = gr.Audio(label="Generated Audio") | |
log_output_mic = gr.Text(label="Log Output") | |
def process_mic_and_generate(input_text_mic, mic_ref_audio, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic): | |
if mic_ref_audio: | |
data = str(time.time()).encode("utf-8") | |
hash = hashlib.sha1(data).hexdigest()[:10] | |
output_path = os.path.join(temp_dir, (f"mic_{hash}.wav")) | |
torch_audio = torch.from_numpy(mic_ref_audio[1].astype(float)) | |
try: | |
torchaudio.save(output_path, torch_audio.unsqueeze(0), mic_ref_audio[0]) | |
return generate_speech(input_text_mic, output_path, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic) | |
except Exception as e: | |
logger.error(f"Error saving audio file: {e}") | |
return None, f"Error saving audio file: {e}" | |
else: | |
return None, "Please record an audio!" | |
with gr.Tab("Advanced Settings"): | |
with gr.Row(): | |
with gr.Column(): | |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.3, step=0.05) | |
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=50.0, value=9.5, step=1.0) | |
with gr.Column(): | |
top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.85, step=0.05) | |
top_k = gr.Slider(label="Top K", minimum=0, maximum=100, value=50, step=5) | |
synthesize_button.click( | |
synthesize_speech, | |
inputs=[input_text, speaker_id, temperature, top_p, top_k, repetition_penalty, language], | |
outputs=[audio_output, log_output], | |
) | |
generate_button.click( | |
generate_speech, | |
inputs=[input_text_generate, speaker_reference_audio, enhance_speech, temperature, top_p, top_k, repetition_penalty, language_generate], | |
outputs=[audio_output_generate, log_output_generate], | |
) | |
generate_button_mic.click( | |
process_mic_and_generate, | |
inputs=[input_text_mic, mic_ref_audio, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic], | |
outputs=[audio_output_mic, log_output_mic], | |
) | |
return ui | |
if __name__ == "__main__": | |
ui = build_gradio_ui() | |
ui.launch(debug=False) | |