Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
import yaml | |
import logging | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .nn.feature_extractor import MelFeatureExtractor | |
from .nn.modules import OmniAudioEncoder, OmniAudioDecoder, ResidualDownConv, UpConv, Transformer, Vocos | |
from .nn.quantizer import ResidualVQ | |
class XY_Tokenizer(nn.Module): | |
def __init__(self, generator_params): | |
super().__init__() | |
# Basic parameters | |
self.input_sample_rate = generator_params['input_sample_rate'] | |
self.output_sample_rate = generator_params['output_sample_rate'] | |
self.encoder_downsample_rate = 1280 | |
self.decoder_upsample_rate = 1920 | |
self.code_dim = generator_params['quantizer_kwargs']['input_dim'] | |
## Codec part | |
## Semantic channel | |
self.semantic_encoder = OmniAudioEncoder(**generator_params['semantic_encoder_kwargs']) | |
self.semantic_encoder_adapter = Transformer(**generator_params['semantic_encoder_adapter_kwargs']) | |
## Acoustic channel | |
self.acoustic_encoder = OmniAudioEncoder(**generator_params['acoustic_encoder_kwargs']) | |
## Semantic & acoustic shared parameters | |
self.pre_rvq_adapter = Transformer(**generator_params['pre_rvq_adapter_kwargs']) | |
self.downsample = ResidualDownConv(**generator_params['downsample_kwargs']) | |
self.quantizer = ResidualVQ(**generator_params['quantizer_kwargs']) | |
self.nq = generator_params['quantizer_kwargs']['num_quantizers'] | |
self.post_rvq_adapter = Transformer(**generator_params['post_rvq_adapter_kwargs']) | |
## Acoustic channel | |
self.upsample = UpConv(**generator_params['upsample_kwargs']) | |
self.acoustic_decoder = OmniAudioDecoder(**generator_params['acoustic_decoder_kwargs']) | |
self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs']) | |
## Feature extractor | |
self.feature_extractor = MelFeatureExtractor(**generator_params['feature_extractor_kwargs']) | |
def inference_tokenize(self, x, input_lengths): | |
""" | |
Input: | |
x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate | |
input_lengths: Valid length for each sample # (B,) | |
Output: | |
dict: Contains the following key-value pairs | |
"zq": Quantized embeddings # (B, D, T) | |
"codes": Quantization codes # (nq, B, T) | |
"codes_lengths": Quantization code lengths # (B,) | |
""" | |
list_x = [xi[:, :x_len].reshape(-1).cpu().numpy() for xi, x_len in zip(x, input_lengths)] | |
features = self.feature_extractor( | |
list_x, | |
sampling_rate=self.input_sample_rate, | |
return_tensors="pt", | |
return_attention_mask=True | |
) | |
input_mel = features['input_features'].to(x.device).to(x.dtype) # (B, D, 3000) | |
audio_attention_mask = features['attention_mask'].to(x.device) # (B, 3000) | |
# Get batch size and sequence length of the input | |
mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,) | |
# Semantic channel | |
semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz | |
semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) # (B, D, T), 50hz | |
# Acoustic channel | |
acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz | |
# Semantic & acoustic mixing | |
concated_semantic_acoustic_channel = torch.concat([semantic_encoder_adapter_output, acoustic_encoder_output], dim=1) # (B, D, T) | |
concated_semantic_acoustic_channel_length = acoustic_encoder_output_length | |
pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) # (B, D, T), 50hz | |
downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) # (B, D, T), 50hz -> 12.5hz | |
zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,) | |
return { | |
"zq": zq, # (B, D, T) | |
"codes": codes, # (nq, B, T) | |
"codes_lengths": quantizer_output_length # (B,) | |
} | |
def inference_detokenize(self, codes, codes_lengths): | |
""" | |
Input: | |
codes: Quantization codes # (nq, B, T) | |
codes_lengths: Quantization code lengths for each sample # (B,) | |
Output: | |
dict: Contains the following key-value pairs | |
"y": Synthesized audio waveform # (B, 1, T) | |
"output_length": Output lengths # (B,) | |
""" | |
zq = self.quantizer.decode_codes(codes) # (B, D, T) | |
post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) # (B, D, T), 12.5hz | |
# Acoustic channel | |
upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) # (B, D, T), 12.5hz -> 50hz | |
acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) # (B, D, T), 50hz -> 100hz | |
y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) # (B, 1, T), 100hz -> 16khz | |
return { | |
"y": y, # (B, 1, T) | |
"output_length": vocos_output_length, # (B,) | |
} | |
def encode(self, wav_list, overlap_seconds=10, device=torch.device("cuda")): | |
""" | |
Input: | |
wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,) | |
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output | |
Output: | |
dict: Contains the following key-value pairs | |
"codes_list": List of quantization codes # B * (nq, T) | |
""" | |
duration_seconds = 30 - overlap_seconds | |
chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk | |
duration_size = int(duration_seconds * self.input_sample_rate) # Valid output samples per chunk | |
code_duration_length = duration_size // self.encoder_downsample_rate # Valid code length per chunk | |
# Get maximum waveform length | |
max_length = max(len(wav) for wav in wav_list) | |
batch_size = len(wav_list) | |
wav_tensor = torch.zeros(batch_size, 1, max_length, device=device) | |
input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) | |
for i, wav in enumerate(wav_list): | |
wav_tensor[i, 0, :len(wav)] = wav | |
input_lengths[i] = len(wav) # (B,) | |
# Calculate number of chunks needed | |
max_chunks = (max_length + duration_size - 1) // duration_size | |
codes_list = [] | |
# Process the entire batch in chunks | |
for chunk_idx in range(max_chunks): | |
start = chunk_idx * duration_size | |
end = min(start + chunk_size, max_length) | |
chunk = wav_tensor[:, :, start:end] # (B, 1, T') | |
chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,) | |
# Skip empty chunks | |
if chunk_lengths.max() == 0: | |
continue | |
# Encode | |
result = self.inference_tokenize(chunk, chunk_lengths) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)} | |
chunk_codes = result["codes"] # (nq, B, T') | |
chunk_code_lengths = result["codes_lengths"] # (B,) | |
# Extract valid portion | |
valid_code_lengths = torch.clamp(chunk_code_lengths, 0, code_duration_length) # (B,) | |
valid_chunk_codes = torch.zeros(self.nq, batch_size, code_duration_length, device=device, dtype=chunk_codes.dtype) | |
for b in range(batch_size): | |
if valid_code_lengths[b] > 0: | |
valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[:, b, :valid_code_lengths[b]] # (nq, B, valid_code_length) | |
codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length) | |
# Concatenate all chunks | |
if codes_list: | |
codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total) | |
codes_list = [codes_tensor[:, i, :input_lengths[i] // self.encoder_downsample_rate] for i in range(batch_size)] # B * (nq, T) | |
else: | |
codes_list = [torch.zeros(self.nq, 0, device=device, dtype=torch.long) for _ in range(batch_size)] # B * (nq, 0) | |
return { | |
"codes_list": codes_list # B * (nq, T) | |
} | |
def decode(self, codes_list, overlap_seconds=10, device=torch.device("cuda")): | |
""" | |
Input: | |
codes_list: List of quantization codes # B * (nq, T) | |
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output | |
Output: | |
dict: Contains the following key-value pairs | |
"syn_wav_list": List of synthesized audio waveforms # B * (T,) | |
""" | |
duration_seconds = 30 - overlap_seconds | |
chunk_code_length = int(30 * self.input_sample_rate // self.encoder_downsample_rate) # Maximum code length per chunk | |
duration_code_length = int(duration_seconds * self.input_sample_rate // self.encoder_downsample_rate) # Valid code length per chunk | |
duration_wav_length = duration_code_length * self.decoder_upsample_rate # Valid waveform length per chunk | |
# Get maximum code length | |
max_code_length = max(codes.shape[-1] for codes in codes_list) | |
batch_size = len(codes_list) | |
codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=device, dtype=torch.long) | |
code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) | |
for i, codes in enumerate(codes_list): | |
codes_tensor[:, i, :codes.shape[-1]] = codes.to(device) | |
code_lengths[i] = codes.shape[-1] # (B,) | |
# Calculate number of chunks needed | |
max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length | |
wav_list = [] | |
# Process the entire batch in chunks | |
for chunk_idx in range(max_chunks): | |
start = chunk_idx * duration_code_length | |
end = min(start + chunk_code_length, max_code_length) | |
chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T') | |
chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,) | |
# Skip empty chunks | |
if chunk_code_lengths.max() == 0: | |
continue | |
# Decode | |
result = self.inference_detokenize(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)} | |
chunk_wav = result["y"] # (B, 1, T') | |
chunk_wav_lengths = result["output_length"] # (B,) | |
# Extract valid portion | |
valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,) | |
valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=device) | |
for b in range(batch_size): | |
if valid_wav_lengths[b] > 0: | |
valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length) | |
wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length) | |
# Concatenate all chunks | |
if wav_list: | |
wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total) | |
syn_wav_list = [wav_tensor[i, 0, :code_lengths[i] * self.decoder_upsample_rate] for i in range(batch_size)] # B * (T,) | |
else: | |
syn_wav_list = [torch.zeros(0, device=device) for _ in range(batch_size)] # B * (0,) | |
return { | |
"syn_wav_list": syn_wav_list # B * (T,) | |
} | |
def load_from_checkpoint(cls, config_path: str, ckpt_path: str): | |
# Load model from configuration file and checkpoint | |
logging.info(f"Loading model from {config_path} and {ckpt_path}") | |
# Load configuration | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f) | |
# Create model instance | |
model = cls(config['generator_params']) | |
# Load checkpoint | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
# Check if checkpoint contains 'generator' key | |
if 'generator' in checkpoint: | |
model.load_state_dict(checkpoint['generator']) | |
else: | |
model.load_state_dict(checkpoint) | |
return model |