Jimmy Vu commited on
Commit
2e99c77
·
1 Parent(s): 8460d0e
.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WadaSNR/
2
+ .idea/
3
+ *.pyc
4
+ .DS_Store
5
+ ./__init__.py
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ .hypothesis/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ .static_storage/
61
+ .media/
62
+ local_settings.py
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # pyenv
81
+ .python-version
82
+
83
+ # celery beat schedule file
84
+ celerybeat-schedule
85
+
86
+ # SageMath parsed files
87
+ *.sage.py
88
+
89
+ # Environments
90
+ .env
91
+ .venv
92
+ env/
93
+ venv/
94
+ ENV/
95
+ env.bak/
96
+ venv.bak/
97
+
98
+ # Spyder project settings
99
+ .spyderproject
100
+ .spyproject
101
+
102
+ # Rope project settings
103
+ .ropeproject
104
+
105
+ # mkdocs documentation
106
+ /site
107
+
108
+ # mypy
109
+ .mypy_cache/
110
+
111
+ # vim
112
+ *.swp
113
+ *.swm
114
+ *.swn
115
+ *.swo
116
+
117
+ # pytorch models
118
+ *.pth
119
+ *.pth.tar
120
+ !dummy_speakers.pth
121
+ result/
122
+
123
+ # setup.py
124
+ version.py
125
+
126
+ # jupyter dummy files
127
+ core
128
+
129
+ # ignore local datasets
130
+ recipes/WIP/*
131
+ recipes/ljspeech/LJSpeech-1.1/*
132
+ recipes/vctk/VCTK/*
133
+ recipes/**/*.npy
134
+ recipes/**/*.json
135
+ VCTK-Corpus-removed-silence/*
136
+
137
+ # ignore training logs
138
+ trainer_*_log.txt
139
+
140
+ # files used internally for dev, test etc.
141
+ tests/outputs/*
142
+ tests/train_outputs/*
143
+ TODO.txt
144
+ .vscode/*
145
+ data/*
146
+ notebooks/data/*
147
+ TTS/tts/utils/monotonic_align/core.c
148
+ .vscode-upload.json
149
+ temp_build/*
150
+ events.out*
151
+ old_configs/*
152
+ model_importers/*
153
+ model_profiling/*
154
+ docs/source/TODO/*
155
+ .noseids
156
+ .dccache
157
+ log.txt
158
+ umap.png
159
+ *.out
160
+ SocialMedia.txt
161
+ output.wav
162
+ tts_output.wav
163
+ deps.json
164
+ speakers.json
165
+ internal/*
166
+ *_pitch.npy
167
+ *_phoneme.npy
168
+ wandb
169
+ depot/*
170
+ coqui_recipes/*
171
+ local_scripts/*
172
+ coqui_demos/*
173
+ cache/*
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: green
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.15.0
8
- app_file: app.py
9
  pinned: false
10
  license: mpl-2.0
11
  short_description: Coqui-XTTS Text-to-Speech Demo with Vietnamese
 
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.15.0
8
+ app_file: gradio_app.py
9
  pinned: false
10
  license: mpl-2.0
11
  short_description: Coqui-XTTS Text-to-Speech Demo with Vietnamese
app.py DELETED
@@ -1,7 +0,0 @@
1
- import gradio as gr
2
-
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
-
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
gradio_app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+ import hashlib
5
+ from pathlib import Path
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import torchaudio
10
+ import numpy as np
11
+
12
+ from underthesea import sent_tokenize
13
+ from df.enhance import enhance, init_df, load_audio, save_audio
14
+
15
+ from huggingface_hub import hf_hub_download, snapshot_download
16
+
17
+ from langdetect import detect
18
+
19
+ from utils.vietnamese_normalization import normalize_vietnamese_text
20
+ from utils.logger import setup_logger
21
+ from utils.sentence import split_sentence, merge_sentences
22
+
23
+ import warnings
24
+ warnings.filterwarnings("ignore")
25
+
26
+ logger = setup_logger(__file__)
27
+
28
+ df_model, df_state = None, None
29
+
30
+ APP_DIR = os.path.dirname(os.path.abspath(__file__))
31
+ checkpoint_dir=f"{APP_DIR}/cache"
32
+ temp_dir=f"{APP_DIR}/cache/temp/"
33
+ sample_audio_dir=f"{APP_DIR}/cache/audio_samples/"
34
+ enhance_audio_dir=f"{APP_DIR}/cache/audio_enhances/"
35
+ for d in [checkpoint_dir, temp_dir, sample_audio_dir, enhance_audio_dir]:
36
+ os.makedirs(d, exist_ok=True)
37
+
38
+ language_dict = {'English': 'en', 'Español (Spanish)': 'es', 'Français (French)': 'fr',
39
+ 'Deutsch (German)': 'de', 'Italiano (Italian)': 'it', 'Português (Portuguese)': 'pt',
40
+ 'Polski (Polish)': 'pl', 'Türkçe (Turkish)': 'tr', 'Русский (Russian)': 'ru',
41
+ 'Nederlands (Dutch)': 'nl', 'Čeština (Czech)': 'cs', 'العربية (Arabic)': 'ar', '中文 (Chinese)': 'zh-cn',
42
+ 'Magyar nyelv (Hungarian)': 'hu', '한국어 (Korean)': 'ko', '日本語 (Japanese)': 'ja',
43
+ 'Tiếng Việt (Vietnamese)': 'vi', 'Auto': 'auto'}
44
+
45
+ default_language = 'Auto'
46
+ language_codes = [v for _, v in language_dict.items()]
47
+ def lang_detect(text):
48
+ try:
49
+ lang = detect(text)
50
+ if lang == 'zh-tw':
51
+ return 'zh-cn'
52
+ return lang if lang in language_codes else 'en'
53
+ except:
54
+ return 'en'
55
+
56
+ input_text_max_length = 3000
57
+ use_deepspeed = False
58
+
59
+ try:
60
+ import spaces
61
+ except ImportError:
62
+ from utils import spaces
63
+
64
+ xtts_model = None
65
+ def load_model():
66
+ global xtts_model
67
+
68
+ from TTS.tts.configs.xtts_config import XttsConfig
69
+ from TTS.tts.models.xtts import Xtts
70
+ repo_id = "jimmyvu/xtts"
71
+ snapshot_download(repo_id=repo_id,
72
+ local_dir=checkpoint_dir,
73
+ allow_patterns=["*.safetensors", "*.wav", "*.json"],
74
+ ignore_patterns="*.pth")
75
+
76
+ config = XttsConfig()
77
+ config.load_json(os.path.join(checkpoint_dir, "config.json"))
78
+ xtts_model = Xtts.init_from_config(config)
79
+
80
+ logger.info("Loading model...")
81
+ xtts_model.load_safetensors_checkpoint(
82
+ config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
83
+ )
84
+ if torch.cuda.is_available():
85
+ xtts_model.cuda()
86
+ logger.info(f"Successfully loaded model from {checkpoint_dir}")
87
+
88
+ load_model()
89
+
90
+ default_speaker_reference_audio = os.path.join(sample_audio_dir, 'harvard.wav')
91
+
92
+ @spaces.GPU
93
+ def generate_speech(input_text, speaker_reference_audio, enhance_speech, temperature=0.3, top_p=0.85, top_k=50, repetition_penalty=10.0, language='Auto', *args):
94
+ """Process text and generate audio."""
95
+ global df_model, df_state, xtts_model
96
+ log_messages = ""
97
+ if len(input_text) > input_text_max_length:
98
+ gr.Warning("Text is too long! Please provide a shorter text.")
99
+ log_messages += "Text is too long! Please provide a shorter text.\n"
100
+ return None, log_messages
101
+
102
+ language_code = language_dict.get(language, 'en')
103
+ logger.info(f"Language [{language}], code: [{language_code}]")
104
+ lang = lang_detect(input_text) if language_code == 'auto' else language_code
105
+ if (lang not in ['ja', 'kr', 'zh-cn'] and len(input_text.split()) < 2) or \
106
+ (lang in ['ja', 'kr', 'zh-cn'] and len(input_text) < 2):
107
+ gr.Warning("Text is too short! Please provide a longer text.")
108
+ log_messages += "Text is too short! Please provide a longer text.\n"
109
+ return None, log_messages
110
+
111
+ if not speaker_reference_audio:
112
+ gr.Warning("Please provide at least one reference audio!")
113
+ log_messages += "Please provide at least one reference audio!\n"
114
+ return None, log_messages
115
+
116
+ start = time.time()
117
+ logger.info(f"Start processing text: {input_text[:30]}... [length: {len(input_text)}]")
118
+
119
+ if enhance_speech:
120
+ logger.info("Enhancing reference audio...")
121
+ _, audio_file = os.path.split(speaker_reference_audio)
122
+ enhanced_audio_path = os.path.join(enhance_audio_dir, f"{audio_file}.enh.wav")
123
+ if not os.path.exists(enhanced_audio_path):
124
+ if not df_model:
125
+ df_model, df_state, _ = init_df()
126
+ audio, _ = load_audio(speaker_reference_audio, sr=df_state.sr())
127
+ # denoise audio
128
+ enhanced_audio = enhance(df_model, df_state, audio)
129
+ # save enhanced audio
130
+ save_audio(enhanced_audio_path, enhanced_audio, sr=df_state.sr())
131
+ speaker_reference_audio = enhanced_audio_path
132
+
133
+ gpt_cond_latent, speaker_embedding = xtts_model.get_conditioning_latents(
134
+ audio_path=speaker_reference_audio,
135
+ gpt_cond_len=xtts_model.config.gpt_cond_len,
136
+ max_ref_length=xtts_model.config.max_ref_len,
137
+ sound_norm_refs=xtts_model.config.sound_norm_refs,
138
+ )
139
+
140
+ # Split text by sentence
141
+ if lang in ["ja", "zh-cn"]:
142
+ sentences = input_text.split("。")
143
+ else:
144
+ sentences = sent_tokenize(input_text)
145
+ # merge short sentences to next/prev ones
146
+ sentences = merge_sentences(sentences)
147
+ # inference
148
+ wav_array = inference(sentences, language_code, gpt_cond_latent, speaker_embedding, temperature, top_p, top_k, repetition_penalty)
149
+ end = time.time()
150
+ logger.info(f"End processing text: {input_text[:30]}... Processing time: {end - start:.2f}s")
151
+ log_messages += f"Processing time: {end - start:.2f}s"
152
+ return (24000, wav_array), log_messages
153
+
154
+
155
+ def inference(sentences, language_code, gpt_cond_latent, speaker_embedding, temperature, top_p, top_k, repetition_penalty):
156
+ # set dynamic length penalty from -1.0 to 1,0 based on text length
157
+ max_text_length = 180
158
+ dynamic_length_penalty = lambda text_length: (2 * (min(max_text_length, text_length) / max_text_length)) - 1
159
+ # inference
160
+ out_wavs = []
161
+ for sentence in sentences:
162
+ if len(sentence.strip()) == 0:
163
+ continue
164
+ lang = lang_detect(sentence) if language_code == 'auto' else language_code
165
+ if lang == 'vi':
166
+ sentence = normalize_vietnamese_text(sentence)
167
+ # split too long sentence
168
+ texts = split_sentence(sentence) if len(sentence) > max_text_length else [sentence]
169
+ for text in texts:
170
+ logger.info(f"[{lang}] {text}")
171
+ try:
172
+ out = xtts_model.inference(
173
+ text=text,
174
+ language=lang,
175
+ gpt_cond_latent=gpt_cond_latent,
176
+ speaker_embedding=speaker_embedding,
177
+ temperature=temperature,
178
+ top_p=top_p,
179
+ top_k=top_k,
180
+ repetition_penalty=repetition_penalty,
181
+ length_penalty=dynamic_length_penalty(len(text)),
182
+ enable_text_splitting=True,
183
+ )
184
+ out_wavs.append(out["wav"])
185
+ except Exception as e:
186
+ logger.error(f"Error processing text: {text} - {e}")
187
+
188
+ return np.concatenate(out_wavs)
189
+
190
+ def build_gradio_ui():
191
+ """Builds and launches the Gradio UI."""
192
+ theme=gr.Theme.from_hub('JohnSmith9982/small_and_pretty')
193
+ setattr(theme, 'button_secondary_background_fill', '#fcd53f')
194
+ setattr(theme, 'checkbox_border_color', '#02c160')
195
+ setattr(theme, 'input-border-width', '1px')
196
+ setattr(theme, 'input-background-fill', '#ffffff')
197
+ setattr(theme, 'input-background-fill_focus', '#ffffff')
198
+ setattr(theme, 'input-border-color', '#d1d5db')
199
+ setattr(theme, 'input-border-color_focus', '#fcd53f')
200
+
201
+ default_prompt = ("Hi, I am a multilingual text-to-speech AI model.\n"
202
+ "Bonjour, je suis un modèle d'IA de synthèse vocale multilingue.\n"
203
+ "Hallo, ich bin ein mehrsprachiges Text-zu-Sprache KI-Modell.\n"
204
+ "Ciao, sono un modello di intelligenza artificiale di sintesi vocale multilingue.\n"
205
+ "Привет, я многоязычная модель искусственного интеллекта, преобразующая текст в речь.\n"
206
+ "Xin chào, tôi là một mô hình AI chuyển đổi văn bản thành giọng nói đa ngôn ngữ.\n")
207
+
208
+ with gr.Blocks(title="Coqui XTTS Demo", theme=theme) as ui:
209
+ gr.Markdown(
210
+ """
211
+ # 🐸 Coqui-XTTS Text-to-Speech Demo
212
+ Convert text to speech with advanced voice cloning and enhancement.
213
+ Support 17 languages, \u2605 **Vietnamese** \u2605 newly added.
214
+ """
215
+ )
216
+
217
+ with gr.Tab("Text to Speech"):
218
+ with gr.Row():
219
+ with gr.Column():
220
+ input_text = gr.Text(label="Enter Text Here",
221
+ placeholder="Write the text you want to convert...",
222
+ value=default_prompt,
223
+ lines=5,
224
+ max_length=input_text_max_length)
225
+ speaker_reference_audio = gr.Audio(
226
+ label="Speaker reference audio:",
227
+ type="filepath",
228
+ editable=False,
229
+ min_length=3,
230
+ max_length=300,
231
+ value=default_speaker_reference_audio
232
+ )
233
+ enhance_speech = gr.Checkbox(label="Enhance Reference Audio", value=False)
234
+ language = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language)
235
+ generate_button = gr.Button("Generate Speech")
236
+ with gr.Column():
237
+ audio_output = gr.Audio(label="Generated Audio")
238
+ log_output = gr.Text(label="Log Output")
239
+
240
+ with gr.Tab("Clone Your Voice"):
241
+ with gr.Row():
242
+ with gr.Column():
243
+ input_text_mic = gr.Text(label="Enter Text Here",
244
+ placeholder="Write the text you want to convert...",
245
+ lines=5,
246
+ max_length=input_text_max_length)
247
+ mic_ref_audio = gr.Audio(label="Record Reference Audio", sources=["microphone"])
248
+ enhance_speech_mic = gr.Checkbox(label="Enhance Reference Audio", value=True)
249
+ language_mic = gr.Dropdown(label="Target Language", choices=[k for k in language_dict.keys()], value=default_language)
250
+ generate_button_mic = gr.Button("Generate Speech")
251
+ with gr.Column():
252
+ audio_output_mic = gr.Audio(label="Generated Audio")
253
+ log_output_mic = gr.Text(label="Log Output")
254
+
255
+
256
+ def process_mic_and_generate(input_text_mic, mic_ref_audio, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic):
257
+ if mic_ref_audio:
258
+ data = str(time.time()).encode("utf-8")
259
+ hash = hashlib.sha1(data).hexdigest()[:10]
260
+ output_path = os.path.join(temp_dir, (f"mic_{hash}.wav"))
261
+
262
+ torch_audio = torch.from_numpy(mic_ref_audio[1].astype(float))
263
+ try:
264
+ torchaudio.save(output_path, torch_audio.unsqueeze(0), mic_ref_audio[0])
265
+ return generate_speech(input_text_mic, output_path, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic)
266
+ except Exception as e:
267
+ logger.error(f"Error saving audio file: {e}")
268
+ return None, f"Error saving audio file: {e}"
269
+ else:
270
+ return None, "Please record an audio!"
271
+
272
+ with gr.Tab("Advanced Settings"):
273
+ with gr.Row():
274
+ with gr.Column():
275
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.3, step=0.05)
276
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=50.0, value=9.5, step=1.0)
277
+
278
+ with gr.Column():
279
+ top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.85, step=0.05)
280
+ top_k = gr.Slider(label="Top K", minimum=0, maximum=100, value=50, step=5)
281
+
282
+ generate_button.click(
283
+ generate_speech,
284
+ inputs=[input_text, speaker_reference_audio, enhance_speech, temperature, top_p, top_k, repetition_penalty, language],
285
+ outputs=[audio_output, log_output],
286
+ )
287
+
288
+ generate_button_mic.click(
289
+ process_mic_and_generate,
290
+ inputs=[input_text_mic, mic_ref_audio, enhance_speech_mic, temperature, top_p, top_k, repetition_penalty, language_mic],
291
+ outputs=[audio_output_mic, log_output_mic],
292
+ )
293
+
294
+ return ui
295
+
296
+ if __name__ == "__main__":
297
+ ui = build_gradio_ui()
298
+ ui.launch(debug=False)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.1
2
+ deepfilternet==0.5.6
3
+ underthesea==6.8.0
4
+ deepspeed
5
+ colorama
6
+ pyvi
7
+ langdetect
8
+ cutlet
9
+ unidic
10
+ # for Japanese
11
+ # python -m unidic download
12
+ git+https://github.com/quangvu3/coqui-xtts.git
utils/__init__.py ADDED
File without changes
utils/cuda_toolkit.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os
3
+
4
+ def install_cuda_toolkit():
5
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
6
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
7
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
8
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
9
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
10
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
11
+
12
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
13
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
14
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
15
+ os.environ["CUDA_HOME"],
16
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
17
+ )
18
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
19
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
utils/logger.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from pathlib import Path
4
+ from datetime import datetime
5
+ import colorama
6
+ from colorama import Fore, Back, Style
7
+ from typing import Optional, Union
8
+ import re
9
+ import traceback
10
+ import copy
11
+ import os
12
+
13
+ # Initialize colorama
14
+ colorama.init()
15
+
16
+ class ColoredFormatter(logging.Formatter):
17
+ """Colored formatter for structured log output.
18
+
19
+ This formatter adds color-coding, icons, timestamps, and file location
20
+ information to log messages. It supports different color schemes for
21
+ different log levels and includes special formatting for exceptions.
22
+
23
+ Attributes:
24
+ COLORS (dict): Color schemes for different log levels, including:
25
+ - color: Foreground color
26
+ - style: Text style (dim, normal, bright)
27
+ - icon: Emoji icon for the log level
28
+ - bg: Background color (for critical logs)
29
+ """
30
+
31
+ COLORS = {
32
+ 'DEBUG': {
33
+ 'color': Fore.CYAN,
34
+ 'style': Style.DIM,
35
+ 'icon': '🔍'
36
+ },
37
+ 'INFO': {
38
+ 'color': Fore.GREEN,
39
+ 'style': Style.NORMAL,
40
+ 'icon': 'ℹ️'
41
+ },
42
+ 'WARNING': {
43
+ 'color': Fore.YELLOW,
44
+ 'style': Style.BRIGHT,
45
+ 'icon': '⚠️'
46
+ },
47
+ 'ERROR': {
48
+ 'color': Fore.RED,
49
+ 'style': Style.BRIGHT,
50
+ 'icon': '❌'
51
+ },
52
+ 'CRITICAL': {
53
+ 'color': Fore.WHITE,
54
+ 'style': Style.BRIGHT,
55
+ 'bg': Back.RED,
56
+ 'icon': '💀'
57
+ }
58
+ }
59
+
60
+ def format(self, record: logging.LogRecord) -> str:
61
+ """Format a log record with color and structure.
62
+
63
+ This method formats log records with:
64
+ - Timestamp in HH:MM:SS.mmm format
65
+ - File location (filename:line)
66
+ - Color-coded level name with icon
67
+ - Color-coded message
68
+ - Formatted exception traceback if present
69
+
70
+ Args:
71
+ record (logging.LogRecord): Log record to format.
72
+
73
+ Returns:
74
+ str: Formatted log message with color and structure.
75
+ """
76
+ colored_record = copy.copy(record)
77
+
78
+ # Get color scheme
79
+ scheme = self.COLORS.get(record.levelname, {
80
+ 'color': Fore.WHITE,
81
+ 'style': Style.NORMAL,
82
+ 'icon': '•'
83
+ })
84
+
85
+ # Format timestamp
86
+ timestamp = datetime.fromtimestamp(record.created).strftime('%H:%M:%S.%f')[:-3]
87
+
88
+ # Get file location
89
+ file_location = f"{os.path.basename(record.pathname)}:{record.lineno}"
90
+
91
+ # Build components
92
+ components = []
93
+
94
+ # log formatting
95
+ components.extend([
96
+ f"{Fore.BLUE}{timestamp}{Style.RESET_ALL}",
97
+ f"{Fore.WHITE}{Style.DIM}{file_location}{Style.RESET_ALL}",
98
+ f"{scheme['color']}{scheme['style']}{scheme['icon']} {record.levelname:8}{Style.RESET_ALL}",
99
+ f"{scheme['color']}{record.msg}{Style.RESET_ALL}"
100
+ ])
101
+
102
+ # Add exception info
103
+ if record.exc_info:
104
+ components.append(
105
+ f"\n{Fore.RED}{Style.BRIGHT}"
106
+ f"{''.join(traceback.format_exception(*record.exc_info))}"
107
+ f"{Style.RESET_ALL}"
108
+ )
109
+
110
+ return " | ".join(components)
111
+
112
+
113
+ def setup_logger(
114
+ name: Optional[Union[str, Path]] = None,
115
+ level: int = logging.INFO
116
+ ) -> logging.Logger:
117
+ """Set up a colored logger
118
+
119
+ This function creates or retrieves a logger with colored output and
120
+ automatic log interception. If a file path is provided as the name,
121
+ it will use the filename (without extension) as the logger name.
122
+
123
+ Args:
124
+ name (Optional[Union[str, Path]], optional): Logger name or __file__ for
125
+ module name. Defaults to None.
126
+ level (int, optional): Logging level. Defaults to logging.INFO.
127
+
128
+ Returns:
129
+ logging.Logger: Configured logger instance.
130
+ """
131
+ # Get logger name from file path
132
+ if isinstance(name, (str, Path)) and Path(name).suffix == '.py':
133
+ name = Path(name).stem
134
+
135
+ # Get or create logger
136
+ logger = logging.getLogger(name)
137
+ logger.setLevel(level)
138
+
139
+ # Only add handler if none exists
140
+ if not logger.handlers:
141
+ # Create console handler
142
+ console_handler = logging.StreamHandler(sys.stdout)
143
+ console_handler.setFormatter(ColoredFormatter())
144
+ logger.addHandler(console_handler)
145
+
146
+ return logger
utils/sentence.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def split_sentence(sentence, delimiters=",;-!?"):
3
+ """
4
+ Splits a sentence into two halves, prioritizing the delimiter closest to the middle.
5
+ If no delimiter is found, it ensures words are not split in the middle.
6
+
7
+ Args:
8
+ sentence (str): The input sentence to split.
9
+ delimiters (str): A string of delimiters to prioritize for splitting (default: ",;!?").
10
+
11
+ Returns:
12
+ tuple: A tuple containing the two halves of the sentence.
13
+ """
14
+ # Find all delimiter indices in the sentence
15
+ delimiter_indices = [i for i, char in enumerate(sentence) if char in delimiters]
16
+
17
+ if delimiter_indices:
18
+ # Calculate the midpoint of the sentence
19
+ midpoint = len(sentence) // 2
20
+
21
+ # Find the delimiter closest to the midpoint
22
+ closest_delimiter = min(delimiter_indices, key=lambda x: abs(x - midpoint))
23
+
24
+ # Split at the closest delimiter
25
+ first_half = sentence[:closest_delimiter].strip()
26
+ second_half = sentence[closest_delimiter + 1:].strip()
27
+ else:
28
+ # If no delimiter, split at the nearest space (word boundary)
29
+ midpoint = len(sentence) // 2
30
+
31
+ # Find the nearest space (word boundary) around the midpoint
32
+ left_space = sentence.rfind(" ", 0, midpoint)
33
+ right_space = sentence.find(" ", midpoint)
34
+
35
+ # Choose the closest space to the midpoint
36
+ if left_space == -1 and right_space == -1:
37
+ # No spaces found (single word), split at midpoint
38
+ split_index = midpoint
39
+ elif left_space == -1:
40
+ # Only right space found
41
+ split_index = right_space
42
+ elif right_space == -1:
43
+ # Only left space found
44
+ split_index = left_space
45
+ else:
46
+ # Choose the closest space to the midpoint
47
+ split_index = left_space if (midpoint - left_space) <= (right_space - midpoint) else right_space
48
+
49
+ # Split the sentence into two parts
50
+ first_half = sentence[:split_index].strip()
51
+ second_half = sentence[split_index:].strip()
52
+
53
+ return first_half, second_half
54
+
55
+
56
+ def merge_sentences(sentences):
57
+ """ handling short sentences by merging them to next/prev ones """
58
+ merged_sentences = []
59
+ i = 0
60
+ while i < len(sentences):
61
+ s = sentences[i]
62
+ word_count = len(s.split())
63
+ j = 1
64
+ # merge the short sentence to the next one until long enough
65
+ while word_count <= 6 and i+j < len(sentences):
66
+ s += ' ' + sentences[i+j]
67
+ word_count = len(s.split())
68
+ j += 1
69
+ merged_sentences.append(s)
70
+ i += j
71
+ # merge the last one to the prev one until long enough
72
+ while len(merged_sentences) > 1 and len(merged_sentences[len(merged_sentences) - 1].split()) < 6:
73
+ merged_sentences[len(merged_sentences) - 2] += ' ' + merged_sentences[len(merged_sentences) - 1]
74
+ merged_sentences.pop()
75
+ return merged_sentences
utils/spaces.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ def GPU(func):
4
+ """Decorator to run a function on the fake GPU
5
+ to get comparable with HF Space"""
6
+ @functools.wraps(func) # Preserves original function's metadata
7
+ def wrapper(*args, **kwargs):
8
+ result = func(*args, **kwargs)
9
+ return result
10
+ return wrapper
utils/vietnamese_normalization.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from underthesea import text_normalize
3
+
4
+ # Dictionary to map numbers to Vietnamese words
5
+ number_to_words = {
6
+ 0: 'không',
7
+ 1: 'một',
8
+ 2: 'hai',
9
+ 3: 'ba',
10
+ 4: 'bốn',
11
+ 5: 'năm',
12
+ 6: 'sáu',
13
+ 7: 'bảy',
14
+ 8: 'tám',
15
+ 9: 'chín',
16
+ 10: 'mười',
17
+ 100: 'trăm',
18
+ 1000: 'nghìn',
19
+ 1000000: 'triệu',
20
+ 1000000000: 'tỷ'
21
+ }
22
+
23
+ # Dictionary to map Roman numerals to integers
24
+ roman_to_int = {
25
+ 'I': 1,
26
+ 'V': 5,
27
+ 'X': 10,
28
+ 'L': 50,
29
+ 'C': 100,
30
+ 'D': 500,
31
+ 'M': 1000
32
+ }
33
+
34
+ # Function to convert Roman numerals to integers
35
+ def roman_to_integer(roman):
36
+ total = 0
37
+ prev_value = 0
38
+ for char in reversed(roman):
39
+ value = roman_to_int.get(char, 0)
40
+ if value < prev_value:
41
+ total -= value
42
+ else:
43
+ total += value
44
+ prev_value = value
45
+ return total
46
+
47
+ currency_symbols ={
48
+ '~': '~ ',
49
+ '%': 'phần trăm',
50
+ '$': 'đô la',
51
+ '₫': 'đồng',
52
+ 'đ': 'đồng',
53
+ '€': 'ơ rô',
54
+ '£': 'bảng',
55
+ '¥': 'yên',
56
+ '₹': 'ru pi',
57
+ '₽': 'rúp',
58
+ '₺': 'li ra',
59
+ '₩': 'uôn',
60
+ }
61
+
62
+ def currency_symbol_to_word(currency_sign):
63
+ if currency_sign in currency_symbols:
64
+ return currency_symbols[currency_sign]
65
+ return currency_sign
66
+
67
+ def detect_number_format(number_str):
68
+ # Check if the number contains a comma and a dot
69
+ if ',' in number_str and '.' in number_str:
70
+ # If the last comma is after the last dot, it's Vietnamese
71
+ if number_str.rfind(',') > number_str.rfind('.'):
72
+ # Validate Vietnamese format
73
+ if re.match(r'^\d{1,3}(?:\.\d{3})*(?:,\d+)?$', number_str):
74
+ return "Vietnamese"
75
+ else:
76
+ return "Invalid"
77
+ # Otherwise, it's US
78
+ else:
79
+ # Validate US format
80
+ if re.match(r'^\d{1,3}(?:,\d{3})*(?:\.\d+)?$', number_str):
81
+ return "US"
82
+ else:
83
+ return "Invalid"
84
+ # If only commas are present
85
+ elif ',' in number_str:
86
+ if re.match(r'^\d{1,3}(?:,\d{3})*(?:\.\d+)?$', number_str):
87
+ return "US"
88
+ elif re.match(r'^(\d+,\d+)?$', number_str):
89
+ return "Vietnamese"
90
+ else:
91
+ return "Invalid"
92
+ # If only dots are present
93
+ elif '.' in number_str:
94
+ if re.match(r'^\d{1,3}(?:\.\d{3})*(?:,\d+)?$', number_str):
95
+ return "Vietnamese"
96
+ elif re.match(r'^(\d+\.\d+)?$', number_str):
97
+ return "US"
98
+ else:
99
+ return "Invalid"
100
+ # If no separators are present, assume Vietnamese (default)
101
+ else:
102
+ return "Vietnamese"
103
+
104
+ # Function to convert numbers to Vietnamese words
105
+ def number_to_vietnamese_words(number_str):
106
+ number_str = str(number_str)
107
+ if detect_number_format(number_str) == 'Invalid':
108
+ return number_str
109
+
110
+ if detect_number_format(number_str) == 'US': # convert US number to Vietnamese one: 1,234.5 to 1234,5
111
+ number = re.sub(r'\.', ',', re.sub(r',', '', number_str))
112
+ else: # remove any dot inside number
113
+ number = re.sub(r'\.', '', number_str)
114
+
115
+ if isinstance(number, str) and ',' in number:
116
+ # Handle decimal numbers (e.g., "120,57")
117
+ integer_part, decimal_part = number.split(',')
118
+ integer_words = _convert_integer_part(int(integer_part))
119
+ decimal_words = _convert_decimal_part(decimal_part)
120
+ return f"{integer_words} phẩy {decimal_words}"
121
+ else:
122
+ # Handle integer numbers
123
+ return _convert_integer_part(int(number))
124
+
125
+ # Helper function to convert the integer part of a number
126
+ def _convert_integer_part(number):
127
+ if number == 0:
128
+ return number_to_words[0]
129
+
130
+ words = []
131
+
132
+ # Handle billions
133
+ if number >= 1000000000:
134
+ billion = number // 1000000000
135
+ words.append(_convert_integer_part(billion))
136
+ words.append(number_to_words[1000000000])
137
+ number %= 1000000000
138
+
139
+ # Handle millions
140
+ if number >= 1000000:
141
+ million = number // 1000000
142
+ words.append(_convert_integer_part(million))
143
+ words.append(number_to_words[1000000])
144
+ number %= 1000000
145
+
146
+ # Handle thousands
147
+ if number >= 1000:
148
+ thousand = number // 1000
149
+ words.append(_convert_integer_part(thousand))
150
+ words.append(number_to_words[1000])
151
+ number %= 1000
152
+ if number < 100 and number > 0:
153
+ words.append('không trăm')
154
+ if number < 10 and number > 0:
155
+ words.append('không')
156
+
157
+ # Handle hundreds
158
+ if number >= 100:
159
+ hundred = number // 100
160
+ words.append(number_to_words[hundred])
161
+ words.append(number_to_words[100])
162
+ number %= 100
163
+ if number > 0 and number < 10:
164
+ words.append('lẻ') # Add "lẻ" for numbers like 106 (một trăm lẻ sáu)
165
+
166
+ # Handle tens and units
167
+ if number >= 20:
168
+ ten = number // 10
169
+ words.append(number_to_words[ten])
170
+ words.append('mươi')
171
+ number %= 10
172
+ elif number >= 10:
173
+ words.append(number_to_words[10])
174
+ number %= 10
175
+
176
+ # Handle units (1-9)
177
+ if number > 0:
178
+ if number == 5 and len(words) > 1 and not words[-1] in['lẻ', 'không']: w = 'lăm'
179
+ elif number == 1 and len(words) > 1 and not words[-1] in ['lẻ', 'mười', 'không']: w = 'mốt'
180
+ else: w = number_to_words[number]
181
+ words.append(w)
182
+
183
+ return ' '.join(words)
184
+
185
+
186
+ # Helper function to convert the decimal part of a number
187
+ def _convert_decimal_part(decimal_part):
188
+ words = []
189
+ for digit in decimal_part:
190
+ words.append(number_to_words[int(digit)])
191
+ return ' '.join(words)
192
+
193
+ # abbreviation replacement
194
+ abbreviation_map = {
195
+ "AI": "Ây Ai",
196
+ "ASEAN": "A Xê An",
197
+ "ATGT": "An toàn giao thông",
198
+ "BCA": "Bộ Công an",
199
+ "BCH": "Ban chấp hành",
200
+ "BCHTW": "Ban Chấp hành Trung ương",
201
+ "BCT": "Bộ Chính trị",
202
+ "BGD": "Bộ Giáo dục",
203
+ "BKH": "Bộ Khoa học và Công nghệ",
204
+ "BNN": "Bộ Nông nghiệp",
205
+ "BQP": "Bộ Quốc phòng",
206
+ "BTC": "Ban tổ chức",
207
+ "BTL": "Bộ Tư lệnh",
208
+ "BYT": "Bộ Y tế",
209
+ "CA" : "công an",
210
+ "CAND" : "Công an nhân dân",
211
+ "CNCS": "chủ nghĩa cộng sản",
212
+ "CNTB": "chủ nghĩa tư bản",
213
+ "CNXH": "chủ nghĩa xã hội",
214
+ "CNY": "nhân dân tệ",
215
+ "CSGT": "Cảnh sát giao thông",
216
+ "CTN": "Chủ tịch nước",
217
+ "ĐBQH": "Đại biểu Quốc hội",
218
+ "ĐBSCL": "Đồng bằng sông Cửu Long",
219
+ "ĐCS": "Đảng cộng sản",
220
+ "ĐH": "Đại học",
221
+ "ĐHBK": "Đại học Bách khoa",
222
+ "ĐHKHTN": "Đại học Khoa học tự nhiên",
223
+ "ĐHQG": "Đại học Quốc gia",
224
+ "ĐSQ": "Đại sứ quán",
225
+ "EU": "Ơ u",
226
+ "GD": "Giáo dục",
227
+ "HCM": "Hồ Chí Minh",
228
+ "HĐBA": "Hội đồng bảo an",
229
+ "HĐND": "Hội đồng nhân dân",
230
+ "HĐQT": "Hội đồng quản trị",
231
+ "HN": "Hà Nội",
232
+ "HV": "Học viện",
233
+ "KHXH&NV": "Khoa học Xã hội và Nhân văn",
234
+ "KT": "Kinh tế",
235
+ "KTQS": "Kỹ thuật Quân sự",
236
+ "LĐ": "lao động",
237
+ "KHKT": "khoa học kỹ thuật",
238
+ "km": "ki lô mét",
239
+ "LHQ": "Liên Hiệp Quốc",
240
+ "NATO": "Na tô",
241
+ "ND": "nhân dân",
242
+ "NHNN": "ngân hàng nhà nước",
243
+ "NXB": "Nhà xuất bản",
244
+ "PCCC": "Phòng cháy chữa cháy",
245
+ "PTTH": "Phổ thông trung học",
246
+ "PTCS": "Phổ thông cơ sở",
247
+ "QĐND" : "Quân đội nhân dân",
248
+ "QĐNDVN" : "Quân đội nhân dân Việt Nam",
249
+ "QG": "Quốc gia",
250
+ "QK": "Quân khu",
251
+ "sau CN": "sau công nguyên",
252
+ "SG": "Sài Gòn",
253
+ "TAND": "Tòa án nhân dân",
254
+ "TBCN": "tư bản chủ nghĩa",
255
+ "TBT": "Tổng bí thư",
256
+ "TCN": "trước công nguyên",
257
+ "TCT": "Tổng công ty",
258
+ "THCS": "Trung học cơ sở",
259
+ "THPT": "Trung học phổ thông",
260
+ "TNHH": "Trách nhiệm hữu hạn",
261
+ "TNHH MTV": "Trách nhiệm hữu hạn một thành viên",
262
+ "TP": "thành phố",
263
+ "TP.": "thành phố",
264
+ "TPHCM": "Thành phố Hồ Chí Minh",
265
+ "TT": "Thủ tướng",
266
+ "TTCK": "Thị trường chứng khoán",
267
+ "TTTC": "Thị trường tài chính",
268
+ "TTCP": "Thủ tướng chính phủ",
269
+ "TTNT": "Trí tuệ nhân tạo",
270
+ "TTXVN": "Thông tấn xã Việt Nam",
271
+ "TƯ": "Trung ương",
272
+ "TW": "Trung ương",
273
+ "UB": "Ủy ban",
274
+ "UBND": "Ủy ban nhân dân",
275
+ "VH": "Văn hóa",
276
+ "VKSND": "Viện kiểm sát nhân dân",
277
+ "VN": "Việt Nam",
278
+ "VND": "Việt Nam đồng",
279
+ "XH": "Xã hội",
280
+ "XHCN": "xã hội chủ nghĩa",
281
+ "%": "phần trăm",
282
+ "@": "a còng",
283
+ "&": "và",
284
+ }
285
+
286
+ abbreviation_pattern = re.compile(r'\b(' + '|'.join(re.escape(key) for key in abbreviation_map.keys()) + r')\b')
287
+ def replace_abbreviations(text):
288
+ def replacement(match):
289
+ return abbreviation_map[match.group(0)]
290
+ return abbreviation_pattern.sub(replacement, text)
291
+
292
+
293
+ def convert_abbreviations(text):
294
+ """Converts abbreviations like M.A.S.H. to MASH"""
295
+ return re.sub(r"([A-Z]\.){2,}", lambda match: "".join(c for c in match.group(0) if c.isalpha()), text)
296
+
297
+
298
+ # Function to normalize Vietnamese text
299
+ def normalize_vietnamese_text(text):
300
+ text = text_normalize(text)
301
+
302
+ def replace_slash_with_word(text):
303
+ def replacement(match):
304
+ word = match.group(1)
305
+ if word in ['ngày', 'giờ', 'tháng', 'quí', 'quý', 'năm']:
306
+ return f" mỗi {word}"
307
+ else:
308
+ return f" trên {word}"
309
+ return re.sub(r'/(\w+)', replacement, text)
310
+
311
+ # find and replace "/word" with "per word"
312
+ text = replace_slash_with_word(text)
313
+
314
+ # Convert standalone currency amounts (e.g., $200, ₫200, €50, £75, ¥1000)
315
+ def replace_currency(match):
316
+ currency_sign = match.group(1)
317
+ amount = match.group(2)
318
+ return f"{number_to_vietnamese_words(amount)} {currency_symbol_to_word(currency_sign)}"
319
+ text = re.sub(r'([$₫đ€£¥₹₽₩₺])([\d.,]+)', replace_currency, text)
320
+
321
+ # (reverse case) convert standalone currency amounts (e.g., 200$, 200đ, 50€, 75£, 1000¥)
322
+ def replace_currency_suffix(match):
323
+ amount = match.group(1)
324
+ currency_sign = match.group(2)
325
+ return f"{number_to_vietnamese_words(amount)} {currency_symbol_to_word(currency_sign)}"
326
+ text = re.sub(r'([\d.,]+)([$₫đ€£¥₹₽₩₺%])', replace_currency_suffix, text)
327
+
328
+ # in case symbol [¥] is used for Chinese currency and followed by CNY
329
+ text = text.replace('yên CNY', 'nhân dân tệ')
330
+
331
+ # Replace abbreviations
332
+ text = convert_abbreviations(text)
333
+ text = replace_abbreviations(text)
334
+
335
+ # Convert Roman numerals to integers
336
+ def replace_roman(match):
337
+ roman_numeral = match.group()
338
+ return str(roman_to_integer(roman_numeral))
339
+ # Replace Roman numerals with integers
340
+ text = re.sub(r'\b[IVXLCDM]+\b', replace_roman, text)
341
+
342
+ # Convert standalone numbers to words
343
+ text = re.sub(r'\b[\d.,]+\b', lambda match: number_to_vietnamese_words(match.group()), text)
344
+
345
+ # Fix common grammar errors
346
+ text = re.sub(r'\s+', ' ', text) # Remove extra spaces
347
+ text = re.sub(r'\s([,\.])', r'\1', text) # Remove space before punctuation
348
+ text = re.sub(r'([,\.])(\S)', r'\1 \2', text) # Add space after punctuation
349
+
350
+ text = ( text.replace("..", ".")
351
+ .replace("!.", "!")
352
+ .replace("?.", "?")
353
+ .replace(" .", ".")
354
+ .replace(" ,", ",")
355
+ .replace(" (", ", ")
356
+ .replace(") ", ", ")
357
+ )
358
+
359
+ return text.strip()
360
+