metadata
license: apache-2.0
ReSpark TTS Model
This repository contains the ReSpark Text-to-Speech (TTS) model, a powerful and efficient model for generating high-quality speech from text. It is based on the RWKV architecture and utilizes the BiCodec tokenizer for audio processing.
Installation
First, install the required dependencies:
pip install transformers rwkv-fla torch torchaudio torchvision transformers soundfile numpy librosa omegaconf soxr soundfile einx librosa
Usage
The tts.py
script provides a complete example of how to use this model for text-to-speech synthesis with voice cloning.
Running the Test Script
To generate speech, simply run the script:
python tts.py
How it Works
The script performs the following steps:
- Loads the pre-trained
AutoModelForCausalLM
andAutoTokenizer
from the current directory. - Initializes the
BiCodecTokenizer
for audio encoding and decoding. - Loads a reference audio file (
kafka.wav
) and its corresponding transcript (prompt_text
) to provide a voice prompt. - Resamples the reference audio to match the model's expected sample rate (24000 Hz).
- Takes a target text (
text
) to be synthesized. - Calls the
generate_speech
function, which generates audio based on the target text and the voice from the reference audio. - Saves the generated audio to
output.wav
.
You can modify the prompt_text
, prompt_audio_file
, and text
variables in tts.py
to synthesize different text with different voices.
Example Code (tts.py
)
import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
print('add current dir to sys.path', current_dir)
sys.path.append(current_dir)
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
import soundfile as sf
import numpy as np
import torch
from utilities import generate_embeddings
def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None,
max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95,
temperature=1.0, device="cuda:0"):
"""
Function to generate speech.
"""
eos_token_id = model.config.vocab_size - 1
embeddings = generate_embeddings(
model=model,
tokenizer=tokenizer,
text=text,
bicodec=bicodec,
prompt_text=prompt_text,
prompt_audio=prompt_audio
)
global_tokens = embeddings['global_tokens'].unsqueeze(0)
model.eval()
with torch.no_grad():
generated_outputs = model.generate(
inputs_embeds=embeddings['input_embs'],
attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]),dtype=torch.long,device=device),
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
use_cache=True
)
semantic_tokens_tensor = generated_outputs[:,:-1]
with torch.no_grad():
wav = bicodec.detokenize(global_tokens, semantic_tokens_tensor)
return wav
# --- Main execution ---
device = 'cuda:0'
# Initialize tokenizers and model
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
model = model.bfloat16().to(device)
model.eval()
# Prepare prompt audio and text for voice cloning
prompt_text = "我们并不是通过物理移动手段找到星河的。"
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
prompt_audio, sampling_rate = sf.read(prompt_audio_file)
# Resample audio if necessary
target_sample_rate = audio_tokenizer.config['sample_rate']
if sampling_rate != target_sample_rate:
from librosa import resample
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
prompt_audio = np.array(prompt_audio, dtype=np.float32)
# Text to synthesize
text = "科学技术是第一生产力,最近 AI的迅猛发展让我们看到了迈向星辰大海的希望。"
# Generate speech
wav = generate_speech(model, tokenizer, text, audio_tokenizer, prompt_audio=prompt_audio, device=device)
# Save the output
sf.write('output.wav', wav, target_sample_rate)
print("Generated audio saved to output.wav")