|
|
|
|
|
|
|
|
|
|
|
|
|
import librosa |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.data |
|
import torchaudio |
|
|
|
|
|
EMBEDDER_PARAMS = { |
|
'num_mels': 40, |
|
'n_fft': 512, |
|
'emb_dim': 256, |
|
'lstm_hidden': 768, |
|
'lstm_layers': 3, |
|
'window': 80, |
|
'stride': 40, |
|
} |
|
|
|
|
|
def set_requires_grad(nets, requires_grad=False): |
|
"""Set requies_grad=Fasle for all the networks to avoid unnecessary |
|
computations |
|
Parameters: |
|
nets (network list) -- a list of networks |
|
requires_grad (bool) -- whether the networks require gradients or not |
|
""" |
|
if not isinstance(nets, list): |
|
nets = [nets] |
|
for net in nets: |
|
if net is not None: |
|
for param in net.parameters(): |
|
param.requires_grad = requires_grad |
|
|
|
|
|
class LinearNorm(nn.Module): |
|
def __init__(self, hp): |
|
super(LinearNorm, self).__init__() |
|
self.linear_layer = nn.Linear(hp["lstm_hidden"], hp["emb_dim"]) |
|
|
|
def forward(self, x): |
|
return self.linear_layer(x) |
|
|
|
|
|
class SpeechEmbedder(nn.Module): |
|
def __init__(self, hp): |
|
super(SpeechEmbedder, self).__init__() |
|
self.lstm = nn.LSTM(hp["num_mels"], |
|
hp["lstm_hidden"], |
|
num_layers=hp["lstm_layers"], |
|
batch_first=True) |
|
self.proj = LinearNorm(hp) |
|
self.hp = hp |
|
|
|
def forward(self, mel): |
|
|
|
mels = mel.unfold(1, self.hp["window"], self.hp["stride"]) |
|
mels = mels.permute(1, 2, 0) |
|
x, _ = self.lstm(mels) |
|
x = x[:, -1, :] |
|
x = self.proj(x) |
|
x = x / torch.norm(x, p=2, dim=1, keepdim=True) |
|
|
|
x = x.mean(dim=0) |
|
if x.norm(p=2) != 0: |
|
x = x / x.norm(p=2) |
|
return x |
|
|
|
|
|
class SpkrEmbedder(nn.Module): |
|
RATE = 16000 |
|
|
|
def __init__( |
|
self, |
|
embedder_path, |
|
embedder_params=EMBEDDER_PARAMS, |
|
rate=16000, |
|
hop_length=160, |
|
win_length=400, |
|
pad=False, |
|
): |
|
super(SpkrEmbedder, self).__init__() |
|
embedder_pt = torch.load(embedder_path, map_location="cpu") |
|
self.embedder = SpeechEmbedder(embedder_params) |
|
self.embedder.load_state_dict(embedder_pt) |
|
self.embedder.eval() |
|
set_requires_grad(self.embedder, requires_grad=False) |
|
self.embedder_params = embedder_params |
|
|
|
self.register_buffer('mel_basis', torch.from_numpy( |
|
librosa.filters.mel( |
|
sr=self.RATE, |
|
n_fft=self.embedder_params["n_fft"], |
|
n_mels=self.embedder_params["num_mels"]) |
|
) |
|
) |
|
|
|
self.resample = None |
|
if rate != self.RATE: |
|
self.resample = torchaudio.transforms.Resample(rate, self.RATE) |
|
self.hop_length = hop_length |
|
self.win_length = win_length |
|
self.pad = pad |
|
|
|
def get_mel(self, y): |
|
if self.pad and y.shape[-1] < 14000: |
|
y = F.pad(y, (0, 14000 - y.shape[-1])) |
|
|
|
window = torch.hann_window(self.win_length).to(y) |
|
y = torch.stft(y, n_fft=self.embedder_params["n_fft"], |
|
hop_length=self.hop_length, |
|
win_length=self.win_length, |
|
window=window) |
|
magnitudes = torch.norm(y, dim=-1, p=2) ** 2 |
|
mel = torch.log10(self.mel_basis @ magnitudes + 1e-6) |
|
return mel |
|
|
|
def forward(self, inputs): |
|
dvecs = [] |
|
for wav in inputs: |
|
mel = self.get_mel(wav) |
|
if mel.dim() == 3: |
|
mel = mel.squeeze(0) |
|
dvecs += [self.embedder(mel)] |
|
dvecs = torch.stack(dvecs) |
|
|
|
dvec = torch.mean(dvecs, dim=0) |
|
dvec = dvec / torch.norm(dvec) |
|
|
|
return dvec |
|
|