Spaces:
Sleeping
Sleeping
Jimmy Vu
commited on
Commit
·
2e99c77
1
Parent(s):
8460d0e
Add files
Browse files- .gitignore +173 -0
- README.md +1 -1
- app.py +0 -7
- gradio_app.py +298 -0
- requirements.txt +12 -0
- utils/__init__.py +0 -0
- utils/cuda_toolkit.py +19 -0
- utils/logger.py +146 -0
- utils/sentence.py +75 -0
- utils/spaces.py +10 -0
- utils/vietnamese_normalization.py +360 -0
.gitignore
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
WadaSNR/
|
2 |
+
.idea/
|
3 |
+
*.pyc
|
4 |
+
.DS_Store
|
5 |
+
./__init__.py
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
.hypothesis/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
.static_storage/
|
61 |
+
.media/
|
62 |
+
local_settings.py
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# pyenv
|
81 |
+
.python-version
|
82 |
+
|
83 |
+
# celery beat schedule file
|
84 |
+
celerybeat-schedule
|
85 |
+
|
86 |
+
# SageMath parsed files
|
87 |
+
*.sage.py
|
88 |
+
|
89 |
+
# Environments
|
90 |
+
.env
|
91 |
+
.venv
|
92 |
+
env/
|
93 |
+
venv/
|
94 |
+
ENV/
|
95 |
+
env.bak/
|
96 |
+
venv.bak/
|
97 |
+
|
98 |
+
# Spyder project settings
|
99 |
+
.spyderproject
|
100 |
+
.spyproject
|
101 |
+
|
102 |
+
# Rope project settings
|
103 |
+
.ropeproject
|
104 |
+
|
105 |
+
# mkdocs documentation
|
106 |
+
/site
|
107 |
+
|
108 |
+
# mypy
|
109 |
+
.mypy_cache/
|
110 |
+
|
111 |
+
# vim
|
112 |
+
*.swp
|
113 |
+
*.swm
|
114 |
+
*.swn
|
115 |
+
*.swo
|
116 |
+
|
117 |
+
# pytorch models
|
118 |
+
*.pth
|
119 |
+
*.pth.tar
|
120 |
+
!dummy_speakers.pth
|
121 |
+
result/
|
122 |
+
|
123 |
+
# setup.py
|
124 |
+
version.py
|
125 |
+
|
126 |
+
# jupyter dummy files
|
127 |
+
core
|
128 |
+
|
129 |
+
# ignore local datasets
|
130 |
+
recipes/WIP/*
|
131 |
+
recipes/ljspeech/LJSpeech-1.1/*
|
132 |
+
recipes/vctk/VCTK/*
|
133 |
+
recipes/**/*.npy
|
134 |
+
recipes/**/*.json
|
135 |
+
VCTK-Corpus-removed-silence/*
|
136 |
+
|
137 |
+
# ignore training logs
|
138 |
+
trainer_*_log.txt
|
139 |
+
|
140 |
+
# files used internally for dev, test etc.
|
141 |
+
tests/outputs/*
|
142 |
+
tests/train_outputs/*
|
143 |
+
TODO.txt
|
144 |
+
.vscode/*
|
145 |
+
data/*
|
146 |
+
notebooks/data/*
|
147 |
+
TTS/tts/utils/monotonic_align/core.c
|
148 |
+
.vscode-upload.json
|
149 |
+
temp_build/*
|
150 |
+
events.out*
|
151 |
+
old_configs/*
|
152 |
+
model_importers/*
|
153 |
+
model_profiling/*
|
154 |
+
docs/source/TODO/*
|
155 |
+
.noseids
|
156 |
+
.dccache
|
157 |
+
log.txt
|
158 |
+
umap.png
|
159 |
+
*.out
|
160 |
+
SocialMedia.txt
|
161 |
+
output.wav
|
162 |
+
tts_output.wav
|
163 |
+
deps.json
|
164 |
+
speakers.json
|
165 |
+
internal/*
|
166 |
+
*_pitch.npy
|
167 |
+
*_phoneme.npy
|
168 |
+
wandb
|
169 |
+
depot/*
|
170 |
+
coqui_recipes/*
|
171 |
+
local_scripts/*
|
172 |
+
coqui_demos/*
|
173 |
+
cache/*
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: green
|
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.15.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mpl-2.0
|
11 |
short_description: Coqui-XTTS Text-to-Speech Demo with Vietnamese
|
|
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.15.0
|
8 |
+
app_file: gradio_app.py
|
9 |
pinned: false
|
10 |
license: mpl-2.0
|
11 |
short_description: Coqui-XTTS Text-to-Speech Demo with Vietnamese
|
app.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
-
|
6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradio_app.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import uuid
|
4 |
+
import hashlib
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from underthesea import sent_tokenize
|
13 |
+
from df.enhance import enhance, init_df, load_audio, save_audio
|
14 |
+
|
15 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
16 |
+
|
17 |
+
from langdetect import detect
|
18 |
+
|
19 |
+
from utils.vietnamese_normalization import normalize_vietnamese_text
|
20 |
+
from utils.logger import setup_logger
|
21 |
+
from utils.sentence import split_sentence, merge_sentences
|
22 |
+
|
23 |
+
import warnings
|
24 |
+
warnings.filterwarnings("ignore")
|
25 |
+
|
26 |
+
logger = setup_logger(__file__)
|
27 |
+
|
28 |
+
df_model, df_state = None, None
|
29 |
+
|
30 |
+
APP_DIR = os.path.dirname(os.path.abspath(__file__))
|
31 |
+
checkpoint_dir=f"{APP_DIR}/cache"
|
32 |
+
temp_dir=f"{APP_DIR}/cache/temp/"
|
33 |
+
sample_audio_dir=f"{APP_DIR}/cache/audio_samples/"
|
34 |
+
enhance_audio_dir=f"{APP_DIR}/cache/audio_enhances/"
|
35 |
+
for d in [checkpoint_dir, temp_dir, sample_audio_dir, enhance_audio_dir]:
|
36 |
+
os.makedirs(d, exist_ok=True)
|
37 |
+
|
38 |
+
language_dict = {'English': 'en', 'Español (Spanish)': 'es', 'Français (French)': 'fr',
|
39 |
+
'Deutsch (German)': 'de', 'Italiano (Italian)': 'it', 'Português (Portuguese)': 'pt',
|
40 |
+
'Polski (Polish)': 'pl', 'Türkçe (Turkish)': 'tr', 'Русский (Russian)': 'ru',
|
41 |
+
'Nederlands (Dutch)': 'nl', 'Čeština (Czech)': 'cs', 'العربية (Arabic)': 'ar', '中文 (Chinese)': 'zh-cn',
|
42 |
+
'Magyar nyelv (Hungarian)': 'hu', '한국어 (Korean)': 'ko', '日本語 (Japanese)': 'ja',
|
43 |
+
'Tiếng Việt (Vietnamese)': 'vi', 'Auto': 'auto'}
|
44 |
+
|
45 |
+
default_language = 'Auto'
|
46 |
+
language_codes = [v for _, v in language_dict.items()]
|
47 |
+
def lang_detect(text):
|
48 |
+
try:
|
49 |
+
lang = detect(text)
|
50 |
+
if lang == 'zh-tw':
|
51 |
+
return 'zh-cn'
|
52 |
+
return lang if lang in language_codes else 'en'
|
53 |
+
except:
|
54 |
+
return 'en'
|
55 |
+
|
56 |
+
input_text_max_length = 3000
|
57 |
+
use_deepspeed = False
|
58 |
+
|
59 |
+
try:
|
60 |
+
import spaces
|
61 |
+
except ImportError:
|
62 |
+
from utils import spaces
|
63 |
+
|
64 |
+
xtts_model = None
|
65 |
+
def load_model():
|
66 |
+
global xtts_model
|
67 |
+
|
68 |
+
from TTS.tts.configs.xtts_config import XttsConfig
|
69 |
+
from TTS.tts.models.xtts import Xtts
|
70 |
+
repo_id = "jimmyvu/xtts"
|
71 |
+
snapshot_download(repo_id=repo_id,
|
72 |
+
local_dir=checkpoint_dir,
|
73 |
+
allow_patterns=["*.safetensors", "*.wav", "*.json"],
|
74 |
+
ignore_patterns="*.pth")
|
75 |
+
|
76 |
+
config = XttsConfig()
|
77 |
+
config.load_json(os.path.join(checkpoint_dir, "config.json"))
|
78 |
+
xtts_model = Xtts.init_from_config(config)
|
79 |
+
|
80 |
+
logger.info("Loading model...")
|
81 |
+
xtts_model.load_safetensors_checkpoint(
|
82 |
+
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
|
83 |
+
)
|
84 |
+
if torch.cuda.is_available():
|
85 |
+
xtts_model.cuda()
|
86 |
+
logger.info(f"Successfully loaded model from {checkpoint_dir}")
|
87 |
+
|
88 |
+
load_model()
|
89 |
+
|
90 |
+
default_speaker_reference_audio = os.path.join(sample_audio_dir, 'harvard.wav')
|
91 |
+
|
92 |
+
@spaces.GPU
|
93 |
+
def generate_speech(input_text, speaker_reference_audio, enhance_speech, temperature=0.3, top_p=0.85, top_k=50, repetition_penalty=10.0, language='Auto', *args):
|
94 |
+
"""Process text and generate audio."""
|
95 |
+
global df_model, df_state, xtts_model
|
96 |
+
log_messages = ""
|
97 |
+
if len(input_text) > input_text_max_length:
|
98 |
+
gr.Warning("Text is too long! Please provide a shorter text.")
|
99 |
+
log_messages += "Text is too long! Please provide a shorter text.\n"
|
100 |
+
return None, log_messages
|
101 |
+
|
102 |
+
language_code = language_dict.get(language, 'en')
|
103 |
+
logger.info(f"Language [{language}], code: [{language_code}]")
|
104 |
+
lang = lang_detect(input_text) if language_code == 'auto' else language_code
|
105 |
+
if (lang not in ['ja', 'kr', 'zh-cn'] and len(input_text.split()) < 2) or \
|
106 |
+
(lang in ['ja', 'kr', 'zh-cn'] and len(input_text) < 2):
|
107 |
+
gr.Warning("Text is too short! Please provide a longer text.")
|
108 |
+
log_messages += "Text is too short! Please provide a longer text.\n"
|
109 |
+
return None, log_messages
|
110 |
+
|
111 |
+
if not speaker_reference_audio:
|
112 |
+
gr.Warning("Please provide at least one reference audio!")
|
113 |
+
log_messages += "Please provide at least one reference audio!\n"
|
114 |
+
return None, log_messages
|
115 |
+
|
116 |
+
start = time.time()
|
117 |
+
logger.info(f"Start processing text: {input_text[:30]}... [length: {len(input_text)}]")
|
118 |
+
|
119 |
+
if enhance_speech:
|
120 |
+
logger.info("Enhancing reference audio...")
|
121 |
+
_, audio_file = os.path.split(speaker_reference_audio)
|
122 |
+
enhanced_audio_path = os.path.join(enhance_audio_dir, f"{audio_file}.enh.wav")
|
123 |
+
if not os.path.exists(enhanced_audio_path):
|
124 |
+
if not df_model:
|
125 |
+
df_model, df_state, _ = init_df()
|
126 |
+
audio, _ = load_audio(speaker_reference_audio, sr=df_state.sr())
|
127 |
+
# denoise audio
|
128 |
+
enhanced_audio = enhance(df_model, df_state, audio)
|
129 |
+
# save enhanced audio
|
130 |
+
save_audio(enhanced_audio_path, enhanced_audio, sr=df_state.sr())
|
131 |
+
speaker_reference_audio = enhanced_audio_path
|
132 |
+
|
133 |
+
gpt_cond_latent, speaker_embedding = xtts_model.get_conditioning_latents(
|
134 |
+
audio_path=speaker_reference_audio,
|
135 |
+
gpt_cond_len=xtts_model.config.gpt_cond_len,
|
136 |
+
max_ref_length=xtts_model.config.max_ref_len,
|
137 |
+
sound_norm_refs=xtts_model.config.sound_norm_refs,
|
138 |
+
)
|
139 |
+
|
140 |
+
# Split text by sentence
|
141 |
+
if lang in ["ja", "zh-cn"]:
|
142 |
+
sentences = input_text.split("。")
|
143 |
+
else:
|
144 |
+
sentences = sent_tokenize(input_text)
|
145 |
+
# merge short sentences to next/prev ones
|
146 |
+
sentences = merge_sentences(sentences)
|
147 |
+
# inference
|
148 |
+
wav_array = inference(sentences, language_code, gpt_cond_latent, speaker_embedding, temperature, top_p, top_k, repetition_penalty)
|
149 |
+
end = time.time()
|
150 |
+
logger.info(f"End processing text: {input_text[:30]}... Processing time: {end - start:.2f}s")
|
151 |
+
log_messages += f"Processing time: {end - start:.2f}s"
|
152 |
+
return (24000, wav_array), log_messages
|
153 |
+
|
154 |
+
|
155 |
+
def inference(sentences, language_code, gpt_cond_latent, speaker_embedding, temperature, top_p, top_k, repetition_penalty):
|
156 |
+
# set dynamic length penalty from -1.0 to 1,0 based on text length
|
157 |
+
max_text_length = 180
|
158 |
+
dynamic_length_penalty = lambda text_length: (2 * (min(max_text_length, text_length) / max_text_length)) - 1
|
159 |
+
# inference
|
160 |
+
out_wavs = []
|
161 |
+
for sentence in sentences:
|
162 |
+
if len(sentence.strip()) == 0:
|
163 |
+
continue
|
164 |
+
lang = lang_detect(sentence) if language_code == 'auto' else language_code
|
165 |
+
if lang == 'vi':
|
166 |
+
sentence = normalize_vietnamese_text(sentence)
|
167 |
+
# split too long sentence
|
168 |
+
texts = split_sentence(sentence) if len(sentence) > max_text_length else [sentence]
|
169 |
+
for text in texts:
|
170 |
+
logger.info(f"[{lang}] {text}")
|
171 |
+
try:
|
172 |
+
out = xtts_model.inference(
|
173 |
+
text=text,
|
174 |
+
language=lang,
|
175 |
+
gpt_cond_latent=gpt_cond_latent,
|
176 |
+
speaker_embedding=speaker_embedding,
|
177 |
+
temperature=temperature,
|
178 |
+
top_p=top_p,
|
179 |
+
top_k=top_k,
|
180 |
+
repetition_penalty=repetition_penalty,
|
181 |
+
length_penalty=dynamic_length_penalty(len(text)),
|
182 |
+
enable_text_splitting=True,
|
183 |
+
)
|
184 |
+
out_wavs.append(out["wav"])
|
185 |
+
except Exception as e:
|
186 |
+
logger.error(f"Error processing text: {text} - {e}")
|
187 |
+
|
188 |
+
return np.concatenate(out_wavs)
|
189 |
+
|
190 |
+
def build_gradio_ui():
|
191 |
+
"""Builds and launches the Gradio UI."""
|
192 |
+
theme=gr.Theme.from_hub('JohnSmith9982/small_and_pretty')
|
193 |
+
setattr(theme, 'button_secondary_background_fill', '#fcd53f')
|
194 |
+
setattr(theme, 'checkbox_border_color', '#02c160')
|
195 |
+
setattr(theme, 'input-border-width', '1px')
|
196 |
+
setattr(theme, 'input-background-fill', '#ffffff')
|
197 |
+
setattr(theme, 'input-background-fill_focus', '#ffffff')
|
198 |
+
setattr(theme, 'input-border-color', '#d1d5db')
|
199 |
+
setattr(theme, 'input-border-color_focus', '#fcd53f')
|
200 |
+
|
201 |
+
default_prompt = ("Hi, I am a multilingual text-to-speech AI model.\n"
|
202 |
+
"Bonjour, je suis un modèle d'IA de synthèse vocale multilingue.\n"
|
203 |
+
"Hallo, ich bin ein mehrsprachiges Text-zu-Sprache KI-Modell.\n"
|
204 |
+
"Ciao, sono un modello di intelligenza artificiale di sintesi vocale multilingue.\n"
|
205 |
+
"Привет, я многоязычная модель искусственного интеллекта, преобразующая текст в речь.\n"
|
206 |
+
"Xin chào, tôi là một mô hình AI chuyển đổi văn bản thành giọng nói đa ngôn ngữ.\n")
|
207 |
+
|
208 |
+
with gr.Blocks(title="Coqui XTTS Demo", theme=theme) as ui:
|
209 |
+
gr.Markdown(
|
210 |
+
"""
|
211 |
+
# 🐸 Coqui-XTTS Text-to-Speech Demo
|
212 |
+
Convert text to speech with advanced voice cloning and enhancement.
|
213 |
+
Support 17 languages, \u2605 **Vietnamese** \u2605 newly added.
|
214 |
+
"""
|
215 |
+
)
|
216 |
+
|
217 |
+
with gr.Tab("Text to Speech"):
|
218 |
+
with gr.Row():
|
219 |
+
with gr.Column():
|
220 |
+
input_text = gr.Text(label="Enter Text Here",
|
221 |
+
placeholder="Write the text you want to convert...",
|
222 |
+
value=default_prompt,
|
223 |
+
lines=5,
|
224 |
+
max_length=input_text_max_length)
|
225 |
+
speaker_reference_audio = gr.Audio(
|
226 |
+
label="Speaker reference audio:",
|
227 |
+
type="filepath",
|
228 |
+
editable=False,
|
229 |
+
min_length=3,
|
230 |
+
max_length=300,
|
231 |
+
value=default_speaker_reference_audio
|
232 |
+
)
|
233 |
+
enhance_speech = gr.Checkbox(label="Enhance Reference Audio", value=False)
|
234 |
+
language = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language)
|
235 |
+
generate_button = gr.Button("Generate Speech")
|
236 |
+
with gr.Column():
|
237 |
+
audio_output = gr.Audio(label="Generated Audio")
|
238 |
+
log_output = gr.Text(label="Log Output")
|
239 |
+
|
240 |
+
with gr.Tab("Clone Your Voice"):
|
241 |
+
with gr.Row():
|
242 |
+
with gr.Column():
|
243 |
+
input_text_mic = gr.Text(label="Enter Text Here",
|
244 |
+
placeholder="Write the text you want to convert...",
|
245 |
+
lines=5,
|
246 |
+
max_length=input_text_max_length)
|
247 |
+
mic_ref_audio = gr.Audio(label="Record Reference Audio", sources=["microphone"])
|
248 |
+
enhance_speech_mic = gr.Checkbox(label="Enhance Reference Audio", value=True)
|
249 |
+
language_mic = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language)
|
250 |
+
generate_button_mic = gr.Button("Generate Speech")
|
251 |
+
with gr.Column():
|
252 |
+
audio_output_mic = gr.Audio(label="Generated Audio")
|
253 |
+
log_output_mic = gr.Text(label="Log Output")
|
254 |
+
|
255 |
+
|
256 |
+
def process_mic_and_generate(input_text_mic, mic_ref_audio, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic):
|
257 |
+
if mic_ref_audio:
|
258 |
+
data = str(time.time()).encode("utf-8")
|
259 |
+
hash = hashlib.sha1(data).hexdigest()[:10]
|
260 |
+
output_path = os.path.join(temp_dir, (f"mic_{hash}.wav"))
|
261 |
+
|
262 |
+
torch_audio = torch.from_numpy(mic_ref_audio[1].astype(float))
|
263 |
+
try:
|
264 |
+
torchaudio.save(output_path, torch_audio.unsqueeze(0), mic_ref_audio[0])
|
265 |
+
return generate_speech(input_text_mic, output_path, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic)
|
266 |
+
except Exception as e:
|
267 |
+
logger.error(f"Error saving audio file: {e}")
|
268 |
+
return None, f"Error saving audio file: {e}"
|
269 |
+
else:
|
270 |
+
return None, "Please record an audio!"
|
271 |
+
|
272 |
+
with gr.Tab("Advanced Settings"):
|
273 |
+
with gr.Row():
|
274 |
+
with gr.Column():
|
275 |
+
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.3, step=0.05)
|
276 |
+
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=50.0, value=9.5, step=1.0)
|
277 |
+
|
278 |
+
with gr.Column():
|
279 |
+
top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.85, step=0.05)
|
280 |
+
top_k = gr.Slider(label="Top K", minimum=0, maximum=100, value=50, step=5)
|
281 |
+
|
282 |
+
generate_button.click(
|
283 |
+
generate_speech,
|
284 |
+
inputs=[input_text, speaker_reference_audio, enhance_speech, temperature, top_p, top_k, repetition_penalty, language],
|
285 |
+
outputs=[audio_output, log_output],
|
286 |
+
)
|
287 |
+
|
288 |
+
generate_button_mic.click(
|
289 |
+
process_mic_and_generate,
|
290 |
+
inputs=[input_text_mic, mic_ref_audio, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic],
|
291 |
+
outputs=[audio_output_mic, log_output_mic],
|
292 |
+
)
|
293 |
+
|
294 |
+
return ui
|
295 |
+
|
296 |
+
if __name__ == "__main__":
|
297 |
+
ui = build_gradio_ui()
|
298 |
+
ui.launch(debug=False)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.44.1
|
2 |
+
deepfilternet==0.5.6
|
3 |
+
underthesea==6.8.0
|
4 |
+
deepspeed
|
5 |
+
colorama
|
6 |
+
pyvi
|
7 |
+
langdetect
|
8 |
+
cutlet
|
9 |
+
unidic
|
10 |
+
# for Japanese
|
11 |
+
# python -m unidic download
|
12 |
+
git+https://github.com/quangvu3/coqui-xtts.git
|
utils/__init__.py
ADDED
File without changes
|
utils/cuda_toolkit.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import os
|
3 |
+
|
4 |
+
def install_cuda_toolkit():
|
5 |
+
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
|
6 |
+
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
|
7 |
+
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
|
8 |
+
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
|
9 |
+
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
|
10 |
+
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
|
11 |
+
|
12 |
+
os.environ["CUDA_HOME"] = "/usr/local/cuda"
|
13 |
+
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
|
14 |
+
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
|
15 |
+
os.environ["CUDA_HOME"],
|
16 |
+
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
|
17 |
+
)
|
18 |
+
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
|
19 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
|
utils/logger.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
from datetime import datetime
|
5 |
+
import colorama
|
6 |
+
from colorama import Fore, Back, Style
|
7 |
+
from typing import Optional, Union
|
8 |
+
import re
|
9 |
+
import traceback
|
10 |
+
import copy
|
11 |
+
import os
|
12 |
+
|
13 |
+
# Initialize colorama
|
14 |
+
colorama.init()
|
15 |
+
|
16 |
+
class ColoredFormatter(logging.Formatter):
|
17 |
+
"""Colored formatter for structured log output.
|
18 |
+
|
19 |
+
This formatter adds color-coding, icons, timestamps, and file location
|
20 |
+
information to log messages. It supports different color schemes for
|
21 |
+
different log levels and includes special formatting for exceptions.
|
22 |
+
|
23 |
+
Attributes:
|
24 |
+
COLORS (dict): Color schemes for different log levels, including:
|
25 |
+
- color: Foreground color
|
26 |
+
- style: Text style (dim, normal, bright)
|
27 |
+
- icon: Emoji icon for the log level
|
28 |
+
- bg: Background color (for critical logs)
|
29 |
+
"""
|
30 |
+
|
31 |
+
COLORS = {
|
32 |
+
'DEBUG': {
|
33 |
+
'color': Fore.CYAN,
|
34 |
+
'style': Style.DIM,
|
35 |
+
'icon': '🔍'
|
36 |
+
},
|
37 |
+
'INFO': {
|
38 |
+
'color': Fore.GREEN,
|
39 |
+
'style': Style.NORMAL,
|
40 |
+
'icon': 'ℹ️'
|
41 |
+
},
|
42 |
+
'WARNING': {
|
43 |
+
'color': Fore.YELLOW,
|
44 |
+
'style': Style.BRIGHT,
|
45 |
+
'icon': '⚠️'
|
46 |
+
},
|
47 |
+
'ERROR': {
|
48 |
+
'color': Fore.RED,
|
49 |
+
'style': Style.BRIGHT,
|
50 |
+
'icon': '❌'
|
51 |
+
},
|
52 |
+
'CRITICAL': {
|
53 |
+
'color': Fore.WHITE,
|
54 |
+
'style': Style.BRIGHT,
|
55 |
+
'bg': Back.RED,
|
56 |
+
'icon': '💀'
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
+
def format(self, record: logging.LogRecord) -> str:
|
61 |
+
"""Format a log record with color and structure.
|
62 |
+
|
63 |
+
This method formats log records with:
|
64 |
+
- Timestamp in HH:MM:SS.mmm format
|
65 |
+
- File location (filename:line)
|
66 |
+
- Color-coded level name with icon
|
67 |
+
- Color-coded message
|
68 |
+
- Formatted exception traceback if present
|
69 |
+
|
70 |
+
Args:
|
71 |
+
record (logging.LogRecord): Log record to format.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
str: Formatted log message with color and structure.
|
75 |
+
"""
|
76 |
+
colored_record = copy.copy(record)
|
77 |
+
|
78 |
+
# Get color scheme
|
79 |
+
scheme = self.COLORS.get(record.levelname, {
|
80 |
+
'color': Fore.WHITE,
|
81 |
+
'style': Style.NORMAL,
|
82 |
+
'icon': '•'
|
83 |
+
})
|
84 |
+
|
85 |
+
# Format timestamp
|
86 |
+
timestamp = datetime.fromtimestamp(record.created).strftime('%H:%M:%S.%f')[:-3]
|
87 |
+
|
88 |
+
# Get file location
|
89 |
+
file_location = f"{os.path.basename(record.pathname)}:{record.lineno}"
|
90 |
+
|
91 |
+
# Build components
|
92 |
+
components = []
|
93 |
+
|
94 |
+
# log formatting
|
95 |
+
components.extend([
|
96 |
+
f"{Fore.BLUE}{timestamp}{Style.RESET_ALL}",
|
97 |
+
f"{Fore.WHITE}{Style.DIM}{file_location}{Style.RESET_ALL}",
|
98 |
+
f"{scheme['color']}{scheme['style']}{scheme['icon']} {record.levelname:8}{Style.RESET_ALL}",
|
99 |
+
f"{scheme['color']}{record.msg}{Style.RESET_ALL}"
|
100 |
+
])
|
101 |
+
|
102 |
+
# Add exception info
|
103 |
+
if record.exc_info:
|
104 |
+
components.append(
|
105 |
+
f"\n{Fore.RED}{Style.BRIGHT}"
|
106 |
+
f"{''.join(traceback.format_exception(*record.exc_info))}"
|
107 |
+
f"{Style.RESET_ALL}"
|
108 |
+
)
|
109 |
+
|
110 |
+
return " | ".join(components)
|
111 |
+
|
112 |
+
|
113 |
+
def setup_logger(
|
114 |
+
name: Optional[Union[str, Path]] = None,
|
115 |
+
level: int = logging.INFO
|
116 |
+
) -> logging.Logger:
|
117 |
+
"""Set up a colored logger
|
118 |
+
|
119 |
+
This function creates or retrieves a logger with colored output and
|
120 |
+
automatic log interception. If a file path is provided as the name,
|
121 |
+
it will use the filename (without extension) as the logger name.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
name (Optional[Union[str, Path]], optional): Logger name or __file__ for
|
125 |
+
module name. Defaults to None.
|
126 |
+
level (int, optional): Logging level. Defaults to logging.INFO.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
logging.Logger: Configured logger instance.
|
130 |
+
"""
|
131 |
+
# Get logger name from file path
|
132 |
+
if isinstance(name, (str, Path)) and Path(name).suffix == '.py':
|
133 |
+
name = Path(name).stem
|
134 |
+
|
135 |
+
# Get or create logger
|
136 |
+
logger = logging.getLogger(name)
|
137 |
+
logger.setLevel(level)
|
138 |
+
|
139 |
+
# Only add handler if none exists
|
140 |
+
if not logger.handlers:
|
141 |
+
# Create console handler
|
142 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
143 |
+
console_handler.setFormatter(ColoredFormatter())
|
144 |
+
logger.addHandler(console_handler)
|
145 |
+
|
146 |
+
return logger
|
utils/sentence.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def split_sentence(sentence, delimiters=",;-!?"):
|
3 |
+
"""
|
4 |
+
Splits a sentence into two halves, prioritizing the delimiter closest to the middle.
|
5 |
+
If no delimiter is found, it ensures words are not split in the middle.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
sentence (str): The input sentence to split.
|
9 |
+
delimiters (str): A string of delimiters to prioritize for splitting (default: ",;!?").
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
tuple: A tuple containing the two halves of the sentence.
|
13 |
+
"""
|
14 |
+
# Find all delimiter indices in the sentence
|
15 |
+
delimiter_indices = [i for i, char in enumerate(sentence) if char in delimiters]
|
16 |
+
|
17 |
+
if delimiter_indices:
|
18 |
+
# Calculate the midpoint of the sentence
|
19 |
+
midpoint = len(sentence) // 2
|
20 |
+
|
21 |
+
# Find the delimiter closest to the midpoint
|
22 |
+
closest_delimiter = min(delimiter_indices, key=lambda x: abs(x - midpoint))
|
23 |
+
|
24 |
+
# Split at the closest delimiter
|
25 |
+
first_half = sentence[:closest_delimiter].strip()
|
26 |
+
second_half = sentence[closest_delimiter + 1:].strip()
|
27 |
+
else:
|
28 |
+
# If no delimiter, split at the nearest space (word boundary)
|
29 |
+
midpoint = len(sentence) // 2
|
30 |
+
|
31 |
+
# Find the nearest space (word boundary) around the midpoint
|
32 |
+
left_space = sentence.rfind(" ", 0, midpoint)
|
33 |
+
right_space = sentence.find(" ", midpoint)
|
34 |
+
|
35 |
+
# Choose the closest space to the midpoint
|
36 |
+
if left_space == -1 and right_space == -1:
|
37 |
+
# No spaces found (single word), split at midpoint
|
38 |
+
split_index = midpoint
|
39 |
+
elif left_space == -1:
|
40 |
+
# Only right space found
|
41 |
+
split_index = right_space
|
42 |
+
elif right_space == -1:
|
43 |
+
# Only left space found
|
44 |
+
split_index = left_space
|
45 |
+
else:
|
46 |
+
# Choose the closest space to the midpoint
|
47 |
+
split_index = left_space if (midpoint - left_space) <= (right_space - midpoint) else right_space
|
48 |
+
|
49 |
+
# Split the sentence into two parts
|
50 |
+
first_half = sentence[:split_index].strip()
|
51 |
+
second_half = sentence[split_index:].strip()
|
52 |
+
|
53 |
+
return first_half, second_half
|
54 |
+
|
55 |
+
|
56 |
+
def merge_sentences(sentences):
|
57 |
+
""" handling short sentences by merging them to next/prev ones """
|
58 |
+
merged_sentences = []
|
59 |
+
i = 0
|
60 |
+
while i < len(sentences):
|
61 |
+
s = sentences[i]
|
62 |
+
word_count = len(s.split())
|
63 |
+
j = 1
|
64 |
+
# merge the short sentence to the next one until long enough
|
65 |
+
while word_count <= 6 and i+j < len(sentences):
|
66 |
+
s += ' ' + sentences[i+j]
|
67 |
+
word_count = len(s.split())
|
68 |
+
j += 1
|
69 |
+
merged_sentences.append(s)
|
70 |
+
i += j
|
71 |
+
# merge the last one to the prev one until long enough
|
72 |
+
while len(merged_sentences) > 1 and len(merged_sentences[len(merged_sentences) - 1].split()) < 6:
|
73 |
+
merged_sentences[len(merged_sentences) - 2] += ' ' + merged_sentences[len(merged_sentences) - 1]
|
74 |
+
merged_sentences.pop()
|
75 |
+
return merged_sentences
|
utils/spaces.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
def GPU(func):
|
4 |
+
"""Decorator to run a function on the fake GPU
|
5 |
+
to get comparable with HF Space"""
|
6 |
+
@functools.wraps(func) # Preserves original function's metadata
|
7 |
+
def wrapper(*args, **kwargs):
|
8 |
+
result = func(*args, **kwargs)
|
9 |
+
return result
|
10 |
+
return wrapper
|
utils/vietnamese_normalization.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from underthesea import text_normalize
|
3 |
+
|
4 |
+
# Dictionary to map numbers to Vietnamese words
|
5 |
+
number_to_words = {
|
6 |
+
0: 'không',
|
7 |
+
1: 'một',
|
8 |
+
2: 'hai',
|
9 |
+
3: 'ba',
|
10 |
+
4: 'bốn',
|
11 |
+
5: 'năm',
|
12 |
+
6: 'sáu',
|
13 |
+
7: 'bảy',
|
14 |
+
8: 'tám',
|
15 |
+
9: 'chín',
|
16 |
+
10: 'mười',
|
17 |
+
100: 'trăm',
|
18 |
+
1000: 'nghìn',
|
19 |
+
1000000: 'triệu',
|
20 |
+
1000000000: 'tỷ'
|
21 |
+
}
|
22 |
+
|
23 |
+
# Dictionary to map Roman numerals to integers
|
24 |
+
roman_to_int = {
|
25 |
+
'I': 1,
|
26 |
+
'V': 5,
|
27 |
+
'X': 10,
|
28 |
+
'L': 50,
|
29 |
+
'C': 100,
|
30 |
+
'D': 500,
|
31 |
+
'M': 1000
|
32 |
+
}
|
33 |
+
|
34 |
+
# Function to convert Roman numerals to integers
|
35 |
+
def roman_to_integer(roman):
|
36 |
+
total = 0
|
37 |
+
prev_value = 0
|
38 |
+
for char in reversed(roman):
|
39 |
+
value = roman_to_int.get(char, 0)
|
40 |
+
if value < prev_value:
|
41 |
+
total -= value
|
42 |
+
else:
|
43 |
+
total += value
|
44 |
+
prev_value = value
|
45 |
+
return total
|
46 |
+
|
47 |
+
currency_symbols ={
|
48 |
+
'~': '~ ',
|
49 |
+
'%': 'phần trăm',
|
50 |
+
'$': 'đô la',
|
51 |
+
'₫': 'đồng',
|
52 |
+
'đ': 'đồng',
|
53 |
+
'€': 'ơ rô',
|
54 |
+
'£': 'bảng',
|
55 |
+
'¥': 'yên',
|
56 |
+
'₹': 'ru pi',
|
57 |
+
'₽': 'rúp',
|
58 |
+
'₺': 'li ra',
|
59 |
+
'₩': 'uôn',
|
60 |
+
}
|
61 |
+
|
62 |
+
def currency_symbol_to_word(currency_sign):
|
63 |
+
if currency_sign in currency_symbols:
|
64 |
+
return currency_symbols[currency_sign]
|
65 |
+
return currency_sign
|
66 |
+
|
67 |
+
def detect_number_format(number_str):
|
68 |
+
# Check if the number contains a comma and a dot
|
69 |
+
if ',' in number_str and '.' in number_str:
|
70 |
+
# If the last comma is after the last dot, it's Vietnamese
|
71 |
+
if number_str.rfind(',') > number_str.rfind('.'):
|
72 |
+
# Validate Vietnamese format
|
73 |
+
if re.match(r'^\d{1,3}(?:\.\d{3})*(?:,\d+)?$', number_str):
|
74 |
+
return "Vietnamese"
|
75 |
+
else:
|
76 |
+
return "Invalid"
|
77 |
+
# Otherwise, it's US
|
78 |
+
else:
|
79 |
+
# Validate US format
|
80 |
+
if re.match(r'^\d{1,3}(?:,\d{3})*(?:\.\d+)?$', number_str):
|
81 |
+
return "US"
|
82 |
+
else:
|
83 |
+
return "Invalid"
|
84 |
+
# If only commas are present
|
85 |
+
elif ',' in number_str:
|
86 |
+
if re.match(r'^\d{1,3}(?:,\d{3})*(?:\.\d+)?$', number_str):
|
87 |
+
return "US"
|
88 |
+
elif re.match(r'^(\d+,\d+)?$', number_str):
|
89 |
+
return "Vietnamese"
|
90 |
+
else:
|
91 |
+
return "Invalid"
|
92 |
+
# If only dots are present
|
93 |
+
elif '.' in number_str:
|
94 |
+
if re.match(r'^\d{1,3}(?:\.\d{3})*(?:,\d+)?$', number_str):
|
95 |
+
return "Vietnamese"
|
96 |
+
elif re.match(r'^(\d+\.\d+)?$', number_str):
|
97 |
+
return "US"
|
98 |
+
else:
|
99 |
+
return "Invalid"
|
100 |
+
# If no separators are present, assume Vietnamese (default)
|
101 |
+
else:
|
102 |
+
return "Vietnamese"
|
103 |
+
|
104 |
+
# Function to convert numbers to Vietnamese words
|
105 |
+
def number_to_vietnamese_words(number_str):
|
106 |
+
number_str = str(number_str)
|
107 |
+
if detect_number_format(number_str) == 'Invalid':
|
108 |
+
return number_str
|
109 |
+
|
110 |
+
if detect_number_format(number_str) == 'US': # convert US number to Vietnamese one: 1,234.5 to 1234,5
|
111 |
+
number = re.sub(r'\.', ',', re.sub(r',', '', number_str))
|
112 |
+
else: # remove any dot inside number
|
113 |
+
number = re.sub(r'\.', '', number_str)
|
114 |
+
|
115 |
+
if isinstance(number, str) and ',' in number:
|
116 |
+
# Handle decimal numbers (e.g., "120,57")
|
117 |
+
integer_part, decimal_part = number.split(',')
|
118 |
+
integer_words = _convert_integer_part(int(integer_part))
|
119 |
+
decimal_words = _convert_decimal_part(decimal_part)
|
120 |
+
return f"{integer_words} phẩy {decimal_words}"
|
121 |
+
else:
|
122 |
+
# Handle integer numbers
|
123 |
+
return _convert_integer_part(int(number))
|
124 |
+
|
125 |
+
# Helper function to convert the integer part of a number
|
126 |
+
def _convert_integer_part(number):
|
127 |
+
if number == 0:
|
128 |
+
return number_to_words[0]
|
129 |
+
|
130 |
+
words = []
|
131 |
+
|
132 |
+
# Handle billions
|
133 |
+
if number >= 1000000000:
|
134 |
+
billion = number // 1000000000
|
135 |
+
words.append(_convert_integer_part(billion))
|
136 |
+
words.append(number_to_words[1000000000])
|
137 |
+
number %= 1000000000
|
138 |
+
|
139 |
+
# Handle millions
|
140 |
+
if number >= 1000000:
|
141 |
+
million = number // 1000000
|
142 |
+
words.append(_convert_integer_part(million))
|
143 |
+
words.append(number_to_words[1000000])
|
144 |
+
number %= 1000000
|
145 |
+
|
146 |
+
# Handle thousands
|
147 |
+
if number >= 1000:
|
148 |
+
thousand = number // 1000
|
149 |
+
words.append(_convert_integer_part(thousand))
|
150 |
+
words.append(number_to_words[1000])
|
151 |
+
number %= 1000
|
152 |
+
if number < 100 and number > 0:
|
153 |
+
words.append('không trăm')
|
154 |
+
if number < 10 and number > 0:
|
155 |
+
words.append('không')
|
156 |
+
|
157 |
+
# Handle hundreds
|
158 |
+
if number >= 100:
|
159 |
+
hundred = number // 100
|
160 |
+
words.append(number_to_words[hundred])
|
161 |
+
words.append(number_to_words[100])
|
162 |
+
number %= 100
|
163 |
+
if number > 0 and number < 10:
|
164 |
+
words.append('lẻ') # Add "lẻ" for numbers like 106 (một trăm lẻ sáu)
|
165 |
+
|
166 |
+
# Handle tens and units
|
167 |
+
if number >= 20:
|
168 |
+
ten = number // 10
|
169 |
+
words.append(number_to_words[ten])
|
170 |
+
words.append('mươi')
|
171 |
+
number %= 10
|
172 |
+
elif number >= 10:
|
173 |
+
words.append(number_to_words[10])
|
174 |
+
number %= 10
|
175 |
+
|
176 |
+
# Handle units (1-9)
|
177 |
+
if number > 0:
|
178 |
+
if number == 5 and len(words) > 1 and not words[-1] in['lẻ', 'không']: w = 'lăm'
|
179 |
+
elif number == 1 and len(words) > 1 and not words[-1] in ['lẻ', 'mười', 'không']: w = 'mốt'
|
180 |
+
else: w = number_to_words[number]
|
181 |
+
words.append(w)
|
182 |
+
|
183 |
+
return ' '.join(words)
|
184 |
+
|
185 |
+
|
186 |
+
# Helper function to convert the decimal part of a number
|
187 |
+
def _convert_decimal_part(decimal_part):
|
188 |
+
words = []
|
189 |
+
for digit in decimal_part:
|
190 |
+
words.append(number_to_words[int(digit)])
|
191 |
+
return ' '.join(words)
|
192 |
+
|
193 |
+
# abbreviation replacement
|
194 |
+
abbreviation_map = {
|
195 |
+
"AI": "Ây Ai",
|
196 |
+
"ASEAN": "A Xê An",
|
197 |
+
"ATGT": "An toàn giao thông",
|
198 |
+
"BCA": "Bộ Công an",
|
199 |
+
"BCH": "Ban chấp hành",
|
200 |
+
"BCHTW": "Ban Chấp hành Trung ương",
|
201 |
+
"BCT": "Bộ Chính trị",
|
202 |
+
"BGD": "Bộ Giáo dục",
|
203 |
+
"BKH": "Bộ Khoa học và Công nghệ",
|
204 |
+
"BNN": "Bộ Nông nghiệp",
|
205 |
+
"BQP": "Bộ Quốc phòng",
|
206 |
+
"BTC": "Ban tổ chức",
|
207 |
+
"BTL": "Bộ Tư lệnh",
|
208 |
+
"BYT": "Bộ Y tế",
|
209 |
+
"CA" : "công an",
|
210 |
+
"CAND" : "Công an nhân dân",
|
211 |
+
"CNCS": "chủ nghĩa cộng sản",
|
212 |
+
"CNTB": "chủ nghĩa tư bản",
|
213 |
+
"CNXH": "chủ nghĩa xã hội",
|
214 |
+
"CNY": "nhân dân tệ",
|
215 |
+
"CSGT": "Cảnh sát giao thông",
|
216 |
+
"CTN": "Chủ tịch nước",
|
217 |
+
"ĐBQH": "Đại biểu Quốc hội",
|
218 |
+
"ĐBSCL": "Đồng bằng sông Cửu Long",
|
219 |
+
"ĐCS": "Đảng cộng sản",
|
220 |
+
"ĐH": "Đại học",
|
221 |
+
"ĐHBK": "Đại học Bách khoa",
|
222 |
+
"ĐHKHTN": "Đại học Khoa học tự nhiên",
|
223 |
+
"ĐHQG": "Đại học Quốc gia",
|
224 |
+
"ĐSQ": "Đại sứ quán",
|
225 |
+
"EU": "Ơ u",
|
226 |
+
"GD": "Giáo dục",
|
227 |
+
"HCM": "Hồ Chí Minh",
|
228 |
+
"HĐBA": "Hội đồng bảo an",
|
229 |
+
"HĐND": "Hội đồng nhân dân",
|
230 |
+
"HĐQT": "Hội đồng quản trị",
|
231 |
+
"HN": "Hà Nội",
|
232 |
+
"HV": "Học viện",
|
233 |
+
"KHXH&NV": "Khoa học Xã hội và Nhân văn",
|
234 |
+
"KT": "Kinh tế",
|
235 |
+
"KTQS": "Kỹ thuật Quân sự",
|
236 |
+
"LĐ": "lao động",
|
237 |
+
"KHKT": "khoa học kỹ thuật",
|
238 |
+
"km": "ki lô mét",
|
239 |
+
"LHQ": "Liên Hiệp Quốc",
|
240 |
+
"NATO": "Na tô",
|
241 |
+
"ND": "nhân dân",
|
242 |
+
"NHNN": "ngân hàng nhà nước",
|
243 |
+
"NXB": "Nhà xuất bản",
|
244 |
+
"PCCC": "Phòng cháy chữa cháy",
|
245 |
+
"PTTH": "Phổ thông trung học",
|
246 |
+
"PTCS": "Phổ thông cơ sở",
|
247 |
+
"QĐND" : "Quân đội nhân dân",
|
248 |
+
"QĐNDVN" : "Quân đội nhân dân Việt Nam",
|
249 |
+
"QG": "Quốc gia",
|
250 |
+
"QK": "Quân khu",
|
251 |
+
"sau CN": "sau công nguyên",
|
252 |
+
"SG": "Sài Gòn",
|
253 |
+
"TAND": "Tòa án nhân dân",
|
254 |
+
"TBCN": "tư bản chủ nghĩa",
|
255 |
+
"TBT": "Tổng bí thư",
|
256 |
+
"TCN": "trước công nguyên",
|
257 |
+
"TCT": "Tổng công ty",
|
258 |
+
"THCS": "Trung học cơ sở",
|
259 |
+
"THPT": "Trung học phổ thông",
|
260 |
+
"TNHH": "Trách nhiệm hữu hạn",
|
261 |
+
"TNHH MTV": "Trách nhiệm hữu hạn một thành viên",
|
262 |
+
"TP": "thành phố",
|
263 |
+
"TP.": "thành phố",
|
264 |
+
"TPHCM": "Thành phố Hồ Chí Minh",
|
265 |
+
"TT": "Thủ tướng",
|
266 |
+
"TTCK": "Thị trường chứng khoán",
|
267 |
+
"TTTC": "Thị trường tài chính",
|
268 |
+
"TTCP": "Thủ tướng chính phủ",
|
269 |
+
"TTNT": "Trí tuệ nhân tạo",
|
270 |
+
"TTXVN": "Thông tấn xã Việt Nam",
|
271 |
+
"TƯ": "Trung ương",
|
272 |
+
"TW": "Trung ương",
|
273 |
+
"UB": "Ủy ban",
|
274 |
+
"UBND": "Ủy ban nhân dân",
|
275 |
+
"VH": "Văn hóa",
|
276 |
+
"VKSND": "Viện kiểm sát nhân dân",
|
277 |
+
"VN": "Việt Nam",
|
278 |
+
"VND": "Việt Nam đồng",
|
279 |
+
"XH": "Xã hội",
|
280 |
+
"XHCN": "xã hội chủ nghĩa",
|
281 |
+
"%": "phần trăm",
|
282 |
+
"@": "a còng",
|
283 |
+
"&": "và",
|
284 |
+
}
|
285 |
+
|
286 |
+
abbreviation_pattern = re.compile(r'\b(' + '|'.join(re.escape(key) for key in abbreviation_map.keys()) + r')\b')
|
287 |
+
def replace_abbreviations(text):
|
288 |
+
def replacement(match):
|
289 |
+
return abbreviation_map[match.group(0)]
|
290 |
+
return abbreviation_pattern.sub(replacement, text)
|
291 |
+
|
292 |
+
|
293 |
+
def convert_abbreviations(text):
|
294 |
+
"""Converts abbreviations like M.A.S.H. to MASH"""
|
295 |
+
return re.sub(r"([A-Z]\.){2,}", lambda match: "".join(c for c in match.group(0) if c.isalpha()), text)
|
296 |
+
|
297 |
+
|
298 |
+
# Function to normalize Vietnamese text
|
299 |
+
def normalize_vietnamese_text(text):
|
300 |
+
text = text_normalize(text)
|
301 |
+
|
302 |
+
def replace_slash_with_word(text):
|
303 |
+
def replacement(match):
|
304 |
+
word = match.group(1)
|
305 |
+
if word in ['ngày', 'giờ', 'tháng', 'quí', 'quý', 'năm']:
|
306 |
+
return f" mỗi {word}"
|
307 |
+
else:
|
308 |
+
return f" trên {word}"
|
309 |
+
return re.sub(r'/(\w+)', replacement, text)
|
310 |
+
|
311 |
+
# find and replace "/word" with "per word"
|
312 |
+
text = replace_slash_with_word(text)
|
313 |
+
|
314 |
+
# Convert standalone currency amounts (e.g., $200, ₫200, €50, £75, ¥1000)
|
315 |
+
def replace_currency(match):
|
316 |
+
currency_sign = match.group(1)
|
317 |
+
amount = match.group(2)
|
318 |
+
return f"{number_to_vietnamese_words(amount)} {currency_symbol_to_word(currency_sign)}"
|
319 |
+
text = re.sub(r'([$₫đ€£¥₹₽₩₺])([\d.,]+)', replace_currency, text)
|
320 |
+
|
321 |
+
# (reverse case) convert standalone currency amounts (e.g., 200$, 200đ, 50€, 75£, 1000¥)
|
322 |
+
def replace_currency_suffix(match):
|
323 |
+
amount = match.group(1)
|
324 |
+
currency_sign = match.group(2)
|
325 |
+
return f"{number_to_vietnamese_words(amount)} {currency_symbol_to_word(currency_sign)}"
|
326 |
+
text = re.sub(r'([\d.,]+)([$₫đ€£¥₹₽₩₺%])', replace_currency_suffix, text)
|
327 |
+
|
328 |
+
# in case symbol [¥] is used for Chinese currency and followed by CNY
|
329 |
+
text = text.replace('yên CNY', 'nhân dân tệ')
|
330 |
+
|
331 |
+
# Replace abbreviations
|
332 |
+
text = convert_abbreviations(text)
|
333 |
+
text = replace_abbreviations(text)
|
334 |
+
|
335 |
+
# Convert Roman numerals to integers
|
336 |
+
def replace_roman(match):
|
337 |
+
roman_numeral = match.group()
|
338 |
+
return str(roman_to_integer(roman_numeral))
|
339 |
+
# Replace Roman numerals with integers
|
340 |
+
text = re.sub(r'\b[IVXLCDM]+\b', replace_roman, text)
|
341 |
+
|
342 |
+
# Convert standalone numbers to words
|
343 |
+
text = re.sub(r'\b[\d.,]+\b', lambda match: number_to_vietnamese_words(match.group()), text)
|
344 |
+
|
345 |
+
# Fix common grammar errors
|
346 |
+
text = re.sub(r'\s+', ' ', text) # Remove extra spaces
|
347 |
+
text = re.sub(r'\s([,\.])', r'\1', text) # Remove space before punctuation
|
348 |
+
text = re.sub(r'([,\.])(\S)', r'\1 \2', text) # Add space after punctuation
|
349 |
+
|
350 |
+
text = ( text.replace("..", ".")
|
351 |
+
.replace("!.", "!")
|
352 |
+
.replace("?.", "?")
|
353 |
+
.replace(" .", ".")
|
354 |
+
.replace(" ,", ",")
|
355 |
+
.replace(" (", ", ")
|
356 |
+
.replace(") ", ", ")
|
357 |
+
)
|
358 |
+
|
359 |
+
return text.strip()
|
360 |
+
|