VidMuse-cvpr
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +92 -1
- audiocraft/__init__.py +26 -0
- audiocraft/adversarial/__init__.py +22 -0
- audiocraft/adversarial/discriminators/__init__.py +10 -0
- audiocraft/adversarial/discriminators/base.py +34 -0
- audiocraft/adversarial/discriminators/mpd.py +106 -0
- audiocraft/adversarial/discriminators/msd.py +126 -0
- audiocraft/adversarial/discriminators/msstftd.py +134 -0
- audiocraft/adversarial/losses.py +228 -0
- audiocraft/data/__init__.py +10 -0
- audiocraft/data/audio.py +231 -0
- audiocraft/data/audio_dataset.py +694 -0
- audiocraft/data/audio_utils.py +176 -0
- audiocraft/data/info_audio_dataset.py +111 -0
- audiocraft/data/music_dataset.py +307 -0
- audiocraft/data/sound_dataset.py +330 -0
- audiocraft/data/video.py +83 -0
- audiocraft/data/zip.py +76 -0
- audiocraft/environment.py +176 -0
- audiocraft/losses/__init__.py +21 -0
- audiocraft/losses/balancer.py +136 -0
- audiocraft/losses/sisnr.py +97 -0
- audiocraft/losses/specloss.py +149 -0
- audiocraft/losses/stftloss.py +207 -0
- audiocraft/metrics/__init__.py +14 -0
- audiocraft/metrics/chroma_cosinesim.py +72 -0
- audiocraft/metrics/clap_consistency.py +84 -0
- audiocraft/metrics/fad.py +329 -0
- audiocraft/metrics/kld.py +220 -0
- audiocraft/metrics/rvm.py +110 -0
- audiocraft/metrics/visqol.py +216 -0
- audiocraft/models/__init__.py +18 -0
- audiocraft/models/audiogen.py +267 -0
- audiocraft/models/builders.py +268 -0
- audiocraft/models/encodec.py +580 -0
- audiocraft/models/lm.py +685 -0
- audiocraft/models/lm_back.py +698 -0
- audiocraft/models/loaders.py +149 -0
- audiocraft/models/multibanddiffusion.py +196 -0
- audiocraft/models/transformer_module.py +177 -0
- audiocraft/models/unet.py +214 -0
- audiocraft/models/vidmuse.py +425 -0
- audiocraft/modules/__init__.py +22 -0
- audiocraft/modules/activations.py +96 -0
- audiocraft/modules/chroma.py +66 -0
- audiocraft/modules/codebooks_patterns.py +544 -0
- audiocraft/modules/conditioners.py +1357 -0
- audiocraft/modules/conv.py +243 -0
- audiocraft/modules/diffusion_schedule.py +272 -0
- audiocraft/modules/lstm.py +25 -0
README.md
CHANGED
@@ -1,3 +1,94 @@
|
|
1 |
---
|
2 |
-
license: cc-by-
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
license: cc-by-4.0
|
3 |
---
|
4 |
+
|
5 |
+
# VidMuse
|
6 |
+
|
7 |
+
## VidMuse: A Simple Video-to-Music Generation Framework with Long-Short-Term Modeling
|
8 |
+
|
9 |
+
[TL;DR]: VidMuse is a framework for generating high-fidelity music aligned with video content, utilizing Long-Short-Term modeling, and has been accepted to CVPR 2025.
|
10 |
+
|
11 |
+
### Links
|
12 |
+
- **[Paper](https://arxiv.org/pdf/2406.04321)**: Explore the research behind VidMuse.
|
13 |
+
- **[Project](https://vidmuse.github.io/)**: Visit the official project page for more information and updates.
|
14 |
+
- **[Dataset](https://huggingface.co/datasets/HKUSTAudio/VidMuse-Dataset)**: Download the dataset used in the paper.
|
15 |
+
|
16 |
+
## Clone the repository
|
17 |
+
```bash
|
18 |
+
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/HKUSTAudio/VidMuse
|
19 |
+
cd VidMuse
|
20 |
+
```
|
21 |
+
|
22 |
+
## Usage
|
23 |
+
|
24 |
+
1. First install the [`VidMuse` library](https://github.com/ZeyueT/VidMuse)
|
25 |
+
```
|
26 |
+
conda create -n VidMuse python=3.9
|
27 |
+
conda activate VidMuse
|
28 |
+
pip install git+https://github.com/ZeyueT/VidMuse.git
|
29 |
+
```
|
30 |
+
|
31 |
+
2. Install ffmpeg:
|
32 |
+
Install ffmpeg:
|
33 |
+
```bash
|
34 |
+
sudo apt-get install ffmpeg
|
35 |
+
# Or if you are using Anaconda or Miniconda
|
36 |
+
conda install "ffmpeg<5" -c conda-forge
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
3. Run the following Python code:
|
41 |
+
|
42 |
+
|
43 |
+
```py
|
44 |
+
from video_processor import VideoProcessor, merge_video_audio
|
45 |
+
from audiocraft.models import VidMuse
|
46 |
+
import scipy
|
47 |
+
|
48 |
+
# Path to the video
|
49 |
+
video_path = 'sample.mp4'
|
50 |
+
# Initialize the video processor
|
51 |
+
processor = VideoProcessor()
|
52 |
+
# Process the video to obtain tensors and duration
|
53 |
+
local_video_tensor, global_video_tensor, duration = processor.process(video_path)
|
54 |
+
|
55 |
+
progress = True
|
56 |
+
USE_DIFFUSION = False
|
57 |
+
|
58 |
+
# Load the pre-trained VidMuse model
|
59 |
+
MODEL = VidMuse.get_pretrained('HKUSTAudio/VidMuse')
|
60 |
+
# Set generation parameters for the model based on video duration
|
61 |
+
MODEL.set_generation_params(duration=duration)
|
62 |
+
|
63 |
+
try:
|
64 |
+
# Generate outputs using the model
|
65 |
+
outputs = MODEL.generate([local_video_tensor, global_video_tensor], progress=progress, return_tokens=USE_DIFFUSION)
|
66 |
+
except RuntimeError as e:
|
67 |
+
print(e)
|
68 |
+
|
69 |
+
# Detach outputs from the computation graph and convert to CPU float tensor
|
70 |
+
outputs = outputs.detach().cpu().float()
|
71 |
+
|
72 |
+
|
73 |
+
sampling_rate = 32000
|
74 |
+
output_wav_path = "vidmuse_sample.wav"
|
75 |
+
# Write the output audio data to a WAV file
|
76 |
+
scipy.io.wavfile.write(output_wav_path, rate=sampling_rate, data=outputs[0, 0].numpy())
|
77 |
+
|
78 |
+
output_video_path = "vidmuse_sample.mp4"
|
79 |
+
# Merge the original video with the generated music
|
80 |
+
merge_video_audio(video_path, output_wav_path, output_video_path)
|
81 |
+
```
|
82 |
+
|
83 |
+
|
84 |
+
## Citation
|
85 |
+
If you find our work useful, please consider citing:
|
86 |
+
|
87 |
+
```
|
88 |
+
@article{tian2024vidmuse,
|
89 |
+
title={Vidmuse: A simple video-to-music generation framework with long-short-term modeling},
|
90 |
+
author={Tian, Zeyue and Liu, Zhaoyang and Yuan, Ruibin and Pan, Jiahao and Liu, Qifeng and Tan, Xu and Chen, Qifeng and Xue, Wei and Guo, Yike},
|
91 |
+
journal={arXiv preprint arXiv:2406.04321},
|
92 |
+
year={2024}
|
93 |
+
}
|
94 |
+
```
|
audiocraft/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
AudioCraft is a general framework for training audio generative models.
|
8 |
+
At the moment we provide the training code for:
|
9 |
+
|
10 |
+
- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
|
11 |
+
text-to-music and melody+text autoregressive generative model.
|
12 |
+
For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
|
13 |
+
`audiocraft.models.musicgen.MusicGen`.
|
14 |
+
- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
|
15 |
+
text-to-general-audio generative model.
|
16 |
+
- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
|
17 |
+
neural audio codec which provides an excellent tokenizer for autoregressive language models.
|
18 |
+
See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
|
19 |
+
- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
|
20 |
+
improves the perceived quality and reduces the artifacts coming from adversarial decoders.
|
21 |
+
"""
|
22 |
+
|
23 |
+
# flake8: noqa
|
24 |
+
from . import data, modules, models
|
25 |
+
|
26 |
+
__version__ = '1.2.0a1'
|
audiocraft/adversarial/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Adversarial losses and discriminator architectures."""
|
7 |
+
|
8 |
+
# flake8: noqa
|
9 |
+
from .discriminators import (
|
10 |
+
MultiPeriodDiscriminator,
|
11 |
+
MultiScaleDiscriminator,
|
12 |
+
MultiScaleSTFTDiscriminator
|
13 |
+
)
|
14 |
+
from .losses import (
|
15 |
+
AdversarialLoss,
|
16 |
+
AdvLossType,
|
17 |
+
get_adv_criterion,
|
18 |
+
get_fake_criterion,
|
19 |
+
get_real_criterion,
|
20 |
+
FeatLossType,
|
21 |
+
FeatureMatchingLoss
|
22 |
+
)
|
audiocraft/adversarial/discriminators/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from .mpd import MultiPeriodDiscriminator
|
9 |
+
from .msd import MultiScaleDiscriminator
|
10 |
+
from .msstftd import MultiScaleSTFTDiscriminator
|
audiocraft/adversarial/discriminators/base.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
|
14 |
+
FeatureMapType = tp.List[torch.Tensor]
|
15 |
+
LogitsType = torch.Tensor
|
16 |
+
MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
|
17 |
+
|
18 |
+
|
19 |
+
class MultiDiscriminator(ABC, nn.Module):
|
20 |
+
"""Base implementation for discriminators composed of sub-discriminators acting at different scales.
|
21 |
+
"""
|
22 |
+
def __init__(self):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
27 |
+
...
|
28 |
+
|
29 |
+
@property
|
30 |
+
@abstractmethod
|
31 |
+
def num_discriminators(self) -> int:
|
32 |
+
"""Number of discriminators.
|
33 |
+
"""
|
34 |
+
...
|
audiocraft/adversarial/discriminators/mpd.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from ...modules import NormConv2d
|
14 |
+
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
15 |
+
|
16 |
+
|
17 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
18 |
+
return int((kernel_size * dilation - dilation) / 2)
|
19 |
+
|
20 |
+
|
21 |
+
class PeriodDiscriminator(nn.Module):
|
22 |
+
"""Period sub-discriminator.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
period (int): Period between samples of audio.
|
26 |
+
in_channels (int): Number of input channels.
|
27 |
+
out_channels (int): Number of output channels.
|
28 |
+
n_layers (int): Number of convolutional layers.
|
29 |
+
kernel_sizes (list of int): Kernel sizes for convolutions.
|
30 |
+
stride (int): Stride for convolutions.
|
31 |
+
filters (int): Initial number of filters in convolutions.
|
32 |
+
filters_scale (int): Multiplier of number of filters as we increase depth.
|
33 |
+
max_filters (int): Maximum number of filters.
|
34 |
+
norm (str): Normalization method.
|
35 |
+
activation (str): Activation function.
|
36 |
+
activation_params (dict): Parameters to provide to the activation function.
|
37 |
+
"""
|
38 |
+
def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
|
39 |
+
n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
|
40 |
+
filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
|
41 |
+
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
|
42 |
+
activation_params: dict = {'negative_slope': 0.2}):
|
43 |
+
super().__init__()
|
44 |
+
self.period = period
|
45 |
+
self.n_layers = n_layers
|
46 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
47 |
+
self.convs = nn.ModuleList()
|
48 |
+
in_chs = in_channels
|
49 |
+
for i in range(self.n_layers):
|
50 |
+
out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
|
51 |
+
eff_stride = 1 if i == self.n_layers - 1 else stride
|
52 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
|
53 |
+
padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
|
54 |
+
in_chs = out_chs
|
55 |
+
self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
|
56 |
+
padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor):
|
59 |
+
fmap = []
|
60 |
+
# 1d to 2d
|
61 |
+
b, c, t = x.shape
|
62 |
+
if t % self.period != 0: # pad first
|
63 |
+
n_pad = self.period - (t % self.period)
|
64 |
+
x = F.pad(x, (0, n_pad), 'reflect')
|
65 |
+
t = t + n_pad
|
66 |
+
x = x.view(b, c, t // self.period, self.period)
|
67 |
+
|
68 |
+
for conv in self.convs:
|
69 |
+
x = conv(x)
|
70 |
+
x = self.activation(x)
|
71 |
+
fmap.append(x)
|
72 |
+
x = self.conv_post(x)
|
73 |
+
fmap.append(x)
|
74 |
+
# x = torch.flatten(x, 1, -1)
|
75 |
+
|
76 |
+
return x, fmap
|
77 |
+
|
78 |
+
|
79 |
+
class MultiPeriodDiscriminator(MultiDiscriminator):
|
80 |
+
"""Multi-Period (MPD) Discriminator.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
in_channels (int): Number of input channels.
|
84 |
+
out_channels (int): Number of output channels.
|
85 |
+
periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
|
86 |
+
**kwargs: Additional args for `PeriodDiscriminator`
|
87 |
+
"""
|
88 |
+
def __init__(self, in_channels: int = 1, out_channels: int = 1,
|
89 |
+
periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
|
90 |
+
super().__init__()
|
91 |
+
self.discriminators = nn.ModuleList([
|
92 |
+
PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
|
93 |
+
])
|
94 |
+
|
95 |
+
@property
|
96 |
+
def num_discriminators(self):
|
97 |
+
return len(self.discriminators)
|
98 |
+
|
99 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
100 |
+
logits = []
|
101 |
+
fmaps = []
|
102 |
+
for disc in self.discriminators:
|
103 |
+
logit, fmap = disc(x)
|
104 |
+
logits.append(logit)
|
105 |
+
fmaps.append(fmap)
|
106 |
+
return logits, fmaps
|
audiocraft/adversarial/discriminators/msd.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from ...modules import NormConv1d
|
14 |
+
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
15 |
+
|
16 |
+
|
17 |
+
class ScaleDiscriminator(nn.Module):
|
18 |
+
"""Waveform sub-discriminator.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
in_channels (int): Number of input channels.
|
22 |
+
out_channels (int): Number of output channels.
|
23 |
+
kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
|
24 |
+
filters (int): Number of initial filters for convolutions.
|
25 |
+
max_filters (int): Maximum number of filters.
|
26 |
+
downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
|
27 |
+
inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
|
28 |
+
groups (Sequence[int] or None): Groups for inner convolutions.
|
29 |
+
strides (Sequence[int] or None): Strides for inner convolutions.
|
30 |
+
paddings (Sequence[int] or None): Paddings for inner convolutions.
|
31 |
+
norm (str): Normalization method.
|
32 |
+
activation (str): Activation function.
|
33 |
+
activation_params (dict): Parameters to provide to the activation function.
|
34 |
+
pad (str): Padding for initial convolution.
|
35 |
+
pad_params (dict): Parameters to provide to the padding module.
|
36 |
+
"""
|
37 |
+
def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
|
38 |
+
filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
|
39 |
+
inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
|
40 |
+
strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
|
41 |
+
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
|
42 |
+
activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
|
43 |
+
pad_params: dict = {}):
|
44 |
+
super().__init__()
|
45 |
+
assert len(kernel_sizes) == 2
|
46 |
+
assert kernel_sizes[0] % 2 == 1
|
47 |
+
assert kernel_sizes[1] % 2 == 1
|
48 |
+
assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
|
49 |
+
assert (groups is None or len(groups) == len(downsample_scales))
|
50 |
+
assert (strides is None or len(strides) == len(downsample_scales))
|
51 |
+
assert (paddings is None or len(paddings) == len(downsample_scales))
|
52 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
53 |
+
self.convs = nn.ModuleList()
|
54 |
+
self.convs.append(
|
55 |
+
nn.Sequential(
|
56 |
+
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
|
57 |
+
NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
|
58 |
+
)
|
59 |
+
)
|
60 |
+
|
61 |
+
in_chs = filters
|
62 |
+
for i, downsample_scale in enumerate(downsample_scales):
|
63 |
+
out_chs = min(in_chs * downsample_scale, max_filters)
|
64 |
+
default_kernel_size = downsample_scale * 10 + 1
|
65 |
+
default_stride = downsample_scale
|
66 |
+
default_padding = (default_kernel_size - 1) // 2
|
67 |
+
default_groups = in_chs // 4
|
68 |
+
self.convs.append(
|
69 |
+
NormConv1d(in_chs, out_chs,
|
70 |
+
kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
|
71 |
+
stride=strides[i] if strides else default_stride,
|
72 |
+
groups=groups[i] if groups else default_groups,
|
73 |
+
padding=paddings[i] if paddings else default_padding,
|
74 |
+
norm=norm))
|
75 |
+
in_chs = out_chs
|
76 |
+
|
77 |
+
out_chs = min(in_chs * 2, max_filters)
|
78 |
+
self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
|
79 |
+
padding=(kernel_sizes[0] - 1) // 2, norm=norm))
|
80 |
+
self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
|
81 |
+
padding=(kernel_sizes[1] - 1) // 2, norm=norm)
|
82 |
+
|
83 |
+
def forward(self, x: torch.Tensor):
|
84 |
+
fmap = []
|
85 |
+
for layer in self.convs:
|
86 |
+
x = layer(x)
|
87 |
+
x = self.activation(x)
|
88 |
+
fmap.append(x)
|
89 |
+
x = self.conv_post(x)
|
90 |
+
fmap.append(x)
|
91 |
+
# x = torch.flatten(x, 1, -1)
|
92 |
+
return x, fmap
|
93 |
+
|
94 |
+
|
95 |
+
class MultiScaleDiscriminator(MultiDiscriminator):
|
96 |
+
"""Multi-Scale (MSD) Discriminator,
|
97 |
+
|
98 |
+
Args:
|
99 |
+
in_channels (int): Number of input channels.
|
100 |
+
out_channels (int): Number of output channels.
|
101 |
+
downsample_factor (int): Downsampling factor between the different scales.
|
102 |
+
scale_norms (Sequence[str]): Normalization for each sub-discriminator.
|
103 |
+
**kwargs: Additional args for ScaleDiscriminator.
|
104 |
+
"""
|
105 |
+
def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
|
106 |
+
scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
|
107 |
+
super().__init__()
|
108 |
+
self.discriminators = nn.ModuleList([
|
109 |
+
ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
|
110 |
+
])
|
111 |
+
self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
|
112 |
+
|
113 |
+
@property
|
114 |
+
def num_discriminators(self):
|
115 |
+
return len(self.discriminators)
|
116 |
+
|
117 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
118 |
+
logits = []
|
119 |
+
fmaps = []
|
120 |
+
for i, disc in enumerate(self.discriminators):
|
121 |
+
if i != 0:
|
122 |
+
self.downsample(x)
|
123 |
+
logit, fmap = disc(x)
|
124 |
+
logits.append(logit)
|
125 |
+
fmaps.append(fmap)
|
126 |
+
return logits, fmaps
|
audiocraft/adversarial/discriminators/msstftd.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torchaudio
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from ...modules import NormConv2d
|
15 |
+
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
16 |
+
|
17 |
+
|
18 |
+
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
|
19 |
+
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
|
20 |
+
|
21 |
+
|
22 |
+
class DiscriminatorSTFT(nn.Module):
|
23 |
+
"""STFT sub-discriminator.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
filters (int): Number of filters in convolutions.
|
27 |
+
in_channels (int): Number of input channels.
|
28 |
+
out_channels (int): Number of output channels.
|
29 |
+
n_fft (int): Size of FFT for each scale.
|
30 |
+
hop_length (int): Length of hop between STFT windows for each scale.
|
31 |
+
kernel_size (tuple of int): Inner Conv2d kernel sizes.
|
32 |
+
stride (tuple of int): Inner Conv2d strides.
|
33 |
+
dilations (list of int): Inner Conv2d dilation on the time dimension.
|
34 |
+
win_length (int): Window size for each scale.
|
35 |
+
normalized (bool): Whether to normalize by magnitude after stft.
|
36 |
+
norm (str): Normalization method.
|
37 |
+
activation (str): Activation function.
|
38 |
+
activation_params (dict): Parameters to provide to the activation function.
|
39 |
+
growth (int): Growth factor for the filters.
|
40 |
+
"""
|
41 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
|
42 |
+
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
|
43 |
+
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
|
44 |
+
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
|
45 |
+
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
|
46 |
+
super().__init__()
|
47 |
+
assert len(kernel_size) == 2
|
48 |
+
assert len(stride) == 2
|
49 |
+
self.filters = filters
|
50 |
+
self.in_channels = in_channels
|
51 |
+
self.out_channels = out_channels
|
52 |
+
self.n_fft = n_fft
|
53 |
+
self.hop_length = hop_length
|
54 |
+
self.win_length = win_length
|
55 |
+
self.normalized = normalized
|
56 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
57 |
+
self.spec_transform = torchaudio.transforms.Spectrogram(
|
58 |
+
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
|
59 |
+
normalized=self.normalized, center=False, pad_mode=None, power=None)
|
60 |
+
spec_channels = 2 * self.in_channels
|
61 |
+
self.convs = nn.ModuleList()
|
62 |
+
self.convs.append(
|
63 |
+
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
|
64 |
+
)
|
65 |
+
in_chs = min(filters_scale * self.filters, max_filters)
|
66 |
+
for i, dilation in enumerate(dilations):
|
67 |
+
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
|
68 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
|
69 |
+
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
|
70 |
+
norm=norm))
|
71 |
+
in_chs = out_chs
|
72 |
+
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
|
73 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
|
74 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
75 |
+
norm=norm))
|
76 |
+
self.conv_post = NormConv2d(out_chs, self.out_channels,
|
77 |
+
kernel_size=(kernel_size[0], kernel_size[0]),
|
78 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
79 |
+
norm=norm)
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor):
|
82 |
+
fmap = []
|
83 |
+
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
|
84 |
+
z = torch.cat([z.real, z.imag], dim=1)
|
85 |
+
z = rearrange(z, 'b c w t -> b c t w')
|
86 |
+
for i, layer in enumerate(self.convs):
|
87 |
+
z = layer(z)
|
88 |
+
z = self.activation(z)
|
89 |
+
fmap.append(z)
|
90 |
+
z = self.conv_post(z)
|
91 |
+
return z, fmap
|
92 |
+
|
93 |
+
|
94 |
+
class MultiScaleSTFTDiscriminator(MultiDiscriminator):
|
95 |
+
"""Multi-Scale STFT (MS-STFT) discriminator.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
filters (int): Number of filters in convolutions.
|
99 |
+
in_channels (int): Number of input channels.
|
100 |
+
out_channels (int): Number of output channels.
|
101 |
+
sep_channels (bool): Separate channels to distinct samples for stereo support.
|
102 |
+
n_ffts (Sequence[int]): Size of FFT for each scale.
|
103 |
+
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
|
104 |
+
win_lengths (Sequence[int]): Window size for each scale.
|
105 |
+
**kwargs: Additional args for STFTDiscriminator.
|
106 |
+
"""
|
107 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
|
108 |
+
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
|
109 |
+
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
|
110 |
+
super().__init__()
|
111 |
+
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
112 |
+
self.sep_channels = sep_channels
|
113 |
+
self.discriminators = nn.ModuleList([
|
114 |
+
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
|
115 |
+
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
|
116 |
+
for i in range(len(n_ffts))
|
117 |
+
])
|
118 |
+
|
119 |
+
@property
|
120 |
+
def num_discriminators(self):
|
121 |
+
return len(self.discriminators)
|
122 |
+
|
123 |
+
def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
|
124 |
+
B, C, T = x.shape
|
125 |
+
return x.view(-1, 1, T)
|
126 |
+
|
127 |
+
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
128 |
+
logits = []
|
129 |
+
fmaps = []
|
130 |
+
for disc in self.discriminators:
|
131 |
+
logit, fmap = disc(x)
|
132 |
+
logits.append(logit)
|
133 |
+
fmaps.append(fmap)
|
134 |
+
return logits, fmaps
|
audiocraft/adversarial/losses.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility module to handle adversarial losses without requiring to mess up the main training loop.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import flashy
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
|
20 |
+
|
21 |
+
|
22 |
+
AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
|
23 |
+
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
|
24 |
+
|
25 |
+
|
26 |
+
class AdversarialLoss(nn.Module):
|
27 |
+
"""Adversary training wrapper.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
|
31 |
+
We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
|
32 |
+
where the first item is a list of logits and the second item is a list of feature maps.
|
33 |
+
optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
|
34 |
+
loss (AdvLossType): Loss function for generator training.
|
35 |
+
loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
|
36 |
+
loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
|
37 |
+
loss_feat (FeatLossType): Feature matching loss function for generator training.
|
38 |
+
normalize (bool): Whether to normalize by number of sub-discriminators.
|
39 |
+
|
40 |
+
Example of usage:
|
41 |
+
adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
|
42 |
+
for real in loader:
|
43 |
+
noise = torch.randn(...)
|
44 |
+
fake = model(noise)
|
45 |
+
adv_loss.train_adv(fake, real)
|
46 |
+
loss, _ = adv_loss(fake, real)
|
47 |
+
loss.backward()
|
48 |
+
"""
|
49 |
+
def __init__(self,
|
50 |
+
adversary: nn.Module,
|
51 |
+
optimizer: torch.optim.Optimizer,
|
52 |
+
loss: AdvLossType,
|
53 |
+
loss_real: AdvLossType,
|
54 |
+
loss_fake: AdvLossType,
|
55 |
+
loss_feat: tp.Optional[FeatLossType] = None,
|
56 |
+
normalize: bool = True):
|
57 |
+
super().__init__()
|
58 |
+
self.adversary: nn.Module = adversary
|
59 |
+
flashy.distrib.broadcast_model(self.adversary)
|
60 |
+
self.optimizer = optimizer
|
61 |
+
self.loss = loss
|
62 |
+
self.loss_real = loss_real
|
63 |
+
self.loss_fake = loss_fake
|
64 |
+
self.loss_feat = loss_feat
|
65 |
+
self.normalize = normalize
|
66 |
+
|
67 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
68 |
+
# Add the optimizer state dict inside our own.
|
69 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
70 |
+
destination[prefix + 'optimizer'] = self.optimizer.state_dict()
|
71 |
+
return destination
|
72 |
+
|
73 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
74 |
+
# Load optimizer state.
|
75 |
+
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
|
76 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
77 |
+
|
78 |
+
def get_adversary_pred(self, x):
|
79 |
+
"""Run adversary model, validating expected output format."""
|
80 |
+
logits, fmaps = self.adversary(x)
|
81 |
+
assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
|
82 |
+
f'Expecting a list of tensors as logits but {type(logits)} found.'
|
83 |
+
assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
|
84 |
+
for fmap in fmaps:
|
85 |
+
assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
|
86 |
+
f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
|
87 |
+
return logits, fmaps
|
88 |
+
|
89 |
+
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
|
90 |
+
"""Train the adversary with the given fake and real example.
|
91 |
+
|
92 |
+
We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
|
93 |
+
The first item being the logits and second item being a list of feature maps for each sub-discriminator.
|
94 |
+
|
95 |
+
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
|
96 |
+
and call the optimizer.
|
97 |
+
"""
|
98 |
+
loss = torch.tensor(0., device=fake.device)
|
99 |
+
all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
|
100 |
+
all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
|
101 |
+
n_sub_adversaries = len(all_logits_fake_is_fake)
|
102 |
+
for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
|
103 |
+
loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
|
104 |
+
|
105 |
+
if self.normalize:
|
106 |
+
loss /= n_sub_adversaries
|
107 |
+
|
108 |
+
self.optimizer.zero_grad()
|
109 |
+
with flashy.distrib.eager_sync_model(self.adversary):
|
110 |
+
loss.backward()
|
111 |
+
self.optimizer.step()
|
112 |
+
|
113 |
+
return loss
|
114 |
+
|
115 |
+
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
116 |
+
"""Return the loss for the generator, i.e. trying to fool the adversary,
|
117 |
+
and feature matching loss if provided.
|
118 |
+
"""
|
119 |
+
adv = torch.tensor(0., device=fake.device)
|
120 |
+
feat = torch.tensor(0., device=fake.device)
|
121 |
+
with flashy.utils.readonly(self.adversary):
|
122 |
+
all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
|
123 |
+
all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
|
124 |
+
n_sub_adversaries = len(all_logits_fake_is_fake)
|
125 |
+
for logit_fake_is_fake in all_logits_fake_is_fake:
|
126 |
+
adv += self.loss(logit_fake_is_fake)
|
127 |
+
if self.loss_feat:
|
128 |
+
for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
|
129 |
+
feat += self.loss_feat(fmap_fake, fmap_real)
|
130 |
+
|
131 |
+
if self.normalize:
|
132 |
+
adv /= n_sub_adversaries
|
133 |
+
feat /= n_sub_adversaries
|
134 |
+
|
135 |
+
return adv, feat
|
136 |
+
|
137 |
+
|
138 |
+
def get_adv_criterion(loss_type: str) -> tp.Callable:
|
139 |
+
assert loss_type in ADVERSARIAL_LOSSES
|
140 |
+
if loss_type == 'mse':
|
141 |
+
return mse_loss
|
142 |
+
elif loss_type == 'hinge':
|
143 |
+
return hinge_loss
|
144 |
+
elif loss_type == 'hinge2':
|
145 |
+
return hinge2_loss
|
146 |
+
raise ValueError('Unsupported loss')
|
147 |
+
|
148 |
+
|
149 |
+
def get_fake_criterion(loss_type: str) -> tp.Callable:
|
150 |
+
assert loss_type in ADVERSARIAL_LOSSES
|
151 |
+
if loss_type == 'mse':
|
152 |
+
return mse_fake_loss
|
153 |
+
elif loss_type in ['hinge', 'hinge2']:
|
154 |
+
return hinge_fake_loss
|
155 |
+
raise ValueError('Unsupported loss')
|
156 |
+
|
157 |
+
|
158 |
+
def get_real_criterion(loss_type: str) -> tp.Callable:
|
159 |
+
assert loss_type in ADVERSARIAL_LOSSES
|
160 |
+
if loss_type == 'mse':
|
161 |
+
return mse_real_loss
|
162 |
+
elif loss_type in ['hinge', 'hinge2']:
|
163 |
+
return hinge_real_loss
|
164 |
+
raise ValueError('Unsupported loss')
|
165 |
+
|
166 |
+
|
167 |
+
def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
|
168 |
+
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
|
169 |
+
|
170 |
+
|
171 |
+
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
|
172 |
+
return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
|
173 |
+
|
174 |
+
|
175 |
+
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
|
176 |
+
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
177 |
+
|
178 |
+
|
179 |
+
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
|
180 |
+
return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
181 |
+
|
182 |
+
|
183 |
+
def mse_loss(x: torch.Tensor) -> torch.Tensor:
|
184 |
+
if x.numel() == 0:
|
185 |
+
return torch.tensor([0.0], device=x.device)
|
186 |
+
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
|
187 |
+
|
188 |
+
|
189 |
+
def hinge_loss(x: torch.Tensor) -> torch.Tensor:
|
190 |
+
if x.numel() == 0:
|
191 |
+
return torch.tensor([0.0], device=x.device)
|
192 |
+
return -x.mean()
|
193 |
+
|
194 |
+
|
195 |
+
def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
|
196 |
+
if x.numel() == 0:
|
197 |
+
return torch.tensor([0.0])
|
198 |
+
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
199 |
+
|
200 |
+
|
201 |
+
class FeatureMatchingLoss(nn.Module):
|
202 |
+
"""Feature matching loss for adversarial training.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
|
206 |
+
normalize (bool): Whether to normalize the loss.
|
207 |
+
by number of feature maps.
|
208 |
+
"""
|
209 |
+
def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
|
210 |
+
super().__init__()
|
211 |
+
self.loss = loss
|
212 |
+
self.normalize = normalize
|
213 |
+
|
214 |
+
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
|
215 |
+
assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
|
216 |
+
feat_loss = torch.tensor(0., device=fmap_fake[0].device)
|
217 |
+
feat_scale = torch.tensor(0., device=fmap_fake[0].device)
|
218 |
+
n_fmaps = 0
|
219 |
+
for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
|
220 |
+
assert feat_fake.shape == feat_real.shape
|
221 |
+
n_fmaps += 1
|
222 |
+
feat_loss += self.loss(feat_fake, feat_real)
|
223 |
+
feat_scale += torch.mean(torch.abs(feat_real))
|
224 |
+
|
225 |
+
if self.normalize:
|
226 |
+
feat_loss /= n_fmaps
|
227 |
+
|
228 |
+
return feat_loss
|
audiocraft/data/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Audio loading and writing support. Datasets for raw audio
|
7 |
+
or also including some metadata."""
|
8 |
+
|
9 |
+
# flake8: noqa
|
10 |
+
from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
|
audiocraft/data/audio.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Audio IO methods are defined in this module (info, read, write),
|
9 |
+
We rely on av library for faster read when possible, otherwise on torchaudio.
|
10 |
+
"""
|
11 |
+
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from pathlib import Path
|
14 |
+
import logging
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import soundfile
|
19 |
+
import torch
|
20 |
+
from torch.nn import functional as F
|
21 |
+
|
22 |
+
import av
|
23 |
+
import subprocess as sp
|
24 |
+
|
25 |
+
from .audio_utils import f32_pcm, normalize_audio
|
26 |
+
|
27 |
+
_av_initialized = False
|
28 |
+
|
29 |
+
|
30 |
+
def _init_av():
|
31 |
+
global _av_initialized
|
32 |
+
if _av_initialized:
|
33 |
+
return
|
34 |
+
logger = logging.getLogger('libav.mp3')
|
35 |
+
logger.setLevel(logging.ERROR)
|
36 |
+
_av_initialized = True
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass(frozen=True)
|
40 |
+
class AudioFileInfo:
|
41 |
+
sample_rate: int
|
42 |
+
duration: float
|
43 |
+
channels: int
|
44 |
+
|
45 |
+
|
46 |
+
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
47 |
+
_init_av()
|
48 |
+
with av.open(str(filepath)) as af:
|
49 |
+
stream = af.streams.audio[0]
|
50 |
+
sample_rate = stream.codec_context.sample_rate
|
51 |
+
duration = float(stream.duration * stream.time_base)
|
52 |
+
channels = stream.channels
|
53 |
+
return AudioFileInfo(sample_rate, duration, channels)
|
54 |
+
|
55 |
+
|
56 |
+
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
57 |
+
info = soundfile.info(filepath)
|
58 |
+
return AudioFileInfo(info.samplerate, info.duration, info.channels)
|
59 |
+
|
60 |
+
|
61 |
+
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
62 |
+
# torchaudio no longer returns useful duration informations for some formats like mp3s.
|
63 |
+
filepath = Path(filepath)
|
64 |
+
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
|
65 |
+
# ffmpeg has some weird issue with flac.
|
66 |
+
return _soundfile_info(filepath)
|
67 |
+
else:
|
68 |
+
return _av_info(filepath)
|
69 |
+
|
70 |
+
|
71 |
+
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
|
72 |
+
"""FFMPEG-based audio file reading using PyAV bindings.
|
73 |
+
Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
filepath (str or Path): Path to audio file to read.
|
77 |
+
seek_time (float): Time at which to start reading in the file.
|
78 |
+
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
79 |
+
Returns:
|
80 |
+
tuple of torch.Tensor, int: Tuple containing audio data and sample rate
|
81 |
+
"""
|
82 |
+
_init_av()
|
83 |
+
with av.open(str(filepath)) as af:
|
84 |
+
stream = af.streams.audio[0]
|
85 |
+
sr = stream.codec_context.sample_rate
|
86 |
+
num_frames = int(sr * duration) if duration >= 0 else -1
|
87 |
+
frame_offset = int(sr * seek_time)
|
88 |
+
# we need a small negative offset otherwise we get some edge artifact
|
89 |
+
# from the mp3 decoder.
|
90 |
+
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
|
91 |
+
frames = []
|
92 |
+
length = 0
|
93 |
+
for frame in af.decode(streams=stream.index):
|
94 |
+
current_offset = int(frame.rate * frame.pts * frame.time_base)
|
95 |
+
strip = max(0, frame_offset - current_offset)
|
96 |
+
buf = torch.from_numpy(frame.to_ndarray())
|
97 |
+
if buf.shape[0] != stream.channels:
|
98 |
+
buf = buf.view(-1, stream.channels).t()
|
99 |
+
buf = buf[:, strip:]
|
100 |
+
frames.append(buf)
|
101 |
+
length += buf.shape[1]
|
102 |
+
if num_frames > 0 and length >= num_frames:
|
103 |
+
break
|
104 |
+
assert frames
|
105 |
+
# If the above assert fails, it is likely because we seeked past the end of file point,
|
106 |
+
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
|
107 |
+
# This will need proper debugging, in due time.
|
108 |
+
wav = torch.cat(frames, dim=1)
|
109 |
+
assert wav.shape[0] == stream.channels
|
110 |
+
if num_frames > 0:
|
111 |
+
wav = wav[:, :num_frames]
|
112 |
+
return f32_pcm(wav), sr
|
113 |
+
|
114 |
+
|
115 |
+
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
116 |
+
duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
|
117 |
+
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
filepath (str or Path): Path to audio file to read.
|
121 |
+
seek_time (float): Time at which to start reading in the file.
|
122 |
+
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
123 |
+
pad (bool): Pad output audio if not reaching expected duration.
|
124 |
+
Returns:
|
125 |
+
tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
|
126 |
+
"""
|
127 |
+
fp = Path(filepath)
|
128 |
+
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
129 |
+
# There is some bug with ffmpeg and reading flac
|
130 |
+
info = _soundfile_info(filepath)
|
131 |
+
frames = -1 if duration <= 0 else int(duration * info.sample_rate)
|
132 |
+
frame_offset = int(seek_time * info.sample_rate)
|
133 |
+
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
|
134 |
+
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
|
135 |
+
wav = torch.from_numpy(wav).t().contiguous()
|
136 |
+
if len(wav.shape) == 1:
|
137 |
+
wav = torch.unsqueeze(wav, 0)
|
138 |
+
else:
|
139 |
+
wav, sr = _av_read(filepath, seek_time, duration)
|
140 |
+
|
141 |
+
if pad and duration > 0:
|
142 |
+
expected_frames = int(duration * sr)
|
143 |
+
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
|
144 |
+
return wav, sr
|
145 |
+
|
146 |
+
|
147 |
+
def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
|
148 |
+
# ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
|
149 |
+
assert wav.dim() == 2, wav.shape
|
150 |
+
command = [
|
151 |
+
'ffmpeg',
|
152 |
+
'-loglevel', 'error',
|
153 |
+
'-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
|
154 |
+
'-i', '-'] + flags + [str(out_path)]
|
155 |
+
input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
|
156 |
+
sp.run(command, input=input_, check=True)
|
157 |
+
|
158 |
+
|
159 |
+
def audio_write(stem_name: tp.Union[str, Path],
|
160 |
+
wav: torch.Tensor, sample_rate: int,
|
161 |
+
format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
|
162 |
+
normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
163 |
+
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
164 |
+
loudness_compressor: bool = False,
|
165 |
+
log_clipping: bool = True, make_parent_dir: bool = True,
|
166 |
+
add_suffix: bool = True) -> Path:
|
167 |
+
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
stem_name (str or Path): Filename without extension which will be added automatically.
|
171 |
+
wav (torch.Tensor): Audio data to save.
|
172 |
+
sample_rate (int): Sample rate of audio data.
|
173 |
+
format (str): Either "wav", "mp3", "ogg", or "flac".
|
174 |
+
mp3_rate (int): kbps when using mp3s.
|
175 |
+
ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
|
176 |
+
normalize (bool): if `True` (default), normalizes according to the prescribed
|
177 |
+
strategy (see after). If `False`, the strategy is only used in case clipping
|
178 |
+
would happen.
|
179 |
+
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
180 |
+
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
181 |
+
with extra headroom to avoid clipping. 'clip' just clips.
|
182 |
+
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
183 |
+
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
184 |
+
than the `peak_clip` one to avoid further clipping.
|
185 |
+
loudness_headroom_db (float): Target loudness for loudness normalization.
|
186 |
+
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
|
187 |
+
when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
|
188 |
+
occurs despite strategy (only for 'rms').
|
189 |
+
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
190 |
+
Returns:
|
191 |
+
Path: Path of the saved audio.
|
192 |
+
"""
|
193 |
+
assert wav.dtype.is_floating_point, "wav is not floating point"
|
194 |
+
if wav.dim() == 1:
|
195 |
+
wav = wav[None]
|
196 |
+
elif wav.dim() > 2:
|
197 |
+
raise ValueError("Input wav should be at most 2 dimension.")
|
198 |
+
assert wav.isfinite().all()
|
199 |
+
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
200 |
+
rms_headroom_db, loudness_headroom_db, loudness_compressor,
|
201 |
+
log_clipping=log_clipping, sample_rate=sample_rate,
|
202 |
+
stem_name=str(stem_name))
|
203 |
+
if format == 'mp3':
|
204 |
+
suffix = '.mp3'
|
205 |
+
flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
|
206 |
+
elif format == 'wav':
|
207 |
+
suffix = '.wav'
|
208 |
+
flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
|
209 |
+
elif format == 'ogg':
|
210 |
+
suffix = '.ogg'
|
211 |
+
flags = ['-f', 'ogg', '-c:a', 'libvorbis']
|
212 |
+
if ogg_rate is not None:
|
213 |
+
flags += ['-b:a', f'{ogg_rate}k']
|
214 |
+
elif format == 'flac':
|
215 |
+
suffix = '.flac'
|
216 |
+
flags = ['-f', 'flac']
|
217 |
+
else:
|
218 |
+
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
219 |
+
if not add_suffix:
|
220 |
+
suffix = ''
|
221 |
+
path = Path(str(stem_name) + suffix)
|
222 |
+
if make_parent_dir:
|
223 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
224 |
+
try:
|
225 |
+
_piping_to_ffmpeg(path, wav, sample_rate, flags)
|
226 |
+
except Exception:
|
227 |
+
if path.exists():
|
228 |
+
# we do not want to leave half written files around.
|
229 |
+
path.unlink()
|
230 |
+
raise
|
231 |
+
return path
|
audiocraft/data/audio_dataset.py
ADDED
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Audiocraft (https://github.com/facebookresearch/audiocraft)
|
2 |
+
|
3 |
+
"""AudioDataset support. In order to handle a larger number of files
|
4 |
+
without having to scan again the folders, we precompute some metadata
|
5 |
+
(filename, sample rate, duration), and use that to efficiently sample audio segments.
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
import copy
|
9 |
+
from concurrent.futures import ThreadPoolExecutor, Future
|
10 |
+
from dataclasses import dataclass, fields
|
11 |
+
from contextlib import ExitStack
|
12 |
+
from functools import lru_cache
|
13 |
+
import gzip
|
14 |
+
import json
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
from pathlib import Path
|
18 |
+
import random
|
19 |
+
import sys
|
20 |
+
import typing as tp
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
from .audio import audio_read, audio_info
|
27 |
+
from .video import video_read_local, video_read_global
|
28 |
+
from .audio_utils import convert_audio
|
29 |
+
from .zip import PathInZip
|
30 |
+
import h5py
|
31 |
+
try:
|
32 |
+
import dora
|
33 |
+
except ImportError:
|
34 |
+
dora = None # type: ignore
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass(order=True)
|
38 |
+
class BaseInfo:
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def _dict2fields(cls, dictionary: dict):
|
42 |
+
return {
|
43 |
+
field.name: dictionary[field.name]
|
44 |
+
for field in fields(cls) if field.name in dictionary
|
45 |
+
}
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def from_dict(cls, dictionary: dict):
|
49 |
+
_dictionary = cls._dict2fields(dictionary)
|
50 |
+
return cls(**_dictionary)
|
51 |
+
|
52 |
+
def to_dict(self):
|
53 |
+
return {
|
54 |
+
field.name: self.__getattribute__(field.name)
|
55 |
+
for field in fields(self)
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
@dataclass(order=True)
|
60 |
+
class AudioMeta(BaseInfo):
|
61 |
+
path: str
|
62 |
+
video_path: str
|
63 |
+
duration: float
|
64 |
+
sample_rate: int
|
65 |
+
amplitude: tp.Optional[float] = None
|
66 |
+
weight: tp.Optional[float] = None
|
67 |
+
# info_path is used to load additional information about the audio file that is stored in zip files.
|
68 |
+
info_path: tp.Optional[PathInZip] = None
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def from_dict(cls, dictionary: dict):
|
72 |
+
# print(f'dictionary:{dictionary}')
|
73 |
+
# print(f'cls:{cls}')
|
74 |
+
base = cls._dict2fields(dictionary)
|
75 |
+
# print(f'base:{base}')
|
76 |
+
if 'info_path' in base and base['info_path'] is not None:
|
77 |
+
base['info_path'] = PathInZip(base['info_path'])
|
78 |
+
# print(f'base:{base}')
|
79 |
+
# exit()
|
80 |
+
return cls(**base)
|
81 |
+
|
82 |
+
def to_dict(self):
|
83 |
+
d = super().to_dict()
|
84 |
+
if d['info_path'] is not None:
|
85 |
+
d['info_path'] = str(d['info_path'])
|
86 |
+
return d
|
87 |
+
|
88 |
+
|
89 |
+
@dataclass(order=True)
|
90 |
+
class SegmentInfo(BaseInfo):
|
91 |
+
meta: AudioMeta
|
92 |
+
seek_time: float
|
93 |
+
# The following values are given once the audio is processed, e.g.
|
94 |
+
# at the target sample rate and target number of channels.
|
95 |
+
n_frames: int # actual number of frames without padding
|
96 |
+
total_frames: int # total number of frames, padding included
|
97 |
+
sample_rate: int # actual sample rate
|
98 |
+
channels: int # number of audio channels.
|
99 |
+
|
100 |
+
|
101 |
+
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
102 |
+
|
103 |
+
logger = logging.getLogger(__name__)
|
104 |
+
|
105 |
+
|
106 |
+
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
|
107 |
+
"""AudioMeta from a path to an audio file.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
file_path (str): Resolved path of valid audio file.
|
111 |
+
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
112 |
+
Returns:
|
113 |
+
AudioMeta: Audio file path and its metadata.
|
114 |
+
"""
|
115 |
+
info = audio_info(file_path)
|
116 |
+
amplitude: tp.Optional[float] = None
|
117 |
+
if not minimal:
|
118 |
+
wav, sr = audio_read(file_path)
|
119 |
+
amplitude = wav.abs().max().item()
|
120 |
+
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
|
121 |
+
|
122 |
+
|
123 |
+
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
|
124 |
+
"""If Dora is available as a dependency, try to resolve potential relative paths
|
125 |
+
in list of AudioMeta. This method is expected to be used when loading meta from file.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
m (AudioMeta): Audio meta to resolve.
|
129 |
+
fast (bool): If True, uses a really fast check for determining if a file
|
130 |
+
is already absolute or not. Only valid on Linux/Mac.
|
131 |
+
Returns:
|
132 |
+
AudioMeta: Audio meta with resolved path.
|
133 |
+
"""
|
134 |
+
def is_abs(m):
|
135 |
+
if fast:
|
136 |
+
return str(m)[0] == '/'
|
137 |
+
else:
|
138 |
+
os.path.isabs(str(m))
|
139 |
+
|
140 |
+
if not dora:
|
141 |
+
return m
|
142 |
+
|
143 |
+
if not is_abs(m.path):
|
144 |
+
m.path = dora.git_save.to_absolute_path(m.path)
|
145 |
+
if m.info_path is not None and not is_abs(m.info_path.zip_path):
|
146 |
+
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
|
147 |
+
return m
|
148 |
+
|
149 |
+
|
150 |
+
def find_audio_files(path: tp.Union[Path, str],
|
151 |
+
exts: tp.List[str] = DEFAULT_EXTS,
|
152 |
+
resolve: bool = True,
|
153 |
+
minimal: bool = True,
|
154 |
+
progress: bool = False,
|
155 |
+
workers: int = 0) -> tp.List[AudioMeta]:
|
156 |
+
"""Build a list of AudioMeta from a given path,
|
157 |
+
collecting relevant audio files and fetching meta info.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
path (str or Path): Path to folder containing audio files.
|
161 |
+
exts (list of str): List of file extensions to consider for audio files.
|
162 |
+
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
163 |
+
progress (bool): Whether to log progress on audio files collection.
|
164 |
+
workers (int): number of parallel workers, if 0, use only the current thread.
|
165 |
+
Returns:
|
166 |
+
list of AudioMeta: List of audio file path and its metadata.
|
167 |
+
"""
|
168 |
+
audio_files = []
|
169 |
+
futures: tp.List[Future] = []
|
170 |
+
pool: tp.Optional[ThreadPoolExecutor] = None
|
171 |
+
with ExitStack() as stack:
|
172 |
+
if workers > 0:
|
173 |
+
pool = ThreadPoolExecutor(workers)
|
174 |
+
stack.enter_context(pool)
|
175 |
+
|
176 |
+
if progress:
|
177 |
+
print("Finding audio files...")
|
178 |
+
for root, folders, files in os.walk(path, followlinks=True):
|
179 |
+
for file in files:
|
180 |
+
full_path = Path(root) / file
|
181 |
+
if full_path.suffix.lower() in exts:
|
182 |
+
audio_files.append(full_path)
|
183 |
+
if pool is not None:
|
184 |
+
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
|
185 |
+
if progress:
|
186 |
+
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
|
187 |
+
|
188 |
+
if progress:
|
189 |
+
print("Getting audio metadata...")
|
190 |
+
meta: tp.List[AudioMeta] = []
|
191 |
+
for idx, file_path in enumerate(audio_files):
|
192 |
+
try:
|
193 |
+
if pool is None:
|
194 |
+
m = _get_audio_meta(str(file_path), minimal)
|
195 |
+
else:
|
196 |
+
m = futures[idx].result()
|
197 |
+
if resolve:
|
198 |
+
m = _resolve_audio_meta(m)
|
199 |
+
except Exception as err:
|
200 |
+
print("Error with", str(file_path), err, file=sys.stderr)
|
201 |
+
continue
|
202 |
+
meta.append(m)
|
203 |
+
if progress:
|
204 |
+
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
|
205 |
+
meta.sort()
|
206 |
+
return meta
|
207 |
+
|
208 |
+
|
209 |
+
def load_audio_meta(path: tp.Union[str, Path],
|
210 |
+
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
|
211 |
+
"""Load list of AudioMeta from an optionally compressed json file.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
path (str or Path): Path to JSON file.
|
215 |
+
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
216 |
+
fast (bool): activates some tricks to make things faster.
|
217 |
+
Returns:
|
218 |
+
list of AudioMeta: List of audio file path and its total duration.
|
219 |
+
"""
|
220 |
+
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
221 |
+
with open_fn(path, 'rb') as fp: # type: ignore
|
222 |
+
lines = fp.readlines()
|
223 |
+
meta = []
|
224 |
+
for line in lines:
|
225 |
+
d = json.loads(line)
|
226 |
+
# print(f'line:{d}')
|
227 |
+
m = AudioMeta.from_dict(d)
|
228 |
+
# print(f'm:{m}')
|
229 |
+
|
230 |
+
if resolve:
|
231 |
+
m = _resolve_audio_meta(m, fast=fast)
|
232 |
+
# print(f'm:{m}')
|
233 |
+
meta.append(m)
|
234 |
+
# exit()
|
235 |
+
# print(f'meta:{meta}')
|
236 |
+
# exit()
|
237 |
+
return meta
|
238 |
+
|
239 |
+
|
240 |
+
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
|
241 |
+
"""Save the audio metadata to the file pointer as json.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
path (str or Path): Path to JSON file.
|
245 |
+
metadata (list of BaseAudioMeta): List of audio meta to save.
|
246 |
+
"""
|
247 |
+
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
248 |
+
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
249 |
+
with open_fn(path, 'wb') as fp: # type: ignore
|
250 |
+
for m in meta:
|
251 |
+
json_str = json.dumps(m.to_dict()) + '\n'
|
252 |
+
json_bytes = json_str.encode('utf-8')
|
253 |
+
fp.write(json_bytes)
|
254 |
+
|
255 |
+
|
256 |
+
class AudioDataset:
|
257 |
+
"""Base audio dataset.
|
258 |
+
|
259 |
+
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
|
260 |
+
and potentially additional information, by creating random segments from the list of audio
|
261 |
+
files referenced in the metadata and applying minimal data pre-processing such as resampling,
|
262 |
+
mixing of channels, padding, etc.
|
263 |
+
|
264 |
+
If no segment_duration value is provided, the AudioDataset will return the full wav for each
|
265 |
+
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
|
266 |
+
duration, applying padding if required.
|
267 |
+
|
268 |
+
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
|
269 |
+
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
270 |
+
original audio meta.
|
271 |
+
|
272 |
+
Note that you can call `start_epoch(epoch)` in order to get
|
273 |
+
a deterministic "randomization" for `shuffle=True`.
|
274 |
+
For a given epoch and dataset index, this will always return the same extract.
|
275 |
+
You can get back some diversity by setting the `shuffle_seed` param.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
meta (list of AudioMeta): List of audio files metadata.
|
279 |
+
segment_duration (float, optional): Optional segment duration of audio to load.
|
280 |
+
If not specified, the dataset will load the full audio segment from the file.
|
281 |
+
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
282 |
+
sample_rate (int): Target sample rate of the loaded audio samples.
|
283 |
+
channels (int): Target number of channels of the loaded audio samples.
|
284 |
+
sample_on_duration (bool): Set to `True` to sample segments with probability
|
285 |
+
dependent on audio file duration. This is only used if `segment_duration` is provided.
|
286 |
+
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
|
287 |
+
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
|
288 |
+
of the file duration and file weight. This is only used if `segment_duration` is provided.
|
289 |
+
min_segment_ratio (float): Minimum segment ratio to use when the audio file
|
290 |
+
is shorter than the desired segment.
|
291 |
+
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
292 |
+
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
293 |
+
min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
|
294 |
+
audio shorter than this will be filtered out.
|
295 |
+
max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
|
296 |
+
audio longer than this will be filtered out.
|
297 |
+
shuffle_seed (int): can be used to further randomize
|
298 |
+
load_wav (bool): if False, skip loading the wav but returns a tensor of 0
|
299 |
+
with the expected segment_duration (which must be provided if load_wav is False).
|
300 |
+
permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
|
301 |
+
are False. Will ensure a permutation on files when going through the dataset.
|
302 |
+
In that case the epoch number must be provided in order for the model
|
303 |
+
to continue the permutation across epochs. In that case, it is assumed
|
304 |
+
that `num_samples = total_batch_size * num_updates_per_epoch`, with
|
305 |
+
`total_batch_size` the overall batch size accounting for all gpus.
|
306 |
+
"""
|
307 |
+
def __init__(self,
|
308 |
+
meta: tp.List[AudioMeta],
|
309 |
+
segment_duration: tp.Optional[float] = None,
|
310 |
+
shuffle: bool = True,
|
311 |
+
num_samples: int = 10_000,
|
312 |
+
sample_rate: int = 48_000,
|
313 |
+
video_fps: int = 2,
|
314 |
+
video_overlap: int = 2,
|
315 |
+
if_add_gobal: bool = False,
|
316 |
+
global_mode: str = "average",
|
317 |
+
global_num_frames: int = 64,
|
318 |
+
global_feature_path: bool = False,
|
319 |
+
channels: int = 2,
|
320 |
+
pad: bool = True,
|
321 |
+
sample_on_duration: bool = True,
|
322 |
+
sample_on_weight: bool = True,
|
323 |
+
min_segment_ratio: float = 0.5,
|
324 |
+
max_read_retry: int = 10,
|
325 |
+
return_info: bool = False,
|
326 |
+
min_audio_duration: tp.Optional[float] = None,
|
327 |
+
max_audio_duration: tp.Optional[float] = None,
|
328 |
+
shuffle_seed: int = 0,
|
329 |
+
load_wav: bool = True,
|
330 |
+
permutation_on_files: bool = False,
|
331 |
+
):
|
332 |
+
assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
|
333 |
+
assert segment_duration is None or segment_duration > 0
|
334 |
+
assert segment_duration is None or min_segment_ratio >= 0
|
335 |
+
self.segment_duration = segment_duration
|
336 |
+
self.min_segment_ratio = min_segment_ratio
|
337 |
+
self.max_audio_duration = max_audio_duration
|
338 |
+
self.min_audio_duration = min_audio_duration
|
339 |
+
if self.min_audio_duration is not None and self.max_audio_duration is not None:
|
340 |
+
assert self.min_audio_duration <= self.max_audio_duration
|
341 |
+
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
|
342 |
+
assert len(self.meta) # Fail fast if all data has been filtered.
|
343 |
+
self.total_duration = sum(d.duration for d in self.meta)
|
344 |
+
|
345 |
+
if segment_duration is None:
|
346 |
+
num_samples = len(self.meta)
|
347 |
+
self.num_samples = num_samples
|
348 |
+
self.shuffle = shuffle
|
349 |
+
self.sample_rate = sample_rate
|
350 |
+
self.video_fps = video_fps
|
351 |
+
self.video_overlap = video_overlap
|
352 |
+
self.if_add_gobal = if_add_gobal
|
353 |
+
self.global_mode = global_mode
|
354 |
+
self.global_num_frames = global_num_frames
|
355 |
+
self.global_feature_path = global_feature_path
|
356 |
+
self.channels = channels
|
357 |
+
self.pad = pad
|
358 |
+
self.sample_on_weight = sample_on_weight
|
359 |
+
self.sample_on_duration = sample_on_duration
|
360 |
+
self.sampling_probabilities = self._get_sampling_probabilities()
|
361 |
+
self.max_read_retry = max_read_retry
|
362 |
+
self.return_info = return_info
|
363 |
+
self.shuffle_seed = shuffle_seed
|
364 |
+
self.current_epoch: tp.Optional[int] = None
|
365 |
+
self.load_wav = load_wav
|
366 |
+
if not load_wav:
|
367 |
+
assert segment_duration is not None
|
368 |
+
self.permutation_on_files = permutation_on_files
|
369 |
+
if permutation_on_files:
|
370 |
+
assert not self.sample_on_duration
|
371 |
+
assert not self.sample_on_weight
|
372 |
+
assert self.shuffle
|
373 |
+
|
374 |
+
def start_epoch(self, epoch: int):
|
375 |
+
self.current_epoch = epoch
|
376 |
+
|
377 |
+
def __len__(self):
|
378 |
+
return self.num_samples
|
379 |
+
|
380 |
+
def _get_sampling_probabilities(self, normalized: bool = True):
|
381 |
+
"""Return the sampling probabilities for each file inside `self.meta`."""
|
382 |
+
scores: tp.List[float] = []
|
383 |
+
for file_meta in self.meta:
|
384 |
+
score = 1.
|
385 |
+
if self.sample_on_weight and file_meta.weight is not None:
|
386 |
+
score *= file_meta.weight
|
387 |
+
if self.sample_on_duration:
|
388 |
+
score *= file_meta.duration
|
389 |
+
scores.append(score)
|
390 |
+
probabilities = torch.tensor(scores)
|
391 |
+
if normalized:
|
392 |
+
probabilities /= probabilities.sum()
|
393 |
+
return probabilities
|
394 |
+
|
395 |
+
@staticmethod
|
396 |
+
@lru_cache(16)
|
397 |
+
def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
|
398 |
+
# Used to keep the most recent files permutation in memory implicitely.
|
399 |
+
# will work unless someone is using a lot of Datasets in parallel.
|
400 |
+
rng = torch.Generator()
|
401 |
+
rng.manual_seed(base_seed + permutation_index)
|
402 |
+
return torch.randperm(num_files, generator=rng)
|
403 |
+
|
404 |
+
def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
|
405 |
+
"""Sample a given file from `self.meta`. Can be overridden in subclasses.
|
406 |
+
This is only called if `segment_duration` is not None.
|
407 |
+
|
408 |
+
You must use the provided random number generator `rng` for reproducibility.
|
409 |
+
You can further make use of the index accessed.
|
410 |
+
"""
|
411 |
+
if self.permutation_on_files:
|
412 |
+
assert self.current_epoch is not None
|
413 |
+
total_index = self.current_epoch * len(self) + index
|
414 |
+
permutation_index = total_index // len(self.meta)
|
415 |
+
relative_index = total_index % len(self.meta)
|
416 |
+
permutation = AudioDataset._get_file_permutation(
|
417 |
+
len(self.meta), permutation_index, self.shuffle_seed)
|
418 |
+
file_index = permutation[relative_index]
|
419 |
+
return self.meta[file_index]
|
420 |
+
|
421 |
+
if not self.sample_on_weight and not self.sample_on_duration:
|
422 |
+
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
423 |
+
else:
|
424 |
+
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
|
425 |
+
|
426 |
+
return self.meta[file_index]
|
427 |
+
|
428 |
+
def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
|
429 |
+
# Override this method in subclass if needed.
|
430 |
+
|
431 |
+
if self.load_wav:
|
432 |
+
return audio_read(path, seek_time, duration, pad=False)
|
433 |
+
else:
|
434 |
+
assert self.segment_duration is not None
|
435 |
+
n_frames = int(self.sample_rate * self.segment_duration)
|
436 |
+
return torch.zeros(self.channels, n_frames), self.sample_rate
|
437 |
+
|
438 |
+
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
439 |
+
if self.segment_duration is None:
|
440 |
+
file_meta = self.meta[index]
|
441 |
+
out, sr = audio_read(file_meta.path)
|
442 |
+
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
443 |
+
n_frames = out.shape[-1]
|
444 |
+
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
445 |
+
|
446 |
+
if self.if_add_gobal:
|
447 |
+
# global_feature_path
|
448 |
+
if self.global_feature_path!='' and os.path.exists(self.global_feature_path):
|
449 |
+
ytb_id = file_meta.video_path.split('/')[-1][:11]
|
450 |
+
with h5py.File(f'{self.global_feature_path}/{ytb_id}.h5', 'r') as file:
|
451 |
+
data = file['global_video_array'][:]
|
452 |
+
global_video = np.array(data)
|
453 |
+
local_video = video_read_local(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration)
|
454 |
+
else:
|
455 |
+
|
456 |
+
local_video, global_video = video_read_global(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration, global_mode=self.global_mode, global_num_frames=self.global_num_frames)
|
457 |
+
|
458 |
+
else: # local only
|
459 |
+
video = video_read_local(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration)
|
460 |
+
|
461 |
+
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
462 |
+
sample_rate=self.sample_rate, channels=out.shape[0])
|
463 |
+
else:
|
464 |
+
rng = torch.Generator()
|
465 |
+
if self.shuffle:
|
466 |
+
# We use index, plus extra randomness, either totally random if we don't know the epoch.
|
467 |
+
# otherwise we make use of the epoch number and optional shuffle_seed.
|
468 |
+
if self.current_epoch is None:
|
469 |
+
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
470 |
+
else:
|
471 |
+
rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
|
472 |
+
else:
|
473 |
+
# We only use index
|
474 |
+
rng.manual_seed(index)
|
475 |
+
|
476 |
+
for retry in range(self.max_read_retry):
|
477 |
+
file_meta = self.sample_file(index, rng)
|
478 |
+
# We add some variance in the file position even if audio file is smaller than segment
|
479 |
+
# without ending up with empty segments
|
480 |
+
|
481 |
+
overlap = self.video_overlap
|
482 |
+
segment_duration_no_overlap = self.segment_duration - overlap
|
483 |
+
max_seek = max(0, file_meta.duration - segment_duration_no_overlap * self.min_segment_ratio)
|
484 |
+
max_value = max_seek
|
485 |
+
random_value = torch.rand(1, generator=rng).item() * max_value
|
486 |
+
base_seek_time = segment_duration_no_overlap * int(random_value // segment_duration_no_overlap)
|
487 |
+
|
488 |
+
seek_time = random.randint(base_seek_time, base_seek_time + overlap)
|
489 |
+
seek_time = min(max_seek, seek_time)
|
490 |
+
|
491 |
+
try:
|
492 |
+
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
|
493 |
+
|
494 |
+
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
495 |
+
n_frames = out.shape[-1]
|
496 |
+
target_frames = int(self.segment_duration * self.sample_rate)
|
497 |
+
|
498 |
+
if self.if_add_gobal:
|
499 |
+
if self.global_feature_path!='' and os.path.exists(self.global_feature_path):
|
500 |
+
ytb_id = file_meta.video_path.split('/')[-1][:11]
|
501 |
+
with h5py.File(f'{self.global_feature_path}/{ytb_id}.h5', 'r') as file:
|
502 |
+
data = file['global_video_array'][:]
|
503 |
+
global_video = np.array(data)
|
504 |
+
indices = np.linspace(0, global_video.shape[1]-1, num=self.global_num_frames, endpoint=True).round().astype(int)
|
505 |
+
global_video = global_video[:,indices,:,:]
|
506 |
+
global_video = torch.from_numpy(global_video)
|
507 |
+
local_video = video_read_local(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration)
|
508 |
+
else:
|
509 |
+
local_video, global_video = video_read_global(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration, global_mode=self.global_mode, global_num_frames=self.global_num_frames)
|
510 |
+
assert global_video.shape[1]==self.global_num_frames
|
511 |
+
|
512 |
+
n_frames_video = local_video.shape[1]
|
513 |
+
|
514 |
+
else: # local only
|
515 |
+
video = video_read_local(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration)
|
516 |
+
n_frames_video = video.shape[1]
|
517 |
+
|
518 |
+
target_frames_video = int(self.segment_duration * self.video_fps)
|
519 |
+
|
520 |
+
if self.pad:
|
521 |
+
out = F.pad(out, (0, target_frames - n_frames))
|
522 |
+
|
523 |
+
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
524 |
+
sample_rate=self.sample_rate, channels=out.shape[0])
|
525 |
+
except Exception as exc:
|
526 |
+
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
527 |
+
if retry == self.max_read_retry - 1:
|
528 |
+
raise
|
529 |
+
else:
|
530 |
+
break
|
531 |
+
if self.if_add_gobal:
|
532 |
+
if self.return_info:
|
533 |
+
# Returns the wav and additional information on the wave segment
|
534 |
+
return out, [local_video, global_video], segment_info
|
535 |
+
else:
|
536 |
+
return out, [local_video, global_video]
|
537 |
+
else:
|
538 |
+
if self.return_info:
|
539 |
+
# Returns the wav and additional information on the wave segment
|
540 |
+
return out, [video], segment_info
|
541 |
+
else:
|
542 |
+
return out, [video]
|
543 |
+
|
544 |
+
def collater(self, samples):
|
545 |
+
"""The collater function has to be provided to the dataloader
|
546 |
+
if AudioDataset has return_info=True in order to properly collate
|
547 |
+
the samples of a batch.
|
548 |
+
"""
|
549 |
+
if self.segment_duration is None and len(samples) > 1:
|
550 |
+
assert self.pad, "Must allow padding when batching examples of different durations."
|
551 |
+
|
552 |
+
# In this case the audio reaching the collater is of variable length as segment_duration=None.
|
553 |
+
to_pad = self.segment_duration is None and self.pad
|
554 |
+
if to_pad:
|
555 |
+
max_len = max([wav.shape[-1] for wav, _ in samples])
|
556 |
+
|
557 |
+
def _pad_wav(wav):
|
558 |
+
return F.pad(wav, (0, max_len - wav.shape[-1]))
|
559 |
+
|
560 |
+
if self.return_info:
|
561 |
+
if len(samples) > 0:
|
562 |
+
assert len(samples[0]) == 3
|
563 |
+
assert isinstance(samples[0][0], torch.Tensor)
|
564 |
+
assert isinstance(samples[0][1], list)
|
565 |
+
assert isinstance(samples[0][2], SegmentInfo)
|
566 |
+
|
567 |
+
|
568 |
+
wavs = [wav for wav, _, _ in samples]
|
569 |
+
video_lists = [video_list for _, video_list, _ in samples]
|
570 |
+
segment_infos = [copy.deepcopy(info) for _, _, info in samples]
|
571 |
+
wav = torch.stack(wavs)
|
572 |
+
|
573 |
+
assert isinstance(video_lists[0],list)
|
574 |
+
if len(video_lists[0])==1:
|
575 |
+
videos=[video_list[0] for video_list in video_lists]
|
576 |
+
if to_pad:
|
577 |
+
# Each wav could be of a different duration as they are not segmented.
|
578 |
+
for i in range(len(samples)):
|
579 |
+
# Determines the total length of the signal with padding, so we update here as we pad.
|
580 |
+
segment_infos[i].total_frames = max_len
|
581 |
+
wavs[i] = _pad_wav(wavs[i])
|
582 |
+
video = torch.stack(videos)
|
583 |
+
|
584 |
+
return wav, [video], segment_infos
|
585 |
+
|
586 |
+
elif len(video_lists[0])==2:
|
587 |
+
|
588 |
+
local_videos=[video_list[0] for video_list in video_lists]
|
589 |
+
global_videos=[video_list[1] for video_list in video_lists]
|
590 |
+
|
591 |
+
if to_pad:
|
592 |
+
# Each wav could be of a different duration as they are not segmented.
|
593 |
+
for i in range(len(samples)):
|
594 |
+
# Determines the total length of the signal with padding, so we update here as we pad.
|
595 |
+
segment_infos[i].total_frames = max_len
|
596 |
+
wavs[i] = _pad_wav(wavs[i])
|
597 |
+
local_video = torch.stack(local_videos)
|
598 |
+
global_video = torch.stack(global_videos)
|
599 |
+
|
600 |
+
return wav, [local_video, global_video], segment_infos
|
601 |
+
|
602 |
+
else:
|
603 |
+
assert isinstance(samples[0], torch.Tensor)
|
604 |
+
if to_pad:
|
605 |
+
samples = [_pad_wav(s) for s in samples]
|
606 |
+
return torch.stack(samples)
|
607 |
+
|
608 |
+
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
609 |
+
"""Filters out audio files with audio durations that will not allow to sample examples from them."""
|
610 |
+
orig_len = len(meta)
|
611 |
+
|
612 |
+
# Filter data that is too short.
|
613 |
+
if self.min_audio_duration is not None:
|
614 |
+
meta = [m for m in meta if m.duration >= self.min_audio_duration]
|
615 |
+
|
616 |
+
# Filter data that is too long.
|
617 |
+
if self.max_audio_duration is not None:
|
618 |
+
meta = [m for m in meta if m.duration <= self.max_audio_duration]
|
619 |
+
|
620 |
+
filtered_len = len(meta)
|
621 |
+
removed_percentage = 100*(1-float(filtered_len)/orig_len)
|
622 |
+
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
|
623 |
+
if removed_percentage < 10:
|
624 |
+
logging.debug(msg)
|
625 |
+
else:
|
626 |
+
logging.warning(msg)
|
627 |
+
return meta
|
628 |
+
|
629 |
+
@classmethod
|
630 |
+
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
|
631 |
+
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
root (str or Path): Path to root folder containing audio files.
|
635 |
+
kwargs: Additional keyword arguments for the AudioDataset.
|
636 |
+
"""
|
637 |
+
root = Path(root)
|
638 |
+
if root.is_dir():
|
639 |
+
if (root / 'data.jsonl').exists():
|
640 |
+
root = root / 'data.jsonl'
|
641 |
+
elif (root / 'data.jsonl.gz').exists():
|
642 |
+
root = root / 'data.jsonl.gz'
|
643 |
+
else:
|
644 |
+
raise ValueError("Don't know where to read metadata from in the dir. "
|
645 |
+
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
|
646 |
+
meta = load_audio_meta(root)
|
647 |
+
|
648 |
+
return cls(meta, **kwargs)
|
649 |
+
|
650 |
+
@classmethod
|
651 |
+
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
|
652 |
+
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
|
653 |
+
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
|
654 |
+
|
655 |
+
Args:
|
656 |
+
root (str or Path): Path to root folder containing audio files.
|
657 |
+
minimal_meta (bool): Whether to only load minimal metadata or not.
|
658 |
+
exts (list of str): Extensions for audio files.
|
659 |
+
kwargs: Additional keyword arguments for the AudioDataset.
|
660 |
+
"""
|
661 |
+
root = Path(root)
|
662 |
+
if root.is_file():
|
663 |
+
meta = load_audio_meta(root, resolve=True)
|
664 |
+
else:
|
665 |
+
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
|
666 |
+
return cls(meta, **kwargs)
|
667 |
+
|
668 |
+
|
669 |
+
def main():
|
670 |
+
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
|
671 |
+
parser = argparse.ArgumentParser(
|
672 |
+
prog='audio_dataset',
|
673 |
+
description='Generate .jsonl files by scanning a folder.')
|
674 |
+
parser.add_argument('root', help='Root folder with all the audio files')
|
675 |
+
parser.add_argument('output_meta_file',
|
676 |
+
help='Output file to store the metadata, ')
|
677 |
+
parser.add_argument('--complete',
|
678 |
+
action='store_false', dest='minimal', default=True,
|
679 |
+
help='Retrieve all metadata, even the one that are expansive '
|
680 |
+
'to compute (e.g. normalization).')
|
681 |
+
parser.add_argument('--resolve',
|
682 |
+
action='store_true', default=False,
|
683 |
+
help='Resolve the paths to be absolute and with no symlinks.')
|
684 |
+
parser.add_argument('--workers',
|
685 |
+
default=10, type=int,
|
686 |
+
help='Number of workers.')
|
687 |
+
args = parser.parse_args()
|
688 |
+
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
|
689 |
+
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
|
690 |
+
save_audio_meta(args.output_meta_file, meta)
|
691 |
+
|
692 |
+
|
693 |
+
if __name__ == '__main__':
|
694 |
+
main()
|
audiocraft/data/audio_utils.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Various utilities for audio convertion (pcm format, sample rate and channels),
|
7 |
+
and volume normalization."""
|
8 |
+
import sys
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import julius
|
12 |
+
import torch
|
13 |
+
import torchaudio
|
14 |
+
|
15 |
+
|
16 |
+
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
|
17 |
+
"""Convert audio to the given number of channels.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
wav (torch.Tensor): Audio wave of shape [B, C, T].
|
21 |
+
channels (int): Expected number of channels as output.
|
22 |
+
Returns:
|
23 |
+
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
|
24 |
+
"""
|
25 |
+
*shape, src_channels, length = wav.shape
|
26 |
+
if src_channels == channels:
|
27 |
+
pass
|
28 |
+
elif channels == 1:
|
29 |
+
# Case 1:
|
30 |
+
# The caller asked 1-channel audio, and the stream has multiple
|
31 |
+
# channels, downmix all channels.
|
32 |
+
wav = wav.mean(dim=-2, keepdim=True)
|
33 |
+
elif src_channels == 1:
|
34 |
+
# Case 2:
|
35 |
+
# The caller asked for multiple channels, but the input file has
|
36 |
+
# a single channel, replicate the audio over all channels.
|
37 |
+
wav = wav.expand(*shape, channels, length)
|
38 |
+
elif src_channels >= channels:
|
39 |
+
# Case 3:
|
40 |
+
# The caller asked for multiple channels, and the input file has
|
41 |
+
# more channels than requested. In that case return the first channels.
|
42 |
+
wav = wav[..., :channels, :]
|
43 |
+
else:
|
44 |
+
# Case 4: What is a reasonable choice here?
|
45 |
+
raise ValueError('The audio file has less channels than requested but is not mono.')
|
46 |
+
return wav
|
47 |
+
|
48 |
+
|
49 |
+
def convert_audio(wav: torch.Tensor, from_rate: float,
|
50 |
+
to_rate: float, to_channels: int) -> torch.Tensor:
|
51 |
+
"""Convert audio to new sample rate and number of audio channels."""
|
52 |
+
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
|
53 |
+
wav = convert_audio_channels(wav, to_channels)
|
54 |
+
return wav
|
55 |
+
|
56 |
+
|
57 |
+
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
|
58 |
+
loudness_compressor: bool = False, energy_floor: float = 2e-3):
|
59 |
+
"""Normalize an input signal to a user loudness in dB LKFS.
|
60 |
+
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
wav (torch.Tensor): Input multichannel audio data.
|
64 |
+
sample_rate (int): Sample rate.
|
65 |
+
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
|
66 |
+
loudness_compressor (bool): Uses tanh for soft clipping.
|
67 |
+
energy_floor (float): anything below that RMS level will not be rescaled.
|
68 |
+
Returns:
|
69 |
+
torch.Tensor: Loudness normalized output data.
|
70 |
+
"""
|
71 |
+
energy = wav.pow(2).mean().sqrt().item()
|
72 |
+
if energy < energy_floor:
|
73 |
+
return wav
|
74 |
+
transform = torchaudio.transforms.Loudness(sample_rate)
|
75 |
+
input_loudness_db = transform(wav).item()
|
76 |
+
# calculate the gain needed to scale to the desired loudness level
|
77 |
+
delta_loudness = -loudness_headroom_db - input_loudness_db
|
78 |
+
gain = 10.0 ** (delta_loudness / 20.0)
|
79 |
+
output = gain * wav
|
80 |
+
if loudness_compressor:
|
81 |
+
output = torch.tanh(output)
|
82 |
+
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
|
83 |
+
return output
|
84 |
+
|
85 |
+
|
86 |
+
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
|
87 |
+
"""Utility function to clip the audio with logging if specified."""
|
88 |
+
max_scale = wav.abs().max()
|
89 |
+
if log_clipping and max_scale > 1:
|
90 |
+
clamp_prob = (wav.abs() > 1).float().mean().item()
|
91 |
+
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
|
92 |
+
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
|
93 |
+
wav.clamp_(-1, 1)
|
94 |
+
|
95 |
+
|
96 |
+
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
97 |
+
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
98 |
+
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
99 |
+
loudness_compressor: bool = False, log_clipping: bool = False,
|
100 |
+
sample_rate: tp.Optional[int] = None,
|
101 |
+
stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
102 |
+
"""Normalize the audio according to the prescribed strategy (see after).
|
103 |
+
|
104 |
+
Args:
|
105 |
+
wav (torch.Tensor): Audio data.
|
106 |
+
normalize (bool): if `True` (default), normalizes according to the prescribed
|
107 |
+
strategy (see after). If `False`, the strategy is only used in case clipping
|
108 |
+
would happen.
|
109 |
+
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
110 |
+
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
111 |
+
with extra headroom to avoid clipping. 'clip' just clips.
|
112 |
+
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
113 |
+
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
114 |
+
than the `peak_clip` one to avoid further clipping.
|
115 |
+
loudness_headroom_db (float): Target loudness for loudness normalization.
|
116 |
+
loudness_compressor (bool): If True, uses tanh based soft clipping.
|
117 |
+
log_clipping (bool): If True, basic logging on stderr when clipping still
|
118 |
+
occurs despite strategy (only for 'rms').
|
119 |
+
sample_rate (int): Sample rate for the audio data (required for loudness).
|
120 |
+
stem_name (str, optional): Stem name for clipping logging.
|
121 |
+
Returns:
|
122 |
+
torch.Tensor: Normalized audio.
|
123 |
+
"""
|
124 |
+
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
|
125 |
+
scale_rms = 10 ** (-rms_headroom_db / 20)
|
126 |
+
if strategy == 'peak':
|
127 |
+
rescaling = (scale_peak / wav.abs().max())
|
128 |
+
if normalize or rescaling < 1:
|
129 |
+
wav = wav * rescaling
|
130 |
+
elif strategy == 'clip':
|
131 |
+
wav = wav.clamp(-scale_peak, scale_peak)
|
132 |
+
elif strategy == 'rms':
|
133 |
+
mono = wav.mean(dim=0)
|
134 |
+
rescaling = scale_rms / mono.pow(2).mean().sqrt()
|
135 |
+
if normalize or rescaling < 1:
|
136 |
+
wav = wav * rescaling
|
137 |
+
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
138 |
+
elif strategy == 'loudness':
|
139 |
+
assert sample_rate is not None, "Loudness normalization requires sample rate."
|
140 |
+
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
|
141 |
+
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
142 |
+
else:
|
143 |
+
assert wav.abs().max() < 1
|
144 |
+
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
|
145 |
+
return wav
|
146 |
+
|
147 |
+
|
148 |
+
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
149 |
+
"""Convert audio to float 32 bits PCM format.
|
150 |
+
"""
|
151 |
+
if wav.dtype.is_floating_point:
|
152 |
+
return wav
|
153 |
+
elif wav.dtype == torch.int16:
|
154 |
+
return wav.float() / 2**15
|
155 |
+
elif wav.dtype == torch.int32:
|
156 |
+
return wav.float() / 2**31
|
157 |
+
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
158 |
+
|
159 |
+
|
160 |
+
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
161 |
+
"""Convert audio to int 16 bits PCM format.
|
162 |
+
|
163 |
+
..Warning:: There exist many formula for doing this conversion. None are perfect
|
164 |
+
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
|
165 |
+
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
|
166 |
+
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
167 |
+
"""
|
168 |
+
if wav.dtype.is_floating_point:
|
169 |
+
assert wav.abs().max() <= 1
|
170 |
+
candidate = (wav * 2 ** 15).round()
|
171 |
+
if candidate.max() >= 2 ** 15: # clipping would occur
|
172 |
+
candidate = (wav * (2 ** 15 - 1)).round()
|
173 |
+
return candidate.short()
|
174 |
+
else:
|
175 |
+
assert wav.dtype == torch.int16
|
176 |
+
return wav
|
audiocraft/data/info_audio_dataset.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Base classes for the datasets that also provide non-audio metadata,
|
7 |
+
e.g. description, text transcription etc.
|
8 |
+
"""
|
9 |
+
from dataclasses import dataclass
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
import re
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from .audio_dataset import AudioDataset, AudioMeta
|
18 |
+
from ..environment import AudioCraftEnvironment
|
19 |
+
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
|
26 |
+
"""Monkey-patch meta to match cluster specificities."""
|
27 |
+
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
|
28 |
+
if meta.info_path is not None:
|
29 |
+
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
|
30 |
+
return meta
|
31 |
+
|
32 |
+
|
33 |
+
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
34 |
+
"""Monkey-patch all meta to match cluster specificities."""
|
35 |
+
return [_clusterify_meta(m) for m in meta]
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class AudioInfo(SegmentWithAttributes):
|
40 |
+
"""Dummy SegmentInfo with empty attributes.
|
41 |
+
|
42 |
+
The InfoAudioDataset is expected to return metadata that inherits
|
43 |
+
from SegmentWithAttributes class and can return conditioning attributes.
|
44 |
+
|
45 |
+
This basically guarantees all datasets will be compatible with current
|
46 |
+
solver that contain conditioners requiring this.
|
47 |
+
"""
|
48 |
+
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
|
49 |
+
|
50 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
51 |
+
return ConditioningAttributes()
|
52 |
+
|
53 |
+
|
54 |
+
class InfoAudioDataset(AudioDataset):
|
55 |
+
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
|
56 |
+
|
57 |
+
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
|
58 |
+
"""
|
59 |
+
def __init__(self, meta: tp.List[AudioMeta], **kwargs):
|
60 |
+
super().__init__(clusterify_all_meta(meta), **kwargs)
|
61 |
+
|
62 |
+
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
|
63 |
+
if not self.return_info:
|
64 |
+
wav = super().__getitem__(index)
|
65 |
+
assert isinstance(wav, torch.Tensor)
|
66 |
+
return wav
|
67 |
+
wav, video_list, meta = super().__getitem__(index)
|
68 |
+
assert isinstance(video_list, list)
|
69 |
+
return wav, video_list, AudioInfo(**meta.to_dict())
|
70 |
+
|
71 |
+
|
72 |
+
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
|
73 |
+
"""Preprocess a single keyword or possible a list of keywords."""
|
74 |
+
if isinstance(value, list):
|
75 |
+
return get_keyword_list(value)
|
76 |
+
else:
|
77 |
+
return get_keyword(value)
|
78 |
+
|
79 |
+
|
80 |
+
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
|
81 |
+
"""Preprocess a single keyword."""
|
82 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
83 |
+
return None
|
84 |
+
else:
|
85 |
+
return value.strip()
|
86 |
+
|
87 |
+
|
88 |
+
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
|
89 |
+
"""Preprocess a single keyword."""
|
90 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
91 |
+
return None
|
92 |
+
else:
|
93 |
+
return value.strip().lower()
|
94 |
+
|
95 |
+
|
96 |
+
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
|
97 |
+
"""Preprocess a list of keywords."""
|
98 |
+
if isinstance(values, str):
|
99 |
+
values = [v.strip() for v in re.split(r'[,\s]', values)]
|
100 |
+
elif isinstance(values, float) and math.isnan(values):
|
101 |
+
values = []
|
102 |
+
if not isinstance(values, list):
|
103 |
+
logger.debug(f"Unexpected keyword list {values}")
|
104 |
+
values = [str(values)]
|
105 |
+
|
106 |
+
kws = [get_keyword(v) for v in values]
|
107 |
+
kw_list = [k for k in kws if k is not None]
|
108 |
+
if len(kw_list) == 0:
|
109 |
+
return None
|
110 |
+
else:
|
111 |
+
return kw_list
|
audiocraft/data/music_dataset.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Dataset of music tracks with rich metadata.
|
7 |
+
"""
|
8 |
+
from dataclasses import dataclass, field, fields, replace
|
9 |
+
import gzip
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
from pathlib import Path
|
13 |
+
import random
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from .info_audio_dataset import (
|
19 |
+
InfoAudioDataset,
|
20 |
+
AudioInfo,
|
21 |
+
get_keyword_list,
|
22 |
+
get_keyword,
|
23 |
+
get_string
|
24 |
+
)
|
25 |
+
from ..modules.conditioners import (
|
26 |
+
ConditioningAttributes,
|
27 |
+
JointEmbedCondition,
|
28 |
+
WavCondition,
|
29 |
+
)
|
30 |
+
from ..utils.utils import warn_once
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class MusicInfo(AudioInfo):
|
38 |
+
"""Segment info augmented with music metadata.
|
39 |
+
"""
|
40 |
+
# music-specific metadata
|
41 |
+
title: tp.Optional[str] = None
|
42 |
+
artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
|
43 |
+
key: tp.Optional[str] = None
|
44 |
+
bpm: tp.Optional[float] = None
|
45 |
+
genre: tp.Optional[str] = None
|
46 |
+
moods: tp.Optional[list] = None
|
47 |
+
keywords: tp.Optional[list] = None
|
48 |
+
description: tp.Optional[str] = None
|
49 |
+
name: tp.Optional[str] = None
|
50 |
+
instrument: tp.Optional[str] = None
|
51 |
+
# original wav accompanying the metadata
|
52 |
+
self_wav: tp.Optional[WavCondition] = None
|
53 |
+
# dict mapping attributes names to tuple of wav, text and metadata
|
54 |
+
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def has_music_meta(self) -> bool:
|
58 |
+
return self.name is not None
|
59 |
+
|
60 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
61 |
+
out = ConditioningAttributes()
|
62 |
+
for _field in fields(self):
|
63 |
+
key, value = _field.name, getattr(self, _field.name)
|
64 |
+
if key == 'self_wav':
|
65 |
+
out.wav[key] = value
|
66 |
+
elif key == 'joint_embed':
|
67 |
+
for embed_attribute, embed_cond in value.items():
|
68 |
+
out.joint_embed[embed_attribute] = embed_cond
|
69 |
+
else:
|
70 |
+
if isinstance(value, list):
|
71 |
+
value = ' '.join(value)
|
72 |
+
out.text[key] = value
|
73 |
+
return out
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def attribute_getter(attribute):
|
77 |
+
if attribute == 'bpm':
|
78 |
+
preprocess_func = get_bpm
|
79 |
+
elif attribute == 'key':
|
80 |
+
preprocess_func = get_musical_key
|
81 |
+
elif attribute in ['moods', 'keywords']:
|
82 |
+
preprocess_func = get_keyword_list
|
83 |
+
elif attribute in ['genre', 'name', 'instrument']:
|
84 |
+
preprocess_func = get_keyword
|
85 |
+
elif attribute in ['title', 'artist', 'description']:
|
86 |
+
preprocess_func = get_string
|
87 |
+
else:
|
88 |
+
preprocess_func = None
|
89 |
+
return preprocess_func
|
90 |
+
|
91 |
+
@classmethod
|
92 |
+
def from_dict(cls, dictionary: dict, fields_required: bool = False):
|
93 |
+
_dictionary: tp.Dict[str, tp.Any] = {}
|
94 |
+
|
95 |
+
# allow a subset of attributes to not be loaded from the dictionary
|
96 |
+
# these attributes may be populated later
|
97 |
+
post_init_attributes = ['self_wav', 'joint_embed']
|
98 |
+
optional_fields = ['keywords']
|
99 |
+
|
100 |
+
for _field in fields(cls):
|
101 |
+
if _field.name in post_init_attributes:
|
102 |
+
continue
|
103 |
+
elif _field.name not in dictionary:
|
104 |
+
if fields_required and _field.name not in optional_fields:
|
105 |
+
raise KeyError(f"Unexpected missing key: {_field.name}")
|
106 |
+
else:
|
107 |
+
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
|
108 |
+
value = dictionary[_field.name]
|
109 |
+
if preprocess_func:
|
110 |
+
value = preprocess_func(value)
|
111 |
+
_dictionary[_field.name] = value
|
112 |
+
return cls(**_dictionary)
|
113 |
+
|
114 |
+
|
115 |
+
def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
|
116 |
+
drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
|
117 |
+
"""Augment MusicInfo description with additional metadata fields and potential dropout.
|
118 |
+
Additional textual attributes are added given probability 'merge_text_conditions_p' and
|
119 |
+
the original textual description is dropped from the augmented description given probability drop_desc_p.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
music_info (MusicInfo): The music metadata to augment.
|
123 |
+
merge_text_p (float): Probability of merging additional metadata to the description.
|
124 |
+
If provided value is 0, then no merging is performed.
|
125 |
+
drop_desc_p (float): Probability of dropping the original description on text merge.
|
126 |
+
if provided value is 0, then no drop out is performed.
|
127 |
+
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
|
128 |
+
Returns:
|
129 |
+
MusicInfo: The MusicInfo with augmented textual description.
|
130 |
+
"""
|
131 |
+
def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
|
132 |
+
valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
|
133 |
+
valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
|
134 |
+
keep_field = random.uniform(0, 1) < drop_other_p
|
135 |
+
return valid_field_name and valid_field_value and keep_field
|
136 |
+
|
137 |
+
def process_value(v: tp.Any) -> str:
|
138 |
+
if isinstance(v, (int, float, str)):
|
139 |
+
return str(v)
|
140 |
+
if isinstance(v, list):
|
141 |
+
return ", ".join(v)
|
142 |
+
else:
|
143 |
+
raise ValueError(f"Unknown type for text value! ({type(v), v})")
|
144 |
+
|
145 |
+
description = music_info.description
|
146 |
+
|
147 |
+
metadata_text = ""
|
148 |
+
if random.uniform(0, 1) < merge_text_p:
|
149 |
+
meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
|
150 |
+
for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
|
151 |
+
random.shuffle(meta_pairs)
|
152 |
+
metadata_text = ". ".join(meta_pairs)
|
153 |
+
description = description if not random.uniform(0, 1) < drop_desc_p else None
|
154 |
+
logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
|
155 |
+
|
156 |
+
if description is None:
|
157 |
+
description = metadata_text if len(metadata_text) > 1 else None
|
158 |
+
else:
|
159 |
+
description = ". ".join([description.rstrip('.'), metadata_text])
|
160 |
+
description = description.strip() if description else None
|
161 |
+
|
162 |
+
music_info = replace(music_info)
|
163 |
+
music_info.description = description
|
164 |
+
return music_info
|
165 |
+
|
166 |
+
|
167 |
+
class Paraphraser:
|
168 |
+
def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
|
169 |
+
self.paraphrase_p = paraphrase_p
|
170 |
+
open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
|
171 |
+
with open_fn(paraphrase_source, 'rb') as f: # type: ignore
|
172 |
+
self.paraphrase_source = json.loads(f.read())
|
173 |
+
logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
|
174 |
+
|
175 |
+
def sample_paraphrase(self, audio_path: str, description: str):
|
176 |
+
if random.random() >= self.paraphrase_p:
|
177 |
+
return description
|
178 |
+
info_path = Path(audio_path).with_suffix('.json')
|
179 |
+
if info_path not in self.paraphrase_source:
|
180 |
+
warn_once(logger, f"{info_path} not in paraphrase source!")
|
181 |
+
return description
|
182 |
+
new_desc = random.choice(self.paraphrase_source[info_path])
|
183 |
+
logger.debug(f"{description} -> {new_desc}")
|
184 |
+
return new_desc
|
185 |
+
|
186 |
+
|
187 |
+
class MusicDataset(InfoAudioDataset):
|
188 |
+
"""Music dataset is an AudioDataset with music-related metadata.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
info_fields_required (bool): Whether to enforce having required fields.
|
192 |
+
merge_text_p (float): Probability of merging additional metadata to the description.
|
193 |
+
drop_desc_p (float): Probability of dropping the original description on text merge.
|
194 |
+
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
|
195 |
+
joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
|
196 |
+
paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
|
197 |
+
paraphrases for the description. The json should be a dict with keys are the
|
198 |
+
original info path (e.g. track_path.json) and each value is a list of possible
|
199 |
+
paraphrased.
|
200 |
+
paraphrase_p (float): probability of taking a paraphrase.
|
201 |
+
|
202 |
+
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
|
203 |
+
"""
|
204 |
+
def __init__(self, *args, info_fields_required: bool = True,
|
205 |
+
merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
|
206 |
+
joint_embed_attributes: tp.List[str] = [],
|
207 |
+
paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
|
208 |
+
**kwargs):
|
209 |
+
kwargs['return_info'] = True # We require the info for each song of the dataset.
|
210 |
+
super().__init__(*args, **kwargs)
|
211 |
+
self.info_fields_required = info_fields_required
|
212 |
+
self.merge_text_p = merge_text_p
|
213 |
+
self.drop_desc_p = drop_desc_p
|
214 |
+
self.drop_other_p = drop_other_p
|
215 |
+
self.joint_embed_attributes = joint_embed_attributes
|
216 |
+
self.paraphraser = None
|
217 |
+
if paraphrase_source is not None:
|
218 |
+
self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
|
219 |
+
|
220 |
+
def __getitem__(self, index):
|
221 |
+
wav, video_list, info = super().__getitem__(index)
|
222 |
+
assert isinstance(video_list, list)
|
223 |
+
|
224 |
+
|
225 |
+
if len(video_list)==1:
|
226 |
+
video=video_list[0]
|
227 |
+
info_data = info.to_dict()
|
228 |
+
music_info_path = Path(info.meta.path).with_suffix('.json')
|
229 |
+
if Path(music_info_path).exists():
|
230 |
+
with open(music_info_path, 'r') as json_file:
|
231 |
+
music_data = json.load(json_file)
|
232 |
+
music_data.update(info_data)
|
233 |
+
music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
|
234 |
+
if self.paraphraser is not None:
|
235 |
+
music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
|
236 |
+
if self.merge_text_p:
|
237 |
+
music_info = augment_music_info_description(
|
238 |
+
music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
|
239 |
+
else:
|
240 |
+
music_info = MusicInfo.from_dict(info_data, fields_required=False)
|
241 |
+
|
242 |
+
music_info.self_wav = WavCondition(
|
243 |
+
wav=wav[None], length=torch.tensor([info.n_frames]),
|
244 |
+
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
245 |
+
|
246 |
+
for att in self.joint_embed_attributes:
|
247 |
+
att_value = getattr(music_info, att)
|
248 |
+
joint_embed_cond = JointEmbedCondition(
|
249 |
+
wav[None], [att_value], torch.tensor([info.n_frames]),
|
250 |
+
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
251 |
+
music_info.joint_embed[att] = joint_embed_cond
|
252 |
+
|
253 |
+
return wav, [video], music_info
|
254 |
+
|
255 |
+
elif len(video_list)==2:
|
256 |
+
local_video=video_list[0]
|
257 |
+
global_video=video_list[1]
|
258 |
+
|
259 |
+
info_data = info.to_dict()
|
260 |
+
music_info_path = Path(info.meta.path).with_suffix('.json')
|
261 |
+
|
262 |
+
if Path(music_info_path).exists():
|
263 |
+
with open(music_info_path, 'r') as json_file:
|
264 |
+
music_data = json.load(json_file)
|
265 |
+
music_data.update(info_data)
|
266 |
+
music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
|
267 |
+
if self.paraphraser is not None:
|
268 |
+
music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
|
269 |
+
if self.merge_text_p:
|
270 |
+
music_info = augment_music_info_description(
|
271 |
+
music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
|
272 |
+
else:
|
273 |
+
music_info = MusicInfo.from_dict(info_data, fields_required=False)
|
274 |
+
|
275 |
+
music_info.self_wav = WavCondition(
|
276 |
+
wav=wav[None], length=torch.tensor([info.n_frames]),
|
277 |
+
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
278 |
+
|
279 |
+
for att in self.joint_embed_attributes:
|
280 |
+
att_value = getattr(music_info, att)
|
281 |
+
joint_embed_cond = JointEmbedCondition(
|
282 |
+
wav[None], [att_value], torch.tensor([info.n_frames]),
|
283 |
+
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
284 |
+
music_info.joint_embed[att] = joint_embed_cond
|
285 |
+
|
286 |
+
return wav, [local_video, global_video], music_info
|
287 |
+
|
288 |
+
|
289 |
+
def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
|
290 |
+
"""Preprocess key keywords, discarding them if there are multiple key defined."""
|
291 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
292 |
+
return None
|
293 |
+
elif ',' in value:
|
294 |
+
# For now, we discard when multiple keys are defined separated with comas
|
295 |
+
return None
|
296 |
+
else:
|
297 |
+
return value.strip().lower()
|
298 |
+
|
299 |
+
|
300 |
+
def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
|
301 |
+
"""Preprocess to a float."""
|
302 |
+
if value is None:
|
303 |
+
return None
|
304 |
+
try:
|
305 |
+
return float(value)
|
306 |
+
except ValueError:
|
307 |
+
return None
|
audiocraft/data/sound_dataset.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Dataset of audio with a simple description.
|
7 |
+
"""
|
8 |
+
|
9 |
+
from dataclasses import dataclass, fields, replace
|
10 |
+
import json
|
11 |
+
from pathlib import Path
|
12 |
+
import random
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from .info_audio_dataset import (
|
19 |
+
InfoAudioDataset,
|
20 |
+
get_keyword_or_keyword_list
|
21 |
+
)
|
22 |
+
from ..modules.conditioners import (
|
23 |
+
ConditioningAttributes,
|
24 |
+
SegmentWithAttributes,
|
25 |
+
WavCondition,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
EPS = torch.finfo(torch.float32).eps
|
30 |
+
TARGET_LEVEL_LOWER = -35
|
31 |
+
TARGET_LEVEL_UPPER = -15
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class SoundInfo(SegmentWithAttributes):
|
36 |
+
"""Segment info augmented with Sound metadata.
|
37 |
+
"""
|
38 |
+
description: tp.Optional[str] = None
|
39 |
+
self_wav: tp.Optional[torch.Tensor] = None
|
40 |
+
|
41 |
+
@property
|
42 |
+
def has_sound_meta(self) -> bool:
|
43 |
+
return self.description is not None
|
44 |
+
|
45 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
46 |
+
out = ConditioningAttributes()
|
47 |
+
|
48 |
+
for _field in fields(self):
|
49 |
+
key, value = _field.name, getattr(self, _field.name)
|
50 |
+
if key == 'self_wav':
|
51 |
+
out.wav[key] = value
|
52 |
+
else:
|
53 |
+
out.text[key] = value
|
54 |
+
return out
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def attribute_getter(attribute):
|
58 |
+
if attribute == 'description':
|
59 |
+
preprocess_func = get_keyword_or_keyword_list
|
60 |
+
else:
|
61 |
+
preprocess_func = None
|
62 |
+
return preprocess_func
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def from_dict(cls, dictionary: dict, fields_required: bool = False):
|
66 |
+
_dictionary: tp.Dict[str, tp.Any] = {}
|
67 |
+
|
68 |
+
# allow a subset of attributes to not be loaded from the dictionary
|
69 |
+
# these attributes may be populated later
|
70 |
+
post_init_attributes = ['self_wav']
|
71 |
+
|
72 |
+
for _field in fields(cls):
|
73 |
+
if _field.name in post_init_attributes:
|
74 |
+
continue
|
75 |
+
elif _field.name not in dictionary:
|
76 |
+
if fields_required:
|
77 |
+
raise KeyError(f"Unexpected missing key: {_field.name}")
|
78 |
+
else:
|
79 |
+
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
|
80 |
+
value = dictionary[_field.name]
|
81 |
+
if preprocess_func:
|
82 |
+
value = preprocess_func(value)
|
83 |
+
_dictionary[_field.name] = value
|
84 |
+
return cls(**_dictionary)
|
85 |
+
|
86 |
+
|
87 |
+
class SoundDataset(InfoAudioDataset):
|
88 |
+
"""Sound audio dataset: Audio dataset with environmental sound-specific metadata.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
|
92 |
+
external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
|
93 |
+
The metadata files contained in this folder are expected to match the stem of the audio file with
|
94 |
+
a json extension.
|
95 |
+
aug_p (float): Probability of performing audio mixing augmentation on the batch.
|
96 |
+
mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
|
97 |
+
mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
|
98 |
+
mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
|
99 |
+
mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
|
100 |
+
kwargs: Additional arguments for AudioDataset.
|
101 |
+
|
102 |
+
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
|
103 |
+
"""
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
*args,
|
107 |
+
info_fields_required: bool = True,
|
108 |
+
external_metadata_source: tp.Optional[str] = None,
|
109 |
+
aug_p: float = 0.,
|
110 |
+
mix_p: float = 0.,
|
111 |
+
mix_snr_low: int = -5,
|
112 |
+
mix_snr_high: int = 5,
|
113 |
+
mix_min_overlap: float = 0.5,
|
114 |
+
**kwargs
|
115 |
+
):
|
116 |
+
kwargs['return_info'] = True # We require the info for each song of the dataset.
|
117 |
+
super().__init__(*args, **kwargs)
|
118 |
+
self.info_fields_required = info_fields_required
|
119 |
+
self.external_metadata_source = external_metadata_source
|
120 |
+
self.aug_p = aug_p
|
121 |
+
self.mix_p = mix_p
|
122 |
+
if self.aug_p > 0:
|
123 |
+
assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
|
124 |
+
assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
|
125 |
+
self.mix_snr_low = mix_snr_low
|
126 |
+
self.mix_snr_high = mix_snr_high
|
127 |
+
self.mix_min_overlap = mix_min_overlap
|
128 |
+
|
129 |
+
def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
|
130 |
+
"""Get path of JSON with metadata (description, etc.).
|
131 |
+
If there exists a JSON with the same name as 'path.name', then it will be used.
|
132 |
+
Else, such JSON will be searched for in an external json source folder if it exists.
|
133 |
+
"""
|
134 |
+
info_path = Path(path).with_suffix('.json')
|
135 |
+
if Path(info_path).exists():
|
136 |
+
return info_path
|
137 |
+
elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
|
138 |
+
return Path(self.external_metadata_source) / info_path.name
|
139 |
+
else:
|
140 |
+
raise Exception(f"Unable to find a metadata JSON for path: {path}")
|
141 |
+
|
142 |
+
def __getitem__(self, index):
|
143 |
+
wav, info = super().__getitem__(index)
|
144 |
+
info_data = info.to_dict()
|
145 |
+
info_path = self._get_info_path(info.meta.path)
|
146 |
+
if Path(info_path).exists():
|
147 |
+
with open(info_path, 'r') as json_file:
|
148 |
+
sound_data = json.load(json_file)
|
149 |
+
sound_data.update(info_data)
|
150 |
+
sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
|
151 |
+
# if there are multiple descriptions, sample one randomly
|
152 |
+
if isinstance(sound_info.description, list):
|
153 |
+
sound_info.description = random.choice(sound_info.description)
|
154 |
+
else:
|
155 |
+
sound_info = SoundInfo.from_dict(info_data, fields_required=False)
|
156 |
+
|
157 |
+
sound_info.self_wav = WavCondition(
|
158 |
+
wav=wav[None], length=torch.tensor([info.n_frames]),
|
159 |
+
sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
160 |
+
|
161 |
+
return wav, sound_info
|
162 |
+
|
163 |
+
def collater(self, samples):
|
164 |
+
# when training, audio mixing is performed in the collate function
|
165 |
+
wav, sound_info = super().collater(samples) # SoundDataset always returns infos
|
166 |
+
if self.aug_p > 0:
|
167 |
+
wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
|
168 |
+
snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
|
169 |
+
min_overlap=self.mix_min_overlap)
|
170 |
+
return wav, sound_info
|
171 |
+
|
172 |
+
|
173 |
+
def rms_f(x: torch.Tensor) -> torch.Tensor:
|
174 |
+
return (x ** 2).mean(1).pow(0.5)
|
175 |
+
|
176 |
+
|
177 |
+
def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
|
178 |
+
"""Normalize the signal to the target level."""
|
179 |
+
rms = rms_f(audio)
|
180 |
+
scalar = 10 ** (target_level / 20) / (rms + EPS)
|
181 |
+
audio = audio * scalar.unsqueeze(1)
|
182 |
+
return audio
|
183 |
+
|
184 |
+
|
185 |
+
def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
|
186 |
+
return (abs(audio) > clipping_threshold).any(1)
|
187 |
+
|
188 |
+
|
189 |
+
def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
|
190 |
+
start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
|
191 |
+
remainder = src.shape[1] - start
|
192 |
+
if dst.shape[1] > remainder:
|
193 |
+
src[:, start:] = src[:, start:] + dst[:, :remainder]
|
194 |
+
else:
|
195 |
+
src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
|
196 |
+
return src
|
197 |
+
|
198 |
+
|
199 |
+
def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
|
200 |
+
target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
|
201 |
+
"""Function to mix clean speech and noise at various SNR levels.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
|
205 |
+
noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
|
206 |
+
snr (int): SNR level when mixing.
|
207 |
+
min_overlap (float): Minimum overlap between the two mixed sources.
|
208 |
+
target_level (int): Gain level in dB.
|
209 |
+
clipping_threshold (float): Threshold for clipping the audio.
|
210 |
+
Returns:
|
211 |
+
torch.Tensor: The mixed audio, of shape [B, T].
|
212 |
+
"""
|
213 |
+
if clean.shape[1] > noise.shape[1]:
|
214 |
+
noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
|
215 |
+
else:
|
216 |
+
noise = noise[:, :clean.shape[1]]
|
217 |
+
|
218 |
+
# normalizing to -25 dB FS
|
219 |
+
clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
|
220 |
+
clean = normalize(clean, target_level)
|
221 |
+
rmsclean = rms_f(clean)
|
222 |
+
|
223 |
+
noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
|
224 |
+
noise = normalize(noise, target_level)
|
225 |
+
rmsnoise = rms_f(noise)
|
226 |
+
|
227 |
+
# set the noise level for a given SNR
|
228 |
+
noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
|
229 |
+
noisenewlevel = noise * noisescalar
|
230 |
+
|
231 |
+
# mix noise and clean speech
|
232 |
+
noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
|
233 |
+
|
234 |
+
# randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
|
235 |
+
# there is a chance of clipping that might happen with very less probability, which is not a major issue.
|
236 |
+
noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
|
237 |
+
rmsnoisy = rms_f(noisyspeech)
|
238 |
+
scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
|
239 |
+
noisyspeech = noisyspeech * scalarnoisy
|
240 |
+
clean = clean * scalarnoisy
|
241 |
+
noisenewlevel = noisenewlevel * scalarnoisy
|
242 |
+
|
243 |
+
# final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
|
244 |
+
clipped = is_clipped(noisyspeech)
|
245 |
+
if clipped.any():
|
246 |
+
noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
|
247 |
+
noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
|
248 |
+
|
249 |
+
return noisyspeech
|
250 |
+
|
251 |
+
|
252 |
+
def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
|
253 |
+
if snr_low == snr_high:
|
254 |
+
snr = snr_low
|
255 |
+
else:
|
256 |
+
snr = np.random.randint(snr_low, snr_high)
|
257 |
+
mix = snr_mixer(src, dst, snr, min_overlap)
|
258 |
+
return mix
|
259 |
+
|
260 |
+
|
261 |
+
def mix_text(src_text: str, dst_text: str):
|
262 |
+
"""Mix text from different sources by concatenating them."""
|
263 |
+
if src_text == dst_text:
|
264 |
+
return src_text
|
265 |
+
return src_text + " " + dst_text
|
266 |
+
|
267 |
+
|
268 |
+
def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
|
269 |
+
snr_low: int, snr_high: int, min_overlap: float):
|
270 |
+
"""Mix samples within a batch, summing the waveforms and concatenating the text infos.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
wavs (torch.Tensor): Audio tensors of shape [B, C, T].
|
274 |
+
infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
|
275 |
+
aug_p (float): Augmentation probability.
|
276 |
+
mix_p (float): Proportion of items in the batch to mix (and merge) together.
|
277 |
+
snr_low (int): Lowerbound for sampling SNR.
|
278 |
+
snr_high (int): Upperbound for sampling SNR.
|
279 |
+
min_overlap (float): Minimum overlap between mixed samples.
|
280 |
+
Returns:
|
281 |
+
tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
|
282 |
+
and mixed SoundInfo for the given batch.
|
283 |
+
"""
|
284 |
+
# no mixing to perform within the batch
|
285 |
+
if mix_p == 0:
|
286 |
+
return wavs, infos
|
287 |
+
|
288 |
+
if random.uniform(0, 1) < aug_p:
|
289 |
+
# perform all augmentations on waveforms as [B, T]
|
290 |
+
# randomly picking pairs of audio to mix
|
291 |
+
assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
|
292 |
+
wavs = wavs.mean(dim=1, keepdim=False)
|
293 |
+
B, T = wavs.shape
|
294 |
+
k = int(mix_p * B)
|
295 |
+
mixed_sources_idx = torch.randperm(B)[:k]
|
296 |
+
mixed_targets_idx = torch.randperm(B)[:k]
|
297 |
+
aug_wavs = snr_mix(
|
298 |
+
wavs[mixed_sources_idx],
|
299 |
+
wavs[mixed_targets_idx],
|
300 |
+
snr_low,
|
301 |
+
snr_high,
|
302 |
+
min_overlap,
|
303 |
+
)
|
304 |
+
# mixing textual descriptions in metadata
|
305 |
+
descriptions = [info.description for info in infos]
|
306 |
+
aug_infos = []
|
307 |
+
for i, j in zip(mixed_sources_idx, mixed_targets_idx):
|
308 |
+
text = mix_text(descriptions[i], descriptions[j])
|
309 |
+
m = replace(infos[i])
|
310 |
+
m.description = text
|
311 |
+
aug_infos.append(m)
|
312 |
+
|
313 |
+
# back to [B, C, T]
|
314 |
+
aug_wavs = aug_wavs.unsqueeze(1)
|
315 |
+
assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
|
316 |
+
assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
|
317 |
+
assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
|
318 |
+
|
319 |
+
return aug_wavs, aug_infos # [B, C, T]
|
320 |
+
else:
|
321 |
+
# randomly pick samples in the batch to match
|
322 |
+
# the batch size when performing audio mixing
|
323 |
+
B, C, T = wavs.shape
|
324 |
+
k = int(mix_p * B)
|
325 |
+
wav_idx = torch.randperm(B)[:k]
|
326 |
+
wavs = wavs[wav_idx]
|
327 |
+
infos = [infos[i] for i in wav_idx]
|
328 |
+
assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
|
329 |
+
|
330 |
+
return wavs, infos # [B, C, T]
|
audiocraft/data/video.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import decord
|
2 |
+
from decord import VideoReader
|
3 |
+
from decord import cpu
|
4 |
+
import torch
|
5 |
+
import math
|
6 |
+
import einops
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
|
9 |
+
def adjust_video_duration(video_tensor, duration, target_fps):
|
10 |
+
current_duration = video_tensor.shape[1]
|
11 |
+
target_duration = duration * target_fps
|
12 |
+
|
13 |
+
if current_duration > target_duration:
|
14 |
+
video_tensor = video_tensor[:, :target_duration]
|
15 |
+
elif current_duration < target_duration:
|
16 |
+
last_frame = video_tensor[:, -1:]
|
17 |
+
repeat_times = target_duration - current_duration
|
18 |
+
video_tensor = torch.cat((video_tensor, last_frame.repeat(1, repeat_times, 1, 1)), dim=1)
|
19 |
+
|
20 |
+
return video_tensor
|
21 |
+
|
22 |
+
def video_read_local(filepath, seek_time=0., duration=-1, target_fps=2):
|
23 |
+
vr = VideoReader(filepath, ctx=cpu(0))
|
24 |
+
fps = vr.get_avg_fps()
|
25 |
+
|
26 |
+
if duration > 0:
|
27 |
+
total_frames_to_read = target_fps * duration
|
28 |
+
frame_interval = int(math.ceil(fps / target_fps))
|
29 |
+
start_frame = int(seek_time * fps)
|
30 |
+
end_frame = start_frame + frame_interval * total_frames_to_read
|
31 |
+
frame_ids = list(range(start_frame, min(end_frame, len(vr)), frame_interval))
|
32 |
+
else:
|
33 |
+
frame_ids = list(range(0, len(vr), int(math.ceil(fps / target_fps))))
|
34 |
+
|
35 |
+
frames = vr.get_batch(frame_ids)
|
36 |
+
frames = torch.from_numpy(frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
|
37 |
+
|
38 |
+
resize_transform = transforms.Resize((224, 224))
|
39 |
+
frames = [resize_transform(frame) for frame in frames]
|
40 |
+
video_tensor = torch.stack(frames)
|
41 |
+
video_tensor = einops.rearrange(video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
|
42 |
+
video_tensor = adjust_video_duration(video_tensor, duration, target_fps)
|
43 |
+
assert video_tensor.shape[1] == duration * target_fps, f"the shape of video_tensor is {video_tensor.shape}"
|
44 |
+
|
45 |
+
return video_tensor
|
46 |
+
|
47 |
+
|
48 |
+
def video_read_global(filepath, seek_time=0., duration=-1, target_fps=2, global_mode='average', global_num_frames=32):
|
49 |
+
vr = VideoReader(filepath, ctx=cpu(0))
|
50 |
+
fps = vr.get_avg_fps()
|
51 |
+
frame_count = len(vr)
|
52 |
+
|
53 |
+
if duration > 0:
|
54 |
+
total_frames_to_read = target_fps * duration
|
55 |
+
frame_interval = int(math.ceil(fps / target_fps))
|
56 |
+
start_frame = int(seek_time * fps)
|
57 |
+
end_frame = start_frame + frame_interval * total_frames_to_read
|
58 |
+
frame_ids = list(range(start_frame, min(end_frame, frame_count), frame_interval))
|
59 |
+
else:
|
60 |
+
frame_ids = list(range(0, frame_count, int(math.ceil(fps / target_fps))))
|
61 |
+
|
62 |
+
local_frames = vr.get_batch(frame_ids)
|
63 |
+
local_frames = torch.from_numpy(local_frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
|
64 |
+
|
65 |
+
resize_transform = transforms.Resize((224, 224))
|
66 |
+
local_frames = [resize_transform(frame) for frame in local_frames]
|
67 |
+
local_video_tensor = torch.stack(local_frames)
|
68 |
+
local_video_tensor = einops.rearrange(local_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
|
69 |
+
local_video_tensor = adjust_video_duration(local_video_tensor, duration, target_fps)
|
70 |
+
|
71 |
+
if global_mode=='average':
|
72 |
+
global_frame_ids = torch.linspace(0, frame_count - 1, global_num_frames).long()
|
73 |
+
|
74 |
+
global_frames = vr.get_batch(global_frame_ids)
|
75 |
+
global_frames = torch.from_numpy(global_frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
|
76 |
+
|
77 |
+
global_frames = [resize_transform(frame) for frame in global_frames]
|
78 |
+
global_video_tensor = torch.stack(global_frames)
|
79 |
+
global_video_tensor = einops.rearrange(global_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
|
80 |
+
|
81 |
+
assert global_video_tensor.shape[1] == global_num_frames, f"the shape of global_video_tensor is {global_video_tensor.shape}"
|
82 |
+
return local_video_tensor, global_video_tensor
|
83 |
+
|
audiocraft/data/zip.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Utility for reading some info from inside a zip file.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import typing
|
10 |
+
import zipfile
|
11 |
+
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from functools import lru_cache
|
14 |
+
from typing_extensions import Literal
|
15 |
+
|
16 |
+
|
17 |
+
DEFAULT_SIZE = 32
|
18 |
+
MODE = Literal['r', 'w', 'x', 'a']
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass(order=True)
|
22 |
+
class PathInZip:
|
23 |
+
"""Hold a path of file within a zip file.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
|
27 |
+
Let's assume there is a zip file /some/location/foo.zip
|
28 |
+
and inside of it is a json file located at /data/file1.json,
|
29 |
+
Then we expect path = "/some/location/foo.zip:/data/file1.json".
|
30 |
+
"""
|
31 |
+
|
32 |
+
INFO_PATH_SEP = ':'
|
33 |
+
zip_path: str
|
34 |
+
file_path: str
|
35 |
+
|
36 |
+
def __init__(self, path: str) -> None:
|
37 |
+
split_path = path.split(self.INFO_PATH_SEP)
|
38 |
+
assert len(split_path) == 2
|
39 |
+
self.zip_path, self.file_path = split_path
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def from_paths(cls, zip_path: str, file_path: str):
|
43 |
+
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
|
44 |
+
|
45 |
+
def __str__(self) -> str:
|
46 |
+
return self.zip_path + self.INFO_PATH_SEP + self.file_path
|
47 |
+
|
48 |
+
|
49 |
+
def _open_zip(path: str, mode: MODE = 'r'):
|
50 |
+
return zipfile.ZipFile(path, mode)
|
51 |
+
|
52 |
+
|
53 |
+
_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
|
54 |
+
|
55 |
+
|
56 |
+
def set_zip_cache_size(max_size: int):
|
57 |
+
"""Sets the maximal LRU caching for zip file opening.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
max_size (int): the maximal LRU cache.
|
61 |
+
"""
|
62 |
+
global _cached_open_zip
|
63 |
+
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
64 |
+
|
65 |
+
|
66 |
+
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
|
67 |
+
"""Opens a file stored inside a zip and returns a file-like object.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
|
71 |
+
mode (str): The mode in which to open the file with.
|
72 |
+
Returns:
|
73 |
+
A file-like object for PathInZip.
|
74 |
+
"""
|
75 |
+
zf = _cached_open_zip(path_in_zip.zip_path)
|
76 |
+
return zf.open(path_in_zip.file_path)
|
audiocraft/environment.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Provides cluster and tools configuration across clusters (slurm, dora, utilities).
|
9 |
+
"""
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
from pathlib import Path
|
14 |
+
import re
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import omegaconf
|
18 |
+
|
19 |
+
from .utils.cluster import _guess_cluster_type
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class AudioCraftEnvironment:
|
26 |
+
"""Environment configuration for teams and clusters.
|
27 |
+
|
28 |
+
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
|
29 |
+
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
|
30 |
+
provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
|
31 |
+
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
|
32 |
+
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
|
33 |
+
|
34 |
+
The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
|
35 |
+
Use the following environment variables to specify the cluster, team or configuration:
|
36 |
+
|
37 |
+
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
|
38 |
+
cannot be inferred automatically.
|
39 |
+
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
|
40 |
+
If not set, configuration is read from config/teams.yaml.
|
41 |
+
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
|
42 |
+
Cluster configuration are shared across teams to match compute allocation,
|
43 |
+
specify your cluster configuration in the configuration file under a key mapping
|
44 |
+
your team name.
|
45 |
+
"""
|
46 |
+
_instance = None
|
47 |
+
DEFAULT_TEAM = "default"
|
48 |
+
|
49 |
+
def __init__(self) -> None:
|
50 |
+
"""Loads configuration."""
|
51 |
+
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
|
52 |
+
cluster_type = _guess_cluster_type()
|
53 |
+
cluster = os.getenv(
|
54 |
+
"AUDIOCRAFT_CLUSTER", cluster_type.value
|
55 |
+
)
|
56 |
+
logger.info("Detecting cluster type %s", cluster_type)
|
57 |
+
|
58 |
+
self.cluster: str = cluster
|
59 |
+
|
60 |
+
config_path = os.getenv(
|
61 |
+
"AUDIOCRAFT_CONFIG",
|
62 |
+
Path(__file__)
|
63 |
+
.parent.parent.joinpath("config/teams", self.team)
|
64 |
+
.with_suffix(".yaml"),
|
65 |
+
)
|
66 |
+
self.config = omegaconf.OmegaConf.load(config_path)
|
67 |
+
self._dataset_mappers = []
|
68 |
+
cluster_config = self._get_cluster_config()
|
69 |
+
if "dataset_mappers" in cluster_config:
|
70 |
+
for pattern, repl in cluster_config["dataset_mappers"].items():
|
71 |
+
regex = re.compile(pattern)
|
72 |
+
self._dataset_mappers.append((regex, repl))
|
73 |
+
|
74 |
+
def _get_cluster_config(self) -> omegaconf.DictConfig:
|
75 |
+
assert isinstance(self.config, omegaconf.DictConfig)
|
76 |
+
return self.config[self.cluster]
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def instance(cls):
|
80 |
+
if cls._instance is None:
|
81 |
+
cls._instance = cls()
|
82 |
+
return cls._instance
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def reset(cls):
|
86 |
+
"""Clears the environment and forces a reload on next invocation."""
|
87 |
+
cls._instance = None
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def get_team(cls) -> str:
|
91 |
+
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
|
92 |
+
If not defined, defaults to "labs".
|
93 |
+
"""
|
94 |
+
return cls.instance().team
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def get_cluster(cls) -> str:
|
98 |
+
"""Gets the detected cluster.
|
99 |
+
This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
|
100 |
+
"""
|
101 |
+
return cls.instance().cluster
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def get_dora_dir(cls) -> Path:
|
105 |
+
"""Gets the path to the dora directory for the current team and cluster.
|
106 |
+
Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
|
107 |
+
"""
|
108 |
+
cluster_config = cls.instance()._get_cluster_config()
|
109 |
+
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
|
110 |
+
logger.warning(f"Dora directory: {dora_dir}")
|
111 |
+
return Path(dora_dir)
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def get_reference_dir(cls) -> Path:
|
115 |
+
"""Gets the path to the reference directory for the current team and cluster.
|
116 |
+
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
|
117 |
+
"""
|
118 |
+
cluster_config = cls.instance()._get_cluster_config()
|
119 |
+
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def get_slurm_exclude(cls) -> tp.Optional[str]:
|
123 |
+
"""Get the list of nodes to exclude for that cluster."""
|
124 |
+
cluster_config = cls.instance()._get_cluster_config()
|
125 |
+
return cluster_config.get("slurm_exclude")
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
|
129 |
+
"""Gets the requested partitions for the current team and cluster as a comma-separated string.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
partition_types (list[str], optional): partition types to retrieve. Values must be
|
133 |
+
from ['global', 'team']. If not provided, the global partition is returned.
|
134 |
+
"""
|
135 |
+
if not partition_types:
|
136 |
+
partition_types = ["global"]
|
137 |
+
|
138 |
+
cluster_config = cls.instance()._get_cluster_config()
|
139 |
+
partitions = [
|
140 |
+
cluster_config["partitions"][partition_type]
|
141 |
+
for partition_type in partition_types
|
142 |
+
]
|
143 |
+
return ",".join(partitions)
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
|
147 |
+
"""Converts reference placeholder in path with configured reference dir to resolve paths.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
path (str or Path): Path to resolve.
|
151 |
+
Returns:
|
152 |
+
Path: Resolved path.
|
153 |
+
"""
|
154 |
+
path = str(path)
|
155 |
+
|
156 |
+
if path.startswith("//reference"):
|
157 |
+
reference_dir = cls.get_reference_dir()
|
158 |
+
logger.warn(f"Reference directory: {reference_dir}")
|
159 |
+
assert (
|
160 |
+
reference_dir.exists() and reference_dir.is_dir()
|
161 |
+
), f"Reference directory does not exist: {reference_dir}."
|
162 |
+
path = re.sub("^//reference", str(reference_dir), path)
|
163 |
+
|
164 |
+
return Path(path)
|
165 |
+
|
166 |
+
@classmethod
|
167 |
+
def apply_dataset_mappers(cls, path: str) -> str:
|
168 |
+
"""Applies dataset mapping regex rules as defined in the configuration.
|
169 |
+
If no rules are defined, the path is returned as-is.
|
170 |
+
"""
|
171 |
+
instance = cls.instance()
|
172 |
+
|
173 |
+
for pattern, repl in instance._dataset_mappers:
|
174 |
+
path = pattern.sub(repl, path)
|
175 |
+
|
176 |
+
return path
|
audiocraft/losses/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Loss related classes and functions. In particular the loss balancer from
|
7 |
+
EnCodec, and the usual spectral losses."""
|
8 |
+
|
9 |
+
# flake8: noqa
|
10 |
+
from .balancer import Balancer
|
11 |
+
from .sisnr import SISNR
|
12 |
+
from .stftloss import (
|
13 |
+
LogSTFTMagnitudeLoss,
|
14 |
+
MRSTFTLoss,
|
15 |
+
SpectralConvergenceLoss,
|
16 |
+
STFTLoss
|
17 |
+
)
|
18 |
+
from .specloss import (
|
19 |
+
MelSpectrogramL1Loss,
|
20 |
+
MultiScaleMelSpectrogramLoss,
|
21 |
+
)
|
audiocraft/losses/balancer.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import flashy
|
10 |
+
import torch
|
11 |
+
from torch import autograd
|
12 |
+
|
13 |
+
|
14 |
+
class Balancer:
|
15 |
+
"""Loss balancer.
|
16 |
+
|
17 |
+
The loss balancer combines losses together to compute gradients for the backward.
|
18 |
+
Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
|
19 |
+
not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
|
20 |
+
`d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
|
21 |
+
the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
|
22 |
+
going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
|
23 |
+
interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
|
24 |
+
|
25 |
+
Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
|
26 |
+
(with `avg` an exponential moving average over the updates),
|
27 |
+
|
28 |
+
G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
|
29 |
+
|
30 |
+
If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
|
31 |
+
standard sum of the partial gradients with the given weights.
|
32 |
+
|
33 |
+
A call to the backward method of the balancer will compute the the partial gradients,
|
34 |
+
combining all the losses and potentially rescaling the gradients,
|
35 |
+
which can help stabilize the training and reason about multiple losses with varying scales.
|
36 |
+
The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
|
37 |
+
|
38 |
+
Expected usage:
|
39 |
+
|
40 |
+
weights = {'loss_a': 1, 'loss_b': 4}
|
41 |
+
balancer = Balancer(weights, ...)
|
42 |
+
losses: dict = {}
|
43 |
+
losses['loss_a'] = compute_loss_a(x, y)
|
44 |
+
losses['loss_b'] = compute_loss_b(x, y)
|
45 |
+
if model.training():
|
46 |
+
effective_loss = balancer.backward(losses, x)
|
47 |
+
|
48 |
+
Args:
|
49 |
+
weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
|
50 |
+
from the backward method to match the weights keys to assign weight to each of the provided loss.
|
51 |
+
balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
|
52 |
+
overall gradient, rather than a constant multiplier.
|
53 |
+
total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
|
54 |
+
emay_decay (float): EMA decay for averaging the norms.
|
55 |
+
per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
|
56 |
+
when rescaling the gradients.
|
57 |
+
epsilon (float): Epsilon value for numerical stability.
|
58 |
+
monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
|
59 |
+
coming from each loss, when calling `backward()`.
|
60 |
+
"""
|
61 |
+
def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
|
62 |
+
ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
|
63 |
+
monitor: bool = False):
|
64 |
+
self.weights = weights
|
65 |
+
self.per_batch_item = per_batch_item
|
66 |
+
self.total_norm = total_norm or 1.
|
67 |
+
self.averager = flashy.averager(ema_decay or 1.)
|
68 |
+
self.epsilon = epsilon
|
69 |
+
self.monitor = monitor
|
70 |
+
self.balance_grads = balance_grads
|
71 |
+
self._metrics: tp.Dict[str, tp.Any] = {}
|
72 |
+
|
73 |
+
@property
|
74 |
+
def metrics(self):
|
75 |
+
return self._metrics
|
76 |
+
|
77 |
+
def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
|
78 |
+
"""Compute the backward and return the effective train loss, e.g. the loss obtained from
|
79 |
+
computing the effective weights. If `balance_grads` is True, the effective weights
|
80 |
+
are the one that needs to be applied to each gradient to respect the desired relative
|
81 |
+
scale of gradients coming from each loss.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
|
85 |
+
input (torch.Tensor): the input of the losses, typically the output of the model.
|
86 |
+
This should be the single point of dependence between the losses
|
87 |
+
and the model being trained.
|
88 |
+
"""
|
89 |
+
norms = {}
|
90 |
+
grads = {}
|
91 |
+
for name, loss in losses.items():
|
92 |
+
# Compute partial derivative of the less with respect to the input.
|
93 |
+
grad, = autograd.grad(loss, [input], retain_graph=True)
|
94 |
+
if self.per_batch_item:
|
95 |
+
# We do not average the gradient over the batch dimension.
|
96 |
+
dims = tuple(range(1, grad.dim()))
|
97 |
+
norm = grad.norm(dim=dims, p=2).mean()
|
98 |
+
else:
|
99 |
+
norm = grad.norm(p=2)
|
100 |
+
norms[name] = norm
|
101 |
+
grads[name] = grad
|
102 |
+
|
103 |
+
count = 1
|
104 |
+
if self.per_batch_item:
|
105 |
+
count = len(grad)
|
106 |
+
# Average norms across workers. Theoretically we should average the
|
107 |
+
# squared norm, then take the sqrt, but it worked fine like that.
|
108 |
+
avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
|
109 |
+
# We approximate the total norm of the gradient as the sums of the norms.
|
110 |
+
# Obviously this can be very incorrect if all gradients are aligned, but it works fine.
|
111 |
+
total = sum(avg_norms.values())
|
112 |
+
|
113 |
+
self._metrics = {}
|
114 |
+
if self.monitor:
|
115 |
+
# Store the ratio of the total gradient represented by each loss.
|
116 |
+
for k, v in avg_norms.items():
|
117 |
+
self._metrics[f'ratio_{k}'] = v / total
|
118 |
+
|
119 |
+
total_weights = sum([self.weights[k] for k in avg_norms])
|
120 |
+
assert total_weights > 0.
|
121 |
+
desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
|
122 |
+
|
123 |
+
out_grad = torch.zeros_like(input)
|
124 |
+
effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
|
125 |
+
for name, avg_norm in avg_norms.items():
|
126 |
+
if self.balance_grads:
|
127 |
+
# g_balanced = g / avg(||g||) * total_norm * desired_ratio
|
128 |
+
scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
|
129 |
+
else:
|
130 |
+
# We just do regular weighted sum of the gradients.
|
131 |
+
scale = self.weights[name]
|
132 |
+
out_grad.add_(grads[name], alpha=scale)
|
133 |
+
effective_loss += scale * losses[name].detach()
|
134 |
+
# Send the computed partial derivative with respect to the output of the model to the model.
|
135 |
+
input.backward(out_grad)
|
136 |
+
return effective_loss
|
audiocraft/losses/sisnr.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
|
16 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
17 |
+
with K the kernel size, by extracting frames with the given stride.
|
18 |
+
This will pad the input so that `F = ceil(T / K)`.
|
19 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
20 |
+
"""
|
21 |
+
*shape, length = a.shape
|
22 |
+
n_frames = math.ceil(length / stride)
|
23 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
24 |
+
a = F.pad(a, (0, tgt_length - length))
|
25 |
+
strides = list(a.stride())
|
26 |
+
assert strides[-1] == 1, "data should be contiguous"
|
27 |
+
strides = strides[:-1] + [stride, 1]
|
28 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
29 |
+
|
30 |
+
|
31 |
+
def _center(x: torch.Tensor) -> torch.Tensor:
|
32 |
+
return x - x.mean(-1, True)
|
33 |
+
|
34 |
+
|
35 |
+
def _norm2(x: torch.Tensor) -> torch.Tensor:
|
36 |
+
return x.pow(2).sum(-1, True)
|
37 |
+
|
38 |
+
|
39 |
+
class SISNR(nn.Module):
|
40 |
+
"""SISNR loss.
|
41 |
+
|
42 |
+
Input should be [B, C, T], output is scalar.
|
43 |
+
|
44 |
+
..Warning:: This function returns the opposite of the SI-SNR (e.g. `-1 * regular_SI_SNR`).
|
45 |
+
Consequently, lower scores are better in terms of reconstruction quality,
|
46 |
+
in particular, it should be negative if training goes well. This done this way so
|
47 |
+
that this module can also be used as a loss function for training model.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
sample_rate (int): Sample rate.
|
51 |
+
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
|
52 |
+
entire audio only.
|
53 |
+
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
|
54 |
+
epsilon (float): Epsilon value for numerical stability.
|
55 |
+
"""
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
sample_rate: int = 16000,
|
59 |
+
segment: tp.Optional[float] = 20,
|
60 |
+
overlap: float = 0.5,
|
61 |
+
epsilon: float = torch.finfo(torch.float32).eps,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.sample_rate = sample_rate
|
65 |
+
self.segment = segment
|
66 |
+
self.overlap = overlap
|
67 |
+
self.epsilon = epsilon
|
68 |
+
|
69 |
+
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
|
70 |
+
B, C, T = ref_sig.shape
|
71 |
+
assert ref_sig.shape == out_sig.shape
|
72 |
+
|
73 |
+
if self.segment is None:
|
74 |
+
frame = T
|
75 |
+
stride = T
|
76 |
+
else:
|
77 |
+
frame = int(self.segment * self.sample_rate)
|
78 |
+
stride = int(frame * (1 - self.overlap))
|
79 |
+
|
80 |
+
epsilon = self.epsilon * frame # make epsilon prop to frame size.
|
81 |
+
|
82 |
+
gt = _unfold(ref_sig, frame, stride)
|
83 |
+
est = _unfold(out_sig, frame, stride)
|
84 |
+
if self.segment is None:
|
85 |
+
assert gt.shape[-1] == 1
|
86 |
+
|
87 |
+
gt = _center(gt)
|
88 |
+
est = _center(est)
|
89 |
+
dot = torch.einsum("bcft,bcft->bcf", gt, est)
|
90 |
+
|
91 |
+
proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
|
92 |
+
noise = est - proj
|
93 |
+
|
94 |
+
sisnr = 10 * (
|
95 |
+
torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
|
96 |
+
)
|
97 |
+
return -1 * sisnr[..., 0].mean()
|
audiocraft/losses/specloss.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from torchaudio.transforms import MelSpectrogram
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from ..modules import pad_for_conv1d
|
16 |
+
|
17 |
+
|
18 |
+
class MelSpectrogramWrapper(nn.Module):
|
19 |
+
"""Wrapper around MelSpectrogram torchaudio transform providing proper padding
|
20 |
+
and additional post-processing including log scaling.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
n_mels (int): Number of mel bins.
|
24 |
+
n_fft (int): Number of fft.
|
25 |
+
hop_length (int): Hop size.
|
26 |
+
win_length (int): Window length.
|
27 |
+
n_mels (int): Number of mel bins.
|
28 |
+
sample_rate (int): Sample rate.
|
29 |
+
f_min (float or None): Minimum frequency.
|
30 |
+
f_max (float or None): Maximum frequency.
|
31 |
+
log (bool): Whether to scale with log.
|
32 |
+
normalized (bool): Whether to normalize the melspectrogram.
|
33 |
+
floor_level (float): Floor level based on human perception (default=1e-5).
|
34 |
+
"""
|
35 |
+
def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
|
36 |
+
n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
|
37 |
+
log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
|
38 |
+
super().__init__()
|
39 |
+
self.n_fft = n_fft
|
40 |
+
hop_length = int(hop_length)
|
41 |
+
self.hop_length = hop_length
|
42 |
+
self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
43 |
+
win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
|
44 |
+
window_fn=torch.hann_window, center=False)
|
45 |
+
self.floor_level = floor_level
|
46 |
+
self.log = log
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
p = int((self.n_fft - self.hop_length) // 2)
|
50 |
+
if len(x.shape) == 2:
|
51 |
+
x = x.unsqueeze(1)
|
52 |
+
x = F.pad(x, (p, p), "reflect")
|
53 |
+
# Make sure that all the frames are full.
|
54 |
+
# The combination of `pad_for_conv1d` and the above padding
|
55 |
+
# will make the output of size ceil(T / hop).
|
56 |
+
x = pad_for_conv1d(x, self.n_fft, self.hop_length)
|
57 |
+
self.mel_transform.to(x.device)
|
58 |
+
mel_spec = self.mel_transform(x)
|
59 |
+
B, C, freqs, frame = mel_spec.shape
|
60 |
+
if self.log:
|
61 |
+
mel_spec = torch.log10(self.floor_level + mel_spec)
|
62 |
+
return mel_spec.reshape(B, C * freqs, frame)
|
63 |
+
|
64 |
+
|
65 |
+
class MelSpectrogramL1Loss(torch.nn.Module):
|
66 |
+
"""L1 Loss on MelSpectrogram.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
sample_rate (int): Sample rate.
|
70 |
+
n_fft (int): Number of fft.
|
71 |
+
hop_length (int): Hop size.
|
72 |
+
win_length (int): Window length.
|
73 |
+
n_mels (int): Number of mel bins.
|
74 |
+
f_min (float or None): Minimum frequency.
|
75 |
+
f_max (float or None): Maximum frequency.
|
76 |
+
log (bool): Whether to scale with log.
|
77 |
+
normalized (bool): Whether to normalize the melspectrogram.
|
78 |
+
floor_level (float): Floor level value based on human perception (default=1e-5).
|
79 |
+
"""
|
80 |
+
def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
|
81 |
+
n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
|
82 |
+
log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
|
83 |
+
super().__init__()
|
84 |
+
self.l1 = torch.nn.L1Loss()
|
85 |
+
self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
|
86 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
87 |
+
log=log, normalized=normalized, floor_level=floor_level)
|
88 |
+
|
89 |
+
def forward(self, x, y):
|
90 |
+
self.melspec.to(x.device)
|
91 |
+
s_x = self.melspec(x)
|
92 |
+
s_y = self.melspec(y)
|
93 |
+
return self.l1(s_x, s_y)
|
94 |
+
|
95 |
+
|
96 |
+
class MultiScaleMelSpectrogramLoss(nn.Module):
|
97 |
+
"""Multi-Scale spectrogram loss (msspec).
|
98 |
+
|
99 |
+
Args:
|
100 |
+
sample_rate (int): Sample rate.
|
101 |
+
range_start (int): Power of 2 to use for the first scale.
|
102 |
+
range_stop (int): Power of 2 to use for the last scale.
|
103 |
+
n_mels (int): Number of mel bins.
|
104 |
+
f_min (float): Minimum frequency.
|
105 |
+
f_max (float or None): Maximum frequency.
|
106 |
+
normalized (bool): Whether to normalize the melspectrogram.
|
107 |
+
alphas (bool): Whether to use alphas as coefficients or not.
|
108 |
+
floor_level (float): Floor level value based on human perception (default=1e-5).
|
109 |
+
"""
|
110 |
+
def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
|
111 |
+
n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
|
112 |
+
normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
|
113 |
+
super().__init__()
|
114 |
+
l1s = list()
|
115 |
+
l2s = list()
|
116 |
+
self.alphas = list()
|
117 |
+
self.total = 0
|
118 |
+
self.normalized = normalized
|
119 |
+
for i in range(range_start, range_end):
|
120 |
+
l1s.append(
|
121 |
+
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
|
122 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
123 |
+
log=False, normalized=normalized, floor_level=floor_level))
|
124 |
+
l2s.append(
|
125 |
+
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
|
126 |
+
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
|
127 |
+
log=True, normalized=normalized, floor_level=floor_level))
|
128 |
+
if alphas:
|
129 |
+
self.alphas.append(np.sqrt(2 ** i - 1))
|
130 |
+
else:
|
131 |
+
self.alphas.append(1)
|
132 |
+
self.total += self.alphas[-1] + 1
|
133 |
+
|
134 |
+
self.l1s = nn.ModuleList(l1s)
|
135 |
+
self.l2s = nn.ModuleList(l2s)
|
136 |
+
|
137 |
+
def forward(self, x, y):
|
138 |
+
loss = 0.0
|
139 |
+
self.l1s.to(x.device)
|
140 |
+
self.l2s.to(x.device)
|
141 |
+
for i in range(len(self.alphas)):
|
142 |
+
s_x_1 = self.l1s[i](x)
|
143 |
+
s_y_1 = self.l1s[i](y)
|
144 |
+
s_x_2 = self.l2s[i](x)
|
145 |
+
s_y_2 = self.l2s[i](y)
|
146 |
+
loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
|
147 |
+
if self.normalized:
|
148 |
+
loss = loss / self.total
|
149 |
+
return loss
|
audiocraft/losses/stftloss.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Adapted from MIT code under the original license
|
7 |
+
# Copyright 2019 Tomoki Hayashi
|
8 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
# TODO: Replace with torchaudio.STFT?
|
17 |
+
def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
|
18 |
+
window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor:
|
19 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
x: Input signal tensor (B, C, T).
|
23 |
+
fft_size (int): FFT size.
|
24 |
+
hop_length (int): Hop size.
|
25 |
+
win_length (int): Window length.
|
26 |
+
window (torch.Tensor or None): Window function type.
|
27 |
+
normalized (bool): Whether to normalize the STFT or not.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1).
|
31 |
+
"""
|
32 |
+
B, C, T = x.shape
|
33 |
+
x_stft = torch.stft(
|
34 |
+
x.view(-1, T), fft_size, hop_length, win_length, window,
|
35 |
+
normalized=normalized, return_complex=True,
|
36 |
+
)
|
37 |
+
x_stft = x_stft.view(B, C, *x_stft.shape[1:])
|
38 |
+
real = x_stft.real
|
39 |
+
imag = x_stft.imag
|
40 |
+
|
41 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
42 |
+
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
|
43 |
+
|
44 |
+
|
45 |
+
class SpectralConvergenceLoss(nn.Module):
|
46 |
+
"""Spectral convergence loss.
|
47 |
+
"""
|
48 |
+
def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
|
49 |
+
super().__init__()
|
50 |
+
self.epsilon = epsilon
|
51 |
+
|
52 |
+
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
|
53 |
+
"""Calculate forward propagation.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
57 |
+
y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: Spectral convergence loss value.
|
60 |
+
"""
|
61 |
+
return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
|
62 |
+
|
63 |
+
|
64 |
+
class LogSTFTMagnitudeLoss(nn.Module):
|
65 |
+
"""Log STFT magnitude loss.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
epsilon (float): Epsilon value for numerical stability.
|
69 |
+
"""
|
70 |
+
def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
|
71 |
+
super().__init__()
|
72 |
+
self.epsilon = epsilon
|
73 |
+
|
74 |
+
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
|
75 |
+
"""Calculate forward propagation.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
79 |
+
y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
80 |
+
Returns:
|
81 |
+
torch.Tensor: Log STFT magnitude loss value.
|
82 |
+
"""
|
83 |
+
return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
|
84 |
+
|
85 |
+
|
86 |
+
class STFTLosses(nn.Module):
|
87 |
+
"""STFT losses.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
n_fft (int): Size of FFT.
|
91 |
+
hop_length (int): Hop length.
|
92 |
+
win_length (int): Window length.
|
93 |
+
window (str): Window function type.
|
94 |
+
normalized (bool): Whether to use normalized STFT or not.
|
95 |
+
epsilon (float): Epsilon for numerical stability.
|
96 |
+
"""
|
97 |
+
def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
|
98 |
+
window: str = "hann_window", normalized: bool = False,
|
99 |
+
epsilon: float = torch.finfo(torch.float32).eps):
|
100 |
+
super().__init__()
|
101 |
+
self.n_fft = n_fft
|
102 |
+
self.hop_length = hop_length
|
103 |
+
self.win_length = win_length
|
104 |
+
self.normalized = normalized
|
105 |
+
self.register_buffer("window", getattr(torch, window)(win_length))
|
106 |
+
self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
|
107 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
|
108 |
+
|
109 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
110 |
+
"""Calculate forward propagation.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): Predicted signal (B, T).
|
114 |
+
y (torch.Tensor): Groundtruth signal (B, T).
|
115 |
+
Returns:
|
116 |
+
torch.Tensor: Spectral convergence loss value.
|
117 |
+
torch.Tensor: Log STFT magnitude loss value.
|
118 |
+
"""
|
119 |
+
x_mag = _stft(x, self.n_fft, self.hop_length,
|
120 |
+
self.win_length, self.window, self.normalized) # type: ignore
|
121 |
+
y_mag = _stft(y, self.n_fft, self.hop_length,
|
122 |
+
self.win_length, self.window, self.normalized) # type: ignore
|
123 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
124 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
125 |
+
|
126 |
+
return sc_loss, mag_loss
|
127 |
+
|
128 |
+
|
129 |
+
class STFTLoss(nn.Module):
|
130 |
+
"""Single Resolution STFT loss.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
n_fft (int): Nb of FFT.
|
134 |
+
hop_length (int): Hop length.
|
135 |
+
win_length (int): Window length.
|
136 |
+
window (str): Window function type.
|
137 |
+
normalized (bool): Whether to use normalized STFT or not.
|
138 |
+
epsilon (float): Epsilon for numerical stability.
|
139 |
+
factor_sc (float): Coefficient for the spectral loss.
|
140 |
+
factor_mag (float): Coefficient for the magnitude loss.
|
141 |
+
"""
|
142 |
+
def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
|
143 |
+
window: str = "hann_window", normalized: bool = False,
|
144 |
+
factor_sc: float = 0.1, factor_mag: float = 0.1,
|
145 |
+
epsilon: float = torch.finfo(torch.float32).eps):
|
146 |
+
super().__init__()
|
147 |
+
self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
|
148 |
+
self.factor_sc = factor_sc
|
149 |
+
self.factor_mag = factor_mag
|
150 |
+
|
151 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
152 |
+
"""Calculate forward propagation.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
x (torch.Tensor): Predicted signal (B, T).
|
156 |
+
y (torch.Tensor): Groundtruth signal (B, T).
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: Single resolution STFT loss.
|
159 |
+
"""
|
160 |
+
sc_loss, mag_loss = self.loss(x, y)
|
161 |
+
return self.factor_sc * sc_loss + self.factor_mag * mag_loss
|
162 |
+
|
163 |
+
|
164 |
+
class MRSTFTLoss(nn.Module):
|
165 |
+
"""Multi resolution STFT loss.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
n_ffts (Sequence[int]): Sequence of FFT sizes.
|
169 |
+
hop_lengths (Sequence[int]): Sequence of hop sizes.
|
170 |
+
win_lengths (Sequence[int]): Sequence of window lengths.
|
171 |
+
window (str): Window function type.
|
172 |
+
factor_sc (float): Coefficient for the spectral loss.
|
173 |
+
factor_mag (float): Coefficient for the magnitude loss.
|
174 |
+
normalized (bool): Whether to use normalized STFT or not.
|
175 |
+
epsilon (float): Epsilon for numerical stability.
|
176 |
+
"""
|
177 |
+
def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
|
178 |
+
win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
|
179 |
+
factor_sc: float = 0.1, factor_mag: float = 0.1,
|
180 |
+
normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
|
181 |
+
super().__init__()
|
182 |
+
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
183 |
+
self.stft_losses = torch.nn.ModuleList()
|
184 |
+
for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
|
185 |
+
self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
|
186 |
+
self.factor_sc = factor_sc
|
187 |
+
self.factor_mag = factor_mag
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
190 |
+
"""Calculate forward propagation.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
x (torch.Tensor): Predicted signal (B, T).
|
194 |
+
y (torch.Tensor): Groundtruth signal (B, T).
|
195 |
+
Returns:
|
196 |
+
torch.Tensor: Multi resolution STFT loss.
|
197 |
+
"""
|
198 |
+
sc_loss = torch.Tensor([0.0])
|
199 |
+
mag_loss = torch.Tensor([0.0])
|
200 |
+
for f in self.stft_losses:
|
201 |
+
sc_l, mag_l = f(x, y)
|
202 |
+
sc_loss += sc_l
|
203 |
+
mag_loss += mag_l
|
204 |
+
sc_loss /= len(self.stft_losses)
|
205 |
+
mag_loss /= len(self.stft_losses)
|
206 |
+
|
207 |
+
return self.factor_sc * sc_loss + self.factor_mag * mag_loss
|
audiocraft/metrics/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc.
|
7 |
+
"""
|
8 |
+
# flake8: noqa
|
9 |
+
from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric
|
10 |
+
from .chroma_cosinesim import ChromaCosineSimilarityMetric
|
11 |
+
from .fad import FrechetAudioDistanceMetric
|
12 |
+
from .kld import KLDivergenceMetric, PasstKLDivergenceMetric
|
13 |
+
from .rvm import RelativeVolumeMel
|
14 |
+
from .visqol import ViSQOL
|
audiocraft/metrics/chroma_cosinesim.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchmetrics
|
9 |
+
|
10 |
+
from ..data.audio_utils import convert_audio
|
11 |
+
from ..modules.chroma import ChromaExtractor
|
12 |
+
|
13 |
+
|
14 |
+
class ChromaCosineSimilarityMetric(torchmetrics.Metric):
|
15 |
+
"""Chroma cosine similarity metric.
|
16 |
+
|
17 |
+
This metric extracts a chromagram for a reference waveform and
|
18 |
+
a generated waveform and compares each frame using the cosine similarity
|
19 |
+
function. The output is the mean cosine similarity.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
sample_rate (int): Sample rate used by the chroma extractor.
|
23 |
+
n_chroma (int): Number of chroma used by the chroma extractor.
|
24 |
+
radix2_exp (int): Exponent for the chroma extractor.
|
25 |
+
argmax (bool): Whether the chroma extractor uses argmax.
|
26 |
+
eps (float): Epsilon for cosine similarity computation.
|
27 |
+
"""
|
28 |
+
def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
|
29 |
+
super().__init__()
|
30 |
+
self.chroma_sample_rate = sample_rate
|
31 |
+
self.n_chroma = n_chroma
|
32 |
+
self.eps = eps
|
33 |
+
self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
|
34 |
+
radix2_exp=radix2_exp, argmax=argmax)
|
35 |
+
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
36 |
+
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
|
37 |
+
|
38 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
39 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
40 |
+
"""Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
|
41 |
+
if preds.size(0) == 0:
|
42 |
+
return
|
43 |
+
|
44 |
+
assert preds.shape == targets.shape, (
|
45 |
+
f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
|
46 |
+
assert preds.size(0) == sizes.size(0), (
|
47 |
+
f"Number of items in preds ({preds.shape}) mismatch ",
|
48 |
+
f"with sizes ({sizes.shape})")
|
49 |
+
assert preds.size(0) == sample_rates.size(0), (
|
50 |
+
f"Number of items in preds ({preds.shape}) mismatch ",
|
51 |
+
f"with sample_rates ({sample_rates.shape})")
|
52 |
+
assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
|
53 |
+
|
54 |
+
device = self.weight.device
|
55 |
+
preds, targets = preds.to(device), targets.to(device) # type: ignore
|
56 |
+
sample_rate = sample_rates[0].item()
|
57 |
+
preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
|
58 |
+
targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
|
59 |
+
gt_chroma = self.chroma_extractor(targets)
|
60 |
+
gen_chroma = self.chroma_extractor(preds)
|
61 |
+
chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
|
62 |
+
for i in range(len(gt_chroma)):
|
63 |
+
t = int(chroma_lens[i].item())
|
64 |
+
cosine_sim = torch.nn.functional.cosine_similarity(
|
65 |
+
gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
|
66 |
+
self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
|
67 |
+
self.weight += torch.tensor(t) # type: ignore
|
68 |
+
|
69 |
+
def compute(self) -> float:
|
70 |
+
"""Computes the average cosine similarty across all generated/target chromagrams pairs."""
|
71 |
+
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
|
72 |
+
return (self.cosine_sum / self.weight).item() # type: ignore
|
audiocraft/metrics/clap_consistency.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchmetrics
|
12 |
+
from transformers import RobertaTokenizer # type: ignore
|
13 |
+
|
14 |
+
from ..data.audio_utils import convert_audio
|
15 |
+
from ..environment import AudioCraftEnvironment
|
16 |
+
from ..utils.utils import load_clap_state_dict
|
17 |
+
|
18 |
+
try:
|
19 |
+
import laion_clap # type: ignore
|
20 |
+
except ImportError:
|
21 |
+
laion_clap = None
|
22 |
+
|
23 |
+
|
24 |
+
class TextConsistencyMetric(torchmetrics.Metric):
|
25 |
+
"""Text consistency metric measuring consistency between audio and text pairs."""
|
26 |
+
|
27 |
+
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
28 |
+
raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
|
29 |
+
|
30 |
+
def compute(self):
|
31 |
+
raise NotImplementedError("implement how to compute the final metric score.")
|
32 |
+
|
33 |
+
|
34 |
+
class CLAPTextConsistencyMetric(TextConsistencyMetric):
|
35 |
+
"""Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
|
36 |
+
|
37 |
+
This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
|
38 |
+
or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
|
39 |
+
|
40 |
+
As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
|
41 |
+
similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
|
42 |
+
well as the generated audio based on them, and define the MCC metric as the average cosine similarity
|
43 |
+
between these embeddings.
|
44 |
+
|
45 |
+
Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
|
46 |
+
"""
|
47 |
+
def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
|
48 |
+
super().__init__()
|
49 |
+
if laion_clap is None:
|
50 |
+
raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
|
51 |
+
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
52 |
+
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
|
53 |
+
self._initialize_model(model_path, model_arch, enable_fusion)
|
54 |
+
|
55 |
+
def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
|
56 |
+
model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
|
57 |
+
self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
58 |
+
self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
59 |
+
self.model_sample_rate = 48_000
|
60 |
+
load_clap_state_dict(self.model, model_path)
|
61 |
+
self.model.eval()
|
62 |
+
|
63 |
+
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
64 |
+
# we use the default params from CLAP module here as well
|
65 |
+
return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
66 |
+
|
67 |
+
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
68 |
+
"""Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
|
69 |
+
assert audio.size(0) == len(text), "Number of audio and text samples should match"
|
70 |
+
assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
|
71 |
+
sample_rate = int(sample_rates[0].item())
|
72 |
+
# convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
|
73 |
+
audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
|
74 |
+
audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
|
75 |
+
text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
76 |
+
# cosine similarity between the text and the audio embedding
|
77 |
+
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
|
78 |
+
self.cosine_sum += cosine_sim.sum(dim=0)
|
79 |
+
self.weight += torch.tensor(cosine_sim.size(0))
|
80 |
+
|
81 |
+
def compute(self):
|
82 |
+
"""Computes the average cosine similarty across all audio/text pairs."""
|
83 |
+
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
|
84 |
+
return (self.cosine_sum / self.weight).item() # type: ignore
|
audiocraft/metrics/fad.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from pathlib import Path
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import tempfile
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
from audiocraft.data.audio import audio_write
|
15 |
+
from audiocraft.data.audio_utils import convert_audio
|
16 |
+
import flashy
|
17 |
+
import torch
|
18 |
+
import torchmetrics
|
19 |
+
|
20 |
+
from ..environment import AudioCraftEnvironment
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
VGGISH_SAMPLE_RATE = 16_000
|
26 |
+
VGGISH_CHANNELS = 1
|
27 |
+
|
28 |
+
|
29 |
+
class FrechetAudioDistanceMetric(torchmetrics.Metric):
|
30 |
+
"""Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
|
31 |
+
|
32 |
+
From: D.C. Dowson & B.V. Landau The Fréchet distance between
|
33 |
+
multivariate normal distributions
|
34 |
+
https://doi.org/10.1016/0047-259X(82)90077-X
|
35 |
+
The Fréchet distance between two multivariate gaussians,
|
36 |
+
`X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
|
37 |
+
d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
|
38 |
+
= (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
|
39 |
+
- 2 * Tr(sqrt(sigma_x*sigma_y)))
|
40 |
+
|
41 |
+
To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
|
42 |
+
from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
|
43 |
+
We provide the below instructions as reference but we do not guarantee for further support
|
44 |
+
in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
|
45 |
+
|
46 |
+
We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
|
47 |
+
|
48 |
+
1. Get the code and models following the repository instructions. We used the steps below:
|
49 |
+
git clone git@github.com:google-research/google-research.git
|
50 |
+
git clone git@github.com:tensorflow/models.git
|
51 |
+
mkdir google-research/tensorflow_models
|
52 |
+
touch google-research/tensorflow_models/__init__.py
|
53 |
+
cp -r models/research/audioset google-research/tensorflow_models/
|
54 |
+
touch google-research/tensorflow_models/audioset/__init__.py
|
55 |
+
echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
|
56 |
+
google-research/tensorflow_models/audioset/__init__.py
|
57 |
+
# we can now remove the tensorflow models repository
|
58 |
+
# rm -r models
|
59 |
+
cd google-research
|
60 |
+
Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
|
61 |
+
assumes it is placed in the AudioCraft reference dir.
|
62 |
+
|
63 |
+
Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
|
64 |
+
- Update xrange for range in:
|
65 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
|
66 |
+
- Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
|
67 |
+
`tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
|
68 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
|
69 |
+
- Update `import vggish_params as params` to `from . import vggish_params as params` in:
|
70 |
+
https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
|
71 |
+
- Add flag to provide a given batch size for running the AudioSet model in:
|
72 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
|
73 |
+
```
|
74 |
+
flags.DEFINE_integer('batch_size', 64,
|
75 |
+
'Number of samples in the batch for AudioSet model.')
|
76 |
+
```
|
77 |
+
Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
|
78 |
+
`batch_size=FLAGS.batch_size` to the provided parameters.
|
79 |
+
|
80 |
+
2. Follow instructions for the library installation and a valid TensorFlow installation
|
81 |
+
```
|
82 |
+
# e.g. instructions from: https://www.tensorflow.org/install/pip
|
83 |
+
conda install -c conda-forge cudatoolkit=11.8.0
|
84 |
+
python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
|
85 |
+
mkdir -p $CONDA_PREFIX/etc/conda/activate.d
|
86 |
+
echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
|
87 |
+
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
88 |
+
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
|
89 |
+
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
90 |
+
source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
91 |
+
# Verify install: on a machine with GPU device
|
92 |
+
python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
|
93 |
+
```
|
94 |
+
|
95 |
+
Now install frechet_audio_distance required dependencies:
|
96 |
+
```
|
97 |
+
# We assume we already have TensorFlow installed from the above steps
|
98 |
+
pip install apache-beam numpy scipy tf_slim
|
99 |
+
```
|
100 |
+
|
101 |
+
Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
|
102 |
+
(you may want to specify --model_ckpt flag pointing to the model's path).
|
103 |
+
|
104 |
+
3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
|
105 |
+
and Tensorflow library path from the above installation steps:
|
106 |
+
export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
|
107 |
+
export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
|
108 |
+
|
109 |
+
e.g. assuming we have installed everything in a dedicated conda env
|
110 |
+
with python 3.10 that is currently active:
|
111 |
+
export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
|
112 |
+
export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
|
113 |
+
|
114 |
+
Finally you may want to export the following variable:
|
115 |
+
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
116 |
+
See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
|
117 |
+
|
118 |
+
You can save those environment variables in your training conda env, when currently active:
|
119 |
+
`$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
|
120 |
+
e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
|
121 |
+
and the training conda env is named audiocraft:
|
122 |
+
```
|
123 |
+
# activate training env
|
124 |
+
conda activate audiocraft
|
125 |
+
# get path to all envs
|
126 |
+
CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
|
127 |
+
# export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
|
128 |
+
touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
129 |
+
echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
|
130 |
+
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
131 |
+
echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
|
132 |
+
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
133 |
+
# optionally:
|
134 |
+
echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
135 |
+
# you may need to reactivate the audiocraft env for this to take effect
|
136 |
+
```
|
137 |
+
|
138 |
+
Args:
|
139 |
+
bin (Path or str): Path to installed frechet audio distance code.
|
140 |
+
model_path (Path or str): Path to Tensorflow checkpoint for the model
|
141 |
+
used to compute statistics over the embedding beams.
|
142 |
+
format (str): Audio format used to save files.
|
143 |
+
log_folder (Path or str, optional): Path where to write process logs.
|
144 |
+
"""
|
145 |
+
def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
|
146 |
+
format: str = "wav", batch_size: tp.Optional[int] = None,
|
147 |
+
log_folder: tp.Optional[tp.Union[Path, str]] = None):
|
148 |
+
super().__init__()
|
149 |
+
self.model_sample_rate = VGGISH_SAMPLE_RATE
|
150 |
+
self.model_channels = VGGISH_CHANNELS
|
151 |
+
self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
|
152 |
+
assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
|
153 |
+
self.format = format
|
154 |
+
self.batch_size = batch_size
|
155 |
+
self.bin = bin
|
156 |
+
self.tf_env = {"PYTHONPATH": str(self.bin)}
|
157 |
+
self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
|
158 |
+
logger.info("Python exe for TF is %s", self.python_path)
|
159 |
+
if 'TF_LIBRARY_PATH' in os.environ:
|
160 |
+
self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
|
161 |
+
if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
|
162 |
+
self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
|
163 |
+
logger.info("Env for TF is %r", self.tf_env)
|
164 |
+
self.reset(log_folder)
|
165 |
+
self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
|
166 |
+
|
167 |
+
def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
|
168 |
+
"""Reset torchmetrics.Metrics state."""
|
169 |
+
log_folder = Path(log_folder or tempfile.mkdtemp())
|
170 |
+
self.tmp_dir = log_folder / 'fad'
|
171 |
+
self.tmp_dir.mkdir(exist_ok=True)
|
172 |
+
self.samples_tests_dir = self.tmp_dir / 'tests'
|
173 |
+
self.samples_tests_dir.mkdir(exist_ok=True)
|
174 |
+
self.samples_background_dir = self.tmp_dir / 'background'
|
175 |
+
self.samples_background_dir.mkdir(exist_ok=True)
|
176 |
+
self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
|
177 |
+
self.manifest_background = self.tmp_dir / 'files_background.cvs'
|
178 |
+
self.stats_tests_dir = self.tmp_dir / 'stats_tests'
|
179 |
+
self.stats_background_dir = self.tmp_dir / 'stats_background'
|
180 |
+
self.counter = 0
|
181 |
+
|
182 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
183 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor,
|
184 |
+
stems: tp.Optional[tp.List[str]] = None):
|
185 |
+
"""Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
|
186 |
+
assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
|
187 |
+
num_samples = preds.shape[0]
|
188 |
+
assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
|
189 |
+
assert stems is None or num_samples == len(set(stems))
|
190 |
+
for i in range(num_samples):
|
191 |
+
self.total_files += 1 # type: ignore
|
192 |
+
self.counter += 1
|
193 |
+
wav_len = int(sizes[i].item())
|
194 |
+
sample_rate = int(sample_rates[i].item())
|
195 |
+
pred_wav = preds[i]
|
196 |
+
target_wav = targets[i]
|
197 |
+
pred_wav = pred_wav[..., :wav_len]
|
198 |
+
target_wav = target_wav[..., :wav_len]
|
199 |
+
stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
|
200 |
+
# dump audio files
|
201 |
+
try:
|
202 |
+
pred_wav = convert_audio(
|
203 |
+
pred_wav.unsqueeze(0), from_rate=sample_rate,
|
204 |
+
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
|
205 |
+
audio_write(
|
206 |
+
self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
|
207 |
+
format=self.format, strategy="peak")
|
208 |
+
except Exception as e:
|
209 |
+
logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
|
210 |
+
try:
|
211 |
+
# for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
|
212 |
+
# the original audio when writing it
|
213 |
+
target_wav = convert_audio(
|
214 |
+
target_wav.unsqueeze(0), from_rate=sample_rate,
|
215 |
+
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
|
216 |
+
audio_write(
|
217 |
+
self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
|
218 |
+
format=self.format, strategy="peak")
|
219 |
+
except Exception as e:
|
220 |
+
logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
|
221 |
+
|
222 |
+
def _get_samples_name(self, is_background: bool):
|
223 |
+
return 'background' if is_background else 'tests'
|
224 |
+
|
225 |
+
def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
|
226 |
+
if is_background:
|
227 |
+
input_samples_dir = self.samples_background_dir
|
228 |
+
input_filename = self.manifest_background
|
229 |
+
stats_name = self.stats_background_dir
|
230 |
+
else:
|
231 |
+
input_samples_dir = self.samples_tests_dir
|
232 |
+
input_filename = self.manifest_tests
|
233 |
+
stats_name = self.stats_tests_dir
|
234 |
+
beams_name = self._get_samples_name(is_background)
|
235 |
+
log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
|
236 |
+
|
237 |
+
logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
|
238 |
+
with open(input_filename, "w") as fout:
|
239 |
+
for path in Path(input_samples_dir).glob(f"*.{self.format}"):
|
240 |
+
fout.write(f"{str(path)}\n")
|
241 |
+
|
242 |
+
cmd = [
|
243 |
+
self.python_path, "-m",
|
244 |
+
"frechet_audio_distance.create_embeddings_main",
|
245 |
+
"--model_ckpt", f"{self.model_path}",
|
246 |
+
"--input_files", f"{str(input_filename)}",
|
247 |
+
"--stats", f"{str(stats_name)}",
|
248 |
+
]
|
249 |
+
if self.batch_size is not None:
|
250 |
+
cmd += ["--batch_size", str(self.batch_size)]
|
251 |
+
logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
|
252 |
+
env = os.environ
|
253 |
+
if gpu_index is not None:
|
254 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
255 |
+
process = subprocess.Popen(
|
256 |
+
cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
|
257 |
+
return process, log_file
|
258 |
+
|
259 |
+
def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
|
260 |
+
cmd = [
|
261 |
+
self.python_path, "-m", "frechet_audio_distance.compute_fad",
|
262 |
+
"--test_stats", f"{str(self.stats_tests_dir)}",
|
263 |
+
"--background_stats", f"{str(self.stats_background_dir)}",
|
264 |
+
]
|
265 |
+
logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
|
266 |
+
env = os.environ
|
267 |
+
if gpu_index is not None:
|
268 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
269 |
+
result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
|
270 |
+
if result.returncode:
|
271 |
+
logger.error(
|
272 |
+
"Error with FAD computation from stats: \n %s \n %s",
|
273 |
+
result.stdout.decode(), result.stderr.decode()
|
274 |
+
)
|
275 |
+
raise RuntimeError("Error while executing FAD computation from stats")
|
276 |
+
try:
|
277 |
+
# result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
|
278 |
+
fad_score = float(result.stdout[4:])
|
279 |
+
return fad_score
|
280 |
+
except Exception as e:
|
281 |
+
raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
|
282 |
+
|
283 |
+
def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
|
284 |
+
beams_name = self._get_samples_name(is_background)
|
285 |
+
if returncode:
|
286 |
+
with open(log_file, "r") as f:
|
287 |
+
error_log = f.read()
|
288 |
+
logger.error(error_log)
|
289 |
+
os._exit(1)
|
290 |
+
else:
|
291 |
+
logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
|
292 |
+
|
293 |
+
def _parallel_create_embedding_beams(self, num_of_gpus: int):
|
294 |
+
assert num_of_gpus > 0
|
295 |
+
logger.info("Creating embeddings beams in a parallel manner on different GPUs")
|
296 |
+
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
|
297 |
+
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
|
298 |
+
tests_beams_code = tests_beams_process.wait()
|
299 |
+
bg_beams_code = bg_beams_process.wait()
|
300 |
+
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
|
301 |
+
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
|
302 |
+
|
303 |
+
def _sequential_create_embedding_beams(self):
|
304 |
+
logger.info("Creating embeddings beams in a sequential manner")
|
305 |
+
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
|
306 |
+
tests_beams_code = tests_beams_process.wait()
|
307 |
+
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
|
308 |
+
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
|
309 |
+
bg_beams_code = bg_beams_process.wait()
|
310 |
+
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
|
311 |
+
|
312 |
+
@flashy.distrib.rank_zero_only
|
313 |
+
def _local_compute_frechet_audio_distance(self):
|
314 |
+
"""Compute Frechet Audio Distance score calling TensorFlow API."""
|
315 |
+
num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
316 |
+
if num_of_gpus > 1:
|
317 |
+
self._parallel_create_embedding_beams(num_of_gpus)
|
318 |
+
else:
|
319 |
+
self._sequential_create_embedding_beams()
|
320 |
+
fad_score = self._compute_fad_score(gpu_index=0)
|
321 |
+
return fad_score
|
322 |
+
|
323 |
+
def compute(self) -> float:
|
324 |
+
"""Compute metrics."""
|
325 |
+
assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore
|
326 |
+
fad_score = self._local_compute_frechet_audio_distance()
|
327 |
+
logger.warning(f"FAD score = {fad_score}")
|
328 |
+
fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
|
329 |
+
return fad_score
|
audiocraft/metrics/kld.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import contextlib
|
8 |
+
from functools import partial
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torchmetrics
|
15 |
+
|
16 |
+
from ..data.audio_utils import convert_audio
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class _patch_passt_stft:
|
23 |
+
"""Decorator to patch torch.stft in PaSST."""
|
24 |
+
def __init__(self):
|
25 |
+
self.old_stft = torch.stft
|
26 |
+
|
27 |
+
def __enter__(self):
|
28 |
+
# return_complex is a mandatory parameter in latest torch versions
|
29 |
+
# torch is throwing RuntimeErrors when not set
|
30 |
+
torch.stft = partial(torch.stft, return_complex=False)
|
31 |
+
|
32 |
+
def __exit__(self, *exc):
|
33 |
+
torch.stft = self.old_stft
|
34 |
+
|
35 |
+
|
36 |
+
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
|
37 |
+
"""Computes the elementwise KL-Divergence loss between probability distributions
|
38 |
+
from generated samples and target samples.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
pred_probs (torch.Tensor): Probabilities for each label obtained
|
42 |
+
from a classifier on generated audio. Expected shape is [B, num_classes].
|
43 |
+
target_probs (torch.Tensor): Probabilities for each label obtained
|
44 |
+
from a classifier on target audio. Expected shape is [B, num_classes].
|
45 |
+
epsilon (float): Epsilon value.
|
46 |
+
Returns:
|
47 |
+
kld (torch.Tensor): KLD loss between each generated sample and target pair.
|
48 |
+
"""
|
49 |
+
kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
|
50 |
+
return kl_div.sum(-1)
|
51 |
+
|
52 |
+
|
53 |
+
class KLDivergenceMetric(torchmetrics.Metric):
|
54 |
+
"""Base implementation for KL Divergence metric.
|
55 |
+
|
56 |
+
The KL divergence is measured between probability distributions
|
57 |
+
of class predictions returned by a pre-trained audio classification model.
|
58 |
+
When the KL-divergence is low, the generated audio is expected to
|
59 |
+
have similar acoustic characteristics as the reference audio,
|
60 |
+
according to the classifier.
|
61 |
+
"""
|
62 |
+
def __init__(self):
|
63 |
+
super().__init__()
|
64 |
+
self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
65 |
+
self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
66 |
+
self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
67 |
+
self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
|
68 |
+
|
69 |
+
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
|
70 |
+
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
|
71 |
+
"""Get model output given provided input tensor.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
x (torch.Tensor): Input audio tensor of shape [B, C, T].
|
75 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
76 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
77 |
+
Returns:
|
78 |
+
probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
|
79 |
+
"""
|
80 |
+
raise NotImplementedError("implement method to extract label distributions from the model.")
|
81 |
+
|
82 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
83 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
84 |
+
"""Calculates running KL-Divergence loss between batches of audio
|
85 |
+
preds (generated) and target (ground-truth)
|
86 |
+
Args:
|
87 |
+
preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
|
88 |
+
targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
|
89 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
90 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
91 |
+
"""
|
92 |
+
assert preds.shape == targets.shape
|
93 |
+
assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
|
94 |
+
preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
|
95 |
+
targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
|
96 |
+
if preds_probs is not None and targets_probs is not None:
|
97 |
+
assert preds_probs.shape == targets_probs.shape
|
98 |
+
kld_scores = kl_divergence(preds_probs, targets_probs)
|
99 |
+
assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
|
100 |
+
self.kld_pq_sum += torch.sum(kld_scores)
|
101 |
+
kld_qp_scores = kl_divergence(targets_probs, preds_probs)
|
102 |
+
self.kld_qp_sum += torch.sum(kld_qp_scores)
|
103 |
+
self.weight += torch.tensor(kld_scores.size(0))
|
104 |
+
|
105 |
+
def compute(self) -> dict:
|
106 |
+
"""Computes KL-Divergence across all evaluated pred/target pairs."""
|
107 |
+
weight: float = float(self.weight.item()) # type: ignore
|
108 |
+
assert weight > 0, "Unable to compute with total number of comparisons <= 0"
|
109 |
+
logger.info(f"Computing KL divergence on a total of {weight} samples")
|
110 |
+
kld_pq = self.kld_pq_sum.item() / weight # type: ignore
|
111 |
+
kld_qp = self.kld_qp_sum.item() / weight # type: ignore
|
112 |
+
kld_both = kld_pq + kld_qp
|
113 |
+
return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
|
114 |
+
|
115 |
+
|
116 |
+
class PasstKLDivergenceMetric(KLDivergenceMetric):
|
117 |
+
"""KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
|
118 |
+
|
119 |
+
From: PaSST: Efficient Training of Audio Transformers with Patchout
|
120 |
+
Paper: https://arxiv.org/abs/2110.05069
|
121 |
+
Implementation: https://github.com/kkoutini/PaSST
|
122 |
+
|
123 |
+
Follow instructions from the github repo:
|
124 |
+
```
|
125 |
+
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
|
126 |
+
```
|
127 |
+
|
128 |
+
Args:
|
129 |
+
pretrained_length (float, optional): Audio duration used for the pretrained model.
|
130 |
+
"""
|
131 |
+
def __init__(self, pretrained_length: tp.Optional[float] = None):
|
132 |
+
super().__init__()
|
133 |
+
self._initialize_model(pretrained_length)
|
134 |
+
|
135 |
+
def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
|
136 |
+
"""Initialize underlying PaSST audio classifier."""
|
137 |
+
model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
|
138 |
+
self.min_input_frames = min_frames
|
139 |
+
self.max_input_frames = max_frames
|
140 |
+
self.model_sample_rate = sr
|
141 |
+
self.model = model
|
142 |
+
self.model.eval()
|
143 |
+
self.model.to(self.device)
|
144 |
+
|
145 |
+
def _load_base_model(self, pretrained_length: tp.Optional[float]):
|
146 |
+
"""Load pretrained model from PaSST."""
|
147 |
+
try:
|
148 |
+
if pretrained_length == 30:
|
149 |
+
from hear21passt.base30sec import get_basic_model # type: ignore
|
150 |
+
max_duration = 30
|
151 |
+
elif pretrained_length == 20:
|
152 |
+
from hear21passt.base20sec import get_basic_model # type: ignore
|
153 |
+
max_duration = 20
|
154 |
+
else:
|
155 |
+
from hear21passt.base import get_basic_model # type: ignore
|
156 |
+
# Original PASST was trained on AudioSet with 10s-long audio samples
|
157 |
+
max_duration = 10
|
158 |
+
min_duration = 0.15
|
159 |
+
min_duration = 0.15
|
160 |
+
except ModuleNotFoundError:
|
161 |
+
raise ModuleNotFoundError(
|
162 |
+
"Please install hear21passt to compute KL divergence: ",
|
163 |
+
"pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
|
164 |
+
)
|
165 |
+
model_sample_rate = 32_000
|
166 |
+
max_input_frames = int(max_duration * model_sample_rate)
|
167 |
+
min_input_frames = int(min_duration * model_sample_rate)
|
168 |
+
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
|
169 |
+
model = get_basic_model(mode='logits')
|
170 |
+
return model, model_sample_rate, max_input_frames, min_input_frames
|
171 |
+
|
172 |
+
def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
|
173 |
+
"""Process audio to feed to the pretrained model."""
|
174 |
+
wav = wav.unsqueeze(0)
|
175 |
+
wav = wav[..., :wav_len]
|
176 |
+
wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
|
177 |
+
wav = wav.squeeze(0)
|
178 |
+
# we don't pad but return a list of audio segments as this otherwise affects the KLD computation
|
179 |
+
segments = torch.split(wav, self.max_input_frames, dim=-1)
|
180 |
+
valid_segments = []
|
181 |
+
for s in segments:
|
182 |
+
# ignoring too small segments that are breaking the model inference
|
183 |
+
if s.size(-1) > self.min_input_frames:
|
184 |
+
valid_segments.append(s)
|
185 |
+
return [s[None] for s in valid_segments]
|
186 |
+
|
187 |
+
def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
|
188 |
+
"""Run the pretrained model and get the predictions."""
|
189 |
+
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
|
190 |
+
wav = wav.mean(dim=1)
|
191 |
+
# PaSST is printing a lot of garbage that we are not interested in
|
192 |
+
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
|
193 |
+
with torch.no_grad(), _patch_passt_stft():
|
194 |
+
logits = self.model(wav.to(self.device))
|
195 |
+
probs = torch.softmax(logits, dim=-1)
|
196 |
+
return probs
|
197 |
+
|
198 |
+
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
|
199 |
+
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
|
200 |
+
"""Get model output given provided input tensor.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
x (torch.Tensor): Input audio tensor of shape [B, C, T].
|
204 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
205 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
206 |
+
Returns:
|
207 |
+
probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
|
208 |
+
"""
|
209 |
+
all_probs: tp.List[torch.Tensor] = []
|
210 |
+
for i, wav in enumerate(x):
|
211 |
+
sample_rate = int(sample_rates[i].item())
|
212 |
+
wav_len = int(sizes[i].item())
|
213 |
+
wav_segments = self._process_audio(wav, sample_rate, wav_len)
|
214 |
+
for segment in wav_segments:
|
215 |
+
probs = self._get_model_preds(segment).mean(dim=0)
|
216 |
+
all_probs.append(probs)
|
217 |
+
if len(all_probs) > 0:
|
218 |
+
return torch.stack(all_probs, dim=0)
|
219 |
+
else:
|
220 |
+
return None
|
audiocraft/metrics/rvm.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torchaudio
|
11 |
+
|
12 |
+
|
13 |
+
def db_to_scale(volume: tp.Union[float, torch.Tensor]):
|
14 |
+
return 10 ** (volume / 20)
|
15 |
+
|
16 |
+
|
17 |
+
def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
|
18 |
+
min_scale = db_to_scale(min_volume)
|
19 |
+
return 20 * torch.log10(scale.clamp(min=min_scale))
|
20 |
+
|
21 |
+
|
22 |
+
class RelativeVolumeMel(nn.Module):
|
23 |
+
"""Relative volume melspectrogram measure.
|
24 |
+
|
25 |
+
Computes a measure of distance over two mel spectrogram that is interpretable in terms
|
26 |
+
of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
|
27 |
+
first renormalize both by the ground truth of `x_ref`.
|
28 |
+
|
29 |
+
..Warning:: This class returns the volume of the distortion at the spectrogram level,
|
30 |
+
e.g. low negative values reflects lower distortion levels. For a SNR (like reported
|
31 |
+
in the MultiBandDiffusion paper), just take `-rvm`.
|
32 |
+
|
33 |
+
Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
|
34 |
+
relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
|
35 |
+
clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
|
36 |
+
with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
|
37 |
+
Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
|
38 |
+
average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
|
39 |
+
good (for a neural network output, although sound engineers typically aim for much lower attenuations).
|
40 |
+
Similarly, anything above +30 dB would just be completely missing the target, and there is no point
|
41 |
+
in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
|
42 |
+
in line with what neural nets currently can achieve.
|
43 |
+
|
44 |
+
For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
|
45 |
+
the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
|
46 |
+
|
47 |
+
The metric can be aggregated over a given frequency band in order have different insights for
|
48 |
+
different region of the spectrum. `num_aggregated_bands` controls the number of bands.
|
49 |
+
|
50 |
+
..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
|
51 |
+
is numerically stable when computing its gradient. We thus advise against using it as a training loss.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
sample_rate (int): Sample rate of the input audio.
|
55 |
+
n_mels (int): Number of mel bands to use.
|
56 |
+
n_fft (int): Number of frequency bins for the STFT.
|
57 |
+
hop_length (int): Hop length of the STFT and the mel-spectrogram.
|
58 |
+
min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
|
59 |
+
the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
|
60 |
+
max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
|
61 |
+
max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
|
62 |
+
to that amount, to avoid rescaling near silence. Given in dB.
|
63 |
+
min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
|
64 |
+
bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
|
65 |
+
and anything below that will be considered equally.
|
66 |
+
num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
|
67 |
+
For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
|
68 |
+
"""
|
69 |
+
def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
|
70 |
+
hop_length: int = 128, min_relative_volume: float = -25,
|
71 |
+
max_relative_volume: float = 25, max_initial_gain: float = 25,
|
72 |
+
min_activity_volume: float = -25,
|
73 |
+
num_aggregated_bands: int = 4) -> None:
|
74 |
+
super().__init__()
|
75 |
+
self.melspec = torchaudio.transforms.MelSpectrogram(
|
76 |
+
n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
|
77 |
+
normalized=True, sample_rate=sample_rate, power=2)
|
78 |
+
self.min_relative_volume = min_relative_volume
|
79 |
+
self.max_relative_volume = max_relative_volume
|
80 |
+
self.max_initial_gain = max_initial_gain
|
81 |
+
self.min_activity_volume = min_activity_volume
|
82 |
+
self.num_aggregated_bands = num_aggregated_bands
|
83 |
+
|
84 |
+
def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
|
85 |
+
"""Compute RVM metric between estimate and reference samples.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
estimate (torch.Tensor): Estimate sample.
|
89 |
+
ground_truth (torch.Tensor): Reference sample.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
|
93 |
+
for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
|
94 |
+
"""
|
95 |
+
min_scale = db_to_scale(-self.max_initial_gain)
|
96 |
+
std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
|
97 |
+
z_gt = self.melspec(ground_truth / std).sqrt()
|
98 |
+
z_est = self.melspec(estimate / std).sqrt()
|
99 |
+
|
100 |
+
delta = z_gt - z_est
|
101 |
+
ref_db = scale_to_db(z_gt, self.min_activity_volume)
|
102 |
+
delta_db = scale_to_db(delta.abs(), min_volume=-120)
|
103 |
+
relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
|
104 |
+
dims = list(range(relative_db.dim()))
|
105 |
+
dims.remove(dims[-2])
|
106 |
+
losses_per_band = relative_db.mean(dim=dims)
|
107 |
+
aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
|
108 |
+
metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
|
109 |
+
metrics['rvm'] = losses_per_band.mean()
|
110 |
+
return metrics
|
audiocraft/metrics/visqol.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import csv
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
import tempfile
|
12 |
+
import typing as tp
|
13 |
+
import subprocess
|
14 |
+
import shutil
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torchaudio
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class ViSQOL:
|
23 |
+
"""ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
|
24 |
+
|
25 |
+
To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
|
26 |
+
instructions available in the open source repository: https://github.com/google/visqol
|
27 |
+
|
28 |
+
ViSQOL is capable of running in two modes:
|
29 |
+
|
30 |
+
Audio Mode:
|
31 |
+
When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
|
32 |
+
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
|
33 |
+
Audio mode uses support vector regression, with the maximum range at ~4.75.
|
34 |
+
|
35 |
+
Speech Mode:
|
36 |
+
When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
|
37 |
+
Input should be resampled to 16kHz.
|
38 |
+
As part of the speech mode processing, a root mean square implementation for voice activity detection
|
39 |
+
is performed on the reference signal to determine what parts of the signal have voice activity and
|
40 |
+
should therefore be included in the comparison. The signal is normalized before performing the voice
|
41 |
+
activity detection.
|
42 |
+
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
|
43 |
+
Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
|
44 |
+
|
45 |
+
For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
|
46 |
+
|
47 |
+
Args:
|
48 |
+
visqol_bin (str): Path to the ViSQOL binary.
|
49 |
+
mode (str): ViSQOL computation mode, expecting "audio" or "speech".
|
50 |
+
model (str): Name of the model to use for similarity to quality model.
|
51 |
+
debug (bool): Whether to also get debug metrics from ViSQOL or not.
|
52 |
+
"""
|
53 |
+
SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
|
54 |
+
ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
|
55 |
+
|
56 |
+
def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
|
57 |
+
model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
|
58 |
+
assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
|
59 |
+
self.visqol_bin = str(bin)
|
60 |
+
self.visqol_mode = mode
|
61 |
+
self.target_sr = self._get_target_sr(self.visqol_mode)
|
62 |
+
self.model = model
|
63 |
+
self.debug = debug
|
64 |
+
assert Path(self.visqol_model).exists(), \
|
65 |
+
f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
|
66 |
+
|
67 |
+
def _get_target_sr(self, mode: str) -> int:
|
68 |
+
# returns target sampling rate for the corresponding ViSQOL mode.
|
69 |
+
if mode not in ViSQOL.SAMPLE_RATES_MODES:
|
70 |
+
raise ValueError(
|
71 |
+
f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
|
72 |
+
)
|
73 |
+
return ViSQOL.SAMPLE_RATES_MODES[mode]
|
74 |
+
|
75 |
+
def _prepare_files(
|
76 |
+
self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
|
77 |
+
):
|
78 |
+
# prepare files for ViSQOL evaluation.
|
79 |
+
assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
|
80 |
+
assert len(ref_sig) == len(deg_sig), (
|
81 |
+
"Expects same number of ref and degraded inputs",
|
82 |
+
f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
|
83 |
+
)
|
84 |
+
# resample audio if needed
|
85 |
+
if sr != target_sr:
|
86 |
+
transform = torchaudio.transforms.Resample(sr, target_sr)
|
87 |
+
pad = int(0.5 * target_sr)
|
88 |
+
rs_ref = []
|
89 |
+
rs_deg = []
|
90 |
+
for i in range(len(ref_sig)):
|
91 |
+
rs_ref_i = transform(ref_sig[i])
|
92 |
+
rs_deg_i = transform(deg_sig[i])
|
93 |
+
if pad_with_silence:
|
94 |
+
rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
|
95 |
+
rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
|
96 |
+
rs_ref.append(rs_ref_i)
|
97 |
+
rs_deg.append(rs_deg_i)
|
98 |
+
ref_sig = torch.stack(rs_ref)
|
99 |
+
deg_sig = torch.stack(rs_deg)
|
100 |
+
# save audio chunks to tmp dir and create csv
|
101 |
+
tmp_dir = Path(tempfile.mkdtemp())
|
102 |
+
try:
|
103 |
+
tmp_input_csv_path = tmp_dir / "input.csv"
|
104 |
+
tmp_results_csv_path = tmp_dir / "results.csv"
|
105 |
+
tmp_debug_json_path = tmp_dir / "debug.json"
|
106 |
+
with open(tmp_input_csv_path, "w") as csv_file:
|
107 |
+
csv_writer = csv.writer(csv_file)
|
108 |
+
csv_writer.writerow(["reference", "degraded"])
|
109 |
+
for i in range(len(ref_sig)):
|
110 |
+
tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
|
111 |
+
tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
|
112 |
+
torchaudio.save(
|
113 |
+
tmp_ref_filename,
|
114 |
+
torch.clamp(ref_sig[i], min=-0.99, max=0.99),
|
115 |
+
sample_rate=target_sr,
|
116 |
+
bits_per_sample=16,
|
117 |
+
encoding="PCM_S"
|
118 |
+
)
|
119 |
+
torchaudio.save(
|
120 |
+
tmp_deg_filename,
|
121 |
+
torch.clamp(deg_sig[i], min=-0.99, max=0.99),
|
122 |
+
sample_rate=target_sr,
|
123 |
+
bits_per_sample=16,
|
124 |
+
encoding="PCM_S"
|
125 |
+
)
|
126 |
+
csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
|
127 |
+
return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
|
128 |
+
except Exception as e:
|
129 |
+
logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
|
130 |
+
return tmp_dir, None, None, None
|
131 |
+
|
132 |
+
def _flush_files(self, tmp_dir: tp.Union[Path, str]):
|
133 |
+
# flush tmp files used to compute ViSQOL.
|
134 |
+
shutil.rmtree(str(tmp_dir))
|
135 |
+
|
136 |
+
def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
|
137 |
+
# collect results for each evaluated pair and return averaged moslqo score.
|
138 |
+
with open(results_csv_path, "r") as csv_file:
|
139 |
+
reader = csv.DictReader(csv_file)
|
140 |
+
moslqo_scores = [float(row["moslqo"]) for row in reader]
|
141 |
+
if len(moslqo_scores) > 0:
|
142 |
+
return sum(moslqo_scores) / len(moslqo_scores)
|
143 |
+
else:
|
144 |
+
return 0.0
|
145 |
+
|
146 |
+
def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
|
147 |
+
# collect debug data for the visqol inference.
|
148 |
+
with open(debug_json_path, "r") as f:
|
149 |
+
data = json.load(f)
|
150 |
+
return data
|
151 |
+
|
152 |
+
@property
|
153 |
+
def visqol_model(self):
|
154 |
+
return f'{self.visqol_bin}/model/{self.model}'
|
155 |
+
|
156 |
+
def _run_visqol(
|
157 |
+
self,
|
158 |
+
input_csv_path: tp.Union[Path, str],
|
159 |
+
results_csv_path: tp.Union[Path, str],
|
160 |
+
debug_csv_path: tp.Optional[tp.Union[Path, str]],
|
161 |
+
):
|
162 |
+
input_csv_path = str(input_csv_path)
|
163 |
+
results_csv_path = str(results_csv_path)
|
164 |
+
debug_csv_path = str(debug_csv_path)
|
165 |
+
cmd = [
|
166 |
+
f'{self.visqol_bin}/bazel-bin/visqol',
|
167 |
+
'--batch_input_csv', f'{input_csv_path}',
|
168 |
+
'--results_csv', f'{results_csv_path}'
|
169 |
+
]
|
170 |
+
if debug_csv_path is not None:
|
171 |
+
cmd += ['--output_debug', f'{debug_csv_path}']
|
172 |
+
if self.visqol_mode == "speech":
|
173 |
+
cmd += ['--use_speech_mode']
|
174 |
+
cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
|
175 |
+
result = subprocess.run(cmd, capture_output=True)
|
176 |
+
if result.returncode:
|
177 |
+
logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
|
178 |
+
raise RuntimeError("Error while executing visqol")
|
179 |
+
result.check_returncode()
|
180 |
+
|
181 |
+
def __call__(
|
182 |
+
self,
|
183 |
+
ref_sig: torch.Tensor,
|
184 |
+
deg_sig: torch.Tensor,
|
185 |
+
sr: int,
|
186 |
+
pad_with_silence: bool = False,
|
187 |
+
):
|
188 |
+
"""Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
|
189 |
+
Args:
|
190 |
+
ref_sig (torch.Tensor): Reference signals as [B, C, T].
|
191 |
+
deg_sig (torch.Tensor): Degraded signals as [B, C, T].
|
192 |
+
sr (int): Sample rate of the two audio signals.
|
193 |
+
pad_with_silence (bool): Whether to pad the file with silences as recommended
|
194 |
+
in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
|
195 |
+
Returns:
|
196 |
+
float: The ViSQOL score or mean score for the batch.
|
197 |
+
"""
|
198 |
+
logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
|
199 |
+
tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
|
200 |
+
ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
|
201 |
+
)
|
202 |
+
try:
|
203 |
+
if input_csv and results_csv:
|
204 |
+
self._run_visqol(
|
205 |
+
input_csv,
|
206 |
+
results_csv,
|
207 |
+
debug_json if self.debug else None,
|
208 |
+
)
|
209 |
+
mosqol = self._collect_moslqo_score(results_csv)
|
210 |
+
return mosqol
|
211 |
+
else:
|
212 |
+
raise RuntimeError("Something unexpected happened when running VISQOL!")
|
213 |
+
except Exception as e:
|
214 |
+
logger.error("Exception occurred when running ViSQOL: %s", e)
|
215 |
+
finally:
|
216 |
+
self._flush_files(tmp_dir)
|
audiocraft/models/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
|
8 |
+
"""
|
9 |
+
# flake8: noqa
|
10 |
+
from . import builders, loaders
|
11 |
+
from .encodec import (
|
12 |
+
CompressionModel, EncodecModel, DAC,
|
13 |
+
HFEncodecModel, HFEncodecCompressionModel)
|
14 |
+
from .audiogen import AudioGen
|
15 |
+
from .lm import LMModel
|
16 |
+
from .multibanddiffusion import MultiBandDiffusion
|
17 |
+
from .vidmuse import VidMuse
|
18 |
+
from .unet import DiffusionUnet
|
audiocraft/models/audiogen.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Main model for using AudioGen. This will combine all the required components
|
9 |
+
and provide easy access to the generation API.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from .encodec import CompressionModel
|
17 |
+
from .lm import LMModel
|
18 |
+
from .builders import get_debug_compression_model, get_debug_lm_model
|
19 |
+
from .loaders import load_compression_model, load_lm_model
|
20 |
+
from ..data.audio_utils import convert_audio
|
21 |
+
from ..modules.conditioners import ConditioningAttributes
|
22 |
+
from ..utils.autocast import TorchAutocast
|
23 |
+
|
24 |
+
|
25 |
+
class AudioGen:
|
26 |
+
"""AudioGen main model with convenient generation API.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
name (str): name of the model.
|
30 |
+
compression_model (CompressionModel): Compression model
|
31 |
+
used to map audio to invertible discrete representations.
|
32 |
+
lm (LMModel): Language model over discrete representations.
|
33 |
+
max_duration (float, optional): maximum duration the model can produce,
|
34 |
+
otherwise, inferred from the training params.
|
35 |
+
"""
|
36 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
37 |
+
max_duration: tp.Optional[float] = None):
|
38 |
+
self.name = name
|
39 |
+
self.compression_model = compression_model
|
40 |
+
self.lm = lm
|
41 |
+
# Just to be safe, let's put everything in eval mode.
|
42 |
+
self.compression_model.eval()
|
43 |
+
self.lm.eval()
|
44 |
+
|
45 |
+
if max_duration is None:
|
46 |
+
if hasattr(lm, 'cfg'):
|
47 |
+
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
48 |
+
else:
|
49 |
+
raise ValueError("You must provide max_duration when building directly AudioGen")
|
50 |
+
assert max_duration is not None
|
51 |
+
self.max_duration: float = max_duration
|
52 |
+
self.device = next(iter(lm.parameters())).device
|
53 |
+
self.generation_params: dict = {}
|
54 |
+
self.set_generation_params(duration=5) # 5 seconds by default
|
55 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
56 |
+
if self.device.type == 'cpu':
|
57 |
+
self.autocast = TorchAutocast(enabled=False)
|
58 |
+
else:
|
59 |
+
self.autocast = TorchAutocast(
|
60 |
+
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
61 |
+
|
62 |
+
@property
|
63 |
+
def frame_rate(self) -> float:
|
64 |
+
"""Roughly the number of AR steps per seconds."""
|
65 |
+
return self.compression_model.frame_rate
|
66 |
+
|
67 |
+
@property
|
68 |
+
def sample_rate(self) -> int:
|
69 |
+
"""Sample rate of the generated audio."""
|
70 |
+
return self.compression_model.sample_rate
|
71 |
+
|
72 |
+
@property
|
73 |
+
def audio_channels(self) -> int:
|
74 |
+
"""Audio channels of the generated audio."""
|
75 |
+
return self.compression_model.channels
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
|
79 |
+
"""Return pretrained model, we provide a single model for now:
|
80 |
+
- facebook/audiogen-medium (1.5B), text to sound,
|
81 |
+
# see: https://huggingface.co/facebook/audiogen-medium
|
82 |
+
"""
|
83 |
+
if device is None:
|
84 |
+
if torch.cuda.device_count():
|
85 |
+
device = 'cuda'
|
86 |
+
else:
|
87 |
+
device = 'cpu'
|
88 |
+
|
89 |
+
if name == 'debug':
|
90 |
+
# used only for unit tests
|
91 |
+
compression_model = get_debug_compression_model(device, sample_rate=16000)
|
92 |
+
lm = get_debug_lm_model(device)
|
93 |
+
return AudioGen(name, compression_model, lm, max_duration=10)
|
94 |
+
|
95 |
+
compression_model = load_compression_model(name, device=device)
|
96 |
+
lm = load_lm_model(name, device=device)
|
97 |
+
assert 'self_wav' not in lm.condition_provider.conditioners, \
|
98 |
+
"AudioGen do not support waveform conditioning for now"
|
99 |
+
return AudioGen(name, compression_model, lm)
|
100 |
+
|
101 |
+
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
102 |
+
top_p: float = 0.0, temperature: float = 1.0,
|
103 |
+
duration: float = 10.0, cfg_coef: float = 3.0,
|
104 |
+
two_step_cfg: bool = False, extend_stride: float = 2):
|
105 |
+
"""Set the generation parameters for AudioGen.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
109 |
+
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
110 |
+
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
111 |
+
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
112 |
+
duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
|
113 |
+
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
114 |
+
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
115 |
+
instead of batching together the two. This has some impact on how things
|
116 |
+
are padded but seems to have little impact in practice.
|
117 |
+
extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
|
118 |
+
should we extend the audio each time. Larger values will mean less context is
|
119 |
+
preserved, and shorter value will require extra computations.
|
120 |
+
"""
|
121 |
+
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
122 |
+
self.extend_stride = extend_stride
|
123 |
+
self.duration = duration
|
124 |
+
self.generation_params = {
|
125 |
+
'use_sampling': use_sampling,
|
126 |
+
'temp': temperature,
|
127 |
+
'top_k': top_k,
|
128 |
+
'top_p': top_p,
|
129 |
+
'cfg_coef': cfg_coef,
|
130 |
+
'two_step_cfg': two_step_cfg,
|
131 |
+
}
|
132 |
+
|
133 |
+
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
134 |
+
"""Override the default progress callback."""
|
135 |
+
self._progress_callback = progress_callback
|
136 |
+
|
137 |
+
def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
|
138 |
+
"""Generate samples conditioned on text.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
142 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
143 |
+
"""
|
144 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
145 |
+
assert prompt_tokens is None
|
146 |
+
return self._generate_tokens(attributes, prompt_tokens, progress)
|
147 |
+
|
148 |
+
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
149 |
+
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
150 |
+
progress: bool = False) -> torch.Tensor:
|
151 |
+
"""Generate samples conditioned on audio prompts.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
155 |
+
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
156 |
+
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
157 |
+
descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
|
158 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
159 |
+
"""
|
160 |
+
if prompt.dim() == 2:
|
161 |
+
prompt = prompt[None]
|
162 |
+
if prompt.dim() != 3:
|
163 |
+
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
164 |
+
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
|
165 |
+
if descriptions is None:
|
166 |
+
descriptions = [None] * len(prompt)
|
167 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
168 |
+
assert prompt_tokens is not None
|
169 |
+
return self._generate_tokens(attributes, prompt_tokens, progress)
|
170 |
+
|
171 |
+
@torch.no_grad()
|
172 |
+
def _prepare_tokens_and_attributes(
|
173 |
+
self,
|
174 |
+
descriptions: tp.Sequence[tp.Optional[str]],
|
175 |
+
prompt: tp.Optional[torch.Tensor],
|
176 |
+
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
177 |
+
"""Prepare model inputs.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
181 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
182 |
+
"""
|
183 |
+
attributes = [
|
184 |
+
ConditioningAttributes(text={'description': description})
|
185 |
+
for description in descriptions]
|
186 |
+
|
187 |
+
if prompt is not None:
|
188 |
+
if descriptions is not None:
|
189 |
+
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
190 |
+
prompt = prompt.to(self.device)
|
191 |
+
prompt_tokens, scale = self.compression_model.encode(prompt)
|
192 |
+
assert scale is None
|
193 |
+
else:
|
194 |
+
prompt_tokens = None
|
195 |
+
return attributes, prompt_tokens
|
196 |
+
|
197 |
+
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
198 |
+
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
199 |
+
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
attributes (list of ConditioningAttributes): Conditions used for generation (here text).
|
203 |
+
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
204 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
205 |
+
Returns:
|
206 |
+
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
207 |
+
"""
|
208 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
209 |
+
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
210 |
+
current_gen_offset: int = 0
|
211 |
+
|
212 |
+
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
213 |
+
generated_tokens += current_gen_offset
|
214 |
+
if self._progress_callback is not None:
|
215 |
+
# Note that total_gen_len might be quite wrong depending on the
|
216 |
+
# codebook pattern used, but with delay it is almost accurate.
|
217 |
+
self._progress_callback(generated_tokens, total_gen_len)
|
218 |
+
else:
|
219 |
+
print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
220 |
+
|
221 |
+
if prompt_tokens is not None:
|
222 |
+
assert max_prompt_len >= prompt_tokens.shape[-1], \
|
223 |
+
"Prompt is longer than audio to generate"
|
224 |
+
|
225 |
+
callback = None
|
226 |
+
if progress:
|
227 |
+
callback = _progress_callback
|
228 |
+
|
229 |
+
if self.duration <= self.max_duration:
|
230 |
+
# generate by sampling from LM, simple case.
|
231 |
+
with self.autocast:
|
232 |
+
gen_tokens = self.lm.generate(
|
233 |
+
prompt_tokens, attributes,
|
234 |
+
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
235 |
+
|
236 |
+
else:
|
237 |
+
all_tokens = []
|
238 |
+
if prompt_tokens is None:
|
239 |
+
prompt_length = 0
|
240 |
+
else:
|
241 |
+
all_tokens.append(prompt_tokens)
|
242 |
+
prompt_length = prompt_tokens.shape[-1]
|
243 |
+
|
244 |
+
stride_tokens = int(self.frame_rate * self.extend_stride)
|
245 |
+
while current_gen_offset + prompt_length < total_gen_len:
|
246 |
+
time_offset = current_gen_offset / self.frame_rate
|
247 |
+
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
248 |
+
max_gen_len = int(chunk_duration * self.frame_rate)
|
249 |
+
with self.autocast:
|
250 |
+
gen_tokens = self.lm.generate(
|
251 |
+
prompt_tokens, attributes,
|
252 |
+
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
253 |
+
if prompt_tokens is None:
|
254 |
+
all_tokens.append(gen_tokens)
|
255 |
+
else:
|
256 |
+
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
257 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
258 |
+
prompt_length = prompt_tokens.shape[-1]
|
259 |
+
current_gen_offset += stride_tokens
|
260 |
+
|
261 |
+
gen_tokens = torch.cat(all_tokens, dim=-1)
|
262 |
+
|
263 |
+
# generate audio
|
264 |
+
assert gen_tokens.dim() == 3
|
265 |
+
with torch.no_grad():
|
266 |
+
gen_audio = self.compression_model.decode(gen_tokens, None)
|
267 |
+
return gen_audio
|
audiocraft/models/builders.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
All the functions to build the relevant models and modules
|
9 |
+
from the Hydra config.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import audiocraft
|
15 |
+
import omegaconf
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
|
19 |
+
from .lm import LMModel
|
20 |
+
from ..modules.codebooks_patterns import (
|
21 |
+
CodebooksPatternProvider,
|
22 |
+
DelayedPatternProvider,
|
23 |
+
MusicLMPattern,
|
24 |
+
ParallelPatternProvider,
|
25 |
+
UnrolledPatternProvider,
|
26 |
+
CoarseFirstPattern,
|
27 |
+
)
|
28 |
+
from ..modules.conditioners import (
|
29 |
+
BaseConditioner,
|
30 |
+
ChromaStemConditioner,
|
31 |
+
CLAPEmbeddingConditioner,
|
32 |
+
ConditionFuser,
|
33 |
+
ConditioningProvider,
|
34 |
+
LUTConditioner,
|
35 |
+
T5Conditioner,
|
36 |
+
)
|
37 |
+
# T5Conditioner
|
38 |
+
from .unet import DiffusionUnet
|
39 |
+
from .. import quantization as qt
|
40 |
+
from ..utils.utils import dict_from_config
|
41 |
+
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
|
42 |
+
from omegaconf import OmegaConf
|
43 |
+
|
44 |
+
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
|
45 |
+
klass = {
|
46 |
+
'no_quant': qt.DummyQuantizer,
|
47 |
+
'rvq': qt.ResidualVectorQuantizer
|
48 |
+
}[quantizer]
|
49 |
+
kwargs = dict_from_config(getattr(cfg, quantizer))
|
50 |
+
if quantizer != 'no_quant':
|
51 |
+
kwargs['dimension'] = dimension
|
52 |
+
return klass(**kwargs)
|
53 |
+
|
54 |
+
|
55 |
+
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
56 |
+
if encoder_name == 'seanet':
|
57 |
+
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
58 |
+
encoder_override_kwargs = kwargs.pop('encoder')
|
59 |
+
decoder_override_kwargs = kwargs.pop('decoder')
|
60 |
+
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
61 |
+
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
62 |
+
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
|
63 |
+
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
|
64 |
+
return encoder, decoder
|
65 |
+
else:
|
66 |
+
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
67 |
+
|
68 |
+
|
69 |
+
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
70 |
+
"""Instantiate a compression model."""
|
71 |
+
# cfg=eval(cfg)
|
72 |
+
cfg = OmegaConf.create(cfg)
|
73 |
+
|
74 |
+
if cfg.compression_model == 'encodec':
|
75 |
+
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
76 |
+
encoder_name = kwargs.pop('autoencoder')
|
77 |
+
quantizer_name = kwargs.pop('quantizer')
|
78 |
+
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
79 |
+
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
80 |
+
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
81 |
+
renormalize = kwargs.pop('renormalize', False)
|
82 |
+
# deprecated params
|
83 |
+
kwargs.pop('renorm', None)
|
84 |
+
return EncodecModel(encoder, decoder, quantizer,
|
85 |
+
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
86 |
+
else:
|
87 |
+
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
88 |
+
|
89 |
+
|
90 |
+
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
91 |
+
"""Instantiate a transformer LM."""
|
92 |
+
|
93 |
+
if cfg.lm_model == 'transformer_lm':
|
94 |
+
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
95 |
+
n_q = kwargs['n_q']
|
96 |
+
q_modeling = kwargs.pop('q_modeling', None)
|
97 |
+
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
98 |
+
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
99 |
+
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
100 |
+
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
101 |
+
fuser = get_condition_fuser(cfg)
|
102 |
+
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
103 |
+
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
104 |
+
kwargs['cross_attention'] = True
|
105 |
+
if codebooks_pattern_cfg.modeling is None:
|
106 |
+
assert q_modeling is not None, \
|
107 |
+
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
|
108 |
+
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
109 |
+
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
|
110 |
+
)
|
111 |
+
|
112 |
+
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
113 |
+
return LMModel(
|
114 |
+
pattern_provider=pattern_provider,
|
115 |
+
condition_provider=condition_provider,
|
116 |
+
|
117 |
+
visual_encoder=cfg.video.visual_encoder,
|
118 |
+
if_add_gobal=cfg.video.add_global.if_add_gobal,
|
119 |
+
|
120 |
+
fuser=fuser,
|
121 |
+
cfg_dropout=cfg_prob,
|
122 |
+
cfg_coef=cfg_coef,
|
123 |
+
attribute_dropout=attribute_dropout,
|
124 |
+
dtype=getattr(torch, cfg.dtype),
|
125 |
+
device=cfg.device,
|
126 |
+
**kwargs
|
127 |
+
).to(cfg.device)
|
128 |
+
else:
|
129 |
+
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
130 |
+
|
131 |
+
|
132 |
+
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
|
133 |
+
"""Instantiate a conditioning model."""
|
134 |
+
device = cfg.device
|
135 |
+
duration = cfg.dataset.segment_duration
|
136 |
+
cfg = getattr(cfg, 'conditioners')
|
137 |
+
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
138 |
+
conditioners: tp.Dict[str, BaseConditioner] = {}
|
139 |
+
condition_provider_args = dict_cfg.pop('args', {})
|
140 |
+
condition_provider_args.pop('merge_text_conditions_p', None)
|
141 |
+
condition_provider_args.pop('drop_desc_p', None)
|
142 |
+
|
143 |
+
for cond, cond_cfg in dict_cfg.items():
|
144 |
+
model_type = cond_cfg['model']
|
145 |
+
model_args = cond_cfg[model_type]
|
146 |
+
if model_type == 't5':
|
147 |
+
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
148 |
+
elif model_type == 'lut':
|
149 |
+
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
150 |
+
elif model_type == 'chroma_stem':
|
151 |
+
conditioners[str(cond)] = ChromaStemConditioner(
|
152 |
+
output_dim=output_dim,
|
153 |
+
duration=duration,
|
154 |
+
device=device,
|
155 |
+
**model_args
|
156 |
+
)
|
157 |
+
elif model_type == 'clap':
|
158 |
+
conditioners[str(cond)] = CLAPEmbeddingConditioner(
|
159 |
+
output_dim=output_dim,
|
160 |
+
device=device,
|
161 |
+
**model_args
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
165 |
+
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
166 |
+
return conditioner
|
167 |
+
|
168 |
+
|
169 |
+
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
170 |
+
"""Instantiate a condition fuser object."""
|
171 |
+
fuser_cfg = getattr(cfg, 'fuser')
|
172 |
+
fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
|
173 |
+
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
174 |
+
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
175 |
+
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
176 |
+
return fuser
|
177 |
+
|
178 |
+
|
179 |
+
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
|
180 |
+
"""Instantiate a codebooks pattern provider object."""
|
181 |
+
pattern_providers = {
|
182 |
+
'parallel': ParallelPatternProvider,
|
183 |
+
'delay': DelayedPatternProvider,
|
184 |
+
'unroll': UnrolledPatternProvider,
|
185 |
+
'coarse_first': CoarseFirstPattern,
|
186 |
+
'musiclm': MusicLMPattern,
|
187 |
+
}
|
188 |
+
name = cfg.modeling
|
189 |
+
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
190 |
+
klass = pattern_providers[name]
|
191 |
+
return klass(n_q, **kwargs)
|
192 |
+
|
193 |
+
|
194 |
+
def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
|
195 |
+
"""Instantiate a debug compression model to be used for unit tests."""
|
196 |
+
assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
|
197 |
+
model_ratios = {
|
198 |
+
16000: [10, 8, 8], # 25 Hz at 16kHz
|
199 |
+
32000: [10, 8, 16] # 25 Hz at 32kHz
|
200 |
+
}
|
201 |
+
ratios: tp.List[int] = model_ratios[sample_rate]
|
202 |
+
frame_rate = 25
|
203 |
+
seanet_kwargs: dict = {
|
204 |
+
'n_filters': 4,
|
205 |
+
'n_residual_layers': 1,
|
206 |
+
'dimension': 32,
|
207 |
+
'ratios': ratios,
|
208 |
+
}
|
209 |
+
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
210 |
+
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
211 |
+
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
|
212 |
+
init_x = torch.randn(8, 32, 128)
|
213 |
+
quantizer(init_x, 1) # initialize kmeans etc.
|
214 |
+
compression_model = EncodecModel(
|
215 |
+
encoder, decoder, quantizer,
|
216 |
+
frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
|
217 |
+
return compression_model.eval()
|
218 |
+
|
219 |
+
|
220 |
+
def get_diffusion_model(cfg: omegaconf.DictConfig):
|
221 |
+
# TODO Find a way to infer the channels from dset
|
222 |
+
channels = cfg.channels
|
223 |
+
num_steps = cfg.schedule.num_steps
|
224 |
+
return DiffusionUnet(
|
225 |
+
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
226 |
+
|
227 |
+
|
228 |
+
def get_processor(cfg, sample_rate: int = 24000):
|
229 |
+
sample_processor = SampleProcessor()
|
230 |
+
if cfg.use:
|
231 |
+
kw = dict(cfg)
|
232 |
+
kw.pop('use')
|
233 |
+
kw.pop('name')
|
234 |
+
if cfg.name == "multi_band_processor":
|
235 |
+
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
236 |
+
return sample_processor
|
237 |
+
|
238 |
+
|
239 |
+
def get_debug_lm_model(device='cpu'):
|
240 |
+
"""Instantiate a debug LM to be used for unit tests."""
|
241 |
+
pattern = DelayedPatternProvider(n_q=4)
|
242 |
+
dim = 16
|
243 |
+
providers = {
|
244 |
+
'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
|
245 |
+
}
|
246 |
+
condition_provider = ConditioningProvider(providers)
|
247 |
+
fuser = ConditionFuser(
|
248 |
+
{'cross': ['description'], 'prepend': [],
|
249 |
+
'sum': [], 'input_interpolate': []})
|
250 |
+
lm = LMModel(
|
251 |
+
pattern, condition_provider, fuser,
|
252 |
+
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
|
253 |
+
cross_attention=True, causal=True)
|
254 |
+
return lm.to(device).eval()
|
255 |
+
|
256 |
+
|
257 |
+
def get_wrapped_compression_model(
|
258 |
+
compression_model: CompressionModel,
|
259 |
+
cfg: omegaconf.DictConfig) -> CompressionModel:
|
260 |
+
if hasattr(cfg, 'interleave_stereo_codebooks'):
|
261 |
+
if cfg.interleave_stereo_codebooks.use:
|
262 |
+
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
|
263 |
+
kwargs.pop('use')
|
264 |
+
compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
|
265 |
+
if hasattr(cfg, 'compression_model_n_q'):
|
266 |
+
if cfg.compression_model_n_q is not None:
|
267 |
+
compression_model.set_num_codebooks(cfg.compression_model_n_q)
|
268 |
+
return compression_model
|
audiocraft/models/encodec.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Compression models or wrapper around existing models.
|
7 |
+
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
import logging
|
12 |
+
import math
|
13 |
+
from pathlib import Path
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
from einops import rearrange
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
from transformers import EncodecModel as HFEncodecModel
|
21 |
+
from .. import quantization as qt
|
22 |
+
# from semanticodec import SemantiCodec
|
23 |
+
|
24 |
+
logger = logging.getLogger()
|
25 |
+
|
26 |
+
|
27 |
+
class CompressionModel(ABC, nn.Module):
|
28 |
+
"""Base API for all compression models that aim at being used as audio tokenizers
|
29 |
+
with a language model.
|
30 |
+
"""
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
34 |
+
...
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
38 |
+
"""See `EncodecModel.encode`."""
|
39 |
+
...
|
40 |
+
|
41 |
+
@abstractmethod
|
42 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
43 |
+
"""See `EncodecModel.decode`."""
|
44 |
+
...
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def decode_latent(self, codes: torch.Tensor):
|
48 |
+
"""Decode from the discrete codes to continuous latent space."""
|
49 |
+
...
|
50 |
+
|
51 |
+
@property
|
52 |
+
@abstractmethod
|
53 |
+
def channels(self) -> int:
|
54 |
+
...
|
55 |
+
|
56 |
+
@property
|
57 |
+
@abstractmethod
|
58 |
+
def frame_rate(self) -> float:
|
59 |
+
...
|
60 |
+
|
61 |
+
@property
|
62 |
+
@abstractmethod
|
63 |
+
def sample_rate(self) -> int:
|
64 |
+
...
|
65 |
+
|
66 |
+
@property
|
67 |
+
@abstractmethod
|
68 |
+
def cardinality(self) -> int:
|
69 |
+
...
|
70 |
+
|
71 |
+
@property
|
72 |
+
@abstractmethod
|
73 |
+
def num_codebooks(self) -> int:
|
74 |
+
...
|
75 |
+
|
76 |
+
@property
|
77 |
+
@abstractmethod
|
78 |
+
def total_codebooks(self) -> int:
|
79 |
+
...
|
80 |
+
|
81 |
+
@abstractmethod
|
82 |
+
def set_num_codebooks(self, n: int):
|
83 |
+
"""Set the active number of codebooks used by the quantizer."""
|
84 |
+
...
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def get_pretrained(
|
88 |
+
name: str, device: tp.Union[torch.device, str] = 'cpu'
|
89 |
+
) -> 'CompressionModel':
|
90 |
+
"""Instantiate a CompressionModel from a given pretrained model.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
name (Path or str): name of the pretrained model. See after.
|
94 |
+
device (torch.device or str): Device on which the model is loaded.
|
95 |
+
|
96 |
+
Pretrained models:
|
97 |
+
- dac_44khz (https://github.com/descriptinc/descript-audio-codec)
|
98 |
+
- dac_24khz (same)
|
99 |
+
- facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
|
100 |
+
- facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
|
101 |
+
- your own model on Hugging Face. Export instructions to come...
|
102 |
+
"""
|
103 |
+
|
104 |
+
from . import builders, loaders
|
105 |
+
model: CompressionModel
|
106 |
+
if name in ['dac_44khz', 'dac_24khz']:
|
107 |
+
model_type = name.split('_')[1]
|
108 |
+
logger.info("Getting pretrained compression model from DAC %s", model_type)
|
109 |
+
model = DAC(model_type)
|
110 |
+
|
111 |
+
elif name in ['semantic_16khz']:
|
112 |
+
model_type = name.split('_')[1]
|
113 |
+
logger.info("Getting pretrained compression model from Semantic Codec %s", model_type)
|
114 |
+
model = Semantic_Codec(model_type)
|
115 |
+
|
116 |
+
elif name in ['debug_compression_model']:
|
117 |
+
logger.info("Getting pretrained compression model for debug")
|
118 |
+
model = builders.get_debug_compression_model()
|
119 |
+
elif Path(name).exists():
|
120 |
+
# We assume here if the path exists that it is in fact an AC checkpoint
|
121 |
+
# that was exported using `audiocraft.utils.export` functions.
|
122 |
+
model = loaders.load_compression_model(name, device=device)
|
123 |
+
else:
|
124 |
+
logger.info("Getting pretrained compression model from HF %s", name)
|
125 |
+
hf_model = HFEncodecModel.from_pretrained(name)
|
126 |
+
model = HFEncodecCompressionModel(hf_model).to(device)
|
127 |
+
return model.to(device).eval()
|
128 |
+
|
129 |
+
|
130 |
+
class EncodecModel(CompressionModel):
|
131 |
+
"""Encodec model operating on the raw waveform.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
encoder (nn.Module): Encoder network.
|
135 |
+
decoder (nn.Module): Decoder network.
|
136 |
+
quantizer (qt.BaseQuantizer): Quantizer network.
|
137 |
+
frame_rate (int): Frame rate for the latent representation.
|
138 |
+
sample_rate (int): Audio sample rate.
|
139 |
+
channels (int): Number of audio channels.
|
140 |
+
causal (bool): Whether to use a causal version of the model.
|
141 |
+
renormalize (bool): Whether to renormalize the audio before running the model.
|
142 |
+
"""
|
143 |
+
# we need assignment to override the property in the abstract class,
|
144 |
+
# I couldn't find a better way...
|
145 |
+
frame_rate: float = 0
|
146 |
+
sample_rate: int = 0
|
147 |
+
channels: int = 0
|
148 |
+
|
149 |
+
def __init__(self,
|
150 |
+
encoder: nn.Module,
|
151 |
+
decoder: nn.Module,
|
152 |
+
quantizer: qt.BaseQuantizer,
|
153 |
+
frame_rate: int,
|
154 |
+
sample_rate: int,
|
155 |
+
channels: int,
|
156 |
+
causal: bool = False,
|
157 |
+
renormalize: bool = False):
|
158 |
+
super().__init__()
|
159 |
+
self.encoder = encoder
|
160 |
+
self.decoder = decoder
|
161 |
+
self.quantizer = quantizer
|
162 |
+
self.frame_rate = frame_rate
|
163 |
+
self.sample_rate = sample_rate
|
164 |
+
self.channels = channels
|
165 |
+
self.renormalize = renormalize
|
166 |
+
self.causal = causal
|
167 |
+
if self.causal:
|
168 |
+
# we force disabling here to avoid handling linear overlap of segments
|
169 |
+
# as supported in original EnCodec codebase.
|
170 |
+
assert not self.renormalize, 'Causal model does not support renormalize'
|
171 |
+
|
172 |
+
@property
|
173 |
+
def total_codebooks(self):
|
174 |
+
"""Total number of quantizer codebooks available."""
|
175 |
+
return self.quantizer.total_codebooks
|
176 |
+
|
177 |
+
@property
|
178 |
+
def num_codebooks(self):
|
179 |
+
"""Active number of codebooks used by the quantizer."""
|
180 |
+
return self.quantizer.num_codebooks
|
181 |
+
|
182 |
+
def set_num_codebooks(self, n: int):
|
183 |
+
"""Set the active number of codebooks used by the quantizer."""
|
184 |
+
self.quantizer.set_num_codebooks(n)
|
185 |
+
|
186 |
+
@property
|
187 |
+
def cardinality(self):
|
188 |
+
"""Cardinality of each codebook."""
|
189 |
+
return self.quantizer.bins
|
190 |
+
|
191 |
+
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
192 |
+
scale: tp.Optional[torch.Tensor]
|
193 |
+
if self.renormalize:
|
194 |
+
mono = x.mean(dim=1, keepdim=True)
|
195 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
196 |
+
scale = 1e-8 + volume
|
197 |
+
x = x / scale
|
198 |
+
scale = scale.view(-1, 1)
|
199 |
+
else:
|
200 |
+
scale = None
|
201 |
+
return x, scale
|
202 |
+
|
203 |
+
def postprocess(self,
|
204 |
+
x: torch.Tensor,
|
205 |
+
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
206 |
+
if scale is not None:
|
207 |
+
assert self.renormalize
|
208 |
+
x = x * scale.view(-1, 1, 1)
|
209 |
+
return x
|
210 |
+
|
211 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
212 |
+
assert x.dim() == 3
|
213 |
+
length = x.shape[-1]
|
214 |
+
x, scale = self.preprocess(x)
|
215 |
+
|
216 |
+
emb = self.encoder(x)
|
217 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
218 |
+
out = self.decoder(q_res.x)
|
219 |
+
|
220 |
+
# remove extra padding added by the encoder and decoder
|
221 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
222 |
+
out = out[..., :length]
|
223 |
+
|
224 |
+
q_res.x = self.postprocess(out, scale)
|
225 |
+
|
226 |
+
return q_res
|
227 |
+
|
228 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
229 |
+
"""Encode the given input tensor to quantized representation along with scale parameter.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
236 |
+
codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
237 |
+
scale: a float tensor containing the scale for audio renormalization.
|
238 |
+
"""
|
239 |
+
assert x.dim() == 3
|
240 |
+
x, scale = self.preprocess(x)
|
241 |
+
emb = self.encoder(x)
|
242 |
+
codes = self.quantizer.encode(emb)
|
243 |
+
return codes, scale
|
244 |
+
|
245 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
246 |
+
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
247 |
+
audio denormalization if needed.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
251 |
+
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
255 |
+
"""
|
256 |
+
emb = self.decode_latent(codes)
|
257 |
+
out = self.decoder(emb)
|
258 |
+
out = self.postprocess(out, scale)
|
259 |
+
# out contains extra padding added by the encoder and decoder
|
260 |
+
return out
|
261 |
+
|
262 |
+
def decode_latent(self, codes: torch.Tensor):
|
263 |
+
"""Decode from the discrete codes to continuous latent space."""
|
264 |
+
return self.quantizer.decode(codes)
|
265 |
+
|
266 |
+
|
267 |
+
class DAC(CompressionModel):
|
268 |
+
def __init__(self, model_type: str = "44khz"):
|
269 |
+
super().__init__()
|
270 |
+
try:
|
271 |
+
import dac.utils
|
272 |
+
except ImportError:
|
273 |
+
raise RuntimeError("Could not import dac, make sure it is installed, "
|
274 |
+
"please run `pip install descript-audio-codec`")
|
275 |
+
self.model = dac.utils.load_model(model_type=model_type)
|
276 |
+
self.n_quantizers = self.total_codebooks
|
277 |
+
self.model.eval()
|
278 |
+
|
279 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
280 |
+
# We don't support training with this.
|
281 |
+
raise NotImplementedError("Forward and training with DAC not supported.")
|
282 |
+
|
283 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
284 |
+
codes = self.model.encode(x, self.n_quantizers)[1] # [4(B), 9(self.n_quantizers), 430(86*T)], x: [4(B), 1, 220500(44100*T)]
|
285 |
+
|
286 |
+
return codes[:, :self.n_quantizers], None
|
287 |
+
|
288 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
289 |
+
assert scale is None
|
290 |
+
z_q = self.decode_latent(codes)
|
291 |
+
return self.model.decode(z_q)
|
292 |
+
|
293 |
+
def decode_latent(self, codes: torch.Tensor):
|
294 |
+
"""Decode from the discrete codes to continuous latent space."""
|
295 |
+
return self.model.quantizer.from_codes(codes)[0]
|
296 |
+
|
297 |
+
@property
|
298 |
+
def channels(self) -> int:
|
299 |
+
return 1
|
300 |
+
|
301 |
+
@property
|
302 |
+
def frame_rate(self) -> float:
|
303 |
+
return self.model.sample_rate / self.model.hop_length
|
304 |
+
|
305 |
+
@property
|
306 |
+
def sample_rate(self) -> int:
|
307 |
+
return self.model.sample_rate
|
308 |
+
|
309 |
+
@property
|
310 |
+
def cardinality(self) -> int:
|
311 |
+
return self.model.codebook_size
|
312 |
+
|
313 |
+
@property
|
314 |
+
def num_codebooks(self) -> int:
|
315 |
+
return self.n_quantizers
|
316 |
+
|
317 |
+
@property
|
318 |
+
def total_codebooks(self) -> int:
|
319 |
+
return self.model.n_codebooks
|
320 |
+
|
321 |
+
def set_num_codebooks(self, n: int):
|
322 |
+
"""Set the active number of codebooks used by the quantizer.
|
323 |
+
"""
|
324 |
+
assert n >= 1
|
325 |
+
assert n <= self.total_codebooks
|
326 |
+
self.n_quantizers = n
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
|
331 |
+
class Semantic_Codec(CompressionModel):
|
332 |
+
def __init__(self, model_type: str = "16khz"):
|
333 |
+
super().__init__()
|
334 |
+
try:
|
335 |
+
from semanticodec import SemantiCodec
|
336 |
+
except ImportError:
|
337 |
+
raise RuntimeError("Could not import semanticcodec, make sure it is installed, "
|
338 |
+
"please run `pip install git+https://github.com/haoheliu/SemantiCodec-inference.git`")
|
339 |
+
self.model = SemantiCodec(token_rate=100, semantic_vocab_size=16384)
|
340 |
+
# self.n_quantizers = self.total_codebooks
|
341 |
+
self.n_quantizers = 2
|
342 |
+
self.model.sample_rate = 16000
|
343 |
+
self.model.cardinality = 16384
|
344 |
+
self.model.frame_rate = 50
|
345 |
+
self.model.eval()
|
346 |
+
|
347 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
348 |
+
# We don't support training with this.
|
349 |
+
raise NotImplementedError("Forward and training with DAC not supported.")
|
350 |
+
|
351 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
352 |
+
codes = self.model.encode(x)
|
353 |
+
# codes = self.model.encode(x, self.n_quantizers)[1]
|
354 |
+
return codes[:, :self.n_quantizers], None
|
355 |
+
|
356 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
357 |
+
assert scale is None
|
358 |
+
z_q = self.decode_latent(codes)
|
359 |
+
return self.model.decode(z_q)
|
360 |
+
|
361 |
+
def decode_latent(self, codes: torch.Tensor):
|
362 |
+
"""Decode from the discrete codes to continuous latent space."""
|
363 |
+
return self.model.quantizer.from_codes(codes)[0]
|
364 |
+
|
365 |
+
@property
|
366 |
+
def channels(self) -> int:
|
367 |
+
return 1
|
368 |
+
|
369 |
+
@property
|
370 |
+
def frame_rate(self) -> float:
|
371 |
+
return self.model.frame_rate
|
372 |
+
# return self.model.sample_rate / self.model.hop_length
|
373 |
+
|
374 |
+
@property
|
375 |
+
def sample_rate(self) -> int:
|
376 |
+
return self.model.sample_rate
|
377 |
+
|
378 |
+
@property
|
379 |
+
def cardinality(self) -> int:
|
380 |
+
return self.model.cardinality
|
381 |
+
|
382 |
+
@property
|
383 |
+
def num_codebooks(self) -> int:
|
384 |
+
return self.n_quantizers
|
385 |
+
|
386 |
+
@property
|
387 |
+
def total_codebooks(self) -> int:
|
388 |
+
return self.model.n_codebooks
|
389 |
+
|
390 |
+
def set_num_codebooks(self, n: int):
|
391 |
+
"""Set the active number of codebooks used by the quantizer.
|
392 |
+
"""
|
393 |
+
assert n >= 1
|
394 |
+
assert n <= self.total_codebooks
|
395 |
+
self.n_quantizers = n
|
396 |
+
|
397 |
+
class HFEncodecCompressionModel(CompressionModel):
|
398 |
+
"""Wrapper around HuggingFace Encodec.
|
399 |
+
"""
|
400 |
+
def __init__(self, model: HFEncodecModel):
|
401 |
+
super().__init__()
|
402 |
+
self.model = model
|
403 |
+
bws = self.model.config.target_bandwidths
|
404 |
+
num_codebooks = [
|
405 |
+
bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
|
406 |
+
for bw in bws
|
407 |
+
]
|
408 |
+
deltas = [nc - int(nc) for nc in num_codebooks]
|
409 |
+
# Checking we didn't do some bad maths and we indeed have integers!
|
410 |
+
assert all(deltas) <= 1e-3, deltas
|
411 |
+
self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
|
412 |
+
self.set_num_codebooks(max(self.possible_num_codebooks))
|
413 |
+
|
414 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
415 |
+
# We don't support training with this.
|
416 |
+
raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
|
417 |
+
|
418 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
419 |
+
bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
|
420 |
+
bandwidth = self.model.config.target_bandwidths[bandwidth_index]
|
421 |
+
res = self.model.encode(x, None, bandwidth)
|
422 |
+
assert len(res[0]) == 1
|
423 |
+
assert len(res[1]) == 1
|
424 |
+
return res[0][0], res[1][0]
|
425 |
+
|
426 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
427 |
+
if scale is None:
|
428 |
+
scales = [None] # type: ignore
|
429 |
+
else:
|
430 |
+
scales = scale # type: ignore
|
431 |
+
res = self.model.decode(codes[None], scales)
|
432 |
+
return res[0]
|
433 |
+
|
434 |
+
def decode_latent(self, codes: torch.Tensor):
|
435 |
+
"""Decode from the discrete codes to continuous latent space."""
|
436 |
+
return self.model.quantizer.decode(codes.transpose(0, 1))
|
437 |
+
|
438 |
+
@property
|
439 |
+
def channels(self) -> int:
|
440 |
+
return self.model.config.audio_channels
|
441 |
+
|
442 |
+
@property
|
443 |
+
def frame_rate(self) -> float:
|
444 |
+
hop_length = int(np.prod(self.model.config.upsampling_ratios))
|
445 |
+
return self.sample_rate / hop_length
|
446 |
+
|
447 |
+
@property
|
448 |
+
def sample_rate(self) -> int:
|
449 |
+
return self.model.config.sampling_rate
|
450 |
+
|
451 |
+
@property
|
452 |
+
def cardinality(self) -> int:
|
453 |
+
return self.model.config.codebook_size
|
454 |
+
|
455 |
+
@property
|
456 |
+
def num_codebooks(self) -> int:
|
457 |
+
return self._num_codebooks
|
458 |
+
|
459 |
+
@property
|
460 |
+
def total_codebooks(self) -> int:
|
461 |
+
return max(self.possible_num_codebooks)
|
462 |
+
|
463 |
+
def set_num_codebooks(self, n: int):
|
464 |
+
"""Set the active number of codebooks used by the quantizer.
|
465 |
+
"""
|
466 |
+
if n not in self.possible_num_codebooks:
|
467 |
+
raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
|
468 |
+
self._num_codebooks = n
|
469 |
+
|
470 |
+
|
471 |
+
class InterleaveStereoCompressionModel(CompressionModel):
|
472 |
+
"""Wraps a CompressionModel to support stereo inputs. The wrapped model
|
473 |
+
will be applied independently to the left and right channels, and both codebooks
|
474 |
+
will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
|
475 |
+
channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
|
476 |
+
`per_timestep`.
|
477 |
+
|
478 |
+
Args:
|
479 |
+
model (CompressionModel): Compression model to wrap.
|
480 |
+
per_timestep (bool): Whether to interleave on the timestep dimension
|
481 |
+
or on the codebooks dimension.
|
482 |
+
"""
|
483 |
+
def __init__(self, model: CompressionModel, per_timestep: bool = False):
|
484 |
+
super().__init__()
|
485 |
+
self.model = model
|
486 |
+
self.per_timestep = per_timestep
|
487 |
+
assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
|
488 |
+
|
489 |
+
@property
|
490 |
+
def total_codebooks(self):
|
491 |
+
return self.model.total_codebooks
|
492 |
+
|
493 |
+
@property
|
494 |
+
def num_codebooks(self):
|
495 |
+
"""Active number of codebooks used by the quantizer.
|
496 |
+
|
497 |
+
..Warning:: this reports the number of codebooks after the interleaving
|
498 |
+
of the codebooks!
|
499 |
+
"""
|
500 |
+
return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
|
501 |
+
|
502 |
+
def set_num_codebooks(self, n: int):
|
503 |
+
"""Set the active number of codebooks used by the quantizer.
|
504 |
+
|
505 |
+
..Warning:: this sets the number of codebooks before the interleaving!
|
506 |
+
"""
|
507 |
+
self.model.set_num_codebooks(n)
|
508 |
+
|
509 |
+
@property
|
510 |
+
def num_virtual_steps(self) -> float:
|
511 |
+
"""Return the number of virtual steps, e.g. one real step
|
512 |
+
will be split into that many steps.
|
513 |
+
"""
|
514 |
+
return 2 if self.per_timestep else 1
|
515 |
+
|
516 |
+
@property
|
517 |
+
def frame_rate(self) -> float:
|
518 |
+
return self.model.frame_rate * self.num_virtual_steps
|
519 |
+
|
520 |
+
@property
|
521 |
+
def sample_rate(self) -> int:
|
522 |
+
return self.model.sample_rate
|
523 |
+
|
524 |
+
@property
|
525 |
+
def channels(self) -> int:
|
526 |
+
return 2
|
527 |
+
|
528 |
+
@property
|
529 |
+
def cardinality(self):
|
530 |
+
"""Cardinality of each codebook.
|
531 |
+
"""
|
532 |
+
return self.model.cardinality
|
533 |
+
|
534 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
535 |
+
raise NotImplementedError("Not supported, use encode and decode.")
|
536 |
+
|
537 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
538 |
+
B, C, T = x.shape
|
539 |
+
assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
|
540 |
+
|
541 |
+
indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
|
542 |
+
indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
|
543 |
+
indices = torch.stack([indices_c0, indices_c1], dim=0)
|
544 |
+
scales: tp.Optional[torch.Tensor] = None
|
545 |
+
if scales_c0 is not None and scales_c1 is not None:
|
546 |
+
scales = torch.stack([scales_c0, scales_c1], dim=1)
|
547 |
+
|
548 |
+
if self.per_timestep:
|
549 |
+
indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
|
550 |
+
else:
|
551 |
+
indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
|
552 |
+
|
553 |
+
return (indices, scales)
|
554 |
+
|
555 |
+
def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
556 |
+
if self.per_timestep:
|
557 |
+
codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
|
558 |
+
else:
|
559 |
+
codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
|
560 |
+
return codes[0], codes[1]
|
561 |
+
|
562 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
563 |
+
B, K, T = codes.shape
|
564 |
+
assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
|
565 |
+
assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
|
566 |
+
|
567 |
+
scale_c0, scale_c1 = None, None
|
568 |
+
if scale is not None:
|
569 |
+
assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
|
570 |
+
scale_c0 = scale[0, ...]
|
571 |
+
scale_c1 = scale[1, ...]
|
572 |
+
|
573 |
+
codes_c0, codes_c1 = self.get_left_right_codes(codes)
|
574 |
+
audio_c0 = self.model.decode(codes_c0, scale_c0)
|
575 |
+
audio_c1 = self.model.decode(codes_c1, scale_c1)
|
576 |
+
return torch.cat([audio_c0, audio_c1], dim=1)
|
577 |
+
|
578 |
+
def decode_latent(self, codes: torch.Tensor):
|
579 |
+
"""Decode from the discrete codes to continuous latent space."""
|
580 |
+
raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
|
audiocraft/models/lm.py
ADDED
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Audiocraft (https://github.com/facebookresearch/audiocraft)
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from functools import partial
|
5 |
+
import logging
|
6 |
+
import math
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from ..utils import utils
|
13 |
+
from ..modules.streaming import StreamingModule, State
|
14 |
+
from ..modules.transformer import StreamingTransformer, create_norm_fn
|
15 |
+
|
16 |
+
import time
|
17 |
+
from ..modules.conditioners import (
|
18 |
+
ConditionFuser,
|
19 |
+
ClassifierFreeGuidanceDropout,
|
20 |
+
AttributeDropout,
|
21 |
+
ConditioningProvider,
|
22 |
+
ConditioningAttributes,
|
23 |
+
ConditionType,
|
24 |
+
)
|
25 |
+
from ..modules.codebooks_patterns import CodebooksPatternProvider
|
26 |
+
from ..modules.activations import get_activation_fn
|
27 |
+
import warnings
|
28 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms._transforms_video")
|
29 |
+
import torch.nn.init as init
|
30 |
+
import os
|
31 |
+
|
32 |
+
import logging
|
33 |
+
import random
|
34 |
+
import sys
|
35 |
+
import einops
|
36 |
+
from .transformer_module import Attention, PreNorm, FeedForward
|
37 |
+
from transformers import AutoProcessor, CLIPVisionModelWithProjection, VideoMAEModel
|
38 |
+
|
39 |
+
logger = logging.getLogger(__name__)
|
40 |
+
ConditionTensors = tp.Dict[str, ConditionType]
|
41 |
+
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
42 |
+
|
43 |
+
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
44 |
+
"""LM layer initialization.
|
45 |
+
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
46 |
+
|
47 |
+
Args:
|
48 |
+
method (str): Method name for init function. Valid options are:
|
49 |
+
'gaussian', 'uniform'.
|
50 |
+
input_dim (int): Input dimension of the initialized module.
|
51 |
+
init_depth (int, optional): Optional init depth value used to rescale
|
52 |
+
the standard deviation if defined.
|
53 |
+
"""
|
54 |
+
# Compute std
|
55 |
+
std = 1 / math.sqrt(input_dim)
|
56 |
+
# Rescale with depth
|
57 |
+
if init_depth is not None:
|
58 |
+
std = std / math.sqrt(2 * init_depth)
|
59 |
+
|
60 |
+
if method == 'gaussian':
|
61 |
+
return partial(
|
62 |
+
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
|
63 |
+
)
|
64 |
+
elif method == 'uniform':
|
65 |
+
bound = math.sqrt(3) * std # ensure the standard deviation is std
|
66 |
+
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
|
67 |
+
else:
|
68 |
+
raise ValueError("Unsupported layer initialization method")
|
69 |
+
|
70 |
+
|
71 |
+
def init_layer(m: nn.Module,
|
72 |
+
method: str,
|
73 |
+
init_depth: tp.Optional[int] = None,
|
74 |
+
zero_bias_init: bool = False):
|
75 |
+
"""Wrapper around `get_init_fn for proper initialization of LM modules.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
m (nn.Module): Module to initialize.
|
79 |
+
method (str): Method name for the init function.
|
80 |
+
init_depth (int, optional): Optional init depth value used to rescale
|
81 |
+
the standard deviation if defined.
|
82 |
+
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
|
83 |
+
"""
|
84 |
+
if isinstance(m, nn.Linear):
|
85 |
+
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
|
86 |
+
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
87 |
+
weight = m.weight.float()
|
88 |
+
init_fn(weight)
|
89 |
+
m.weight.data[:] = weight.half()
|
90 |
+
else:
|
91 |
+
init_fn(m.weight)
|
92 |
+
if zero_bias_init and m.bias is not None:
|
93 |
+
nn.init.constant_(m.bias, 0)
|
94 |
+
elif isinstance(m, nn.Embedding):
|
95 |
+
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
|
96 |
+
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
97 |
+
weight = m.weight.float()
|
98 |
+
init_fn(weight)
|
99 |
+
m.weight.data[:] = weight.half()
|
100 |
+
else:
|
101 |
+
init_fn(m.weight)
|
102 |
+
|
103 |
+
|
104 |
+
class ScaledEmbedding(nn.Embedding):
|
105 |
+
"""Boost learning rate for embeddings (with scale).
|
106 |
+
"""
|
107 |
+
def __init__(self, *args, lr=None, **kwargs):
|
108 |
+
super().__init__(*args, **kwargs)
|
109 |
+
self.lr = lr
|
110 |
+
|
111 |
+
def make_optim_group(self):
|
112 |
+
group = {"params": list(self.parameters())}
|
113 |
+
if self.lr is not None:
|
114 |
+
group["lr"] = self.lr
|
115 |
+
return group
|
116 |
+
|
117 |
+
|
118 |
+
@dataclass
|
119 |
+
class LMOutput:
|
120 |
+
# The logits are already re-aligned with the input codes
|
121 |
+
# hence no extra shift is required, e.g. when computing CE
|
122 |
+
logits: torch.Tensor # [B, K, T, card]
|
123 |
+
mask: torch.Tensor # [B, K, T]
|
124 |
+
|
125 |
+
|
126 |
+
class Transformer(nn.Module):
|
127 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
128 |
+
super().__init__()
|
129 |
+
self.layers = nn.ModuleList([])
|
130 |
+
self.norm = nn.LayerNorm(dim)
|
131 |
+
for _ in range(depth):
|
132 |
+
self.layers.append(nn.ModuleList([
|
133 |
+
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
134 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
135 |
+
]))
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
for attn, ff in self.layers:
|
139 |
+
x = attn(x) + x
|
140 |
+
x = ff(x) + x
|
141 |
+
return self.norm(x)
|
142 |
+
|
143 |
+
|
144 |
+
class MultiHeadCrossAttention(nn.Module):
|
145 |
+
def __init__(self, x1, num_heads):
|
146 |
+
super().__init__()
|
147 |
+
self.num_heads = num_heads
|
148 |
+
self.depth = x1 // num_heads
|
149 |
+
|
150 |
+
self.query = nn.Linear(x1, x1)
|
151 |
+
self.key = nn.Linear(x1, x1)
|
152 |
+
self.value = nn.Linear(x1, x1)
|
153 |
+
|
154 |
+
self.final_linear = nn.Linear(x1, x1)
|
155 |
+
|
156 |
+
self.norm1 = nn.LayerNorm(x1)
|
157 |
+
self.norm2 = nn.LayerNorm(x1)
|
158 |
+
|
159 |
+
init.constant_(self.final_linear.weight, 0)
|
160 |
+
if self.final_linear.bias is not None:
|
161 |
+
init.constant_(self.final_linear.bias, 0)
|
162 |
+
|
163 |
+
def split_heads(self, x, batch_size):
|
164 |
+
x = x.view(batch_size, -1, self.num_heads, self.depth)
|
165 |
+
return x.permute(0, 2, 1, 3)
|
166 |
+
|
167 |
+
def forward(self, tensor_A, tensor_B):
|
168 |
+
batch_size = tensor_A.size(0)
|
169 |
+
|
170 |
+
Q = self.split_heads(self.query(tensor_A), batch_size)
|
171 |
+
K = self.split_heads(self.key(tensor_B), batch_size)
|
172 |
+
V = self.split_heads(self.value(tensor_B), batch_size)
|
173 |
+
|
174 |
+
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.depth ** 0.5)
|
175 |
+
attention_scores = torch.softmax(attention_scores, dim=-1)
|
176 |
+
|
177 |
+
attention_output = torch.matmul(attention_scores, V)
|
178 |
+
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
|
179 |
+
|
180 |
+
output = attention_output.view(batch_size, -1, self.num_heads * self.depth)
|
181 |
+
|
182 |
+
output = self.norm1(output + tensor_A)
|
183 |
+
output = self.norm2(self.final_linear(output) + output)
|
184 |
+
return output
|
185 |
+
|
186 |
+
|
187 |
+
def evenly_sample_or_duplicate_frames(video_tensor, target_frames=32):
|
188 |
+
num_frames = video_tensor.size(0)
|
189 |
+
if target_frames <= num_frames:
|
190 |
+
indices = torch.linspace(0, num_frames - 1, steps=target_frames).long()
|
191 |
+
return video_tensor[indices]
|
192 |
+
else:
|
193 |
+
scale_factor = target_frames / num_frames
|
194 |
+
repeated_indices = (torch.arange(target_frames) / scale_factor).long()
|
195 |
+
return video_tensor[repeated_indices]
|
196 |
+
|
197 |
+
class LMModel(StreamingModule):
|
198 |
+
"""Transformer-based language model on multiple streams of codes.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
|
202 |
+
condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
|
203 |
+
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
|
204 |
+
n_q (int): Number of parallel streams to model.
|
205 |
+
card (int): Cardinality, vocabulary size.
|
206 |
+
dim (int): Dimension of the transformer encoder.
|
207 |
+
num_heads (int): Number of heads for the transformer encoder.
|
208 |
+
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
209 |
+
norm (str): Normalization method.
|
210 |
+
norm_first (bool): Use pre-norm instead of post-norm.
|
211 |
+
emb_lr (float, optional): Embedding-specific learning rate.
|
212 |
+
bias_proj (bool): Use bias for output projections.
|
213 |
+
weight_init (str, optional): Method for weight initialization.
|
214 |
+
depthwise_init (str, optional): Method for depthwise weight initialization.
|
215 |
+
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
216 |
+
cfg_dropout (float): Classifier-free guidance dropout.
|
217 |
+
cfg_coef (float): Classifier-free guidance coefficient.
|
218 |
+
attribute_dropout (dict): Attribute dropout probabilities.
|
219 |
+
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
220 |
+
**kwargs: Additional parameters for the transformer encoder.
|
221 |
+
"""
|
222 |
+
|
223 |
+
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
|
224 |
+
visual_encoder,
|
225 |
+
if_add_gobal,
|
226 |
+
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
|
227 |
+
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
|
228 |
+
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
|
229 |
+
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
|
230 |
+
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
|
231 |
+
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
|
232 |
+
depth=2,
|
233 |
+
temporal_dim=768,
|
234 |
+
dim_head=64,
|
235 |
+
**kwargs):
|
236 |
+
super().__init__()
|
237 |
+
self.cfg_coef = cfg_coef
|
238 |
+
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
|
239 |
+
self.att_dropout = AttributeDropout(p=attribute_dropout)
|
240 |
+
self.condition_provider = condition_provider
|
241 |
+
self.visual_encoder = visual_encoder
|
242 |
+
self.if_add_gobal = if_add_gobal
|
243 |
+
self.temporal_dim = temporal_dim
|
244 |
+
|
245 |
+
self.fuser = fuser
|
246 |
+
self.card = card
|
247 |
+
embed_dim = self.card + 1
|
248 |
+
self.n_q = n_q
|
249 |
+
self.dim = dim
|
250 |
+
self.pattern_provider = pattern_provider
|
251 |
+
self.two_step_cfg = two_step_cfg
|
252 |
+
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
|
253 |
+
if 'activation' in kwargs:
|
254 |
+
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
255 |
+
self.transformer = StreamingTransformer(
|
256 |
+
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
|
257 |
+
norm=norm, norm_first=norm_first, **kwargs)
|
258 |
+
|
259 |
+
|
260 |
+
self.out_norm: tp.Optional[nn.Module] = None
|
261 |
+
if norm_first:
|
262 |
+
self.out_norm = create_norm_fn(norm, dim)
|
263 |
+
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
|
264 |
+
self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
265 |
+
self._fsdp: tp.Optional[nn.Module]
|
266 |
+
self.__dict__['_fsdp'] = None
|
267 |
+
|
268 |
+
if self.visual_encoder == 'clip':
|
269 |
+
self.visual_encoder_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
270 |
+
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
271 |
+
|
272 |
+
else:
|
273 |
+
print(f'the encoder now is:{self.visual_encoder}')
|
274 |
+
print(f'please input the right video encoder.')
|
275 |
+
exit()
|
276 |
+
|
277 |
+
if self.visual_encoder == 'clip':
|
278 |
+
temporal_dim = 768
|
279 |
+
self.local_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim))
|
280 |
+
self.visual_encoder_model = self.visual_encoder_model.eval()
|
281 |
+
for param in self.visual_encoder_model.parameters():
|
282 |
+
param.requires_grad = False
|
283 |
+
|
284 |
+
self.local_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) # [768, 4, 16, 64, 768*4]
|
285 |
+
|
286 |
+
if self.if_add_gobal:
|
287 |
+
if self.visual_encoder == 'clip':
|
288 |
+
self.global_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim))
|
289 |
+
|
290 |
+
self.global_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) # [768, 4, 16, 64, 768*4]
|
291 |
+
|
292 |
+
cross_attention_num_heads = 3 # MultiHeadCrossAttention
|
293 |
+
self.multi_head_cross_attention = MultiHeadCrossAttention(temporal_dim, cross_attention_num_heads)
|
294 |
+
|
295 |
+
self.visual_feature_proj = nn.Linear(temporal_dim, dim)
|
296 |
+
|
297 |
+
|
298 |
+
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
299 |
+
"""Initialization of the transformer module weights.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
weight_init (str, optional): Weight initialization strategy. See `get_init_fn for valid options.
|
303 |
+
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
|
304 |
+
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
305 |
+
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
306 |
+
zero_bias_init (bool): Whether to initialize bias to zero or not.
|
307 |
+
"""
|
308 |
+
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
309 |
+
assert depthwise_init is None or weight_init is not None, \
|
310 |
+
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
311 |
+
assert not zero_bias_init or weight_init is not None, \
|
312 |
+
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
313 |
+
|
314 |
+
if weight_init is None:
|
315 |
+
return
|
316 |
+
|
317 |
+
for emb_layer in self.emb:
|
318 |
+
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
319 |
+
|
320 |
+
for layer_idx, tr_layer in enumerate(self.transformer.layers):
|
321 |
+
depth = None
|
322 |
+
if depthwise_init == 'current':
|
323 |
+
depth = layer_idx + 1
|
324 |
+
elif depthwise_init == 'global':
|
325 |
+
depth = len(self.transformer.layers)
|
326 |
+
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
|
327 |
+
tr_layer.apply(init_fn)
|
328 |
+
|
329 |
+
for linear in self.linears:
|
330 |
+
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
331 |
+
|
332 |
+
|
333 |
+
@property
|
334 |
+
def special_token_id(self) -> int:
|
335 |
+
return self.card
|
336 |
+
|
337 |
+
@property
|
338 |
+
def num_codebooks(self) -> int:
|
339 |
+
return self.n_q
|
340 |
+
|
341 |
+
def compute_video_emb(self, video_tensor_list: tp.List, device: str) -> torch.Tensor:
|
342 |
+
assert isinstance(video_tensor_list, list)
|
343 |
+
assert self.if_add_gobal
|
344 |
+
assert len(video_tensor_list) == 2
|
345 |
+
|
346 |
+
[local_video_tensor, global_video_tensor] = video_tensor_list
|
347 |
+
local_image = local_video_tensor.to(dtype=torch.float32)
|
348 |
+
global_image = global_video_tensor.to(dtype=torch.float32)
|
349 |
+
|
350 |
+
local_batch_size, _, local_time_length, _, _ = local_image.size()
|
351 |
+
local_image = einops.rearrange(local_image, 'b c t h w -> (b t) c h w')
|
352 |
+
|
353 |
+
global_batch_size, _, global_time_length, _, _ = global_image.size()
|
354 |
+
global_image = einops.rearrange(global_image, 'b c t h w -> (b t) c h w')
|
355 |
+
|
356 |
+
local_temporal_transformer = self.local_temporal_transformer
|
357 |
+
global_temporal_transformer = self.global_temporal_transformer
|
358 |
+
|
359 |
+
local_video_inputs = self.processor(images=local_image.float(), return_tensors="pt")
|
360 |
+
local_pixel_values = local_video_inputs['pixel_values'].to(device)
|
361 |
+
|
362 |
+
global_video_inputs = self.processor(images=global_image.float(), return_tensors="pt")
|
363 |
+
global_pixel_values = global_video_inputs['pixel_values'].to(device)
|
364 |
+
|
365 |
+
if self.visual_encoder == 'clip':
|
366 |
+
with torch.no_grad():
|
367 |
+
local_video_hidden = self.visual_encoder_model(pixel_values=local_pixel_values).last_hidden_state
|
368 |
+
local_video_hidden += self.local_pos_embedding
|
369 |
+
local_video_hidden = local_temporal_transformer(local_video_hidden)
|
370 |
+
local_video_hidden = einops.rearrange(
|
371 |
+
local_video_hidden, '(b t) q h -> b (t q) h',
|
372 |
+
b=local_batch_size, t=local_time_length
|
373 |
+
)
|
374 |
+
|
375 |
+
with torch.no_grad():
|
376 |
+
global_video_hidden = self.visual_encoder_model(pixel_values=global_pixel_values).last_hidden_state
|
377 |
+
global_video_hidden += self.global_pos_embedding
|
378 |
+
global_video_hidden = global_temporal_transformer(global_video_hidden)
|
379 |
+
global_video_hidden = einops.rearrange(
|
380 |
+
global_video_hidden, '(b t) q h -> b (t q) h',
|
381 |
+
b=global_batch_size, t=global_time_length
|
382 |
+
)
|
383 |
+
|
384 |
+
video_hidden = self.multi_head_cross_attention(local_video_hidden, global_video_hidden)
|
385 |
+
video_emb = self.visual_feature_proj(video_hidden)
|
386 |
+
|
387 |
+
return video_emb
|
388 |
+
|
389 |
+
|
390 |
+
def forward(self, sequence: torch.Tensor,
|
391 |
+
conditions: tp.List[ConditioningAttributes],
|
392 |
+
video_tensor_list: tp.List,
|
393 |
+
precomputed_video_emb: tp.Optional[torch.Tensor] = None # 新增参数
|
394 |
+
) -> torch.Tensor:
|
395 |
+
|
396 |
+
B, K, S = sequence.shape
|
397 |
+
assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
|
398 |
+
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
399 |
+
self.device = input_.device
|
400 |
+
assert self.device != "cpu"
|
401 |
+
|
402 |
+
if precomputed_video_emb is None:
|
403 |
+
video_emb = self.compute_video_emb(video_tensor_list, device=self.device)
|
404 |
+
else:
|
405 |
+
video_emb = precomputed_video_emb
|
406 |
+
|
407 |
+
out = self.transformer(input_, cross_attention_src=video_emb)
|
408 |
+
if self.out_norm:
|
409 |
+
out = self.out_norm(out)
|
410 |
+
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)
|
411 |
+
|
412 |
+
if len(self.fuser.fuse2cond['prepend']) > 0:
|
413 |
+
logits = logits[:, :, -S:]
|
414 |
+
return logits # [B, K, S, card]
|
415 |
+
|
416 |
+
|
417 |
+
def compute_predictions(
|
418 |
+
self, codes: torch.Tensor,
|
419 |
+
conditions: tp.List[ConditioningAttributes],
|
420 |
+
condition_tensors_list: tp.List) -> LMOutput:
|
421 |
+
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
422 |
+
forward using the specified codes interleaving pattern.
|
423 |
+
|
424 |
+
Args:
|
425 |
+
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
426 |
+
K the number of codebooks and T the number of timesteps.
|
427 |
+
conditions (list of ConditioningAttributes): Conditions to use when modeling
|
428 |
+
the given codes. Note that when evaluating multiple time with the same conditioning
|
429 |
+
you should pre-compute those and pass them as condition_tensors.
|
430 |
+
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
431 |
+
tensors, see conditions.
|
432 |
+
Returns:
|
433 |
+
LMOutput: Language model outputs
|
434 |
+
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
435 |
+
i.e. the first item corresponds to logits to predict the first code, meaning that
|
436 |
+
no additional shifting of codes and logits is required.
|
437 |
+
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
|
438 |
+
Given the specified interleaving strategies, parts of the logits and codes should
|
439 |
+
not be considered as valid predictions because of invalid context.
|
440 |
+
"""
|
441 |
+
B, K, T = codes.shape
|
442 |
+
codes = codes.contiguous()
|
443 |
+
|
444 |
+
assert isinstance(condition_tensors_list, list)
|
445 |
+
pattern = self.pattern_provider.get_pattern(T)
|
446 |
+
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
447 |
+
codes, self.special_token_id, keep_only_valid_steps=True
|
448 |
+
)
|
449 |
+
|
450 |
+
model = self if self._fsdp is None else self._fsdp
|
451 |
+
logits = model(sequence_codes, conditions, condition_tensors_list) # [B, K, S, card]
|
452 |
+
|
453 |
+
|
454 |
+
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
455 |
+
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
456 |
+
logits, float('nan'), keep_only_valid_steps=True
|
457 |
+
)
|
458 |
+
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
459 |
+
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
460 |
+
return LMOutput(logits, logits_mask)
|
461 |
+
|
462 |
+
|
463 |
+
def _sample_next_token(
|
464 |
+
self,
|
465 |
+
sequence: torch.Tensor,
|
466 |
+
cfg_conditions_list: tp.List,
|
467 |
+
unconditional_state: State,
|
468 |
+
use_sampling: bool = False,
|
469 |
+
temp: float = 1.0,
|
470 |
+
top_k: int = 0,
|
471 |
+
top_p: float = 0.0,
|
472 |
+
cfg_coef: tp.Optional[float] = None,
|
473 |
+
two_step_cfg: tp.Optional[bool] = None,
|
474 |
+
precomputed_video_emb: tp.Optional[torch.Tensor] = None # 新增参数
|
475 |
+
) -> torch.Tensor:
|
476 |
+
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
477 |
+
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
478 |
+
|
479 |
+
Args:
|
480 |
+
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
481 |
+
with K corresponding to the number of codebooks and S the number of sequence steps.
|
482 |
+
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
483 |
+
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
|
484 |
+
should be twice the batch size, being the concatenation of the conditions + null conditions.
|
485 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
486 |
+
temp (float): Sampling temperature.
|
487 |
+
top_k (int): K for "top-k" sampling.
|
488 |
+
top_p (float): P for "top-p" sampling.
|
489 |
+
cfg_coef (float, optional): classifier free guidance coefficient.
|
490 |
+
Returns:
|
491 |
+
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
492 |
+
"""
|
493 |
+
B = sequence.shape[0]
|
494 |
+
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
495 |
+
model = self if self._fsdp is None else self._fsdp
|
496 |
+
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
497 |
+
|
498 |
+
assert isinstance(cfg_conditions_list, list)
|
499 |
+
assert len(cfg_conditions_list) == 2
|
500 |
+
local_cfg_conditions = cfg_conditions_list[0]
|
501 |
+
global_cfg_conditions = cfg_conditions_list[1]
|
502 |
+
|
503 |
+
if two_step_cfg and local_cfg_conditions != {}:
|
504 |
+
assert isinstance(local_cfg_conditions, tuple), type(local_cfg_conditions)
|
505 |
+
local_condition_tensors, local_null_condition_tensors = local_cfg_conditions
|
506 |
+
global_condition_tensors, global_null_condition_tensors = global_cfg_conditions
|
507 |
+
cond_logits = model(sequence, conditions=[], condition_tensors=[local_condition_tensors, global_condition_tensors])
|
508 |
+
|
509 |
+
state = self.get_streaming_state()
|
510 |
+
self.set_streaming_state(unconditional_state)
|
511 |
+
uncond_logits = model(sequence, conditions=[], condition_tensors=[local_null_condition_tensors, global_null_condition_tensors])
|
512 |
+
unconditional_state.update(self.get_streaming_state())
|
513 |
+
self.set_streaming_state(state)
|
514 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
|
515 |
+
else:
|
516 |
+
local_condition_tensors = cfg_conditions_list[0].to(sequence.device)
|
517 |
+
global_condition_tensors = cfg_conditions_list[1].to(sequence.device)
|
518 |
+
sequence = torch.cat([sequence, sequence], dim=0)
|
519 |
+
|
520 |
+
if precomputed_video_emb is None:
|
521 |
+
video_emb = self.compute_video_emb([cfg_conditions_list[0], cfg_conditions_list[1]], device=sequence.device)
|
522 |
+
else:
|
523 |
+
video_emb = precomputed_video_emb
|
524 |
+
|
525 |
+
all_logits = model(
|
526 |
+
sequence,
|
527 |
+
conditions=[],
|
528 |
+
video_tensor_list=[],
|
529 |
+
precomputed_video_emb=video_emb
|
530 |
+
)
|
531 |
+
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
532 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
533 |
+
|
534 |
+
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
535 |
+
logits = logits[..., -1] # [B x K x card]
|
536 |
+
|
537 |
+
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
538 |
+
if use_sampling and temp > 0.0:
|
539 |
+
probs = torch.softmax(logits / temp, dim=-1)
|
540 |
+
if top_p > 0.0:
|
541 |
+
next_token = utils.sample_top_p(probs, p=top_p)
|
542 |
+
elif top_k > 0:
|
543 |
+
next_token = utils.sample_top_k(probs, k=top_k)
|
544 |
+
else:
|
545 |
+
next_token = utils.multinomial(probs, num_samples=1)
|
546 |
+
else:
|
547 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
548 |
+
return next_token
|
549 |
+
|
550 |
+
|
551 |
+
@torch.no_grad()
|
552 |
+
def generate(self,
|
553 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
554 |
+
conditions_list: tp.List = [],
|
555 |
+
num_samples: tp.Optional[int] = None,
|
556 |
+
max_gen_len: int = 256,
|
557 |
+
use_sampling: bool = True,
|
558 |
+
temp: float = 1.0,
|
559 |
+
top_k: int = 250,
|
560 |
+
top_p: float = 0.0,
|
561 |
+
cfg_coef: tp.Optional[float] = None,
|
562 |
+
two_step_cfg: tp.Optional[bool] = None,
|
563 |
+
remove_prompts: bool = False,
|
564 |
+
check: bool = False,
|
565 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
|
566 |
+
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
|
567 |
+
be perform in a greedy fashion or using sampling with top K and top P strategies.
|
568 |
+
|
569 |
+
Args:
|
570 |
+
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
|
571 |
+
conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
|
572 |
+
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
|
573 |
+
max_gen_len (int): Maximum generation length.
|
574 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
575 |
+
temp (float): Sampling temperature.
|
576 |
+
top_k (int): K for "top-k" sampling.
|
577 |
+
top_p (float): P for "top-p" sampling.
|
578 |
+
cfg_coeff (float, optional): Classifier-free guidance coefficient.
|
579 |
+
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
|
580 |
+
remove_prompts (bool): Whether to remove prompts from generation or not.
|
581 |
+
check (bool): Whether to apply further checks on generated sequence.
|
582 |
+
callback (Callback, optional): Callback function to report generation progress.
|
583 |
+
Returns:
|
584 |
+
torch.Tensor: Generated tokens.
|
585 |
+
"""
|
586 |
+
assert not self.training, "generation shouldn't be used in training mode."
|
587 |
+
first_param = next(iter(self.parameters()))
|
588 |
+
device = first_param.device
|
589 |
+
assert isinstance(conditions_list, list)
|
590 |
+
|
591 |
+
assert len(conditions_list) == 2
|
592 |
+
local_conditions = conditions_list[0]
|
593 |
+
global_conditions = conditions_list[1]
|
594 |
+
# Checking all input shapes are consistent.
|
595 |
+
possible_num_samples = []
|
596 |
+
if num_samples is not None:
|
597 |
+
possible_num_samples.append(num_samples)
|
598 |
+
elif prompt is not None:
|
599 |
+
possible_num_samples.append(prompt.shape[0])
|
600 |
+
elif local_conditions is not None:
|
601 |
+
possible_num_samples.append(len(local_conditions))
|
602 |
+
else:
|
603 |
+
possible_num_samples.append(1)
|
604 |
+
|
605 |
+
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
606 |
+
num_samples = possible_num_samples[0]
|
607 |
+
|
608 |
+
local_cfg_conditions: CFGConditions
|
609 |
+
global_cfg_conditions: CFGConditions
|
610 |
+
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
611 |
+
local_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(local_conditions)
|
612 |
+
local_cfg_conditions = torch.cat((local_conditions, local_null_conditions), dim=0)
|
613 |
+
global_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(global_conditions)
|
614 |
+
global_cfg_conditions = torch.cat((global_conditions, global_null_conditions), dim=0)
|
615 |
+
|
616 |
+
if prompt is None:
|
617 |
+
assert num_samples > 0
|
618 |
+
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
619 |
+
|
620 |
+
B, K, T = prompt.shape
|
621 |
+
start_offset = T
|
622 |
+
assert start_offset < max_gen_len
|
623 |
+
|
624 |
+
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
625 |
+
# this token is used as default value for codes that are not generated yet
|
626 |
+
unknown_token = -1
|
627 |
+
|
628 |
+
|
629 |
+
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
630 |
+
gen_codes[..., :start_offset] = prompt
|
631 |
+
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
632 |
+
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
633 |
+
assert start_offset_sequence is not None
|
634 |
+
|
635 |
+
video_emb = self.compute_video_emb([local_cfg_conditions, global_cfg_conditions], device=device)
|
636 |
+
|
637 |
+
with self.streaming():
|
638 |
+
unconditional_state = self.get_streaming_state()
|
639 |
+
prev_offset = 0
|
640 |
+
gen_sequence_len = gen_sequence.shape[-1]
|
641 |
+
|
642 |
+
for offset in range(start_offset_sequence, gen_sequence_len):
|
643 |
+
curr_sequence = gen_sequence[..., prev_offset:offset]
|
644 |
+
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
|
645 |
+
if check:
|
646 |
+
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
|
647 |
+
assert not (curr_sequence == unknown_token).any()
|
648 |
+
next_token = self._sample_next_token(
|
649 |
+
curr_sequence,
|
650 |
+
[local_cfg_conditions, global_cfg_conditions],
|
651 |
+
unconditional_state,
|
652 |
+
use_sampling,
|
653 |
+
temp,
|
654 |
+
top_k,
|
655 |
+
top_p,
|
656 |
+
cfg_coef=cfg_coef,
|
657 |
+
two_step_cfg=two_step_cfg,
|
658 |
+
precomputed_video_emb=video_emb #
|
659 |
+
)
|
660 |
+
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
661 |
+
next_token[~valid_mask] = self.special_token_id
|
662 |
+
gen_sequence[..., offset:offset+1] = torch.where(
|
663 |
+
gen_sequence[..., offset:offset+1] == unknown_token,
|
664 |
+
next_token,
|
665 |
+
gen_sequence[..., offset:offset+1]
|
666 |
+
)
|
667 |
+
prev_offset = offset
|
668 |
+
if callback is not None:
|
669 |
+
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
670 |
+
|
671 |
+
unconditional_state.clear()
|
672 |
+
assert not (gen_sequence == unknown_token).any()
|
673 |
+
assert (
|
674 |
+
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
|
675 |
+
).all()
|
676 |
+
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
677 |
+
|
678 |
+
assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
679 |
+
assert (out_mask[..., :max_gen_len] == 1).all()
|
680 |
+
|
681 |
+
out_start_offset = start_offset if remove_prompts else 0
|
682 |
+
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
683 |
+
|
684 |
+
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
685 |
+
return out_codes
|
audiocraft/models/lm_back.py
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Audiocraft (https://github.com/facebookresearch/audiocraft)
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from functools import partial
|
5 |
+
import logging
|
6 |
+
import math
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from ..utils import utils
|
13 |
+
from ..modules.streaming import StreamingModule, State
|
14 |
+
from ..modules.transformer import StreamingTransformer, create_norm_fn
|
15 |
+
|
16 |
+
import time
|
17 |
+
from ..modules.conditioners import (
|
18 |
+
ConditionFuser,
|
19 |
+
ClassifierFreeGuidanceDropout,
|
20 |
+
AttributeDropout,
|
21 |
+
ConditioningProvider,
|
22 |
+
ConditioningAttributes,
|
23 |
+
ConditionType,
|
24 |
+
)
|
25 |
+
from ..modules.codebooks_patterns import CodebooksPatternProvider
|
26 |
+
from ..modules.activations import get_activation_fn
|
27 |
+
import warnings
|
28 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms._transforms_video")
|
29 |
+
import torch.nn.init as init
|
30 |
+
import os
|
31 |
+
|
32 |
+
import logging
|
33 |
+
import random
|
34 |
+
import sys
|
35 |
+
import einops
|
36 |
+
from .transformer_module import Attention, PreNorm, FeedForward
|
37 |
+
from transformers import AutoProcessor, CLIPVisionModelWithProjection, VideoMAEModel
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.getLogger(__name__)
|
41 |
+
ConditionTensors = tp.Dict[str, ConditionType]
|
42 |
+
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
43 |
+
|
44 |
+
|
45 |
+
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
46 |
+
"""LM layer initialization.
|
47 |
+
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
48 |
+
|
49 |
+
Args:
|
50 |
+
method (str): Method name for init function. Valid options are:
|
51 |
+
'gaussian', 'uniform'.
|
52 |
+
input_dim (int): Input dimension of the initialized module.
|
53 |
+
init_depth (int, optional): Optional init depth value used to rescale
|
54 |
+
the standard deviation if defined.
|
55 |
+
"""
|
56 |
+
# Compute std
|
57 |
+
std = 1 / math.sqrt(input_dim)
|
58 |
+
# Rescale with depth
|
59 |
+
if init_depth is not None:
|
60 |
+
std = std / math.sqrt(2 * init_depth)
|
61 |
+
|
62 |
+
if method == 'gaussian':
|
63 |
+
return partial(
|
64 |
+
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
|
65 |
+
)
|
66 |
+
elif method == 'uniform':
|
67 |
+
bound = math.sqrt(3) * std # ensure the standard deviation is `std`
|
68 |
+
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
|
69 |
+
else:
|
70 |
+
raise ValueError("Unsupported layer initialization method")
|
71 |
+
|
72 |
+
|
73 |
+
def init_layer(m: nn.Module,
|
74 |
+
method: str,
|
75 |
+
init_depth: tp.Optional[int] = None,
|
76 |
+
zero_bias_init: bool = False):
|
77 |
+
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
m (nn.Module): Module to initialize.
|
81 |
+
method (str): Method name for the init function.
|
82 |
+
init_depth (int, optional): Optional init depth value used to rescale
|
83 |
+
the standard deviation if defined.
|
84 |
+
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
|
85 |
+
"""
|
86 |
+
if isinstance(m, nn.Linear):
|
87 |
+
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
|
88 |
+
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
89 |
+
weight = m.weight.float()
|
90 |
+
init_fn(weight)
|
91 |
+
m.weight.data[:] = weight.half()
|
92 |
+
else:
|
93 |
+
init_fn(m.weight)
|
94 |
+
if zero_bias_init and m.bias is not None:
|
95 |
+
nn.init.constant_(m.bias, 0)
|
96 |
+
elif isinstance(m, nn.Embedding):
|
97 |
+
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
|
98 |
+
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
99 |
+
weight = m.weight.float()
|
100 |
+
init_fn(weight)
|
101 |
+
m.weight.data[:] = weight.half()
|
102 |
+
else:
|
103 |
+
init_fn(m.weight)
|
104 |
+
|
105 |
+
|
106 |
+
class ScaledEmbedding(nn.Embedding):
|
107 |
+
"""Boost learning rate for embeddings (with `scale`).
|
108 |
+
"""
|
109 |
+
def __init__(self, *args, lr=None, **kwargs):
|
110 |
+
super().__init__(*args, **kwargs)
|
111 |
+
self.lr = lr
|
112 |
+
|
113 |
+
def make_optim_group(self):
|
114 |
+
group = {"params": list(self.parameters())}
|
115 |
+
if self.lr is not None:
|
116 |
+
group["lr"] = self.lr
|
117 |
+
return group
|
118 |
+
|
119 |
+
|
120 |
+
@dataclass
|
121 |
+
class LMOutput:
|
122 |
+
# The logits are already re-aligned with the input codes
|
123 |
+
# hence no extra shift is required, e.g. when computing CE
|
124 |
+
logits: torch.Tensor # [B, K, T, card]
|
125 |
+
mask: torch.Tensor # [B, K, T]
|
126 |
+
|
127 |
+
class Transformer(nn.Module):
|
128 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
129 |
+
super().__init__()
|
130 |
+
self.layers = nn.ModuleList([])
|
131 |
+
self.norm = nn.LayerNorm(dim)
|
132 |
+
for _ in range(depth):
|
133 |
+
self.layers.append(nn.ModuleList([
|
134 |
+
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
135 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
136 |
+
]))
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
for attn, ff in self.layers:
|
140 |
+
x = attn(x) + x
|
141 |
+
x = ff(x) + x
|
142 |
+
return self.norm(x)
|
143 |
+
|
144 |
+
class MultiHeadCrossAttention(nn.Module):
|
145 |
+
def __init__(self, x1, num_heads):
|
146 |
+
super().__init__()
|
147 |
+
self.num_heads = num_heads
|
148 |
+
self.depth = x1 // num_heads
|
149 |
+
|
150 |
+
self.query = nn.Linear(x1, x1)
|
151 |
+
self.key = nn.Linear(x1, x1)
|
152 |
+
self.value = nn.Linear(x1, x1)
|
153 |
+
|
154 |
+
self.final_linear = nn.Linear(x1, x1)
|
155 |
+
|
156 |
+
self.norm1 = nn.LayerNorm(x1)
|
157 |
+
self.norm2 = nn.LayerNorm(x1)
|
158 |
+
|
159 |
+
init.constant_(self.final_linear.weight, 0)
|
160 |
+
if self.final_linear.bias is not None:
|
161 |
+
init.constant_(self.final_linear.bias, 0)
|
162 |
+
|
163 |
+
def split_heads(self, x, batch_size):
|
164 |
+
x = x.view(batch_size, -1, self.num_heads, self.depth)
|
165 |
+
return x.permute(0, 2, 1, 3)
|
166 |
+
|
167 |
+
def forward(self, tensor_A, tensor_B):
|
168 |
+
batch_size = tensor_A.size(0)
|
169 |
+
|
170 |
+
Q = self.split_heads(self.query(tensor_A), batch_size)
|
171 |
+
K = self.split_heads(self.key(tensor_B), batch_size)
|
172 |
+
V = self.split_heads(self.value(tensor_B), batch_size)
|
173 |
+
|
174 |
+
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.depth ** 0.5)
|
175 |
+
attention_scores = torch.softmax(attention_scores, dim=-1)
|
176 |
+
|
177 |
+
attention_output = torch.matmul(attention_scores, V)
|
178 |
+
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
|
179 |
+
|
180 |
+
output = attention_output.view(batch_size, -1, self.num_heads * self.depth)
|
181 |
+
|
182 |
+
output = self.norm1(output + tensor_A)
|
183 |
+
output = self.norm2(self.final_linear(output) + output)
|
184 |
+
return output
|
185 |
+
|
186 |
+
|
187 |
+
def evenly_sample_or_duplicate_frames(video_tensor, target_frames=32):
|
188 |
+
num_frames = video_tensor.size(0)
|
189 |
+
if target_frames <= num_frames:
|
190 |
+
indices = torch.linspace(0, num_frames - 1, steps=target_frames).long()
|
191 |
+
return video_tensor[indices]
|
192 |
+
else:
|
193 |
+
scale_factor = target_frames / num_frames
|
194 |
+
repeated_indices = (torch.arange(target_frames) / scale_factor).long()
|
195 |
+
return video_tensor[repeated_indices]
|
196 |
+
|
197 |
+
class LMModel(StreamingModule):
|
198 |
+
"""Transformer-based language model on multiple streams of codes.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
|
202 |
+
condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
|
203 |
+
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
|
204 |
+
n_q (int): Number of parallel streams to model.
|
205 |
+
card (int): Cardinality, vocabulary size.
|
206 |
+
dim (int): Dimension of the transformer encoder.
|
207 |
+
num_heads (int): Number of heads for the transformer encoder.
|
208 |
+
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
209 |
+
norm (str): Normalization method.
|
210 |
+
norm_first (bool): Use pre-norm instead of post-norm.
|
211 |
+
emb_lr (float, optional): Embedding-specific learning rate.
|
212 |
+
bias_proj (bool): Use bias for output projections.
|
213 |
+
weight_init (str, optional): Method for weight initialization.
|
214 |
+
depthwise_init (str, optional): Method for depthwise weight initialization.
|
215 |
+
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
216 |
+
cfg_dropout (float): Classifier-free guidance dropout.
|
217 |
+
cfg_coef (float): Classifier-free guidance coefficient.
|
218 |
+
attribute_dropout (dict): Attribute dropout probabilities.
|
219 |
+
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
220 |
+
**kwargs: Additional parameters for the transformer encoder.
|
221 |
+
"""
|
222 |
+
|
223 |
+
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
|
224 |
+
visual_encoder,
|
225 |
+
if_add_gobal,
|
226 |
+
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
|
227 |
+
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
|
228 |
+
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
|
229 |
+
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
|
230 |
+
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
|
231 |
+
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
|
232 |
+
depth=2,
|
233 |
+
temporal_dim=768,
|
234 |
+
dim_head=64,
|
235 |
+
**kwargs):
|
236 |
+
super().__init__()
|
237 |
+
self.cfg_coef = cfg_coef
|
238 |
+
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
|
239 |
+
self.att_dropout = AttributeDropout(p=attribute_dropout)
|
240 |
+
self.condition_provider = condition_provider
|
241 |
+
self.visual_encoder=visual_encoder
|
242 |
+
self.if_add_gobal=if_add_gobal
|
243 |
+
self.temporal_dim=temporal_dim
|
244 |
+
|
245 |
+
self.fuser = fuser
|
246 |
+
self.card = card
|
247 |
+
embed_dim = self.card + 1
|
248 |
+
self.n_q = n_q
|
249 |
+
self.dim = dim
|
250 |
+
self.pattern_provider = pattern_provider
|
251 |
+
self.two_step_cfg = two_step_cfg
|
252 |
+
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
|
253 |
+
if 'activation' in kwargs:
|
254 |
+
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
255 |
+
self.transformer = StreamingTransformer(
|
256 |
+
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
|
257 |
+
norm=norm, norm_first=norm_first, **kwargs)
|
258 |
+
|
259 |
+
|
260 |
+
self.out_norm: tp.Optional[nn.Module] = None
|
261 |
+
if norm_first:
|
262 |
+
self.out_norm = create_norm_fn(norm, dim)
|
263 |
+
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
|
264 |
+
self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
265 |
+
self._fsdp: tp.Optional[nn.Module]
|
266 |
+
self.__dict__['_fsdp'] = None
|
267 |
+
|
268 |
+
if self.visual_encoder=='clip':
|
269 |
+
self.visual_encoder_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
270 |
+
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
271 |
+
|
272 |
+
else:
|
273 |
+
print(f'the encoder now is:{self.visual_encoder}')
|
274 |
+
print(f'please input the right video encoder.')
|
275 |
+
exit()
|
276 |
+
|
277 |
+
if self.visual_encoder=='clip':
|
278 |
+
temporal_dim=768
|
279 |
+
self.local_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim))
|
280 |
+
self.visual_encoder_model = self.visual_encoder_model.eval()
|
281 |
+
for param in self.visual_encoder_model.parameters():
|
282 |
+
param.requires_grad = False
|
283 |
+
|
284 |
+
self.local_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) # [768, 4, 16, 64, 768*4]
|
285 |
+
|
286 |
+
if self.if_add_gobal:
|
287 |
+
if self.visual_encoder=='clip':
|
288 |
+
self.global_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim))
|
289 |
+
|
290 |
+
self.global_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) # [768, 4, 16, 64, 768*4]
|
291 |
+
|
292 |
+
cross_attention_num_heads = 3 # MultiHeadCrossAttention
|
293 |
+
self.multi_head_cross_attention = MultiHeadCrossAttention(temporal_dim, cross_attention_num_heads)
|
294 |
+
|
295 |
+
self.visual_feature_proj = nn.Linear(temporal_dim, dim)
|
296 |
+
|
297 |
+
|
298 |
+
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
299 |
+
"""Initialization of the transformer module weights.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
303 |
+
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
|
304 |
+
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
305 |
+
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
306 |
+
zero_bias_init (bool): Whether to initialize bias to zero or not.
|
307 |
+
"""
|
308 |
+
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
309 |
+
assert depthwise_init is None or weight_init is not None, \
|
310 |
+
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
311 |
+
assert not zero_bias_init or weight_init is not None, \
|
312 |
+
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
313 |
+
|
314 |
+
if weight_init is None:
|
315 |
+
return
|
316 |
+
|
317 |
+
for emb_layer in self.emb:
|
318 |
+
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
319 |
+
|
320 |
+
for layer_idx, tr_layer in enumerate(self.transformer.layers):
|
321 |
+
depth = None
|
322 |
+
if depthwise_init == 'current':
|
323 |
+
depth = layer_idx + 1
|
324 |
+
elif depthwise_init == 'global':
|
325 |
+
depth = len(self.transformer.layers)
|
326 |
+
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
|
327 |
+
tr_layer.apply(init_fn)
|
328 |
+
|
329 |
+
for linear in self.linears:
|
330 |
+
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
331 |
+
|
332 |
+
@property
|
333 |
+
def special_token_id(self) -> int:
|
334 |
+
return self.card
|
335 |
+
|
336 |
+
@property
|
337 |
+
def num_codebooks(self) -> int:
|
338 |
+
return self.n_q
|
339 |
+
|
340 |
+
def forward(self, sequence: torch.Tensor,
|
341 |
+
conditions: tp.List[ConditioningAttributes],
|
342 |
+
video_tensor_list: tp.List) -> torch.Tensor:
|
343 |
+
"""Apply language model on sequence and conditions.
|
344 |
+
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
|
345 |
+
S the sequence steps, return the logits with shape [B, card, K, S].
|
346 |
+
|
347 |
+
Args:
|
348 |
+
indices (torch.Tensor): Indices of the codes to model.
|
349 |
+
conditions (list of ConditioningAttributes): Conditions to use when modeling
|
350 |
+
the given codes. Note that when evaluating multiple time with the same conditioning
|
351 |
+
you should pre-compute those and pass them as `condition_tensors`.
|
352 |
+
# condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
353 |
+
# tensors, see `conditions`.
|
354 |
+
video_tensor (torch.Tensor): Indices of the video features [b c t h w].
|
355 |
+
Returns:
|
356 |
+
torch.Tensor: Logits.
|
357 |
+
"""
|
358 |
+
|
359 |
+
B, K, S = sequence.shape
|
360 |
+
assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
|
361 |
+
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
362 |
+
self.device = input_.device
|
363 |
+
assert self.device != "cpu"
|
364 |
+
|
365 |
+
if self.visual_encoder=='clip':
|
366 |
+
visual_encoder_model = self.visual_encoder_model
|
367 |
+
processor = self.processor
|
368 |
+
|
369 |
+
assert isinstance(video_tensor_list, list)
|
370 |
+
|
371 |
+
assert self.if_add_gobal
|
372 |
+
assert len(video_tensor_list)==2
|
373 |
+
|
374 |
+
[local_video_tensor, global_video_tensor] = video_tensor_list
|
375 |
+
local_image = local_video_tensor.to(dtype=torch.float32)
|
376 |
+
global_image = global_video_tensor.to(dtype=torch.float32)
|
377 |
+
|
378 |
+
local_batch_size,_,local_time_length,_,_ = local_image.size()
|
379 |
+
local_image = einops.rearrange(local_image, 'b c t h w -> (b t) c h w')
|
380 |
+
|
381 |
+
global_batch_size,_,global_time_length,_,_ = global_image.size()
|
382 |
+
global_image = einops.rearrange(global_image, 'b c t h w -> (b t) c h w')
|
383 |
+
|
384 |
+
local_temporal_transformer = self.local_temporal_transformer
|
385 |
+
global_temporal_transformer = self.global_temporal_transformer
|
386 |
+
|
387 |
+
local_video_inputs = processor(images=local_image.float(), return_tensors="pt")
|
388 |
+
local_pixel_values = local_video_inputs['pixel_values'].to(self.device)
|
389 |
+
|
390 |
+
global_video_inputs = processor(images=global_image.float(), return_tensors="pt")
|
391 |
+
global_pixel_values = global_video_inputs['pixel_values'].to(self.device)
|
392 |
+
|
393 |
+
if self.visual_encoder=='clip':
|
394 |
+
with torch.no_grad():
|
395 |
+
local_video_hidden = visual_encoder_model(pixel_values=local_pixel_values).last_hidden_state
|
396 |
+
local_video_hidden += self.local_pos_embedding
|
397 |
+
local_video_hidden = local_temporal_transformer(local_video_hidden)
|
398 |
+
local_video_hidden = einops.rearrange(local_video_hidden, '(b t) q h -> b (t q) h',b=local_batch_size,t=local_time_length)
|
399 |
+
with torch.no_grad():
|
400 |
+
global_video_hidden = visual_encoder_model(pixel_values=global_pixel_values).last_hidden_state
|
401 |
+
global_video_hidden += self.global_pos_embedding
|
402 |
+
global_video_hidden = global_temporal_transformer(global_video_hidden)
|
403 |
+
global_video_hidden = einops.rearrange(global_video_hidden, '(b t) q h -> b (t q) h',b=global_batch_size,t=global_time_length)
|
404 |
+
|
405 |
+
assert local_batch_size==global_batch_size
|
406 |
+
video_hidden = self.multi_head_cross_attention(local_video_hidden, global_video_hidden)
|
407 |
+
video_emb = self.visual_feature_proj(video_hidden)
|
408 |
+
|
409 |
+
out = self.transformer(input_, cross_attention_src=video_emb)
|
410 |
+
if self.out_norm:
|
411 |
+
out = self.out_norm(out)
|
412 |
+
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)
|
413 |
+
|
414 |
+
# remove the prefix from the model outputs
|
415 |
+
if len(self.fuser.fuse2cond['prepend']) > 0:
|
416 |
+
logits = logits[:, :, -S:]
|
417 |
+
return logits # [B, K, S, card]
|
418 |
+
|
419 |
+
|
420 |
+
def compute_predictions(
|
421 |
+
self, codes: torch.Tensor,
|
422 |
+
conditions: tp.List[ConditioningAttributes],
|
423 |
+
condition_tensors_list: tp.List) -> LMOutput:
|
424 |
+
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
425 |
+
forward using the specified codes interleaving pattern.
|
426 |
+
|
427 |
+
Args:
|
428 |
+
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
429 |
+
K the number of codebooks and T the number of timesteps.
|
430 |
+
conditions (list of ConditioningAttributes): conditionings to use when modeling
|
431 |
+
the given codes. Note that when evaluating multiple time with the same conditioning
|
432 |
+
you should pre-compute those and pass them as `condition_tensors`.
|
433 |
+
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
|
434 |
+
tensors, see `conditions`.
|
435 |
+
Returns:
|
436 |
+
LMOutput: Language model outputs
|
437 |
+
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
438 |
+
i.e. the first item corresponds to logits to predict the first code, meaning that
|
439 |
+
no additional shifting of codes and logits is required.
|
440 |
+
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
|
441 |
+
Given the specified interleaving strategies, parts of the logits and codes should
|
442 |
+
not be considered as valid predictions because of invalid context.
|
443 |
+
"""
|
444 |
+
B, K, T = codes.shape
|
445 |
+
codes = codes.contiguous()
|
446 |
+
|
447 |
+
assert isinstance(condition_tensors_list,list)
|
448 |
+
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
449 |
+
pattern = self.pattern_provider.get_pattern(T)
|
450 |
+
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
451 |
+
codes, self.special_token_id, keep_only_valid_steps=True
|
452 |
+
)
|
453 |
+
|
454 |
+
model = self if self._fsdp is None else self._fsdp
|
455 |
+
logits = model(sequence_codes, conditions, condition_tensors_list) # [B, K, S, card]
|
456 |
+
|
457 |
+
# apply model on pattern sequence
|
458 |
+
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
459 |
+
# and provide the corresponding mask over invalid positions of tokens
|
460 |
+
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
461 |
+
# note: we use nans as special token to make it obvious if we feed unexpected logits
|
462 |
+
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
463 |
+
logits, float('nan'), keep_only_valid_steps=True
|
464 |
+
)
|
465 |
+
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
466 |
+
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
467 |
+
return LMOutput(logits, logits_mask)
|
468 |
+
|
469 |
+
def _sample_next_token(self,
|
470 |
+
sequence: torch.Tensor,
|
471 |
+
cfg_conditions_list: tp.List,
|
472 |
+
unconditional_state: State,
|
473 |
+
use_sampling: bool = False,
|
474 |
+
temp: float = 1.0,
|
475 |
+
top_k: int = 0,
|
476 |
+
top_p: float = 0.0,
|
477 |
+
cfg_coef: tp.Optional[float] = None,
|
478 |
+
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
|
479 |
+
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
480 |
+
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
481 |
+
|
482 |
+
Args:
|
483 |
+
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
484 |
+
with K corresponding to the number of codebooks and S the number of sequence steps.
|
485 |
+
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
486 |
+
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
|
487 |
+
should be twice the batch size, being the concatenation of the conditions + null conditions.
|
488 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
489 |
+
temp (float): Sampling temperature.
|
490 |
+
top_k (int): K for "top-k" sampling.
|
491 |
+
top_p (float): P for "top-p" sampling.
|
492 |
+
cfg_coef (float, optional): classifier free guidance coefficient
|
493 |
+
Returns:
|
494 |
+
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
495 |
+
"""
|
496 |
+
B = sequence.shape[0]
|
497 |
+
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
498 |
+
model = self if self._fsdp is None else self._fsdp
|
499 |
+
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
500 |
+
|
501 |
+
assert isinstance(cfg_conditions_list,list)
|
502 |
+
|
503 |
+
assert len(cfg_conditions_list)==2
|
504 |
+
local_cfg_conditions=cfg_conditions_list[0]
|
505 |
+
global_cfg_conditions=cfg_conditions_list[1]
|
506 |
+
if two_step_cfg and local_cfg_conditions != {}:
|
507 |
+
assert isinstance(local_cfg_conditions, tuple), type(local_cfg_conditions)
|
508 |
+
local_condition_tensors, local_null_condition_tensors = local_cfg_conditions
|
509 |
+
global_condition_tensors, global_null_condition_tensors = global_cfg_conditions
|
510 |
+
cond_logits = model(sequence, conditions=[], condition_tensors=[local_condition_tensors, global_condition_tensors])
|
511 |
+
|
512 |
+
state = self.get_streaming_state()
|
513 |
+
self.set_streaming_state(unconditional_state)
|
514 |
+
uncond_logits = model(sequence, conditions=[], condition_tensors=[local_null_condition_tensors, global_null_condition_tensors])
|
515 |
+
unconditional_state.update(self.get_streaming_state())
|
516 |
+
self.set_streaming_state(state)
|
517 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
|
518 |
+
else:
|
519 |
+
local_condition_tensors = local_cfg_conditions
|
520 |
+
sequence = torch.cat([sequence, sequence], dim=0)
|
521 |
+
local_condition_tensors = local_condition_tensors.to(sequence.device)
|
522 |
+
|
523 |
+
global_condition_tensors = global_cfg_conditions
|
524 |
+
global_condition_tensors = global_condition_tensors.to(sequence.device)
|
525 |
+
|
526 |
+
all_logits = model(
|
527 |
+
sequence,
|
528 |
+
conditions=[], video_tensor_list=[local_condition_tensors, global_condition_tensors])
|
529 |
+
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
530 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
531 |
+
|
532 |
+
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
533 |
+
logits = logits[..., -1] # [B x K x card]
|
534 |
+
|
535 |
+
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
536 |
+
if use_sampling and temp > 0.0:
|
537 |
+
probs = torch.softmax(logits / temp, dim=-1)
|
538 |
+
if top_p > 0.0:
|
539 |
+
next_token = utils.sample_top_p(probs, p=top_p)
|
540 |
+
elif top_k > 0:
|
541 |
+
next_token = utils.sample_top_k(probs, k=top_k)
|
542 |
+
else:
|
543 |
+
next_token = utils.multinomial(probs, num_samples=1)
|
544 |
+
else:
|
545 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
546 |
+
return next_token
|
547 |
+
|
548 |
+
@torch.no_grad()
|
549 |
+
def generate(self,
|
550 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
551 |
+
conditions_list: tp.List = [],
|
552 |
+
num_samples: tp.Optional[int] = None,
|
553 |
+
max_gen_len: int = 256,
|
554 |
+
use_sampling: bool = True,
|
555 |
+
temp: float = 1.0,
|
556 |
+
top_k: int = 250,
|
557 |
+
top_p: float = 0.0,
|
558 |
+
cfg_coef: tp.Optional[float] = None,
|
559 |
+
two_step_cfg: tp.Optional[bool] = None,
|
560 |
+
remove_prompts: bool = False,
|
561 |
+
check: bool = False,
|
562 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
|
563 |
+
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
|
564 |
+
be perform in a greedy fashion or using sampling with top K and top P strategies.
|
565 |
+
|
566 |
+
Args:
|
567 |
+
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
|
568 |
+
conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
|
569 |
+
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
|
570 |
+
max_gen_len (int): Maximum generation length.
|
571 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
572 |
+
temp (float): Sampling temperature.
|
573 |
+
top_k (int): K for "top-k" sampling.
|
574 |
+
top_p (float): P for "top-p" sampling.
|
575 |
+
cfg_coeff (float, optional): Classifier-free guidance coefficient.
|
576 |
+
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
|
577 |
+
remove_prompts (bool): Whether to remove prompts from generation or not.
|
578 |
+
check (bool): Whether to apply further checks on generated sequence.
|
579 |
+
callback (Callback, optional): Callback function to report generation progress.
|
580 |
+
Returns:
|
581 |
+
torch.Tensor: Generated tokens.
|
582 |
+
"""
|
583 |
+
assert not self.training, "generation shouldn't be used in training mode."
|
584 |
+
first_param = next(iter(self.parameters()))
|
585 |
+
device = first_param.device
|
586 |
+
assert isinstance(conditions_list,list)
|
587 |
+
|
588 |
+
assert len(conditions_list)==2
|
589 |
+
local_conditions=conditions_list[0]
|
590 |
+
global_conditions=conditions_list[1]
|
591 |
+
# Checking all input shapes are consistent.
|
592 |
+
possible_num_samples = []
|
593 |
+
if num_samples is not None:
|
594 |
+
possible_num_samples.append(num_samples)
|
595 |
+
elif prompt is not None:
|
596 |
+
possible_num_samples.append(prompt.shape[0])
|
597 |
+
elif local_conditions is not None:
|
598 |
+
possible_num_samples.append(len(local_conditions))
|
599 |
+
else:
|
600 |
+
possible_num_samples.append(1)
|
601 |
+
|
602 |
+
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
603 |
+
num_samples = possible_num_samples[0]
|
604 |
+
|
605 |
+
# below we create set of local_conditions: one conditional and one unconditional
|
606 |
+
# to do that we merge the regular condition together with the null condition
|
607 |
+
# we then do 1 forward pass instead of 2.
|
608 |
+
# the reason for that is two-fold:
|
609 |
+
# 1. it is about x2 faster than doing 2 forward passes
|
610 |
+
# 2. avoid the streaming API treating the 2 passes as part of different time steps
|
611 |
+
# We also support doing two different passes, in particular to ensure that
|
612 |
+
# the padding structure is exactly the same between train and test.
|
613 |
+
# With a batch size of 1, this can be slower though.
|
614 |
+
local_cfg_conditions: CFGConditions
|
615 |
+
global_cfg_conditions: CFGConditions
|
616 |
+
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
617 |
+
local_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(local_conditions)
|
618 |
+
local_cfg_conditions = torch.cat((local_conditions,local_null_conditions), dim=0)
|
619 |
+
|
620 |
+
global_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(global_conditions)
|
621 |
+
global_cfg_conditions = torch.cat((global_conditions,global_null_conditions), dim=0)
|
622 |
+
|
623 |
+
if prompt is None:
|
624 |
+
assert num_samples > 0
|
625 |
+
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
626 |
+
|
627 |
+
B, K, T = prompt.shape
|
628 |
+
start_offset = T
|
629 |
+
assert start_offset < max_gen_len
|
630 |
+
|
631 |
+
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
632 |
+
# this token is used as default value for codes that are not generated yet
|
633 |
+
unknown_token = -1
|
634 |
+
|
635 |
+
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
|
636 |
+
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
637 |
+
# filling the gen_codes with the prompt if needed
|
638 |
+
gen_codes[..., :start_offset] = prompt
|
639 |
+
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
640 |
+
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
641 |
+
# retrieve the start_offset in the sequence:
|
642 |
+
# it is the first sequence step that contains the `start_offset` timestep
|
643 |
+
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
644 |
+
assert start_offset_sequence is not None
|
645 |
+
|
646 |
+
with self.streaming():
|
647 |
+
unconditional_state = self.get_streaming_state()
|
648 |
+
prev_offset = 0
|
649 |
+
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
|
650 |
+
|
651 |
+
for offset in range(start_offset_sequence, gen_sequence_len):
|
652 |
+
# get current sequence (note that the streaming API is providing the caching over previous offsets)
|
653 |
+
curr_sequence = gen_sequence[..., prev_offset:offset]
|
654 |
+
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
|
655 |
+
if check:
|
656 |
+
# check coherence between mask and sequence
|
657 |
+
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
|
658 |
+
# should never happen as gen_sequence is filled progressively
|
659 |
+
assert not (curr_sequence == unknown_token).any()
|
660 |
+
# sample next token from the model, next token shape is [B, K, 1]
|
661 |
+
next_token = self._sample_next_token(
|
662 |
+
curr_sequence, [local_cfg_conditions, global_cfg_conditions], unconditional_state, use_sampling, temp, top_k, top_p,
|
663 |
+
cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
|
664 |
+
# ensure the tokens that should be masked are properly set to special_token_id
|
665 |
+
# as the model never output special_token_id
|
666 |
+
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
667 |
+
next_token[~valid_mask] = self.special_token_id
|
668 |
+
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
669 |
+
# (then mask tokens should be left as is as well, which is correct)
|
670 |
+
gen_sequence[..., offset:offset+1] = torch.where(
|
671 |
+
gen_sequence[..., offset:offset+1] == unknown_token,
|
672 |
+
next_token, gen_sequence[..., offset:offset+1]
|
673 |
+
)
|
674 |
+
prev_offset = offset
|
675 |
+
if callback is not None:
|
676 |
+
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
677 |
+
|
678 |
+
unconditional_state.clear()
|
679 |
+
# ensure sequence has been entirely filled
|
680 |
+
assert not (gen_sequence == unknown_token).any()
|
681 |
+
# ensure gen_sequence pattern and mask are matching
|
682 |
+
# which means the gen_sequence is valid according to the pattern
|
683 |
+
assert (
|
684 |
+
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
|
685 |
+
).all()
|
686 |
+
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
|
687 |
+
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
688 |
+
|
689 |
+
# sanity checks over the returned codes and corresponding masks
|
690 |
+
assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
691 |
+
assert (out_mask[..., :max_gen_len] == 1).all()
|
692 |
+
|
693 |
+
out_start_offset = start_offset if remove_prompts else 0
|
694 |
+
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
695 |
+
|
696 |
+
# ensure the returned codes are all valid
|
697 |
+
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
698 |
+
return out_codes
|
audiocraft/models/loaders.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility functions to load from the checkpoints.
|
9 |
+
Each checkpoint is a torch.saved dict with the following keys:
|
10 |
+
- 'xp.cfg': the hydra config as dumped during training. This should be used
|
11 |
+
to rebuild the object using the audiocraft.models.builders functions,
|
12 |
+
- 'model_best_state': a readily loadable best state for the model, including
|
13 |
+
the conditioner. The model obtained from `xp.cfg` should be compatible
|
14 |
+
with this state dict. In the case of a LM, the encodec model would not be
|
15 |
+
bundled along but instead provided separately.
|
16 |
+
|
17 |
+
Those functions also support loading from a remote location with the Torch Hub API.
|
18 |
+
They also support overriding some parameters, in particular the device and dtype
|
19 |
+
of the returned model.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from pathlib import Path
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
+
import typing as tp
|
25 |
+
import os
|
26 |
+
|
27 |
+
from omegaconf import OmegaConf, DictConfig
|
28 |
+
import torch
|
29 |
+
|
30 |
+
import audiocraft
|
31 |
+
from . import builders
|
32 |
+
from .encodec import CompressionModel
|
33 |
+
|
34 |
+
|
35 |
+
def get_audiocraft_cache_dir() -> tp.Optional[str]:
|
36 |
+
return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
|
37 |
+
|
38 |
+
|
39 |
+
def _get_state_dict(
|
40 |
+
file_or_url_or_id: tp.Union[Path, str],
|
41 |
+
filename: tp.Optional[str] = None,
|
42 |
+
device='cpu',
|
43 |
+
cache_dir: tp.Optional[str] = None,
|
44 |
+
):
|
45 |
+
if cache_dir is None:
|
46 |
+
cache_dir = get_audiocraft_cache_dir()
|
47 |
+
# Return the state dict either from a file or url
|
48 |
+
file_or_url_or_id = str(file_or_url_or_id)
|
49 |
+
assert isinstance(file_or_url_or_id, str)
|
50 |
+
|
51 |
+
if os.path.isfile(file_or_url_or_id):
|
52 |
+
return torch.load(file_or_url_or_id, map_location=device)
|
53 |
+
|
54 |
+
if os.path.isdir(file_or_url_or_id):
|
55 |
+
file = f"{file_or_url_or_id}/{filename}"
|
56 |
+
return torch.load(file, map_location=device)
|
57 |
+
|
58 |
+
elif file_or_url_or_id.startswith('https://'):
|
59 |
+
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
60 |
+
|
61 |
+
else:
|
62 |
+
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
63 |
+
|
64 |
+
file = hf_hub_download(
|
65 |
+
repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
|
66 |
+
library_name="audiocraft", library_version=audiocraft.__version__)
|
67 |
+
return torch.load(file, map_location=device)
|
68 |
+
|
69 |
+
|
70 |
+
def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
|
71 |
+
return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
|
72 |
+
|
73 |
+
|
74 |
+
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
75 |
+
pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
76 |
+
if 'pretrained' in pkg:
|
77 |
+
return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
|
78 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
79 |
+
cfg.device = str(device)
|
80 |
+
model = builders.get_compression_model(cfg)
|
81 |
+
model.load_state_dict(pkg['best_state'])
|
82 |
+
model.eval()
|
83 |
+
return model
|
84 |
+
|
85 |
+
|
86 |
+
def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
|
87 |
+
return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
|
88 |
+
|
89 |
+
|
90 |
+
def _delete_param(cfg: DictConfig, full_name: str):
|
91 |
+
parts = full_name.split('.')
|
92 |
+
for part in parts[:-1]:
|
93 |
+
if part in cfg:
|
94 |
+
cfg = cfg[part]
|
95 |
+
else:
|
96 |
+
return
|
97 |
+
OmegaConf.set_struct(cfg, False)
|
98 |
+
if parts[-1] in cfg:
|
99 |
+
del cfg[parts[-1]]
|
100 |
+
OmegaConf.set_struct(cfg, True)
|
101 |
+
|
102 |
+
|
103 |
+
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
104 |
+
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
105 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
106 |
+
cfg.device = str(device)
|
107 |
+
if cfg.device == 'cpu':
|
108 |
+
cfg.dtype = 'float32'
|
109 |
+
else:
|
110 |
+
cfg.dtype = 'float16'
|
111 |
+
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
112 |
+
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
113 |
+
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
114 |
+
model = builders.get_lm_model(cfg)
|
115 |
+
model.load_state_dict(pkg['best_state'])
|
116 |
+
model.eval()
|
117 |
+
model.cfg = cfg
|
118 |
+
return model
|
119 |
+
|
120 |
+
|
121 |
+
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
|
122 |
+
filename: tp.Optional[str] = None,
|
123 |
+
cache_dir: tp.Optional[str] = None):
|
124 |
+
return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
|
125 |
+
|
126 |
+
|
127 |
+
def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
|
128 |
+
device='cpu',
|
129 |
+
filename: tp.Optional[str] = None,
|
130 |
+
cache_dir: tp.Optional[str] = None):
|
131 |
+
pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
|
132 |
+
models = []
|
133 |
+
processors = []
|
134 |
+
cfgs = []
|
135 |
+
sample_rate = pkg['sample_rate']
|
136 |
+
for i in range(pkg['n_bands']):
|
137 |
+
cfg = pkg[i]['cfg']
|
138 |
+
model = builders.get_diffusion_model(cfg)
|
139 |
+
model_dict = pkg[i]['model_state']
|
140 |
+
model.load_state_dict(model_dict)
|
141 |
+
model.to(device)
|
142 |
+
processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
|
143 |
+
processor_dict = pkg[i]['processor_state']
|
144 |
+
processor.load_state_dict(processor_dict)
|
145 |
+
processor.to(device)
|
146 |
+
models.append(model)
|
147 |
+
processors.append(processor)
|
148 |
+
cfgs.append(cfg)
|
149 |
+
return models, processors, cfgs
|
audiocraft/models/multibanddiffusion.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Multi Band Diffusion models as described in
|
9 |
+
"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
|
10 |
+
(paper link).
|
11 |
+
"""
|
12 |
+
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import julius
|
17 |
+
|
18 |
+
from .unet import DiffusionUnet
|
19 |
+
from ..modules.diffusion_schedule import NoiseSchedule
|
20 |
+
from .encodec import CompressionModel
|
21 |
+
from ..solvers.compression import CompressionSolver
|
22 |
+
from .loaders import load_compression_model, load_diffusion_models
|
23 |
+
|
24 |
+
|
25 |
+
class DiffusionProcess:
|
26 |
+
"""Sampling for a diffusion Model.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
model (DiffusionUnet): Diffusion U-Net model.
|
30 |
+
noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
|
31 |
+
"""
|
32 |
+
def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
|
33 |
+
"""
|
34 |
+
"""
|
35 |
+
self.model = model
|
36 |
+
self.schedule = noise_schedule
|
37 |
+
|
38 |
+
def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
|
39 |
+
step_list: tp.Optional[tp.List[int]] = None):
|
40 |
+
"""Perform one diffusion process to generate one of the bands.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
condition (tensor): The embeddings form the compression model.
|
44 |
+
initial_noise (tensor): The initial noise to start the process/
|
45 |
+
"""
|
46 |
+
return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
|
47 |
+
condition=condition)
|
48 |
+
|
49 |
+
|
50 |
+
class MultiBandDiffusion:
|
51 |
+
"""Sample from multiple diffusion models.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
DPs (list of DiffusionProcess): Diffusion processes.
|
55 |
+
codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
|
56 |
+
"""
|
57 |
+
def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
|
58 |
+
self.DPs = DPs
|
59 |
+
self.codec_model = codec_model
|
60 |
+
self.device = next(self.codec_model.parameters()).device
|
61 |
+
|
62 |
+
@property
|
63 |
+
def sample_rate(self) -> int:
|
64 |
+
return self.codec_model.sample_rate
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def get_mbd_musicgen(device=None):
|
68 |
+
"""Load our diffusion models trained for MusicGen."""
|
69 |
+
if device is None:
|
70 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
71 |
+
path = 'facebook/multiband-diffusion'
|
72 |
+
filename = 'mbd_musicgen_32khz.th'
|
73 |
+
name = 'facebook/musicgen-small'
|
74 |
+
codec_model = load_compression_model(name, device=device)
|
75 |
+
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
|
76 |
+
DPs = []
|
77 |
+
for i in range(len(models)):
|
78 |
+
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
|
79 |
+
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
|
80 |
+
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
|
84 |
+
device: tp.Optional[tp.Union[torch.device, str]] = None,
|
85 |
+
n_q: tp.Optional[int] = None):
|
86 |
+
"""Get the pretrained Models for MultibandDiffusion.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
bw (float): Bandwidth of the compression model.
|
90 |
+
pretrained (bool): Whether to use / download if necessary the models.
|
91 |
+
device (torch.device or str, optional): Device on which the models are loaded.
|
92 |
+
n_q (int, optional): Number of quantizers to use within the compression model.
|
93 |
+
"""
|
94 |
+
if device is None:
|
95 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
96 |
+
assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
|
97 |
+
if n_q is not None:
|
98 |
+
assert n_q in [2, 4, 8]
|
99 |
+
assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
|
100 |
+
f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
|
101 |
+
n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
|
102 |
+
codec_model = CompressionSolver.model_from_checkpoint(
|
103 |
+
'//pretrained/facebook/encodec_24khz', device=device)
|
104 |
+
codec_model.set_num_codebooks(n_q)
|
105 |
+
codec_model = codec_model.to(device)
|
106 |
+
path = 'facebook/multiband-diffusion'
|
107 |
+
filename = f'mbd_comp_{n_q}.pt'
|
108 |
+
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
|
109 |
+
DPs = []
|
110 |
+
for i in range(len(models)):
|
111 |
+
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
|
112 |
+
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
|
113 |
+
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
|
114 |
+
|
115 |
+
return MultiBandDiffusion(DPs, codec_model)
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
119 |
+
"""Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.
|
120 |
+
Args:
|
121 |
+
wav (torch.Tensor): The audio that we want to extract the conditioning from
|
122 |
+
sample_rate (int): sample rate of the audio"""
|
123 |
+
if sample_rate != self.sample_rate:
|
124 |
+
wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
|
125 |
+
codes, scale = self.codec_model.encode(wav)
|
126 |
+
assert scale is None, "Scaled compression models not supported."
|
127 |
+
emb = self.get_emb(codes)
|
128 |
+
return emb
|
129 |
+
|
130 |
+
@torch.no_grad()
|
131 |
+
def get_emb(self, codes: torch.Tensor):
|
132 |
+
"""Get latent representation from the discrete codes
|
133 |
+
Argrs:
|
134 |
+
codes (torch.Tensor): discrete tokens"""
|
135 |
+
emb = self.codec_model.decode_latent(codes)
|
136 |
+
return emb
|
137 |
+
|
138 |
+
def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
|
139 |
+
step_list: tp.Optional[tp.List[int]] = None):
|
140 |
+
"""Generate Wavform audio from the latent embeddings of the compression model
|
141 |
+
Args:
|
142 |
+
emb (torch.Tensor): Conditioning embeddinds
|
143 |
+
size (none torch.Size): size of the output
|
144 |
+
if None this is computed from the typical upsampling of the model
|
145 |
+
step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step.
|
146 |
+
"""
|
147 |
+
if size is None:
|
148 |
+
upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
|
149 |
+
size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
|
150 |
+
assert size[0] == emb.size(0)
|
151 |
+
out = torch.zeros(size).to(self.device)
|
152 |
+
for DP in self.DPs:
|
153 |
+
out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
|
154 |
+
return out
|
155 |
+
|
156 |
+
def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
|
157 |
+
"""match the eq to the encodec output by matching the standard deviation of some frequency bands
|
158 |
+
Args:
|
159 |
+
wav (torch.Tensor): audio to equalize
|
160 |
+
ref (torch.Tensor):refenrence audio from which we match the spectrogram.
|
161 |
+
n_bands (int): number of bands of the eq
|
162 |
+
strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching.
|
163 |
+
"""
|
164 |
+
split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
|
165 |
+
bands = split(wav)
|
166 |
+
bands_ref = split(ref)
|
167 |
+
out = torch.zeros_like(ref)
|
168 |
+
for i in range(n_bands):
|
169 |
+
out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
|
170 |
+
return out
|
171 |
+
|
172 |
+
def regenerate(self, wav: torch.Tensor, sample_rate: int):
|
173 |
+
"""Regenerate a wavform through compression and diffusion regeneration.
|
174 |
+
Args:
|
175 |
+
wav (torch.Tensor): Original 'ground truth' audio
|
176 |
+
sample_rate (int): sample rate of the input (and output) wav
|
177 |
+
"""
|
178 |
+
if sample_rate != self.codec_model.sample_rate:
|
179 |
+
wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
|
180 |
+
emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
|
181 |
+
size = wav.size()
|
182 |
+
out = self.generate(emb, size=size)
|
183 |
+
if sample_rate != self.codec_model.sample_rate:
|
184 |
+
out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
|
185 |
+
return out
|
186 |
+
|
187 |
+
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
|
188 |
+
"""Generate Waveform audio with diffusion from the discrete codes.
|
189 |
+
Args:
|
190 |
+
tokens (torch.Tensor): discrete codes
|
191 |
+
n_bands (int): bands for the eq matching.
|
192 |
+
"""
|
193 |
+
wav_encodec = self.codec_model.decode(tokens)
|
194 |
+
condition = self.get_emb(tokens)
|
195 |
+
wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
|
196 |
+
return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
|
audiocraft/models/transformer_module.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from einops.layers.torch import Rearrange
|
7 |
+
|
8 |
+
class Residual(nn.Module):
|
9 |
+
def __init__(self, fn):
|
10 |
+
super().__init__()
|
11 |
+
self.fn = fn
|
12 |
+
def forward(self, x, **kwargs):
|
13 |
+
return self.fn(x, **kwargs) + x
|
14 |
+
|
15 |
+
class PreNorm(nn.Module):
|
16 |
+
def __init__(self, dim, fn):
|
17 |
+
super().__init__()
|
18 |
+
self.norm = nn.LayerNorm(dim)
|
19 |
+
self.fn = fn
|
20 |
+
def forward(self, x, **kwargs):
|
21 |
+
return self.fn(self.norm(x), **kwargs)
|
22 |
+
|
23 |
+
class FeedForward(nn.Module):
|
24 |
+
def __init__(self, dim, hidden_dim, dropout = 0.):
|
25 |
+
super().__init__()
|
26 |
+
self.net = nn.Sequential(
|
27 |
+
nn.Linear(dim, hidden_dim),
|
28 |
+
nn.GELU(),
|
29 |
+
nn.Dropout(dropout),
|
30 |
+
nn.Linear(hidden_dim, dim),
|
31 |
+
nn.Dropout(dropout)
|
32 |
+
)
|
33 |
+
def forward(self, x):
|
34 |
+
return self.net(x)
|
35 |
+
|
36 |
+
class Attention(nn.Module):
|
37 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
38 |
+
super().__init__()
|
39 |
+
inner_dim = dim_head * heads
|
40 |
+
project_out = not (heads == 1 and dim_head == dim)
|
41 |
+
|
42 |
+
self.heads = heads
|
43 |
+
self.scale = dim_head ** -0.5
|
44 |
+
|
45 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
46 |
+
|
47 |
+
self.to_out = nn.Sequential(
|
48 |
+
nn.Linear(inner_dim, dim),
|
49 |
+
nn.Dropout(dropout)
|
50 |
+
) if project_out else nn.Identity()
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
b, n, _, h = *x.shape, self.heads
|
54 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
55 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
56 |
+
|
57 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
58 |
+
|
59 |
+
attn = dots.softmax(dim=-1)
|
60 |
+
|
61 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
62 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
63 |
+
out = self.to_out(out)
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
class ReAttention(nn.Module):
|
68 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
69 |
+
super().__init__()
|
70 |
+
inner_dim = dim_head * heads
|
71 |
+
self.heads = heads
|
72 |
+
self.scale = dim_head ** -0.5
|
73 |
+
|
74 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
75 |
+
|
76 |
+
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
|
77 |
+
|
78 |
+
self.reattn_norm = nn.Sequential(
|
79 |
+
Rearrange('b h i j -> b i j h'),
|
80 |
+
nn.LayerNorm(heads),
|
81 |
+
Rearrange('b i j h -> b h i j')
|
82 |
+
)
|
83 |
+
|
84 |
+
self.to_out = nn.Sequential(
|
85 |
+
nn.Linear(inner_dim, dim),
|
86 |
+
nn.Dropout(dropout)
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
b, n, _, h = *x.shape, self.heads
|
91 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
92 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
93 |
+
|
94 |
+
# attention
|
95 |
+
|
96 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
97 |
+
attn = dots.softmax(dim=-1)
|
98 |
+
|
99 |
+
# re-attention
|
100 |
+
|
101 |
+
attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
|
102 |
+
attn = self.reattn_norm(attn)
|
103 |
+
|
104 |
+
# aggregate and out
|
105 |
+
|
106 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
107 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
108 |
+
out = self.to_out(out)
|
109 |
+
return out
|
110 |
+
|
111 |
+
class LeFF(nn.Module):
|
112 |
+
|
113 |
+
def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
|
114 |
+
super().__init__()
|
115 |
+
|
116 |
+
scale_dim = dim*scale
|
117 |
+
self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
|
118 |
+
Rearrange('b n c -> b c n'),
|
119 |
+
nn.BatchNorm1d(scale_dim),
|
120 |
+
nn.GELU(),
|
121 |
+
Rearrange('b c (h w) -> b c h w', h=14, w=14)
|
122 |
+
)
|
123 |
+
|
124 |
+
self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
|
125 |
+
nn.BatchNorm2d(scale_dim),
|
126 |
+
nn.GELU(),
|
127 |
+
Rearrange('b c h w -> b (h w) c', h=14, w=14)
|
128 |
+
)
|
129 |
+
|
130 |
+
self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
|
131 |
+
Rearrange('b n c -> b c n'),
|
132 |
+
nn.BatchNorm1d(dim),
|
133 |
+
nn.GELU(),
|
134 |
+
Rearrange('b c n -> b n c')
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
x = self.up_proj(x)
|
139 |
+
x = self.depth_conv(x)
|
140 |
+
x = self.down_proj(x)
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
class LCAttention(nn.Module):
|
145 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
146 |
+
super().__init__()
|
147 |
+
inner_dim = dim_head * heads
|
148 |
+
project_out = not (heads == 1 and dim_head == dim)
|
149 |
+
|
150 |
+
self.heads = heads
|
151 |
+
self.scale = dim_head ** -0.5
|
152 |
+
|
153 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
154 |
+
|
155 |
+
self.to_out = nn.Sequential(
|
156 |
+
nn.Linear(inner_dim, dim),
|
157 |
+
nn.Dropout(dropout)
|
158 |
+
) if project_out else nn.Identity()
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
b, n, _, h = *x.shape, self.heads
|
162 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
163 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
164 |
+
q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query
|
165 |
+
|
166 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
167 |
+
|
168 |
+
attn = dots.softmax(dim=-1)
|
169 |
+
|
170 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
171 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
172 |
+
out = self.to_out(out)
|
173 |
+
return out
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
audiocraft/models/unet.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Pytorch Unet Module used for diffusion.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class Output:
|
22 |
+
sample: torch.Tensor
|
23 |
+
|
24 |
+
|
25 |
+
def get_model(cfg, channels: int, side: int, num_steps: int):
|
26 |
+
if cfg.model == 'unet':
|
27 |
+
return DiffusionUnet(
|
28 |
+
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
29 |
+
else:
|
30 |
+
raise RuntimeError('Not Implemented')
|
31 |
+
|
32 |
+
|
33 |
+
class ResBlock(nn.Module):
|
34 |
+
def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
|
35 |
+
dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
36 |
+
dropout: float = 0.):
|
37 |
+
super().__init__()
|
38 |
+
stride = 1
|
39 |
+
padding = dilation * (kernel - stride) // 2
|
40 |
+
Conv = nn.Conv1d
|
41 |
+
Drop = nn.Dropout1d
|
42 |
+
self.norm1 = nn.GroupNorm(norm_groups, channels)
|
43 |
+
self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
44 |
+
self.activation1 = activation()
|
45 |
+
self.dropout1 = Drop(dropout)
|
46 |
+
|
47 |
+
self.norm2 = nn.GroupNorm(norm_groups, channels)
|
48 |
+
self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
49 |
+
self.activation2 = activation()
|
50 |
+
self.dropout2 = Drop(dropout)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
|
54 |
+
h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
|
55 |
+
return x + h
|
56 |
+
|
57 |
+
|
58 |
+
class DecoderLayer(nn.Module):
|
59 |
+
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
60 |
+
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
61 |
+
dropout: float = 0.):
|
62 |
+
super().__init__()
|
63 |
+
padding = (kernel - stride) // 2
|
64 |
+
self.res_blocks = nn.Sequential(
|
65 |
+
*[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
66 |
+
for idx in range(res_blocks)])
|
67 |
+
self.norm = nn.GroupNorm(norm_groups, chin)
|
68 |
+
ConvTr = nn.ConvTranspose1d
|
69 |
+
self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
|
70 |
+
self.activation = activation()
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
73 |
+
x = self.res_blocks(x)
|
74 |
+
x = self.norm(x)
|
75 |
+
x = self.activation(x)
|
76 |
+
x = self.convtr(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class EncoderLayer(nn.Module):
|
81 |
+
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
82 |
+
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
83 |
+
dropout: float = 0.):
|
84 |
+
super().__init__()
|
85 |
+
padding = (kernel - stride) // 2
|
86 |
+
Conv = nn.Conv1d
|
87 |
+
self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
|
88 |
+
self.norm = nn.GroupNorm(norm_groups, chout)
|
89 |
+
self.activation = activation()
|
90 |
+
self.res_blocks = nn.Sequential(
|
91 |
+
*[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
92 |
+
for idx in range(res_blocks)])
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
+
B, C, T = x.shape
|
96 |
+
stride, = self.conv.stride
|
97 |
+
pad = (stride - (T % stride)) % stride
|
98 |
+
x = F.pad(x, (0, pad))
|
99 |
+
|
100 |
+
x = self.conv(x)
|
101 |
+
x = self.norm(x)
|
102 |
+
x = self.activation(x)
|
103 |
+
x = self.res_blocks(x)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class BLSTM(nn.Module):
|
108 |
+
"""BiLSTM with same hidden units as input dim.
|
109 |
+
"""
|
110 |
+
def __init__(self, dim, layers=2):
|
111 |
+
super().__init__()
|
112 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
113 |
+
self.linear = nn.Linear(2 * dim, dim)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
x = x.permute(2, 0, 1)
|
117 |
+
x = self.lstm(x)[0]
|
118 |
+
x = self.linear(x)
|
119 |
+
x = x.permute(1, 2, 0)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class DiffusionUnet(nn.Module):
|
124 |
+
def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
|
125 |
+
max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
|
126 |
+
bilstm: bool = False, transformer: bool = False,
|
127 |
+
codec_dim: tp.Optional[int] = None, **kwargs):
|
128 |
+
super().__init__()
|
129 |
+
self.encoders = nn.ModuleList()
|
130 |
+
self.decoders = nn.ModuleList()
|
131 |
+
self.embeddings: tp.Optional[nn.ModuleList] = None
|
132 |
+
self.embedding = nn.Embedding(num_steps, hidden)
|
133 |
+
if emb_all_layers:
|
134 |
+
self.embeddings = nn.ModuleList()
|
135 |
+
self.condition_embedding: tp.Optional[nn.Module] = None
|
136 |
+
for d in range(depth):
|
137 |
+
encoder = EncoderLayer(chin, hidden, **kwargs)
|
138 |
+
decoder = DecoderLayer(hidden, chin, **kwargs)
|
139 |
+
self.encoders.append(encoder)
|
140 |
+
self.decoders.insert(0, decoder)
|
141 |
+
if emb_all_layers and d > 0:
|
142 |
+
assert self.embeddings is not None
|
143 |
+
self.embeddings.append(nn.Embedding(num_steps, hidden))
|
144 |
+
chin = hidden
|
145 |
+
hidden = min(int(chin * growth), max_channels)
|
146 |
+
self.bilstm: tp.Optional[nn.Module]
|
147 |
+
if bilstm:
|
148 |
+
self.bilstm = BLSTM(chin)
|
149 |
+
else:
|
150 |
+
self.bilstm = None
|
151 |
+
self.use_transformer = transformer
|
152 |
+
self.cross_attention = False
|
153 |
+
if transformer:
|
154 |
+
self.cross_attention = cross_attention
|
155 |
+
self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
|
156 |
+
cross_attention=cross_attention)
|
157 |
+
|
158 |
+
self.use_codec = False
|
159 |
+
if codec_dim is not None:
|
160 |
+
self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
|
161 |
+
self.use_codec = True
|
162 |
+
|
163 |
+
def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
|
164 |
+
skips = []
|
165 |
+
bs = x.size(0)
|
166 |
+
z = x
|
167 |
+
view_args = [1]
|
168 |
+
if type(step) is torch.Tensor:
|
169 |
+
step_tensor = step
|
170 |
+
else:
|
171 |
+
step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
|
172 |
+
|
173 |
+
for idx, encoder in enumerate(self.encoders):
|
174 |
+
z = encoder(z)
|
175 |
+
if idx == 0:
|
176 |
+
z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
|
177 |
+
elif self.embeddings is not None:
|
178 |
+
z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
|
179 |
+
|
180 |
+
skips.append(z)
|
181 |
+
|
182 |
+
if self.use_codec: # insert condition in the bottleneck
|
183 |
+
assert condition is not None, "Model defined for conditionnal generation"
|
184 |
+
condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
|
185 |
+
assert condition_emb.size(-1) <= 2 * z.size(-1), \
|
186 |
+
f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
|
187 |
+
if not self.cross_attention:
|
188 |
+
|
189 |
+
condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
|
190 |
+
assert z.size() == condition_emb.size()
|
191 |
+
z += condition_emb
|
192 |
+
cross_attention_src = None
|
193 |
+
else:
|
194 |
+
cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
|
195 |
+
B, T, C = cross_attention_src.shape
|
196 |
+
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
197 |
+
pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
|
198 |
+
cross_attention_src = cross_attention_src + pos_emb
|
199 |
+
if self.use_transformer:
|
200 |
+
z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
|
201 |
+
else:
|
202 |
+
if self.bilstm is None:
|
203 |
+
z = torch.zeros_like(z)
|
204 |
+
else:
|
205 |
+
z = self.bilstm(z)
|
206 |
+
|
207 |
+
for decoder in self.decoders:
|
208 |
+
s = skips.pop(-1)
|
209 |
+
z = z[:, :, :s.shape[2]]
|
210 |
+
z = z + s
|
211 |
+
z = decoder(z)
|
212 |
+
|
213 |
+
z = z[:, :, :x.shape[2]]
|
214 |
+
return Output(z)
|
audiocraft/models/vidmuse.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Audiocraft (https://github.com/facebookresearch/audiocraft)
|
2 |
+
|
3 |
+
import typing as tp
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import omegaconf
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .encodec import CompressionModel
|
10 |
+
from .lm import LMModel
|
11 |
+
from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
|
12 |
+
from .loaders import load_compression_model, load_lm_model
|
13 |
+
from ..data.audio_utils import convert_audio
|
14 |
+
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
15 |
+
from ..utils.autocast import TorchAutocast
|
16 |
+
|
17 |
+
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
18 |
+
MelodyType = tp.Union[torch.Tensor, MelodyList]
|
19 |
+
|
20 |
+
# backward compatible names mapping
|
21 |
+
_HF_MODEL_CHECKPOINTS_MAP = {
|
22 |
+
"small": "facebook/musicgen-small",
|
23 |
+
"medium": "facebook/musicgen-medium",
|
24 |
+
"large": "facebook/musicgen-large",
|
25 |
+
"melody": "facebook/musicgen-melody",
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class VidMuse:
|
30 |
+
"""VidMuse main model with convenient generation API.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
name (str): name of the model.
|
34 |
+
compression_model (CompressionModel): Compression model
|
35 |
+
used to map audio to invertible discrete representations.
|
36 |
+
lm (LMModel): Language model over discrete representations.
|
37 |
+
max_duration (float, optional): maximum duration the model can produce,
|
38 |
+
otherwise, inferred from the training params.
|
39 |
+
"""
|
40 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
41 |
+
max_duration: tp.Optional[float] = None):
|
42 |
+
self.name = name
|
43 |
+
self.compression_model = compression_model
|
44 |
+
self.lm = lm
|
45 |
+
self.cfg: tp.Optional[omegaconf.DictConfig] = None
|
46 |
+
# Just to be safe, let's put everything in eval mode.
|
47 |
+
self.compression_model.eval()
|
48 |
+
self.lm.eval()
|
49 |
+
|
50 |
+
if hasattr(lm, 'cfg'):
|
51 |
+
cfg = lm.cfg
|
52 |
+
assert isinstance(cfg, omegaconf.DictConfig)
|
53 |
+
self.cfg = cfg
|
54 |
+
|
55 |
+
if self.cfg is not None:
|
56 |
+
self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
|
57 |
+
|
58 |
+
if max_duration is None:
|
59 |
+
if self.cfg is not None:
|
60 |
+
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
61 |
+
else:
|
62 |
+
raise ValueError("You must provide max_duration when building directly MusicGen")
|
63 |
+
assert max_duration is not None
|
64 |
+
self.max_duration: float = max_duration
|
65 |
+
self.device = next(iter(lm.parameters())).device
|
66 |
+
|
67 |
+
self.generation_params: dict = {}
|
68 |
+
self.set_generation_params(duration=15) # 15 seconds by default
|
69 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
70 |
+
if self.device.type == 'cpu':
|
71 |
+
self.autocast = TorchAutocast(enabled=False)
|
72 |
+
else:
|
73 |
+
self.autocast = TorchAutocast(
|
74 |
+
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
75 |
+
|
76 |
+
@property
|
77 |
+
def frame_rate(self) -> float:
|
78 |
+
"""Roughly the number of AR steps per seconds."""
|
79 |
+
return self.compression_model.frame_rate
|
80 |
+
|
81 |
+
@property
|
82 |
+
def sample_rate(self) -> int:
|
83 |
+
"""Sample rate of the generated audio."""
|
84 |
+
return self.compression_model.sample_rate
|
85 |
+
|
86 |
+
@property
|
87 |
+
def audio_channels(self) -> int:
|
88 |
+
"""Audio channels of the generated audio."""
|
89 |
+
return self.compression_model.channels
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
|
93 |
+
"""Return pretrained model, we provide four models:
|
94 |
+
- facebook/musicgen-small (300M), text to music,
|
95 |
+
# see: https://huggingface.co/facebook/musicgen-small
|
96 |
+
- facebook/musicgen-medium (1.5B), text to music,
|
97 |
+
# see: https://huggingface.co/facebook/musicgen-medium
|
98 |
+
- facebook/musicgen-melody (1.5B) text to music and text+melody to music,
|
99 |
+
# see: https://huggingface.co/facebook/musicgen-melody
|
100 |
+
- facebook/musicgen-large (3.3B), text to music,
|
101 |
+
# see: https://huggingface.co/facebook/musicgen-large
|
102 |
+
"""
|
103 |
+
if device is None:
|
104 |
+
if torch.cuda.device_count():
|
105 |
+
device = 'cuda'
|
106 |
+
else:
|
107 |
+
device = 'cpu'
|
108 |
+
|
109 |
+
if name == 'debug':
|
110 |
+
# used only for unit tests
|
111 |
+
compression_model = get_debug_compression_model(device)
|
112 |
+
lm = get_debug_lm_model(device)
|
113 |
+
return VidMuse(name, compression_model, lm, max_duration=30)
|
114 |
+
|
115 |
+
if name in _HF_MODEL_CHECKPOINTS_MAP:
|
116 |
+
# warnings.warn(
|
117 |
+
# "MusicGen pretrained model relying on deprecated checkpoint mapping. " +
|
118 |
+
# f"Please use full pre-trained id instead: facebook/musicgen-{name}")
|
119 |
+
name = _HF_MODEL_CHECKPOINTS_MAP[name]
|
120 |
+
|
121 |
+
lm = load_lm_model(name, device=device)
|
122 |
+
compression_model = load_compression_model(name, device=device)
|
123 |
+
if 'self_wav' in lm.condition_provider.conditioners:
|
124 |
+
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
125 |
+
lm.condition_provider.conditioners['self_wav']._use_masking = False
|
126 |
+
return VidMuse(name, compression_model, lm, max_duration=30)
|
127 |
+
|
128 |
+
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
129 |
+
top_p: float = 0.0, temperature: float = 1.0,
|
130 |
+
duration: float = 30.0, cfg_coef: float = 3.0,
|
131 |
+
two_step_cfg: bool = False, extend_stride: float = 29.5):
|
132 |
+
"""Set the generation parameters for VidMuse.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
136 |
+
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
137 |
+
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
138 |
+
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
139 |
+
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
|
140 |
+
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
141 |
+
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
142 |
+
instead of batching together the two. This has some impact on how things
|
143 |
+
are padded but seems to have little impact in practice.
|
144 |
+
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
|
145 |
+
should we extend the audio each time. Larger values will mean less context is
|
146 |
+
preserved, and shorter value will require extra computations.
|
147 |
+
"""
|
148 |
+
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
149 |
+
self.extend_stride = extend_stride
|
150 |
+
self.duration = duration
|
151 |
+
self.generation_params = {
|
152 |
+
'use_sampling': use_sampling,
|
153 |
+
'temp': temperature,
|
154 |
+
'top_k': top_k,
|
155 |
+
'top_p': top_p,
|
156 |
+
'cfg_coef': cfg_coef,
|
157 |
+
'two_step_cfg': two_step_cfg,
|
158 |
+
}
|
159 |
+
|
160 |
+
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
161 |
+
"""Override the default progress callback."""
|
162 |
+
self._progress_callback = progress_callback
|
163 |
+
|
164 |
+
def generate_unconditional(self, num_samples: int, progress: bool = False,
|
165 |
+
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
166 |
+
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
167 |
+
"""Generate samples in an unconditional manner.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
num_samples (int): Number of samples to be generated.
|
171 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
172 |
+
"""
|
173 |
+
descriptions: tp.List[tp.Optional[torch.Tensor]] = [None] * num_samples
|
174 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
175 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
176 |
+
if return_tokens:
|
177 |
+
return self.generate_audio(tokens), tokens
|
178 |
+
return self.generate_audio(tokens)
|
179 |
+
|
180 |
+
def generate(self, descriptions_list: tp.List, progress: bool = False, return_tokens: bool = False) \
|
181 |
+
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
182 |
+
"""Generate samples conditioned on text.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
186 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
187 |
+
"""
|
188 |
+
|
189 |
+
assert isinstance(descriptions_list,list)
|
190 |
+
assert len(descriptions_list)<=2
|
191 |
+
|
192 |
+
assert len(descriptions_list)==2
|
193 |
+
local_descriptions=[descriptions_list[0]]
|
194 |
+
global_descriptions=[descriptions_list[1]]
|
195 |
+
|
196 |
+
local_attributes = torch.stack(local_descriptions)
|
197 |
+
global_attributes = torch.stack(global_descriptions)
|
198 |
+
|
199 |
+
prompt_tokens = None
|
200 |
+
assert prompt_tokens is None
|
201 |
+
|
202 |
+
assert len(descriptions_list)==2
|
203 |
+
tokens = self._generate_tokens([local_attributes, global_attributes], prompt_tokens, progress)
|
204 |
+
|
205 |
+
if return_tokens:
|
206 |
+
return self.generate_audio(tokens), tokens
|
207 |
+
return self.generate_audio(tokens)
|
208 |
+
|
209 |
+
def generate_with_chroma(self, descriptions: tp.List[torch.Tensor], melody_wavs: MelodyType,
|
210 |
+
melody_sample_rate: int, progress: bool = False,
|
211 |
+
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
212 |
+
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
213 |
+
"""Generate samples conditioned on text and melody.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
217 |
+
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
218 |
+
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
219 |
+
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
220 |
+
a list of [C, T] tensors.
|
221 |
+
melody_sample_rate: (int): Sample rate of the melody waveforms.
|
222 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
223 |
+
"""
|
224 |
+
if isinstance(melody_wavs, torch.Tensor):
|
225 |
+
if melody_wavs.dim() == 2:
|
226 |
+
melody_wavs = melody_wavs[None]
|
227 |
+
if melody_wavs.dim() != 3:
|
228 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
229 |
+
melody_wavs = list(melody_wavs)
|
230 |
+
else:
|
231 |
+
for melody in melody_wavs:
|
232 |
+
if melody is not None:
|
233 |
+
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
|
234 |
+
|
235 |
+
melody_wavs = [
|
236 |
+
convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
|
237 |
+
if wav is not None else None
|
238 |
+
for wav in melody_wavs]
|
239 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
240 |
+
melody_wavs=melody_wavs)
|
241 |
+
assert prompt_tokens is None
|
242 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
243 |
+
if return_tokens:
|
244 |
+
return self.generate_audio(tokens), tokens
|
245 |
+
return self.generate_audio(tokens)
|
246 |
+
|
247 |
+
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
248 |
+
descriptions: tp.Optional[tp.List[tp.Optional[torch.Tensor]]] = None,
|
249 |
+
progress: bool = False, return_tokens: bool = False) \
|
250 |
+
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
251 |
+
"""Generate samples conditioned on audio prompts.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
255 |
+
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
256 |
+
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
257 |
+
descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
|
258 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
259 |
+
"""
|
260 |
+
if prompt.dim() == 2:
|
261 |
+
prompt = prompt[None]
|
262 |
+
if prompt.dim() != 3:
|
263 |
+
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
264 |
+
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
|
265 |
+
if descriptions is None:
|
266 |
+
descriptions = [None] * len(prompt)
|
267 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
268 |
+
assert prompt_tokens is not None
|
269 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
270 |
+
if return_tokens:
|
271 |
+
return self.generate_audio(tokens), tokens
|
272 |
+
return self.generate_audio(tokens)
|
273 |
+
|
274 |
+
@torch.no_grad()
|
275 |
+
def _prepare_tokens_and_attributes(
|
276 |
+
self,
|
277 |
+
descriptions: tp.Sequence[tp.Optional[str]],
|
278 |
+
prompt: tp.Optional[torch.Tensor],
|
279 |
+
melody_wavs: tp.Optional[MelodyList] = None,
|
280 |
+
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
281 |
+
"""Prepare model inputs.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
285 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
286 |
+
melody_wavs (torch.Tensor, optional): A batch of waveforms
|
287 |
+
used as melody conditioning. Defaults to None.
|
288 |
+
"""
|
289 |
+
attributes = [
|
290 |
+
ConditioningAttributes(text={'description': description})
|
291 |
+
for description in descriptions]
|
292 |
+
|
293 |
+
if melody_wavs is None:
|
294 |
+
for attr in attributes:
|
295 |
+
attr.wav['self_wav'] = WavCondition(
|
296 |
+
torch.zeros((1, 1, 1), device=self.device),
|
297 |
+
torch.tensor([0], device=self.device),
|
298 |
+
sample_rate=[self.sample_rate],
|
299 |
+
path=[None])
|
300 |
+
else:
|
301 |
+
if 'self_wav' not in self.lm.condition_provider.conditioners:
|
302 |
+
raise RuntimeError("This model doesn't support melody conditioning. "
|
303 |
+
"Use the `melody` model.")
|
304 |
+
assert len(melody_wavs) == len(descriptions), \
|
305 |
+
f"number of melody wavs must match number of descriptions! " \
|
306 |
+
f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
|
307 |
+
for attr, melody in zip(attributes, melody_wavs):
|
308 |
+
if melody is None:
|
309 |
+
attr.wav['self_wav'] = WavCondition(
|
310 |
+
torch.zeros((1, 1, 1), device=self.device),
|
311 |
+
torch.tensor([0], device=self.device),
|
312 |
+
sample_rate=[self.sample_rate],
|
313 |
+
path=[None])
|
314 |
+
else:
|
315 |
+
attr.wav['self_wav'] = WavCondition(
|
316 |
+
melody[None].to(device=self.device),
|
317 |
+
torch.tensor([melody.shape[-1]], device=self.device),
|
318 |
+
sample_rate=[self.sample_rate],
|
319 |
+
path=[None],
|
320 |
+
)
|
321 |
+
|
322 |
+
if prompt is not None:
|
323 |
+
if descriptions is not None:
|
324 |
+
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
325 |
+
prompt = prompt.to(self.device)
|
326 |
+
prompt_tokens, scale = self.compression_model.encode(prompt)
|
327 |
+
assert scale is None
|
328 |
+
else:
|
329 |
+
prompt_tokens = None
|
330 |
+
return attributes, prompt_tokens
|
331 |
+
|
332 |
+
def _generate_tokens(self, attributes: tp.List,
|
333 |
+
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
334 |
+
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
|
338 |
+
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
339 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
340 |
+
Returns:
|
341 |
+
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
342 |
+
"""
|
343 |
+
self.max_duration = 30
|
344 |
+
|
345 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
346 |
+
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
347 |
+
current_gen_offset: int = 0
|
348 |
+
|
349 |
+
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
350 |
+
generated_tokens += current_gen_offset
|
351 |
+
if self._progress_callback is not None:
|
352 |
+
# Note that total_gen_len might be quite wrong depending on the
|
353 |
+
# codebook pattern used, but with delay it is almost accurate.
|
354 |
+
self._progress_callback(generated_tokens, total_gen_len)
|
355 |
+
else:
|
356 |
+
print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
357 |
+
|
358 |
+
if prompt_tokens is not None:
|
359 |
+
assert max_prompt_len >= prompt_tokens.shape[-1], \
|
360 |
+
"Prompt is longer than audio to generate"
|
361 |
+
|
362 |
+
callback = None
|
363 |
+
if progress:
|
364 |
+
callback = _progress_callback
|
365 |
+
|
366 |
+
if self.duration <= self.max_duration:
|
367 |
+
# generate by sampling from LM, simple case.
|
368 |
+
with self.autocast:
|
369 |
+
gen_tokens = self.lm.generate(
|
370 |
+
prompt_tokens, attributes,
|
371 |
+
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
372 |
+
|
373 |
+
else:
|
374 |
+
# now this gets a bit messier, we need to handle prompts,
|
375 |
+
# melody conditioning etc.
|
376 |
+
|
377 |
+
assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
|
378 |
+
assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
379 |
+
# ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
380 |
+
all_tokens = []
|
381 |
+
if prompt_tokens is None: # None
|
382 |
+
prompt_length = 0
|
383 |
+
else:
|
384 |
+
all_tokens.append(prompt_tokens)
|
385 |
+
prompt_length = prompt_tokens.shape[-1]
|
386 |
+
|
387 |
+
stride_tokens = int(self.frame_rate * self.extend_stride) # max_duration - overlap_duration
|
388 |
+
|
389 |
+
self.fps = 2
|
390 |
+
stride_video_frames = int(self.fps * self.extend_stride)
|
391 |
+
|
392 |
+
while current_gen_offset + prompt_length < total_gen_len:
|
393 |
+
time_offset = current_gen_offset / self.frame_rate
|
394 |
+
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
395 |
+
max_gen_len = int(chunk_duration * self.frame_rate)
|
396 |
+
|
397 |
+
with self.autocast:
|
398 |
+
assert len(attributes)==2
|
399 |
+
# import pdb; pdb.set_trace()
|
400 |
+
gen_tokens = self.lm.generate(
|
401 |
+
prompt_tokens, [attributes[0][:,:,:int(chunk_duration*self.fps),:,:], attributes[1]],
|
402 |
+
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
403 |
+
|
404 |
+
if prompt_tokens is None:
|
405 |
+
all_tokens.append(gen_tokens)
|
406 |
+
else:
|
407 |
+
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
408 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
409 |
+
prompt_length = prompt_tokens.shape[-1]
|
410 |
+
|
411 |
+
if attributes[0].shape[2]-stride_video_frames < self.max_duration*self.fps:
|
412 |
+
attributes[0]=attributes[0][:,:,-self.max_duration*self.fps:,:,:]
|
413 |
+
else:
|
414 |
+
attributes[0]=attributes[0][:,:,stride_video_frames:,:,:]
|
415 |
+
current_gen_offset += stride_tokens
|
416 |
+
|
417 |
+
gen_tokens = torch.cat(all_tokens, dim=-1)
|
418 |
+
return gen_tokens
|
419 |
+
|
420 |
+
def generate_audio(self, gen_tokens: torch.Tensor):
|
421 |
+
"""Generate Audio from tokens"""
|
422 |
+
assert gen_tokens.dim() == 3
|
423 |
+
with torch.no_grad():
|
424 |
+
gen_audio = self.compression_model.decode(gen_tokens, None)
|
425 |
+
return gen_audio
|
audiocraft/modules/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Modules used for building the models."""
|
7 |
+
|
8 |
+
# flake8: noqa
|
9 |
+
from .conv import (
|
10 |
+
NormConv1d,
|
11 |
+
NormConv2d,
|
12 |
+
NormConvTranspose1d,
|
13 |
+
NormConvTranspose2d,
|
14 |
+
StreamableConv1d,
|
15 |
+
StreamableConvTranspose1d,
|
16 |
+
pad_for_conv1d,
|
17 |
+
pad1d,
|
18 |
+
unpad1d,
|
19 |
+
)
|
20 |
+
from .lstm import StreamableLSTM
|
21 |
+
from .seanet import SEANetEncoder, SEANetDecoder
|
22 |
+
from .transformer import StreamingTransformer
|
audiocraft/modules/activations.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import Tensor
|
10 |
+
from typing import Union, Callable
|
11 |
+
|
12 |
+
|
13 |
+
class CustomGLU(nn.Module):
|
14 |
+
"""Custom Gated Linear Unit activation.
|
15 |
+
Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
|
16 |
+
of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
|
17 |
+
function (i.e. sigmoid, swish, etc.).
|
18 |
+
|
19 |
+
Args:
|
20 |
+
activation (nn.Module): The custom activation to apply in the Gated Linear Unit
|
21 |
+
dim (int): the dimension on which to split the input. Default: -1
|
22 |
+
|
23 |
+
Shape:
|
24 |
+
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
25 |
+
dimensions
|
26 |
+
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
27 |
+
|
28 |
+
Examples::
|
29 |
+
>>> m = CustomGLU(nn.Sigmoid())
|
30 |
+
>>> input = torch.randn(4, 2)
|
31 |
+
>>> output = m(input)
|
32 |
+
"""
|
33 |
+
def __init__(self, activation: nn.Module, dim: int = -1):
|
34 |
+
super(CustomGLU, self).__init__()
|
35 |
+
self.dim = dim
|
36 |
+
self.activation = activation
|
37 |
+
|
38 |
+
def forward(self, x: Tensor):
|
39 |
+
assert x.shape[self.dim] % 2 == 0 # M = N / 2
|
40 |
+
a, b = torch.chunk(x, 2, dim=self.dim)
|
41 |
+
return a * self.activation(b)
|
42 |
+
|
43 |
+
|
44 |
+
class SwiGLU(CustomGLU):
|
45 |
+
"""SiLU Gated Linear Unit activation.
|
46 |
+
Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
|
47 |
+
the first half of the input matrices, :math:`b` is the second half.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
dim (int): the dimension on which to split the input. Default: -1
|
51 |
+
"""
|
52 |
+
def __init__(self, dim: int = -1):
|
53 |
+
super(SwiGLU, self).__init__(nn.SiLU(), dim)
|
54 |
+
|
55 |
+
|
56 |
+
class GeGLU(CustomGLU):
|
57 |
+
"""GeLU Gated Linear Unit activation.
|
58 |
+
Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
|
59 |
+
the first half of the input matrices, :math:`b` is the second half.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
dim (int): the dimension on which to split the input. Default: -1
|
63 |
+
"""
|
64 |
+
def __init__(self, dim: int = -1):
|
65 |
+
super(GeGLU, self).__init__(nn.GELU(), dim)
|
66 |
+
|
67 |
+
|
68 |
+
class ReGLU(CustomGLU):
|
69 |
+
"""ReLU Gated Linear Unit activation.
|
70 |
+
Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
|
71 |
+
the first half of the input matrices, :math:`b` is the second half.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
dim (int): the dimension on which to split the input. Default: -1
|
75 |
+
"""
|
76 |
+
def __init__(self, dim: int = -1):
|
77 |
+
super(ReGLU, self).__init__(nn.ReLU(), dim)
|
78 |
+
|
79 |
+
|
80 |
+
def get_activation_fn(
|
81 |
+
activation: Union[str, Callable[[Tensor], Tensor]]
|
82 |
+
) -> Union[str, Callable[[Tensor], Tensor]]:
|
83 |
+
"""Helper function to map an activation string to the activation class.
|
84 |
+
If the supplied activation is not a string that is recognized, the activation is passed back.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
activation (str, or Callable[[Tensor], Tensor]): Activation to check
|
88 |
+
"""
|
89 |
+
if isinstance(activation, str):
|
90 |
+
if activation == "reglu":
|
91 |
+
return ReGLU()
|
92 |
+
elif activation == "geglu":
|
93 |
+
return GeGLU()
|
94 |
+
elif activation == "swiglu":
|
95 |
+
return SwiGLU()
|
96 |
+
return activation
|
audiocraft/modules/chroma.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import typing as tp
|
7 |
+
|
8 |
+
from einops import rearrange
|
9 |
+
from librosa import filters
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torchaudio
|
14 |
+
|
15 |
+
|
16 |
+
class ChromaExtractor(nn.Module):
|
17 |
+
"""Chroma extraction and quantization.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
sample_rate (int): Sample rate for the chroma extraction.
|
21 |
+
n_chroma (int): Number of chroma bins for the chroma extraction.
|
22 |
+
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
|
23 |
+
nfft (int, optional): Number of FFT.
|
24 |
+
winlen (int, optional): Window length.
|
25 |
+
winhop (int, optional): Window hop size.
|
26 |
+
argmax (bool, optional): Whether to use argmax. Defaults to False.
|
27 |
+
norm (float, optional): Norm for chroma normalization. Defaults to inf.
|
28 |
+
"""
|
29 |
+
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
|
30 |
+
winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
|
31 |
+
norm: float = torch.inf):
|
32 |
+
super().__init__()
|
33 |
+
self.winlen = winlen or 2 ** radix2_exp
|
34 |
+
self.nfft = nfft or self.winlen
|
35 |
+
self.winhop = winhop or (self.winlen // 4)
|
36 |
+
self.sample_rate = sample_rate
|
37 |
+
self.n_chroma = n_chroma
|
38 |
+
self.norm = norm
|
39 |
+
self.argmax = argmax
|
40 |
+
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
|
41 |
+
n_chroma=self.n_chroma)), persistent=False)
|
42 |
+
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
|
43 |
+
hop_length=self.winhop, power=2, center=True,
|
44 |
+
pad=0, normalized=True)
|
45 |
+
|
46 |
+
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
47 |
+
T = wav.shape[-1]
|
48 |
+
# in case we are getting a wav that was dropped out (nullified)
|
49 |
+
# from the conditioner, make sure wav length is no less that nfft
|
50 |
+
if T < self.nfft:
|
51 |
+
pad = self.nfft - T
|
52 |
+
r = 0 if pad % 2 == 0 else 1
|
53 |
+
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
|
54 |
+
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
|
55 |
+
|
56 |
+
spec = self.spec(wav).squeeze(1)
|
57 |
+
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
|
58 |
+
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
|
59 |
+
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
|
60 |
+
|
61 |
+
if self.argmax:
|
62 |
+
idx = norm_chroma.argmax(-1, keepdim=True)
|
63 |
+
norm_chroma[:] = 0
|
64 |
+
norm_chroma.scatter_(dim=-1, index=idx, value=1)
|
65 |
+
|
66 |
+
return norm_chroma
|
audiocraft/modules/codebooks_patterns.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import namedtuple
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from functools import lru_cache
|
10 |
+
import logging
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
from abc import ABC, abstractmethod
|
14 |
+
import torch
|
15 |
+
|
16 |
+
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
|
17 |
+
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Pattern:
|
23 |
+
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
24 |
+
|
25 |
+
The codebook pattern consists in a layout, defining for each sequence step
|
26 |
+
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
27 |
+
The first item of the pattern is always an empty list in order to properly insert a special token
|
28 |
+
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
|
29 |
+
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
30 |
+
|
31 |
+
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
32 |
+
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
33 |
+
to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
|
34 |
+
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
35 |
+
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
36 |
+
is returned along with a mask indicating valid tokens.
|
37 |
+
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
38 |
+
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
39 |
+
to fill and specify invalid positions if needed.
|
40 |
+
See the dedicated methods for more details.
|
41 |
+
"""
|
42 |
+
# Pattern layout, for each sequence step, we have a list of coordinates
|
43 |
+
# corresponding to the original codebook timestep and position.
|
44 |
+
# The first list is always an empty list in order to properly insert
|
45 |
+
# a special token to start with.
|
46 |
+
layout: PatternLayout
|
47 |
+
timesteps: int
|
48 |
+
n_q: int
|
49 |
+
|
50 |
+
def __post_init__(self):
|
51 |
+
assert len(self.layout) > 0
|
52 |
+
assert self.layout[0] == []
|
53 |
+
self._validate_layout()
|
54 |
+
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
55 |
+
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
56 |
+
logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
57 |
+
|
58 |
+
def _validate_layout(self):
|
59 |
+
"""Runs checks on the layout to ensure a valid pattern is defined.
|
60 |
+
A pattern is considered invalid if:
|
61 |
+
- Multiple timesteps for a same codebook are defined in the same sequence step
|
62 |
+
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
|
63 |
+
(this would mean that we have future timesteps before past timesteps).
|
64 |
+
"""
|
65 |
+
q_timesteps = {q: 0 for q in range(self.n_q)}
|
66 |
+
for s, seq_coords in enumerate(self.layout):
|
67 |
+
if len(seq_coords) > 0:
|
68 |
+
qs = set()
|
69 |
+
for coord in seq_coords:
|
70 |
+
qs.add(coord.q)
|
71 |
+
last_q_timestep = q_timesteps[coord.q]
|
72 |
+
assert coord.t >= last_q_timestep, \
|
73 |
+
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
|
74 |
+
q_timesteps[coord.q] = coord.t
|
75 |
+
# each sequence step contains at max 1 coordinate per codebook
|
76 |
+
assert len(qs) == len(seq_coords), \
|
77 |
+
f"Multiple entries for a same codebook are found at step {s}"
|
78 |
+
|
79 |
+
@property
|
80 |
+
def num_sequence_steps(self):
|
81 |
+
return len(self.layout) - 1
|
82 |
+
|
83 |
+
@property
|
84 |
+
def max_delay(self):
|
85 |
+
max_t_in_seq_coords = 0
|
86 |
+
for seq_coords in self.layout[1:]:
|
87 |
+
for coords in seq_coords:
|
88 |
+
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
89 |
+
return max_t_in_seq_coords - self.timesteps
|
90 |
+
|
91 |
+
@property
|
92 |
+
def valid_layout(self):
|
93 |
+
valid_step = len(self.layout) - self.max_delay
|
94 |
+
return self.layout[:valid_step]
|
95 |
+
|
96 |
+
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
97 |
+
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
98 |
+
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
99 |
+
and the actual codebook coordinates.
|
100 |
+
"""
|
101 |
+
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
|
102 |
+
if q is not None:
|
103 |
+
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
|
104 |
+
coords = []
|
105 |
+
for s, seq_codes in enumerate(self.layout):
|
106 |
+
for code in seq_codes:
|
107 |
+
if code.t == t and (q is None or code.q == q):
|
108 |
+
coords.append((s, code))
|
109 |
+
return coords
|
110 |
+
|
111 |
+
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
|
112 |
+
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
113 |
+
|
114 |
+
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
|
115 |
+
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
116 |
+
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
117 |
+
|
118 |
+
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
|
119 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
120 |
+
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
timesteps (int): Maximum number of timesteps steps to consider.
|
124 |
+
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
125 |
+
device (torch.device or str): Device for created tensors.
|
126 |
+
Returns:
|
127 |
+
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
128 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
129 |
+
"""
|
130 |
+
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
131 |
+
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
|
132 |
+
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
133 |
+
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
134 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
135 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
136 |
+
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
|
137 |
+
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
|
138 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
139 |
+
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
|
140 |
+
# which will correspond to the index: n_q * timesteps
|
141 |
+
indexes[:] = n_q * timesteps
|
142 |
+
# iterate over the pattern and fill scattered indexes and mask
|
143 |
+
for s, sequence_coords in enumerate(ref_layout):
|
144 |
+
for coords in sequence_coords:
|
145 |
+
if coords.t < timesteps:
|
146 |
+
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
147 |
+
mask[coords.q, s] = 1
|
148 |
+
indexes = torch.from_numpy(indexes).to(device)
|
149 |
+
mask = torch.from_numpy(mask).to(device)
|
150 |
+
return indexes, mask
|
151 |
+
|
152 |
+
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
153 |
+
"""Build sequence corresponding to the pattern from the input tensor z.
|
154 |
+
The sequence is built using up to sequence_steps if specified, and non-pattern
|
155 |
+
coordinates are filled with the special token.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
|
159 |
+
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
|
160 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
161 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
162 |
+
Returns:
|
163 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
|
164 |
+
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
|
165 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
|
166 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
|
167 |
+
"""
|
168 |
+
B, K, T = z.shape
|
169 |
+
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
170 |
+
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
171 |
+
)
|
172 |
+
z = z.view(B, -1)
|
173 |
+
# we append the special token as the last index of our flattened z tensor
|
174 |
+
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
175 |
+
values = z[:, indexes.view(-1)]
|
176 |
+
values = values.view(B, K, indexes.shape[-1])
|
177 |
+
return values, indexes, mask
|
178 |
+
|
179 |
+
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
|
180 |
+
keep_only_valid_steps: bool = False,
|
181 |
+
is_model_output: bool = False,
|
182 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
183 |
+
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
184 |
+
from interleaving pattern.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
sequence_steps (int): Sequence steps.
|
188 |
+
n_q (int): Number of codebooks.
|
189 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
190 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
191 |
+
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
192 |
+
device (torch.device or str): Device for created tensors.
|
193 |
+
Returns:
|
194 |
+
indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
|
195 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
196 |
+
"""
|
197 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
198 |
+
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
|
199 |
+
timesteps = self.timesteps
|
200 |
+
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
201 |
+
assert sequence_steps <= len(ref_layout), \
|
202 |
+
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
203 |
+
|
204 |
+
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
205 |
+
if is_model_output:
|
206 |
+
ref_layout = ref_layout[1:]
|
207 |
+
|
208 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
209 |
+
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
|
210 |
+
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
|
211 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
212 |
+
indexes[:] = n_q * sequence_steps
|
213 |
+
for s, sequence_codes in enumerate(ref_layout):
|
214 |
+
if s < sequence_steps:
|
215 |
+
for code in sequence_codes:
|
216 |
+
if code.t < timesteps:
|
217 |
+
indexes[code.q, code.t] = s + code.q * sequence_steps
|
218 |
+
mask[code.q, code.t] = 1
|
219 |
+
indexes = torch.from_numpy(indexes).to(device)
|
220 |
+
mask = torch.from_numpy(mask).to(device)
|
221 |
+
return indexes, mask
|
222 |
+
|
223 |
+
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
224 |
+
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
|
225 |
+
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
|
226 |
+
are filled with the special token.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
230 |
+
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
231 |
+
Returns:
|
232 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
233 |
+
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
|
234 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
235 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
236 |
+
"""
|
237 |
+
B, K, S = s.shape
|
238 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
239 |
+
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
240 |
+
)
|
241 |
+
s = s.view(B, -1)
|
242 |
+
# we append the special token as the last index of our flattened z tensor
|
243 |
+
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
244 |
+
values = s[:, indexes.view(-1)]
|
245 |
+
values = values.view(B, K, indexes.shape[-1])
|
246 |
+
return values, indexes, mask
|
247 |
+
|
248 |
+
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
|
249 |
+
"""Revert model logits obtained on a sequence built from the pattern
|
250 |
+
back to a tensor matching the original sequence.
|
251 |
+
|
252 |
+
This method is similar to ``revert_pattern_sequence`` with the following specificities:
|
253 |
+
1. It is designed to work with the extra cardinality dimension
|
254 |
+
2. We return the logits for the first sequence item that matches the special_token and
|
255 |
+
which matching target in the original sequence is the first item of the sequence,
|
256 |
+
while we skip the last logits as there is no matching target
|
257 |
+
"""
|
258 |
+
B, card, K, S = logits.shape
|
259 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
260 |
+
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
|
261 |
+
)
|
262 |
+
logits = logits.reshape(B, card, -1)
|
263 |
+
# we append the special token as the last index of our flattened z tensor
|
264 |
+
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
|
265 |
+
values = logits[:, :, indexes.view(-1)]
|
266 |
+
values = values.view(B, card, K, indexes.shape[-1])
|
267 |
+
return values, indexes, mask
|
268 |
+
|
269 |
+
|
270 |
+
class CodebooksPatternProvider(ABC):
|
271 |
+
"""Abstraction around providing pattern for interleaving codebooks.
|
272 |
+
|
273 |
+
The CodebooksPatternProvider abstraction allows to implement various strategies to
|
274 |
+
define interleaving pattern of sequences composed of multiple codebooks. For a given
|
275 |
+
number of codebooks `n_q`, the pattern provider can generate a specified pattern
|
276 |
+
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
|
277 |
+
can be used to construct a new sequence from the original codes respecting the specified
|
278 |
+
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
|
279 |
+
being a tuple with the original timestep and codebook to build the new sequence.
|
280 |
+
Note that all patterns must start with an empty list that is then used to insert a first
|
281 |
+
sequence step of special tokens in the newly generated sequence.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
n_q (int): number of codebooks.
|
285 |
+
cached (bool): if True, patterns for a given length are cached. In general
|
286 |
+
that should be true for efficiency reason to avoid synchronization points.
|
287 |
+
"""
|
288 |
+
def __init__(self, n_q: int, cached: bool = True):
|
289 |
+
assert n_q > 0
|
290 |
+
self.n_q = n_q
|
291 |
+
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
|
292 |
+
|
293 |
+
@abstractmethod
|
294 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
295 |
+
"""Builds pattern with specific interleaving between codebooks.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
timesteps (int): Total number of timesteps.
|
299 |
+
"""
|
300 |
+
raise NotImplementedError()
|
301 |
+
|
302 |
+
|
303 |
+
class DelayedPatternProvider(CodebooksPatternProvider):
|
304 |
+
"""Provider for delayed pattern across delayed codebooks.
|
305 |
+
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
306 |
+
from different timesteps.
|
307 |
+
|
308 |
+
Example:
|
309 |
+
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
|
310 |
+
[[1, 2, 3, 4],
|
311 |
+
[1, 2, 3, 4],
|
312 |
+
[1, 2, 3, 4]]
|
313 |
+
The resulting sequence obtained from the returned pattern is:
|
314 |
+
[[S, 1, 2, 3, 4],
|
315 |
+
[S, S, 1, 2, 3],
|
316 |
+
[S, S, S, 1, 2]]
|
317 |
+
(with S being a special token)
|
318 |
+
|
319 |
+
Args:
|
320 |
+
n_q (int): Number of codebooks.
|
321 |
+
delays (list of int, optional): Delay for each of the codebooks.
|
322 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
323 |
+
flatten_first (int): Flatten the first N timesteps.
|
324 |
+
empty_initial (int): Prepend with N empty list of coordinates.
|
325 |
+
"""
|
326 |
+
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
|
327 |
+
flatten_first: int = 0, empty_initial: int = 0):
|
328 |
+
super().__init__(n_q)
|
329 |
+
if delays is None:
|
330 |
+
delays = list(range(n_q))
|
331 |
+
self.delays = delays
|
332 |
+
self.flatten_first = flatten_first
|
333 |
+
self.empty_initial = empty_initial
|
334 |
+
assert len(self.delays) == self.n_q
|
335 |
+
assert sorted(self.delays) == self.delays
|
336 |
+
|
337 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
338 |
+
out: PatternLayout = [[]]
|
339 |
+
max_delay = max(self.delays)
|
340 |
+
if self.empty_initial:
|
341 |
+
out += [[] for _ in range(self.empty_initial)]
|
342 |
+
if self.flatten_first:
|
343 |
+
for t in range(min(timesteps, self.flatten_first)):
|
344 |
+
for q in range(self.n_q):
|
345 |
+
out.append([LayoutCoord(t, q)])
|
346 |
+
for t in range(self.flatten_first, timesteps + max_delay):
|
347 |
+
v = []
|
348 |
+
for q, delay in enumerate(self.delays):
|
349 |
+
t_for_q = t - delay
|
350 |
+
if t_for_q >= self.flatten_first:
|
351 |
+
v.append(LayoutCoord(t_for_q, q))
|
352 |
+
out.append(v)
|
353 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
354 |
+
|
355 |
+
|
356 |
+
class ParallelPatternProvider(DelayedPatternProvider):
|
357 |
+
"""Provider for parallel pattern across codebooks.
|
358 |
+
This pattern provider is a special case of the delayed pattern with actually no delay,
|
359 |
+
hence delays=repeat(0, n_q).
|
360 |
+
|
361 |
+
Args:
|
362 |
+
n_q (int): Number of codebooks.
|
363 |
+
"""
|
364 |
+
def __init__(self, n_q: int):
|
365 |
+
super().__init__(n_q, [0] * n_q)
|
366 |
+
|
367 |
+
|
368 |
+
class UnrolledPatternProvider(CodebooksPatternProvider):
|
369 |
+
"""Provider for unrolling codebooks pattern.
|
370 |
+
This pattern provider enables to represent the codebook flattened completely or only to some extend
|
371 |
+
while also specifying a given delay between the flattened codebooks representation, allowing to
|
372 |
+
unroll the codebooks in the sequence.
|
373 |
+
|
374 |
+
Example:
|
375 |
+
1. Flattening of the codebooks.
|
376 |
+
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
|
377 |
+
taking n_q = 3 and timesteps = 4:
|
378 |
+
[[1, 2, 3, 4],
|
379 |
+
[1, 2, 3, 4],
|
380 |
+
[1, 2, 3, 4]]
|
381 |
+
will result into:
|
382 |
+
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
|
383 |
+
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
384 |
+
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
385 |
+
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
|
386 |
+
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
|
387 |
+
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
|
388 |
+
[[1, 2, 3, 4],
|
389 |
+
[1, 2, 3, 4],
|
390 |
+
[1, 2, 3, 4]]
|
391 |
+
will result into:
|
392 |
+
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
393 |
+
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
394 |
+
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
395 |
+
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
|
396 |
+
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
|
397 |
+
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
|
398 |
+
and delays = [0, 3, 3]:
|
399 |
+
[[1, 2, 3, 4],
|
400 |
+
[1, 2, 3, 4],
|
401 |
+
[1, 2, 3, 4]]
|
402 |
+
will result into:
|
403 |
+
[[S, S, S, 1, S, 2, S, 3, S, 4],
|
404 |
+
[S, S, S, 1, S, 2, S, 3, S, 4],
|
405 |
+
[1, 2, 3, S, 4, S, 5, S, 6, S]]
|
406 |
+
|
407 |
+
Args:
|
408 |
+
n_q (int): Number of codebooks.
|
409 |
+
flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
|
410 |
+
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
411 |
+
have n_q extra steps for each timestep.
|
412 |
+
delays (list of int, optional): Delay for each of the codebooks. If not defined,
|
413 |
+
no delay is added and therefore will default to [0] * ``n_q``.
|
414 |
+
Note that two codebooks that will be flattened to the same inner step
|
415 |
+
should have the same delay, otherwise the pattern is considered as invalid.
|
416 |
+
"""
|
417 |
+
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
|
418 |
+
|
419 |
+
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
|
420 |
+
delays: tp.Optional[tp.List[int]] = None):
|
421 |
+
super().__init__(n_q)
|
422 |
+
if flattening is None:
|
423 |
+
flattening = list(range(n_q))
|
424 |
+
if delays is None:
|
425 |
+
delays = [0] * n_q
|
426 |
+
assert len(flattening) == n_q
|
427 |
+
assert len(delays) == n_q
|
428 |
+
assert sorted(flattening) == flattening
|
429 |
+
assert sorted(delays) == delays
|
430 |
+
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
|
431 |
+
self.max_delay = max(delays)
|
432 |
+
|
433 |
+
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
|
434 |
+
"""Build a flattened codebooks representation as a dictionary of inner step
|
435 |
+
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
|
436 |
+
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
|
437 |
+
"""
|
438 |
+
flattened_codebooks: dict = {}
|
439 |
+
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
|
440 |
+
if inner_step not in flattened_codebooks:
|
441 |
+
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
|
442 |
+
else:
|
443 |
+
flat_codebook = flattened_codebooks[inner_step]
|
444 |
+
assert flat_codebook.delay == delay, (
|
445 |
+
"Delay and flattening between codebooks is inconsistent: ",
|
446 |
+
"two codebooks flattened to the same position should have the same delay."
|
447 |
+
)
|
448 |
+
flat_codebook.codebooks.append(q)
|
449 |
+
flattened_codebooks[inner_step] = flat_codebook
|
450 |
+
return flattened_codebooks
|
451 |
+
|
452 |
+
@property
|
453 |
+
def _num_inner_steps(self):
|
454 |
+
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
|
455 |
+
"""
|
456 |
+
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
|
457 |
+
|
458 |
+
def num_virtual_steps(self, timesteps: int) -> int:
|
459 |
+
return timesteps * self._num_inner_steps + 1
|
460 |
+
|
461 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
462 |
+
"""Builds pattern for delay across codebooks.
|
463 |
+
|
464 |
+
Args:
|
465 |
+
timesteps (int): Total number of timesteps.
|
466 |
+
"""
|
467 |
+
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
468 |
+
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
469 |
+
indexed_out: list = [(-1, [])]
|
470 |
+
max_timesteps = timesteps + self.max_delay
|
471 |
+
for t in range(max_timesteps):
|
472 |
+
# for each timestep, we unroll the flattened codebooks,
|
473 |
+
# emitting the sequence step with the corresponding delay
|
474 |
+
for step in range(self._num_inner_steps):
|
475 |
+
if step in self._flattened_codebooks:
|
476 |
+
# we have codebooks at this virtual step to emit
|
477 |
+
step_codebooks = self._flattened_codebooks[step]
|
478 |
+
t_for_q = t + step_codebooks.delay
|
479 |
+
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
|
480 |
+
if t_for_q < max_timesteps and t < max_timesteps:
|
481 |
+
indexed_out.append((t_for_q, coords))
|
482 |
+
else:
|
483 |
+
# there is no codebook in this virtual step so we emit an empty list
|
484 |
+
indexed_out.append((t, []))
|
485 |
+
out = [coords for _, coords in sorted(indexed_out)]
|
486 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
487 |
+
|
488 |
+
|
489 |
+
class CoarseFirstPattern(CodebooksPatternProvider):
|
490 |
+
"""First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
|
491 |
+
potentially with delays.
|
492 |
+
|
493 |
+
..Warning:: You must always generate the full training duration at test time, for instance,
|
494 |
+
30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
|
495 |
+
location. This is due to the non causality of the remaining codebooks with respect to
|
496 |
+
the first ones.
|
497 |
+
|
498 |
+
Args:
|
499 |
+
n_q (int): Number of codebooks.
|
500 |
+
delays (list of int, optional): Delay for each of the codebooks.
|
501 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
502 |
+
"""
|
503 |
+
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
504 |
+
super().__init__(n_q)
|
505 |
+
if delays is None:
|
506 |
+
delays = [0] * (n_q - 1)
|
507 |
+
self.delays = delays
|
508 |
+
assert len(self.delays) == self.n_q - 1
|
509 |
+
assert sorted(self.delays) == self.delays
|
510 |
+
|
511 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
512 |
+
out: PatternLayout = [[]]
|
513 |
+
for t in range(timesteps):
|
514 |
+
out.append([LayoutCoord(t, 0)])
|
515 |
+
max_delay = max(self.delays)
|
516 |
+
for t in range(timesteps + max_delay):
|
517 |
+
v = []
|
518 |
+
for q, delay in enumerate(self.delays):
|
519 |
+
t_for_q = t - delay
|
520 |
+
if t_for_q >= 0:
|
521 |
+
v.append(LayoutCoord(t_for_q, q + 1))
|
522 |
+
out.append(v)
|
523 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
524 |
+
|
525 |
+
|
526 |
+
class MusicLMPattern(CodebooksPatternProvider):
|
527 |
+
"""Almost MusicLM style pattern. This is equivalent to full flattening
|
528 |
+
but in a different order.
|
529 |
+
|
530 |
+
Args:
|
531 |
+
n_q (int): Number of codebooks.
|
532 |
+
group_by (int): Number of codebooks to group together.
|
533 |
+
"""
|
534 |
+
def __init__(self, n_q: int, group_by: int = 2):
|
535 |
+
super().__init__(n_q)
|
536 |
+
self.group_by = group_by
|
537 |
+
|
538 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
539 |
+
out: PatternLayout = [[]]
|
540 |
+
for offset in range(0, self.n_q, self.group_by):
|
541 |
+
for t in range(timesteps):
|
542 |
+
for q in range(offset, offset + self.group_by):
|
543 |
+
out.append([LayoutCoord(t, q)])
|
544 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
audiocraft/modules/conditioners.py
ADDED
@@ -0,0 +1,1357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import defaultdict
|
8 |
+
from copy import deepcopy
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from itertools import chain
|
11 |
+
import logging
|
12 |
+
import math
|
13 |
+
from pathlib import Path
|
14 |
+
import random
|
15 |
+
import re
|
16 |
+
import typing as tp
|
17 |
+
import warnings
|
18 |
+
|
19 |
+
import einops
|
20 |
+
from num2words import num2words
|
21 |
+
import spacy
|
22 |
+
from transformers import RobertaTokenizer # type: ignore
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
from torch.nn.utils.rnn import pad_sequence
|
27 |
+
|
28 |
+
from .chroma import ChromaExtractor
|
29 |
+
from .streaming import StreamingModule
|
30 |
+
from .transformer import create_sin_embedding
|
31 |
+
from ..data.audio import audio_read
|
32 |
+
from ..data.audio_dataset import SegmentInfo
|
33 |
+
from ..data.audio_utils import convert_audio
|
34 |
+
from ..environment import AudioCraftEnvironment
|
35 |
+
from ..quantization import ResidualVectorQuantizer
|
36 |
+
from ..utils.autocast import TorchAutocast
|
37 |
+
from ..utils.cache import EmbeddingCache
|
38 |
+
from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
44 |
+
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
45 |
+
|
46 |
+
|
47 |
+
class WavCondition(tp.NamedTuple):
|
48 |
+
wav: torch.Tensor
|
49 |
+
length: torch.Tensor
|
50 |
+
sample_rate: tp.List[int]
|
51 |
+
path: tp.List[tp.Optional[str]] = []
|
52 |
+
seek_time: tp.List[tp.Optional[float]] = []
|
53 |
+
|
54 |
+
|
55 |
+
class JointEmbedCondition(tp.NamedTuple):
|
56 |
+
wav: torch.Tensor
|
57 |
+
text: tp.List[tp.Optional[str]]
|
58 |
+
length: torch.Tensor
|
59 |
+
sample_rate: tp.List[int]
|
60 |
+
path: tp.List[tp.Optional[str]] = []
|
61 |
+
seek_time: tp.List[tp.Optional[float]] = []
|
62 |
+
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class ConditioningAttributes:
|
66 |
+
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
67 |
+
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
68 |
+
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
69 |
+
|
70 |
+
def __getitem__(self, item):
|
71 |
+
return getattr(self, item)
|
72 |
+
|
73 |
+
@property
|
74 |
+
def text_attributes(self):
|
75 |
+
return self.text.keys()
|
76 |
+
|
77 |
+
@property
|
78 |
+
def wav_attributes(self):
|
79 |
+
return self.wav.keys()
|
80 |
+
|
81 |
+
@property
|
82 |
+
def joint_embed_attributes(self):
|
83 |
+
return self.joint_embed.keys()
|
84 |
+
|
85 |
+
@property
|
86 |
+
def attributes(self):
|
87 |
+
return {
|
88 |
+
"text": self.text_attributes,
|
89 |
+
"wav": self.wav_attributes,
|
90 |
+
"joint_embed": self.joint_embed_attributes,
|
91 |
+
}
|
92 |
+
|
93 |
+
def to_flat_dict(self):
|
94 |
+
return {
|
95 |
+
**{f"text.{k}": v for k, v in self.text.items()},
|
96 |
+
**{f"wav.{k}": v for k, v in self.wav.items()},
|
97 |
+
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
|
98 |
+
}
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def from_flat_dict(cls, x):
|
102 |
+
out = cls()
|
103 |
+
for k, v in x.items():
|
104 |
+
kind, att = k.split(".")
|
105 |
+
out[kind][att] = v
|
106 |
+
return out
|
107 |
+
|
108 |
+
|
109 |
+
class SegmentWithAttributes(SegmentInfo):
|
110 |
+
"""Base class for all dataclasses that are used for conditioning.
|
111 |
+
All child classes should implement `to_condition_attributes` that converts
|
112 |
+
the existing attributes to a dataclass of type ConditioningAttributes.
|
113 |
+
"""
|
114 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
115 |
+
raise NotImplementedError()
|
116 |
+
|
117 |
+
|
118 |
+
def nullify_condition(condition: ConditionType, dim: int = 1):
|
119 |
+
"""Transform an input condition to a null condition.
|
120 |
+
The way it is done by converting it to a single zero vector similarly
|
121 |
+
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
|
125 |
+
dim (int): The dimension that will be truncated (should be the time dimension)
|
126 |
+
WARNING!: dim should not be the batch dimension!
|
127 |
+
Returns:
|
128 |
+
ConditionType: A tuple of null condition and mask
|
129 |
+
"""
|
130 |
+
assert dim != 0, "dim cannot be the batch dimension!"
|
131 |
+
assert isinstance(condition, tuple) and \
|
132 |
+
isinstance(condition[0], torch.Tensor) and \
|
133 |
+
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
|
134 |
+
cond, mask = condition
|
135 |
+
B = cond.shape[0]
|
136 |
+
last_dim = cond.dim() - 1
|
137 |
+
out = cond.transpose(dim, last_dim)
|
138 |
+
out = 0. * out[..., :1]
|
139 |
+
out = out.transpose(dim, last_dim)
|
140 |
+
mask = torch.zeros((B, 1), device=out.device).int()
|
141 |
+
assert cond.dim() == out.dim()
|
142 |
+
return out, mask
|
143 |
+
|
144 |
+
|
145 |
+
def nullify_wav(cond: WavCondition) -> WavCondition:
|
146 |
+
"""Transform a WavCondition to a nullified WavCondition.
|
147 |
+
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
|
151 |
+
Returns:
|
152 |
+
WavCondition: Nullified wav condition.
|
153 |
+
"""
|
154 |
+
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
|
155 |
+
return WavCondition(
|
156 |
+
wav=null_wav,
|
157 |
+
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
|
158 |
+
sample_rate=cond.sample_rate,
|
159 |
+
path=[None] * cond.wav.shape[0],
|
160 |
+
seek_time=[None] * cond.wav.shape[0],
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
|
165 |
+
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
|
166 |
+
and replacing metadata by dummy attributes.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
|
170 |
+
"""
|
171 |
+
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
|
172 |
+
return JointEmbedCondition(
|
173 |
+
wav=null_wav, text=[None] * len(embed.text),
|
174 |
+
length=torch.LongTensor([0]).to(embed.wav.device),
|
175 |
+
sample_rate=embed.sample_rate,
|
176 |
+
path=[None] * embed.wav.shape[0],
|
177 |
+
seek_time=[0] * embed.wav.shape[0],
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
class Tokenizer:
|
182 |
+
"""Base tokenizer implementation
|
183 |
+
(in case we want to introduce more advances tokenizers in the future).
|
184 |
+
"""
|
185 |
+
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
186 |
+
raise NotImplementedError()
|
187 |
+
|
188 |
+
|
189 |
+
class WhiteSpaceTokenizer(Tokenizer):
|
190 |
+
"""This tokenizer should be used for natural language descriptions.
|
191 |
+
For example:
|
192 |
+
["he didn't, know he's going home.", 'shorter sentence'] =>
|
193 |
+
[[78, 62, 31, 4, 78, 25, 19, 34],
|
194 |
+
[59, 77, 0, 0, 0, 0, 0, 0]]
|
195 |
+
"""
|
196 |
+
PUNCTUATION = "?:!.,;"
|
197 |
+
|
198 |
+
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
|
199 |
+
lemma: bool = True, stopwords: bool = True) -> None:
|
200 |
+
self.n_bins = n_bins
|
201 |
+
self.pad_idx = pad_idx
|
202 |
+
self.lemma = lemma
|
203 |
+
self.stopwords = stopwords
|
204 |
+
try:
|
205 |
+
self.nlp = spacy.load(language)
|
206 |
+
except IOError:
|
207 |
+
spacy.cli.download(language) # type: ignore
|
208 |
+
self.nlp = spacy.load(language)
|
209 |
+
|
210 |
+
@tp.no_type_check
|
211 |
+
def __call__(self, texts: tp.List[tp.Optional[str]],
|
212 |
+
return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
213 |
+
"""Take a list of strings and convert them to a tensor of indices.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
texts (list[str]): List of strings.
|
217 |
+
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
|
218 |
+
Returns:
|
219 |
+
tuple[torch.Tensor, torch.Tensor]:
|
220 |
+
- Indices of words in the LUT.
|
221 |
+
- And a mask indicating where the padding tokens are
|
222 |
+
"""
|
223 |
+
output, lengths = [], []
|
224 |
+
texts = deepcopy(texts)
|
225 |
+
for i, text in enumerate(texts):
|
226 |
+
# if current sample doesn't have a certain attribute, replace with pad token
|
227 |
+
if text is None:
|
228 |
+
output.append(torch.Tensor([self.pad_idx]))
|
229 |
+
lengths.append(0)
|
230 |
+
continue
|
231 |
+
|
232 |
+
# convert numbers to words
|
233 |
+
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
|
234 |
+
# normalize text
|
235 |
+
text = self.nlp(text) # type: ignore
|
236 |
+
# remove stopwords
|
237 |
+
if self.stopwords:
|
238 |
+
text = [w for w in text if not w.is_stop] # type: ignore
|
239 |
+
# remove punctuation
|
240 |
+
text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
|
241 |
+
# lemmatize if needed
|
242 |
+
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
|
243 |
+
|
244 |
+
texts[i] = " ".join(text)
|
245 |
+
lengths.append(len(text))
|
246 |
+
# convert to tensor
|
247 |
+
tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
|
248 |
+
output.append(tokens)
|
249 |
+
|
250 |
+
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
251 |
+
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
|
252 |
+
if return_text:
|
253 |
+
return padded_output, mask, texts # type: ignore
|
254 |
+
return padded_output, mask
|
255 |
+
|
256 |
+
|
257 |
+
class NoopTokenizer(Tokenizer):
|
258 |
+
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
|
259 |
+
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
|
260 |
+
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
|
261 |
+
split it to ["Jeff", "Buckley"] and return an index per word.
|
262 |
+
|
263 |
+
For example:
|
264 |
+
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
|
265 |
+
["Metal", "Rock", "Classical"] => [0, 223, 51]
|
266 |
+
"""
|
267 |
+
def __init__(self, n_bins: int, pad_idx: int = 0):
|
268 |
+
self.n_bins = n_bins
|
269 |
+
self.pad_idx = pad_idx
|
270 |
+
|
271 |
+
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
272 |
+
output, lengths = [], []
|
273 |
+
for text in texts:
|
274 |
+
# if current sample doesn't have a certain attribute, replace with pad token
|
275 |
+
if text is None:
|
276 |
+
output.append(self.pad_idx)
|
277 |
+
lengths.append(0)
|
278 |
+
else:
|
279 |
+
output.append(hash_trick(text, self.n_bins))
|
280 |
+
lengths.append(1)
|
281 |
+
|
282 |
+
tokens = torch.LongTensor(output).unsqueeze(1)
|
283 |
+
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
284 |
+
return tokens, mask
|
285 |
+
|
286 |
+
|
287 |
+
class BaseConditioner(nn.Module):
|
288 |
+
"""Base model for all conditioner modules.
|
289 |
+
We allow the output dim to be different than the hidden dim for two reasons:
|
290 |
+
1) keep our LUTs small when the vocab is large;
|
291 |
+
2) make all condition dims consistent.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
dim (int): Hidden dim of the model.
|
295 |
+
output_dim (int): Output dim of the conditioner.
|
296 |
+
"""
|
297 |
+
def __init__(self, dim: int, output_dim: int):
|
298 |
+
super().__init__()
|
299 |
+
self.dim = dim
|
300 |
+
self.output_dim = output_dim
|
301 |
+
self.output_proj = nn.Linear(dim, output_dim)
|
302 |
+
|
303 |
+
def tokenize(self, *args, **kwargs) -> tp.Any:
|
304 |
+
"""Should be any part of the processing that will lead to a synchronization
|
305 |
+
point, e.g. BPE tokenization with transfer to the GPU.
|
306 |
+
|
307 |
+
The returned value will be saved and return later when calling forward().
|
308 |
+
"""
|
309 |
+
raise NotImplementedError()
|
310 |
+
|
311 |
+
def forward(self, inputs: tp.Any) -> ConditionType:
|
312 |
+
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
313 |
+
Outputs a ConditionType, after the input data was embedded as a dense vector.
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
ConditionType:
|
317 |
+
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
|
318 |
+
output embedding and D is the dimension of the embedding.
|
319 |
+
- And a mask indicating where the padding tokens.
|
320 |
+
"""
|
321 |
+
raise NotImplementedError()
|
322 |
+
|
323 |
+
|
324 |
+
class TextConditioner(BaseConditioner):
|
325 |
+
...
|
326 |
+
class VideoConditioner():
|
327 |
+
...
|
328 |
+
|
329 |
+
class LUTConditioner(TextConditioner):
|
330 |
+
"""Lookup table TextConditioner.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
n_bins (int): Number of bins.
|
334 |
+
dim (int): Hidden dim of the model (text-encoder/LUT).
|
335 |
+
output_dim (int): Output dim of the conditioner.
|
336 |
+
tokenizer (str): Name of the tokenizer.
|
337 |
+
pad_idx (int, optional): Index for padding token. Defaults to 0.
|
338 |
+
"""
|
339 |
+
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
|
340 |
+
super().__init__(dim, output_dim)
|
341 |
+
self.embed = nn.Embedding(n_bins, dim)
|
342 |
+
self.tokenizer: Tokenizer
|
343 |
+
if tokenizer == 'whitespace':
|
344 |
+
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
|
345 |
+
elif tokenizer == 'noop':
|
346 |
+
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
|
347 |
+
else:
|
348 |
+
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
|
349 |
+
|
350 |
+
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
351 |
+
device = self.embed.weight.device
|
352 |
+
tokens, mask = self.tokenizer(x)
|
353 |
+
tokens, mask = tokens.to(device), mask.to(device)
|
354 |
+
return tokens, mask
|
355 |
+
|
356 |
+
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
|
357 |
+
tokens, mask = inputs
|
358 |
+
embeds = self.embed(tokens)
|
359 |
+
embeds = self.output_proj(embeds)
|
360 |
+
embeds = (embeds * mask.unsqueeze(-1))
|
361 |
+
return embeds, mask
|
362 |
+
|
363 |
+
|
364 |
+
class T5Conditioner(TextConditioner):
|
365 |
+
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
366 |
+
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
367 |
+
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
368 |
+
MODELS_DIMS = {
|
369 |
+
"t5-small": 512,
|
370 |
+
"t5-base": 768,
|
371 |
+
"t5-large": 1024,
|
372 |
+
"t5-3b": 1024,
|
373 |
+
"t5-11b": 1024,
|
374 |
+
"google/flan-t5-small": 512,
|
375 |
+
"google/flan-t5-base": 768,
|
376 |
+
"google/flan-t5-large": 1024,
|
377 |
+
"google/flan-t5-3b": 1024,
|
378 |
+
"google/flan-t5-11b": 1024,
|
379 |
+
}
|
380 |
+
def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
|
381 |
+
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
|
382 |
+
normalize_text: bool = False):
|
383 |
+
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
384 |
+
super().__init__(self.MODELS_DIMS[name], output_dim)
|
385 |
+
|
386 |
+
|
387 |
+
def forward():
|
388 |
+
embeds = torch.zeros([1,1,1])
|
389 |
+
mask = torch.zeros([1,1,1])
|
390 |
+
return embeds, mask
|
391 |
+
|
392 |
+
|
393 |
+
class WaveformConditioner(BaseConditioner):
|
394 |
+
"""Base class for all conditioners that take a waveform as input.
|
395 |
+
Classes that inherit must implement `_get_wav_embedding` that outputs
|
396 |
+
a continuous tensor, and `_downsampling_factor` that returns the down-sampling
|
397 |
+
factor of the embedding model.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
dim (int): The internal representation dimension.
|
401 |
+
output_dim (int): Output dimension.
|
402 |
+
device (tp.Union[torch.device, str]): Device.
|
403 |
+
"""
|
404 |
+
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
|
405 |
+
super().__init__(dim, output_dim)
|
406 |
+
self.device = device
|
407 |
+
# if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
|
408 |
+
self._use_masking = True
|
409 |
+
|
410 |
+
def tokenize(self, x: WavCondition) -> WavCondition:
|
411 |
+
wav, length, sample_rate, path, seek_time = x
|
412 |
+
assert length is not None
|
413 |
+
return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
|
414 |
+
|
415 |
+
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
416 |
+
"""Gets as input a WavCondition and returns a dense embedding."""
|
417 |
+
raise NotImplementedError()
|
418 |
+
|
419 |
+
def _downsampling_factor(self):
|
420 |
+
"""Returns the downsampling factor of the embedding model."""
|
421 |
+
raise NotImplementedError()
|
422 |
+
|
423 |
+
def forward(self, x: WavCondition) -> ConditionType:
|
424 |
+
"""Extract condition embedding and mask from a waveform and its metadata.
|
425 |
+
Args:
|
426 |
+
x (WavCondition): Waveform condition containing raw waveform and metadata.
|
427 |
+
Returns:
|
428 |
+
ConditionType: a dense vector representing the conditioning along with its mask
|
429 |
+
"""
|
430 |
+
wav, lengths, *_ = x
|
431 |
+
with torch.no_grad():
|
432 |
+
embeds = self._get_wav_embedding(x)
|
433 |
+
embeds = embeds.to(self.output_proj.weight)
|
434 |
+
embeds = self.output_proj(embeds)
|
435 |
+
|
436 |
+
if lengths is not None and self._use_masking:
|
437 |
+
lengths = lengths / self._downsampling_factor()
|
438 |
+
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
439 |
+
else:
|
440 |
+
mask = torch.ones_like(embeds[..., 0])
|
441 |
+
embeds = (embeds * mask.unsqueeze(-1))
|
442 |
+
return embeds, mask
|
443 |
+
|
444 |
+
|
445 |
+
class ChromaStemConditioner(WaveformConditioner):
|
446 |
+
"""Chroma conditioner based on stems.
|
447 |
+
The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
|
448 |
+
the drums and bass often dominate the chroma leading to the chroma features
|
449 |
+
not containing information about the melody.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
output_dim (int): Output dimension for the conditioner.
|
453 |
+
sample_rate (int): Sample rate for the chroma extractor.
|
454 |
+
n_chroma (int): Number of chroma bins for the chroma extractor.
|
455 |
+
radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
|
456 |
+
duration (int): duration used during training. This is later used for correct padding
|
457 |
+
in case we are using chroma as prefix.
|
458 |
+
match_len_on_eval (bool, optional): if True then all chromas are padded to the training
|
459 |
+
duration. Defaults to False.
|
460 |
+
eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
|
461 |
+
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
|
462 |
+
Defaults to None.
|
463 |
+
n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
|
464 |
+
device (tp.Union[torch.device, str], optional): Device for the conditioner.
|
465 |
+
**kwargs: Additional parameters for the chroma extractor.
|
466 |
+
"""
|
467 |
+
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
|
468 |
+
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
|
469 |
+
n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
|
470 |
+
device: tp.Union[torch.device, str] = 'cpu', **kwargs):
|
471 |
+
from demucs import pretrained
|
472 |
+
super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
|
473 |
+
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
|
474 |
+
self.sample_rate = sample_rate
|
475 |
+
self.match_len_on_eval = match_len_on_eval
|
476 |
+
if match_len_on_eval:
|
477 |
+
self._use_masking = False
|
478 |
+
self.duration = duration
|
479 |
+
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
|
480 |
+
stem_sources: list = self.demucs.sources # type: ignore
|
481 |
+
self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
|
482 |
+
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
|
483 |
+
radix2_exp=radix2_exp, **kwargs).to(device)
|
484 |
+
self.chroma_len = self._get_chroma_len()
|
485 |
+
self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
|
486 |
+
self.cache = None
|
487 |
+
if cache_path is not None:
|
488 |
+
self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
489 |
+
compute_embed_fn=self._get_full_chroma_for_cache,
|
490 |
+
extract_embed_fn=self._extract_chroma_chunk)
|
491 |
+
|
492 |
+
def _downsampling_factor(self) -> int:
|
493 |
+
return self.chroma.winhop
|
494 |
+
|
495 |
+
def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
|
496 |
+
"""Load pre-defined waveforms from a json.
|
497 |
+
These waveforms will be used for chroma extraction during evaluation.
|
498 |
+
This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
|
499 |
+
"""
|
500 |
+
if path is None:
|
501 |
+
return None
|
502 |
+
|
503 |
+
logger.info(f"Loading evaluation wavs from {path}")
|
504 |
+
from audiocraft.data.audio_dataset import AudioDataset
|
505 |
+
# print(f'self.video_fps:{self.video_fps}')
|
506 |
+
# print(f'self.video_overlap:{self.video_overlap}')
|
507 |
+
# exit()
|
508 |
+
dataset: AudioDataset = AudioDataset.from_meta(
|
509 |
+
path, segment_duration=self.duration, min_audio_duration=self.duration,
|
510 |
+
sample_rate=self.sample_rate, video_fps=self.video_fps, video_overlap=self.video_overlap, channels=1)
|
511 |
+
|
512 |
+
if len(dataset) > 0:
|
513 |
+
eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
|
514 |
+
logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
|
515 |
+
return eval_wavs
|
516 |
+
else:
|
517 |
+
raise ValueError("Could not find evaluation wavs, check lengths of wavs")
|
518 |
+
|
519 |
+
def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
|
520 |
+
self.eval_wavs = eval_wavs
|
521 |
+
|
522 |
+
def has_eval_wavs(self) -> bool:
|
523 |
+
return self.eval_wavs is not None
|
524 |
+
|
525 |
+
def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
|
526 |
+
"""Sample wavs from a predefined list."""
|
527 |
+
assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
|
528 |
+
total_eval_wavs = len(self.eval_wavs)
|
529 |
+
out = self.eval_wavs
|
530 |
+
if num_samples > total_eval_wavs:
|
531 |
+
out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
|
532 |
+
return out[torch.randperm(len(out))][:num_samples]
|
533 |
+
|
534 |
+
def _get_chroma_len(self) -> int:
|
535 |
+
"""Get length of chroma during training."""
|
536 |
+
dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
|
537 |
+
dummy_chr = self.chroma(dummy_wav)
|
538 |
+
return dummy_chr.shape[1]
|
539 |
+
|
540 |
+
@torch.no_grad()
|
541 |
+
def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
542 |
+
"""Get parts of the wav that holds the melody, extracting the main stems from the wav."""
|
543 |
+
from demucs.apply import apply_model
|
544 |
+
from demucs.audio import convert_audio
|
545 |
+
with self.autocast:
|
546 |
+
wav = convert_audio(
|
547 |
+
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
|
548 |
+
stems = apply_model(self.demucs, wav, device=self.device)
|
549 |
+
stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
|
550 |
+
mix_wav = stems.sum(1) # merge extracted stems to single waveform
|
551 |
+
mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
|
552 |
+
return mix_wav
|
553 |
+
|
554 |
+
@torch.no_grad()
|
555 |
+
def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
|
556 |
+
"""Extract chroma features from the waveform."""
|
557 |
+
with self.autocast:
|
558 |
+
return self.chroma(wav)
|
559 |
+
|
560 |
+
@torch.no_grad()
|
561 |
+
def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
562 |
+
"""Compute wav embedding, applying stem and chroma extraction."""
|
563 |
+
# avoid 0-size tensors when we are working with null conds
|
564 |
+
if wav.shape[-1] == 1:
|
565 |
+
return self._extract_chroma(wav)
|
566 |
+
stems = self._get_stemmed_wav(wav, sample_rate)
|
567 |
+
chroma = self._extract_chroma(stems)
|
568 |
+
return chroma
|
569 |
+
|
570 |
+
@torch.no_grad()
|
571 |
+
def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
|
572 |
+
"""Extract chroma from the whole audio waveform at the given path."""
|
573 |
+
wav, sr = audio_read(path)
|
574 |
+
wav = wav[None].to(self.device)
|
575 |
+
wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
|
576 |
+
chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
|
577 |
+
return chroma
|
578 |
+
|
579 |
+
def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
|
580 |
+
"""Extract a chunk of chroma from the full chroma derived from the full waveform."""
|
581 |
+
wav_length = x.wav.shape[-1]
|
582 |
+
seek_time = x.seek_time[idx]
|
583 |
+
assert seek_time is not None, (
|
584 |
+
"WavCondition seek_time is required "
|
585 |
+
"when extracting chroma chunks from pre-computed chroma.")
|
586 |
+
full_chroma = full_chroma.float()
|
587 |
+
frame_rate = self.sample_rate / self._downsampling_factor()
|
588 |
+
target_length = int(frame_rate * wav_length / self.sample_rate)
|
589 |
+
index = int(frame_rate * seek_time)
|
590 |
+
out = full_chroma[index: index + target_length]
|
591 |
+
out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
|
592 |
+
return out.to(self.device)
|
593 |
+
|
594 |
+
@torch.no_grad()
|
595 |
+
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
596 |
+
"""Get the wav embedding from the WavCondition.
|
597 |
+
The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
|
598 |
+
or will rely on the embedding cache to load the pre-computed embedding if relevant.
|
599 |
+
"""
|
600 |
+
sampled_wav: tp.Optional[torch.Tensor] = None
|
601 |
+
if not self.training and self.eval_wavs is not None:
|
602 |
+
warn_once(logger, "Using precomputed evaluation wavs!")
|
603 |
+
sampled_wav = self._sample_eval_wavs(len(x.wav))
|
604 |
+
|
605 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
606 |
+
no_nullified_cond = x.wav.shape[-1] > 1
|
607 |
+
if sampled_wav is not None:
|
608 |
+
chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
|
609 |
+
elif self.cache is not None and no_undefined_paths and no_nullified_cond:
|
610 |
+
paths = [Path(p) for p in x.path if p is not None]
|
611 |
+
chroma = self.cache.get_embed_from_cache(paths, x)
|
612 |
+
else:
|
613 |
+
assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
|
614 |
+
chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
|
615 |
+
|
616 |
+
if self.match_len_on_eval:
|
617 |
+
B, T, C = chroma.shape
|
618 |
+
if T > self.chroma_len:
|
619 |
+
chroma = chroma[:, :self.chroma_len]
|
620 |
+
logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
|
621 |
+
elif T < self.chroma_len:
|
622 |
+
n_repeat = int(math.ceil(self.chroma_len / T))
|
623 |
+
chroma = chroma.repeat(1, n_repeat, 1)
|
624 |
+
chroma = chroma[:, :self.chroma_len]
|
625 |
+
logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
|
626 |
+
|
627 |
+
return chroma
|
628 |
+
|
629 |
+
def tokenize(self, x: WavCondition) -> WavCondition:
|
630 |
+
"""Apply WavConditioner tokenization and populate cache if needed."""
|
631 |
+
x = super().tokenize(x)
|
632 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
633 |
+
if self.cache is not None and no_undefined_paths:
|
634 |
+
paths = [Path(p) for p in x.path if p is not None]
|
635 |
+
self.cache.populate_embed_cache(paths, x)
|
636 |
+
return x
|
637 |
+
|
638 |
+
|
639 |
+
class JointEmbeddingConditioner(BaseConditioner):
|
640 |
+
"""Joint embedding conditioning supporting both audio or text conditioning.
|
641 |
+
|
642 |
+
Args:
|
643 |
+
dim (int): Dimension.
|
644 |
+
output_dim (int): Output dimension.
|
645 |
+
device (str): Device.
|
646 |
+
attribute (str): Attribute used by the conditioner.
|
647 |
+
autocast_dtype (str): Autocast for the conditioner.
|
648 |
+
quantize (bool): Whether to quantize the CLAP embedding.
|
649 |
+
n_q (int): Number of residual quantizers (used if quantize is true).
|
650 |
+
bins (int): Quantizers' codebooks size (used if quantize is true).
|
651 |
+
kwargs: Additional parameters for residual vector quantizer.
|
652 |
+
"""
|
653 |
+
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
654 |
+
autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
|
655 |
+
n_q: int = 12, bins: int = 1024, **kwargs):
|
656 |
+
super().__init__(dim=dim, output_dim=output_dim)
|
657 |
+
self.device = device
|
658 |
+
self.attribute = attribute
|
659 |
+
if autocast_dtype is None or device == 'cpu':
|
660 |
+
self.autocast = TorchAutocast(enabled=False)
|
661 |
+
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
|
662 |
+
else:
|
663 |
+
dtype = getattr(torch, autocast_dtype)
|
664 |
+
assert isinstance(dtype, torch.dtype)
|
665 |
+
logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
|
666 |
+
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
667 |
+
# residual vector quantizer to discretize the conditioned embedding
|
668 |
+
self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
|
669 |
+
if quantize:
|
670 |
+
self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
|
671 |
+
|
672 |
+
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
673 |
+
"""Get joint embedding in latent space from the inputs.
|
674 |
+
|
675 |
+
Returns:
|
676 |
+
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
|
677 |
+
and corresponding empty indexes.
|
678 |
+
"""
|
679 |
+
raise NotImplementedError()
|
680 |
+
|
681 |
+
def forward(self, x: JointEmbedCondition) -> ConditionType:
|
682 |
+
with self.autocast:
|
683 |
+
embed, empty_idx = self._get_embed(x)
|
684 |
+
if self.quantizer is not None:
|
685 |
+
embed = embed.view(-1, self.dim, 1)
|
686 |
+
q_res = self.quantizer(embed, frame_rate=1)
|
687 |
+
out_embed = q_res.x.view(-1, self.dim)
|
688 |
+
else:
|
689 |
+
out_embed = embed
|
690 |
+
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
|
691 |
+
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
|
692 |
+
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
693 |
+
out_embed = (out_embed * mask.unsqueeze(-1))
|
694 |
+
return out_embed, mask
|
695 |
+
|
696 |
+
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
697 |
+
return x
|
698 |
+
|
699 |
+
|
700 |
+
class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
|
701 |
+
"""Joint Embedding conditioner based on pre-trained CLAP model.
|
702 |
+
|
703 |
+
This CLAP-based conditioner supports a caching mechanism
|
704 |
+
over the computed embeddings for faster training.
|
705 |
+
|
706 |
+
Args:
|
707 |
+
dim (int): Dimension.
|
708 |
+
output_dim (int): Output dimension.
|
709 |
+
device (str): Device.
|
710 |
+
attribute (str): Attribute used by the conditioner.
|
711 |
+
quantize (bool): Whether to quantize the CLAP embedding.
|
712 |
+
n_q (int): Number of residual quantizers (used if quantize is true).
|
713 |
+
bins (int): Quantizers' codebooks size (used if quantize is true).
|
714 |
+
checkpoint (str): Path to CLAP checkpoint.
|
715 |
+
model_arch (str): CLAP model architecture.
|
716 |
+
enable_fusion (bool): Enable fusion for CLAP model.
|
717 |
+
sample_rate (int): Sample rate used by CLAP model.
|
718 |
+
max_audio_length (float): Maximum audio length for CLAP model.
|
719 |
+
audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
|
720 |
+
normalize (bool): Whether to normalize the CLAP embedding.
|
721 |
+
text_p (float): Probability of using text representation instead of audio at train time.
|
722 |
+
batch_size (Optional[int]): Batch size for CLAP embedding computation.
|
723 |
+
autocast_dtype (str): Autocast for the conditioner.
|
724 |
+
cache_path (Optional[str]): Path for pre-computed embeddings caching.
|
725 |
+
kwargs: Additional parameters for residual vector quantizer.
|
726 |
+
"""
|
727 |
+
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
728 |
+
quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
|
729 |
+
enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
|
730 |
+
normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
|
731 |
+
autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
|
732 |
+
try:
|
733 |
+
import laion_clap # type: ignore
|
734 |
+
except ImportError:
|
735 |
+
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
|
736 |
+
# warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
|
737 |
+
# "Please retrain all models.")
|
738 |
+
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
|
739 |
+
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
740 |
+
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
741 |
+
load_clap_state_dict(clap_model, checkpoint)
|
742 |
+
clap_model.eval()
|
743 |
+
clap_model.to(device)
|
744 |
+
super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
|
745 |
+
autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
|
746 |
+
**kwargs)
|
747 |
+
self.checkpoint = checkpoint
|
748 |
+
self.enable_fusion = enable_fusion
|
749 |
+
self.model_arch = model_arch
|
750 |
+
self.clap: laion_clap.CLAP_Module
|
751 |
+
self.clap_tokenize: RobertaTokenizer
|
752 |
+
self.clap_sample_rate = sample_rate
|
753 |
+
self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
|
754 |
+
self.clap_stride = int(self.clap_sample_rate * audio_stride)
|
755 |
+
self.batch_size = batch_size or 1
|
756 |
+
self.normalize = normalize
|
757 |
+
self.text_p = text_p
|
758 |
+
self.__dict__['clap_tokenize'] = clap_tokenize
|
759 |
+
self.__dict__['clap'] = clap_model
|
760 |
+
self.wav_cache, self.text_cache = None, None
|
761 |
+
if cache_path is not None:
|
762 |
+
self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
763 |
+
compute_embed_fn=self._get_wav_embedding_for_cache,
|
764 |
+
extract_embed_fn=self._extract_wav_embedding_chunk)
|
765 |
+
self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
|
766 |
+
compute_embed_fn=self._get_text_embedding_for_cache)
|
767 |
+
|
768 |
+
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
769 |
+
# we use the default params from CLAP module here as well
|
770 |
+
return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
771 |
+
|
772 |
+
def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
|
773 |
+
"""Compute text embedding from CLAP model on a given a batch of text.
|
774 |
+
|
775 |
+
Args:
|
776 |
+
text (list[str]): List of text for the batch, with B items.
|
777 |
+
Returns:
|
778 |
+
torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
|
779 |
+
"""
|
780 |
+
with torch.no_grad():
|
781 |
+
embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
782 |
+
return embed.view(embed.size(0), 1, embed.size(-1))
|
783 |
+
|
784 |
+
def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
|
785 |
+
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
786 |
+
"""Get text embedding function for the cache."""
|
787 |
+
text = x.text[idx]
|
788 |
+
text = text if text is not None else ""
|
789 |
+
return self._compute_text_embedding([text])[0]
|
790 |
+
|
791 |
+
def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
|
792 |
+
"""Preprocess wav to expected format by CLAP model.
|
793 |
+
|
794 |
+
Args:
|
795 |
+
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
796 |
+
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
797 |
+
sample_rates (list[int]): Sample rates for each sample in the batch
|
798 |
+
Returns:
|
799 |
+
torch.Tensor: Audio wav of shape [B, T].
|
800 |
+
"""
|
801 |
+
assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
|
802 |
+
if sample_rates is not None:
|
803 |
+
_wav = []
|
804 |
+
for i, audio in enumerate(wav):
|
805 |
+
sr = sample_rates[i]
|
806 |
+
audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
|
807 |
+
_wav.append(audio)
|
808 |
+
wav = torch.stack(_wav, dim=0)
|
809 |
+
wav = wav.mean(dim=1)
|
810 |
+
return wav
|
811 |
+
|
812 |
+
def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
|
813 |
+
sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
|
814 |
+
"""Compute audio wave embedding from CLAP model.
|
815 |
+
|
816 |
+
Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
|
817 |
+
we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
|
818 |
+
average the resulting embeddings.
|
819 |
+
|
820 |
+
Args:
|
821 |
+
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
822 |
+
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
823 |
+
sample_rates (list[int]): Sample rates for each sample in the batch.
|
824 |
+
reduce_mean (bool): Whether to get the average tensor.
|
825 |
+
Returns:
|
826 |
+
torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
|
827 |
+
"""
|
828 |
+
with torch.no_grad():
|
829 |
+
wav = self._preprocess_wav(wav, length, sample_rates)
|
830 |
+
B, T = wav.shape
|
831 |
+
if T >= self.clap_max_frames:
|
832 |
+
wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
|
833 |
+
else:
|
834 |
+
wav = wav.view(-1, 1, T) # [B, F, T] with F=1
|
835 |
+
wav = einops.rearrange(wav, 'b f t -> (b f) t')
|
836 |
+
embed_list = []
|
837 |
+
for i in range(0, wav.size(0), self.batch_size):
|
838 |
+
_wav = wav[i:i+self.batch_size, ...]
|
839 |
+
_embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
|
840 |
+
embed_list.append(_embed)
|
841 |
+
embed = torch.cat(embed_list, dim=0)
|
842 |
+
embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
|
843 |
+
if reduce_mean:
|
844 |
+
embed = embed.mean(dim=1, keepdim=True)
|
845 |
+
return embed # [B, F, D] with F=1 if reduce_mean is True
|
846 |
+
|
847 |
+
def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
|
848 |
+
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
849 |
+
"""Compute audio wave embedding for the cache.
|
850 |
+
The embedding is computed on a given audio read from file.
|
851 |
+
|
852 |
+
Args:
|
853 |
+
path (str or Path): Path to the full audio file.
|
854 |
+
Returns:
|
855 |
+
torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
|
856 |
+
"""
|
857 |
+
wav, sr = audio_read(path) # [C, T]
|
858 |
+
wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
|
859 |
+
wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
|
860 |
+
embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
|
861 |
+
return embed.squeeze(0) # [F, D]
|
862 |
+
|
863 |
+
def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
864 |
+
"""Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
|
865 |
+
|
866 |
+
Args:
|
867 |
+
full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
|
868 |
+
x (JointEmbedCondition): Joint embedding condition for the full batch.
|
869 |
+
idx (int): Index considered for the given embedding to extract.
|
870 |
+
Returns:
|
871 |
+
torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
|
872 |
+
"""
|
873 |
+
sample_rate = x.sample_rate[idx]
|
874 |
+
seek_time = x.seek_time[idx]
|
875 |
+
seek_time = 0. if seek_time is None else seek_time
|
876 |
+
clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
|
877 |
+
end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
|
878 |
+
start_offset = int(seek_time * sample_rate // clap_stride)
|
879 |
+
end_offset = int(end_seek_time * sample_rate // clap_stride)
|
880 |
+
wav_embed = full_embed[start_offset:end_offset, ...]
|
881 |
+
wav_embed = wav_embed.mean(dim=0, keepdim=True)
|
882 |
+
return wav_embed.to(self.device) # [F, D]
|
883 |
+
|
884 |
+
def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
885 |
+
"""Get CLAP embedding from a batch of text descriptions."""
|
886 |
+
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
887 |
+
if self.text_cache is not None and no_nullified_cond:
|
888 |
+
assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
|
889 |
+
paths = [Path(p) for p in x.path if p is not None]
|
890 |
+
embed = self.text_cache.get_embed_from_cache(paths, x)
|
891 |
+
else:
|
892 |
+
text = [xi if xi is not None else "" for xi in x.text]
|
893 |
+
embed = self._compute_text_embedding(text)
|
894 |
+
if self.normalize:
|
895 |
+
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
896 |
+
return embed
|
897 |
+
|
898 |
+
def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
899 |
+
"""Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
|
900 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
901 |
+
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
902 |
+
if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
|
903 |
+
paths = [Path(p) for p in x.path if p is not None]
|
904 |
+
embed = self.wav_cache.get_embed_from_cache(paths, x)
|
905 |
+
else:
|
906 |
+
embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
|
907 |
+
if self.normalize:
|
908 |
+
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
909 |
+
return embed
|
910 |
+
|
911 |
+
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
912 |
+
# Trying to limit as much as possible sync points when the cache is warm.
|
913 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
914 |
+
if self.wav_cache is not None and no_undefined_paths:
|
915 |
+
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
916 |
+
paths = [Path(p) for p in x.path if p is not None]
|
917 |
+
self.wav_cache.populate_embed_cache(paths, x)
|
918 |
+
if self.text_cache is not None and no_undefined_paths:
|
919 |
+
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
920 |
+
paths = [Path(p) for p in x.path if p is not None]
|
921 |
+
self.text_cache.populate_embed_cache(paths, x)
|
922 |
+
return x
|
923 |
+
|
924 |
+
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
925 |
+
"""Extract shared latent representation from either the wav or the text using CLAP."""
|
926 |
+
# decide whether to use text embedding at train time or not
|
927 |
+
use_text_embed = random.random() < self.text_p
|
928 |
+
if self.training and not use_text_embed:
|
929 |
+
embed = self._get_wav_embedding(x)
|
930 |
+
empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
|
931 |
+
else:
|
932 |
+
embed = self._get_text_embedding(x)
|
933 |
+
empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
|
934 |
+
return embed, empty_idx
|
935 |
+
|
936 |
+
|
937 |
+
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
|
938 |
+
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
939 |
+
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
940 |
+
If the condition is of any other type, set its value to None.
|
941 |
+
Works in-place.
|
942 |
+
"""
|
943 |
+
if condition_type not in ['text', 'wav', 'joint_embed']:
|
944 |
+
raise ValueError(
|
945 |
+
"dropout_condition got an unexpected condition type!"
|
946 |
+
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
|
947 |
+
)
|
948 |
+
|
949 |
+
if condition not in getattr(sample, condition_type):
|
950 |
+
raise ValueError(
|
951 |
+
"dropout_condition received an unexpected condition!"
|
952 |
+
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
|
953 |
+
f" but got '{condition}' of type '{condition_type}'!"
|
954 |
+
)
|
955 |
+
|
956 |
+
if condition_type == 'wav':
|
957 |
+
wav_cond = sample.wav[condition]
|
958 |
+
sample.wav[condition] = nullify_wav(wav_cond)
|
959 |
+
elif condition_type == 'joint_embed':
|
960 |
+
embed = sample.joint_embed[condition]
|
961 |
+
sample.joint_embed[condition] = nullify_joint_embed(embed)
|
962 |
+
else:
|
963 |
+
sample.text[condition] = None
|
964 |
+
|
965 |
+
return sample
|
966 |
+
|
967 |
+
|
968 |
+
class DropoutModule(nn.Module):
|
969 |
+
"""Base module for all dropout modules."""
|
970 |
+
def __init__(self, seed: int = 1234):
|
971 |
+
super().__init__()
|
972 |
+
self.rng = torch.Generator()
|
973 |
+
self.rng.manual_seed(seed)
|
974 |
+
|
975 |
+
|
976 |
+
class AttributeDropout(DropoutModule):
|
977 |
+
"""Dropout with a given probability per attribute.
|
978 |
+
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
|
979 |
+
to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
|
980 |
+
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
|
981 |
+
must also be dropped.
|
982 |
+
|
983 |
+
Args:
|
984 |
+
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
985 |
+
...
|
986 |
+
"genre": 0.1,
|
987 |
+
"artist": 0.5,
|
988 |
+
"wav": 0.25,
|
989 |
+
...
|
990 |
+
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
|
991 |
+
seed (int, optional): Random seed.
|
992 |
+
"""
|
993 |
+
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
|
994 |
+
super().__init__(seed=seed)
|
995 |
+
self.active_on_eval = active_on_eval
|
996 |
+
# construct dict that return the values from p otherwise 0
|
997 |
+
self.p = {}
|
998 |
+
for condition_type, probs in p.items():
|
999 |
+
self.p[condition_type] = defaultdict(lambda: 0, probs)
|
1000 |
+
|
1001 |
+
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
1002 |
+
"""
|
1003 |
+
Args:
|
1004 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
1005 |
+
Returns:
|
1006 |
+
list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
|
1007 |
+
"""
|
1008 |
+
if not self.training and not self.active_on_eval:
|
1009 |
+
return samples
|
1010 |
+
|
1011 |
+
samples = deepcopy(samples)
|
1012 |
+
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
1013 |
+
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
1014 |
+
if torch.rand(1, generator=self.rng).item() < p:
|
1015 |
+
for sample in samples:
|
1016 |
+
dropout_condition(sample, condition_type, condition)
|
1017 |
+
return samples
|
1018 |
+
|
1019 |
+
def __repr__(self):
|
1020 |
+
return f"AttributeDropout({dict(self.p)})"
|
1021 |
+
|
1022 |
+
import torch
|
1023 |
+
from torch.nn import Module
|
1024 |
+
from copy import deepcopy
|
1025 |
+
|
1026 |
+
class ClassifierFreeGuidanceDropout(Module):
|
1027 |
+
"""Classifier Free Guidance dropout for tensor inputs.
|
1028 |
+
All elements in the tensor are dropped with the same probability.
|
1029 |
+
|
1030 |
+
Args:
|
1031 |
+
p (float): Probability to apply dropout during training.
|
1032 |
+
seed (int): Random seed.
|
1033 |
+
"""
|
1034 |
+
def __init__(self, p: float, seed: int = 1234):
|
1035 |
+
super().__init__()
|
1036 |
+
self.p = p
|
1037 |
+
torch.manual_seed(seed) # Set the seed for reproducibility
|
1038 |
+
|
1039 |
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
1040 |
+
"""
|
1041 |
+
Args:
|
1042 |
+
tensor (torch.Tensor): Input tensor.
|
1043 |
+
Returns:
|
1044 |
+
torch.Tensor: Tensor after applying dropout.
|
1045 |
+
"""
|
1046 |
+
if not self.training or self.p <= 0:
|
1047 |
+
return tensor
|
1048 |
+
|
1049 |
+
# Create a dropout mask with the same size as the input tensor
|
1050 |
+
mask = torch.rand(tensor.shape) > self.p
|
1051 |
+
mask = mask.to(tensor.device) # Move mask to the same device as the input tensor
|
1052 |
+
|
1053 |
+
# Apply dropout mask
|
1054 |
+
dropped_tensor = tensor * mask.float()
|
1055 |
+
|
1056 |
+
return dropped_tensor
|
1057 |
+
|
1058 |
+
def __repr__(self):
|
1059 |
+
return f"ClassifierFreeGuidanceDropout(p={self.p})"
|
1060 |
+
|
1061 |
+
|
1062 |
+
class ConditioningProvider(nn.Module):
|
1063 |
+
"""Prepare and provide conditions given all the supported conditioners.
|
1064 |
+
|
1065 |
+
Args:
|
1066 |
+
conditioners (dict): Dictionary of conditioners.
|
1067 |
+
device (torch.device or str, optional): Device for conditioners and output condition types.
|
1068 |
+
"""
|
1069 |
+
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
|
1070 |
+
super().__init__()
|
1071 |
+
self.device = device
|
1072 |
+
self.conditioners = nn.ModuleDict(conditioners)
|
1073 |
+
|
1074 |
+
@property
|
1075 |
+
def joint_embed_conditions(self):
|
1076 |
+
return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
|
1077 |
+
|
1078 |
+
@property
|
1079 |
+
def has_joint_embed_conditions(self):
|
1080 |
+
return len(self.joint_embed_conditions) > 0
|
1081 |
+
|
1082 |
+
@property
|
1083 |
+
def text_conditions(self):
|
1084 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
1085 |
+
|
1086 |
+
@property
|
1087 |
+
def wav_conditions(self):
|
1088 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
|
1089 |
+
|
1090 |
+
@property
|
1091 |
+
def has_wav_condition(self):
|
1092 |
+
return len(self.wav_conditions) > 0
|
1093 |
+
|
1094 |
+
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
1095 |
+
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
1096 |
+
This should be called before starting any real GPU work to avoid synchronization points.
|
1097 |
+
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
1098 |
+
|
1099 |
+
Args:
|
1100 |
+
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
1101 |
+
text and wav conditions.
|
1102 |
+
"""
|
1103 |
+
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
1104 |
+
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
1105 |
+
f" but types were {set([type(x) for x in inputs])}"
|
1106 |
+
)
|
1107 |
+
|
1108 |
+
output = {}
|
1109 |
+
text = self._collate_text(inputs)
|
1110 |
+
wavs = self._collate_wavs(inputs)
|
1111 |
+
joint_embeds = self._collate_joint_embeds(inputs)
|
1112 |
+
|
1113 |
+
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
1114 |
+
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
1115 |
+
f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
|
1119 |
+
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
1120 |
+
return output
|
1121 |
+
|
1122 |
+
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
1123 |
+
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
1124 |
+
The output is for example:
|
1125 |
+
{
|
1126 |
+
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
1127 |
+
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
1128 |
+
...
|
1129 |
+
}
|
1130 |
+
|
1131 |
+
Args:
|
1132 |
+
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
1133 |
+
"""
|
1134 |
+
output = {}
|
1135 |
+
for attribute, inputs in tokenized.items():
|
1136 |
+
condition, mask = self.conditioners[attribute](inputs)
|
1137 |
+
output[attribute] = (condition, mask)
|
1138 |
+
return output
|
1139 |
+
|
1140 |
+
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
|
1141 |
+
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
|
1142 |
+
are the attributes and the values are the aggregated input per attribute.
|
1143 |
+
For example:
|
1144 |
+
Input:
|
1145 |
+
[
|
1146 |
+
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
|
1147 |
+
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
|
1148 |
+
]
|
1149 |
+
Output:
|
1150 |
+
{
|
1151 |
+
"genre": ["Rock", "Hip-hop"],
|
1152 |
+
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
1153 |
+
}
|
1154 |
+
|
1155 |
+
Args:
|
1156 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
1157 |
+
Returns:
|
1158 |
+
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
|
1159 |
+
"""
|
1160 |
+
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
1161 |
+
texts = [x.text for x in samples]
|
1162 |
+
for text in texts:
|
1163 |
+
for condition in self.text_conditions:
|
1164 |
+
out[condition].append(text[condition])
|
1165 |
+
return out
|
1166 |
+
|
1167 |
+
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
|
1168 |
+
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
|
1169 |
+
and the values are Tensors of wavs according to said attributes.
|
1170 |
+
|
1171 |
+
*Note*: by the time the samples reach this function, each sample should have some waveform
|
1172 |
+
inside the "wav" attribute. It should be either:
|
1173 |
+
1. A real waveform
|
1174 |
+
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
|
1175 |
+
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
|
1176 |
+
|
1177 |
+
Args:
|
1178 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
1179 |
+
Returns:
|
1180 |
+
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
|
1181 |
+
"""
|
1182 |
+
wavs = defaultdict(list)
|
1183 |
+
lengths = defaultdict(list)
|
1184 |
+
sample_rates = defaultdict(list)
|
1185 |
+
paths = defaultdict(list)
|
1186 |
+
seek_times = defaultdict(list)
|
1187 |
+
out: tp.Dict[str, WavCondition] = {}
|
1188 |
+
|
1189 |
+
for sample in samples:
|
1190 |
+
for attribute in self.wav_conditions:
|
1191 |
+
wav, length, sample_rate, path, seek_time = sample.wav[attribute]
|
1192 |
+
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
|
1193 |
+
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
|
1194 |
+
# mono-channel conditioning
|
1195 |
+
wav = wav.mean(1, keepdim=True) # [1, 1, T]
|
1196 |
+
wavs[attribute].append(wav.flatten()) # [T]
|
1197 |
+
lengths[attribute].append(length)
|
1198 |
+
sample_rates[attribute].extend(sample_rate)
|
1199 |
+
paths[attribute].extend(path)
|
1200 |
+
seek_times[attribute].extend(seek_time)
|
1201 |
+
|
1202 |
+
# stack all wavs to a single tensor
|
1203 |
+
for attribute in self.wav_conditions:
|
1204 |
+
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
1205 |
+
out[attribute] = WavCondition(
|
1206 |
+
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
|
1207 |
+
paths[attribute], seek_times[attribute])
|
1208 |
+
|
1209 |
+
return out
|
1210 |
+
|
1211 |
+
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
|
1212 |
+
"""Generate a dict where the keys are attributes by which we compute joint embeddings,
|
1213 |
+
and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
|
1214 |
+
|
1215 |
+
Args:
|
1216 |
+
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
|
1217 |
+
Returns:
|
1218 |
+
A dictionary mapping an attribute name to joint embeddings.
|
1219 |
+
"""
|
1220 |
+
texts = defaultdict(list)
|
1221 |
+
wavs = defaultdict(list)
|
1222 |
+
lengths = defaultdict(list)
|
1223 |
+
sample_rates = defaultdict(list)
|
1224 |
+
paths = defaultdict(list)
|
1225 |
+
seek_times = defaultdict(list)
|
1226 |
+
channels: int = 0
|
1227 |
+
|
1228 |
+
out = {}
|
1229 |
+
for sample in samples:
|
1230 |
+
for attribute in self.joint_embed_conditions:
|
1231 |
+
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
|
1232 |
+
assert wav.dim() == 3
|
1233 |
+
if channels == 0:
|
1234 |
+
channels = wav.size(1)
|
1235 |
+
else:
|
1236 |
+
assert channels == wav.size(1), "not all audio has same number of channels in batch"
|
1237 |
+
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
|
1238 |
+
wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
|
1239 |
+
wavs[attribute].append(wav)
|
1240 |
+
texts[attribute].extend(text)
|
1241 |
+
lengths[attribute].append(length)
|
1242 |
+
sample_rates[attribute].extend(sample_rate)
|
1243 |
+
paths[attribute].extend(path)
|
1244 |
+
seek_times[attribute].extend(seek_time)
|
1245 |
+
|
1246 |
+
for attribute in self.joint_embed_conditions:
|
1247 |
+
stacked_texts = texts[attribute]
|
1248 |
+
stacked_paths = paths[attribute]
|
1249 |
+
stacked_seek_times = seek_times[attribute]
|
1250 |
+
stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
|
1251 |
+
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
|
1252 |
+
stacked_sample_rates = sample_rates[attribute]
|
1253 |
+
stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
|
1254 |
+
assert stacked_lengths.size(0) == stacked_wavs.size(0)
|
1255 |
+
assert len(stacked_sample_rates) == stacked_wavs.size(0)
|
1256 |
+
assert len(stacked_texts) == stacked_wavs.size(0)
|
1257 |
+
out[attribute] = JointEmbedCondition(
|
1258 |
+
text=stacked_texts, wav=stacked_wavs,
|
1259 |
+
length=stacked_lengths, sample_rate=stacked_sample_rates,
|
1260 |
+
path=stacked_paths, seek_time=stacked_seek_times)
|
1261 |
+
|
1262 |
+
return out
|
1263 |
+
|
1264 |
+
|
1265 |
+
class ConditionFuser(StreamingModule):
|
1266 |
+
"""Condition fuser handles the logic to combine the different conditions
|
1267 |
+
to the actual model input.
|
1268 |
+
|
1269 |
+
Args:
|
1270 |
+
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
|
1271 |
+
each condition. For example:
|
1272 |
+
{
|
1273 |
+
"prepend": ["description"],
|
1274 |
+
"sum": ["genre", "bpm"],
|
1275 |
+
"cross": ["description"],
|
1276 |
+
}
|
1277 |
+
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
|
1278 |
+
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
|
1279 |
+
"""
|
1280 |
+
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
|
1281 |
+
|
1282 |
+
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
1283 |
+
cross_attention_pos_emb_scale: float = 1.0):
|
1284 |
+
super().__init__()
|
1285 |
+
assert all(
|
1286 |
+
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
1287 |
+
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
|
1288 |
+
self.cross_attention_pos_emb = cross_attention_pos_emb
|
1289 |
+
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
|
1290 |
+
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
1291 |
+
self.cond2fuse: tp.Dict[str, str] = {}
|
1292 |
+
for fuse_method, conditions in fuse2cond.items():
|
1293 |
+
for condition in conditions:
|
1294 |
+
self.cond2fuse[condition] = fuse_method
|
1295 |
+
|
1296 |
+
def forward(
|
1297 |
+
self,
|
1298 |
+
input: torch.Tensor,
|
1299 |
+
conditions: tp.Dict[str, ConditionType]
|
1300 |
+
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1301 |
+
"""Fuse the conditions to the provided model input.
|
1302 |
+
|
1303 |
+
Args:
|
1304 |
+
input (torch.Tensor): Transformer input.
|
1305 |
+
conditions (dict[str, ConditionType]): Dict of conditions.
|
1306 |
+
Returns:
|
1307 |
+
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
|
1308 |
+
after the conditions have been fused. The second output tensor is the tensor
|
1309 |
+
used for cross-attention or None if no cross attention inputs exist.
|
1310 |
+
"""
|
1311 |
+
B, T, _ = input.shape
|
1312 |
+
|
1313 |
+
if 'offsets' in self._streaming_state:
|
1314 |
+
first_step = False
|
1315 |
+
offsets = self._streaming_state['offsets']
|
1316 |
+
else:
|
1317 |
+
first_step = True
|
1318 |
+
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
1319 |
+
|
1320 |
+
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
|
1321 |
+
f"given conditions contain unknown attributes for fuser, " \
|
1322 |
+
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
|
1323 |
+
cross_attention_output = None
|
1324 |
+
# print(f'-------------conditioners.py line1386----------condition_tensors:{conditions}')
|
1325 |
+
# print(f'-------------conditioners.py line1387----------condition_tensors.items:{conditions.items}')
|
1326 |
+
# exit()
|
1327 |
+
for cond_type, (cond, cond_mask) in conditions.items():
|
1328 |
+
op = self.cond2fuse[cond_type]
|
1329 |
+
if op == 'sum':
|
1330 |
+
input += cond
|
1331 |
+
elif op == 'input_interpolate':
|
1332 |
+
cond = einops.rearrange(cond, "b t d -> b d t")
|
1333 |
+
cond = F.interpolate(cond, size=input.shape[1])
|
1334 |
+
input += einops.rearrange(cond, "b d t -> b t d")
|
1335 |
+
elif op == 'prepend':
|
1336 |
+
if first_step:
|
1337 |
+
input = torch.cat([cond, input], dim=1)
|
1338 |
+
elif op == 'cross':
|
1339 |
+
if cross_attention_output is not None:
|
1340 |
+
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
1341 |
+
else:
|
1342 |
+
cross_attention_output = cond
|
1343 |
+
else:
|
1344 |
+
raise ValueError(f"unknown op ({op})")
|
1345 |
+
|
1346 |
+
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
1347 |
+
positions = torch.arange(
|
1348 |
+
cross_attention_output.shape[1],
|
1349 |
+
device=cross_attention_output.device
|
1350 |
+
).view(1, -1, 1)
|
1351 |
+
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
|
1352 |
+
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
|
1353 |
+
|
1354 |
+
if self._is_streaming:
|
1355 |
+
self._streaming_state['offsets'] = offsets + T
|
1356 |
+
|
1357 |
+
return input, cross_attention_output
|
audiocraft/modules/conv.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import typing as tp
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
15 |
+
|
16 |
+
|
17 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
18 |
+
'time_group_norm'])
|
19 |
+
|
20 |
+
|
21 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
|
22 |
+
assert norm in CONV_NORMALIZATIONS
|
23 |
+
if norm == 'weight_norm':
|
24 |
+
return weight_norm(module)
|
25 |
+
elif norm == 'spectral_norm':
|
26 |
+
return spectral_norm(module)
|
27 |
+
else:
|
28 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
29 |
+
# doesn't need reparametrization.
|
30 |
+
return module
|
31 |
+
|
32 |
+
|
33 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
|
34 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
35 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
36 |
+
"""
|
37 |
+
assert norm in CONV_NORMALIZATIONS
|
38 |
+
if norm == 'time_group_norm':
|
39 |
+
if causal:
|
40 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
41 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
42 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
43 |
+
else:
|
44 |
+
return nn.Identity()
|
45 |
+
|
46 |
+
|
47 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
48 |
+
padding_total: int = 0) -> int:
|
49 |
+
"""See `pad_for_conv1d`."""
|
50 |
+
length = x.shape[-1]
|
51 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
52 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
53 |
+
return ideal_length - length
|
54 |
+
|
55 |
+
|
56 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
57 |
+
"""Pad for a convolution to make sure that the last window is full.
|
58 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
59 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
60 |
+
might get removed.
|
61 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
62 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
63 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
64 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
65 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
66 |
+
"""
|
67 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
68 |
+
return F.pad(x, (0, extra_padding))
|
69 |
+
|
70 |
+
|
71 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
72 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
73 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
74 |
+
"""
|
75 |
+
length = x.shape[-1]
|
76 |
+
padding_left, padding_right = paddings
|
77 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
78 |
+
if mode == 'reflect':
|
79 |
+
max_pad = max(padding_left, padding_right)
|
80 |
+
extra_pad = 0
|
81 |
+
if length <= max_pad:
|
82 |
+
extra_pad = max_pad - length + 1
|
83 |
+
x = F.pad(x, (0, extra_pad))
|
84 |
+
padded = F.pad(x, paddings, mode, value)
|
85 |
+
end = padded.shape[-1] - extra_pad
|
86 |
+
return padded[..., :end]
|
87 |
+
else:
|
88 |
+
return F.pad(x, paddings, mode, value)
|
89 |
+
|
90 |
+
|
91 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
92 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
93 |
+
padding_left, padding_right = paddings
|
94 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
95 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
96 |
+
end = x.shape[-1] - padding_right
|
97 |
+
return x[..., padding_left: end]
|
98 |
+
|
99 |
+
|
100 |
+
class NormConv1d(nn.Module):
|
101 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
102 |
+
to provide a uniform interface across normalization approaches.
|
103 |
+
"""
|
104 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
105 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
106 |
+
super().__init__()
|
107 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
108 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
109 |
+
self.norm_type = norm
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
x = self.conv(x)
|
113 |
+
x = self.norm(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class NormConv2d(nn.Module):
|
118 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
119 |
+
to provide a uniform interface across normalization approaches.
|
120 |
+
"""
|
121 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
122 |
+
super().__init__()
|
123 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
124 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
125 |
+
self.norm_type = norm
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
x = self.conv(x)
|
129 |
+
x = self.norm(x)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class NormConvTranspose1d(nn.Module):
|
134 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
135 |
+
to provide a uniform interface across normalization approaches.
|
136 |
+
"""
|
137 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
138 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
139 |
+
super().__init__()
|
140 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
141 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
142 |
+
self.norm_type = norm
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
x = self.convtr(x)
|
146 |
+
x = self.norm(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
class NormConvTranspose2d(nn.Module):
|
151 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
152 |
+
to provide a uniform interface across normalization approaches.
|
153 |
+
"""
|
154 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
155 |
+
super().__init__()
|
156 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
157 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
x = self.convtr(x)
|
161 |
+
x = self.norm(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
class StreamableConv1d(nn.Module):
|
166 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
167 |
+
and normalization.
|
168 |
+
"""
|
169 |
+
def __init__(self, in_channels: int, out_channels: int,
|
170 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
171 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
172 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
173 |
+
pad_mode: str = 'reflect'):
|
174 |
+
super().__init__()
|
175 |
+
# warn user on unusual setup between dilation and stride
|
176 |
+
if stride > 1 and dilation > 1:
|
177 |
+
warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
|
178 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
|
179 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
180 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
181 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
182 |
+
self.causal = causal
|
183 |
+
self.pad_mode = pad_mode
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
B, C, T = x.shape
|
187 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
188 |
+
stride = self.conv.conv.stride[0]
|
189 |
+
dilation = self.conv.conv.dilation[0]
|
190 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
191 |
+
padding_total = kernel_size - stride
|
192 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
193 |
+
if self.causal:
|
194 |
+
# Left padding for causal
|
195 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
196 |
+
else:
|
197 |
+
# Asymmetric padding required for odd strides
|
198 |
+
padding_right = padding_total // 2
|
199 |
+
padding_left = padding_total - padding_right
|
200 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
201 |
+
return self.conv(x)
|
202 |
+
|
203 |
+
|
204 |
+
class StreamableConvTranspose1d(nn.Module):
|
205 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
206 |
+
and normalization.
|
207 |
+
"""
|
208 |
+
def __init__(self, in_channels: int, out_channels: int,
|
209 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
210 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
211 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
212 |
+
super().__init__()
|
213 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
214 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
215 |
+
self.causal = causal
|
216 |
+
self.trim_right_ratio = trim_right_ratio
|
217 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
218 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
219 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
220 |
+
|
221 |
+
def forward(self, x):
|
222 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
223 |
+
stride = self.convtr.convtr.stride[0]
|
224 |
+
padding_total = kernel_size - stride
|
225 |
+
|
226 |
+
y = self.convtr(x)
|
227 |
+
|
228 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
229 |
+
# removed at the very end, when keeping only the right length for the output,
|
230 |
+
# as removing it here would require also passing the length at the matching layer
|
231 |
+
# in the encoder.
|
232 |
+
if self.causal:
|
233 |
+
# Trim the padding on the right according to the specified ratio
|
234 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
235 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
236 |
+
padding_left = padding_total - padding_right
|
237 |
+
y = unpad1d(y, (padding_left, padding_right))
|
238 |
+
else:
|
239 |
+
# Asymmetric padding required for odd strides
|
240 |
+
padding_right = padding_total // 2
|
241 |
+
padding_left = padding_total - padding_right
|
242 |
+
y = unpad1d(y, (padding_left, padding_right))
|
243 |
+
return y
|
audiocraft/modules/diffusion_schedule.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from collections import namedtuple
|
12 |
+
import random
|
13 |
+
import typing as tp
|
14 |
+
import julius
|
15 |
+
import torch
|
16 |
+
|
17 |
+
TrainingItem = namedtuple("TrainingItem", "noisy noise step")
|
18 |
+
|
19 |
+
|
20 |
+
def betas_from_alpha_bar(alpha_bar):
|
21 |
+
alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
|
22 |
+
return 1 - alphas
|
23 |
+
|
24 |
+
|
25 |
+
class SampleProcessor(torch.nn.Module):
|
26 |
+
def project_sample(self, x: torch.Tensor):
|
27 |
+
"""Project the original sample to the 'space' where the diffusion will happen."""
|
28 |
+
return x
|
29 |
+
|
30 |
+
def return_sample(self, z: torch.Tensor):
|
31 |
+
"""Project back from diffusion space to the actual sample space."""
|
32 |
+
return z
|
33 |
+
|
34 |
+
|
35 |
+
class MultiBandProcessor(SampleProcessor):
|
36 |
+
"""
|
37 |
+
MultiBand sample processor. The input audio is splitted across
|
38 |
+
frequency bands evenly distributed in mel-scale.
|
39 |
+
|
40 |
+
Each band will be rescaled to match the power distribution
|
41 |
+
of Gaussian noise in that band, using online metrics
|
42 |
+
computed on the first few samples.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
n_bands (int): Number of mel-bands to split the signal over.
|
46 |
+
sample_rate (int): Sample rate of the audio.
|
47 |
+
num_samples (int): Number of samples to use to fit the rescaling
|
48 |
+
for each band. The processor won't be stable
|
49 |
+
until it has seen that many samples.
|
50 |
+
power_std (float or list/tensor): The rescaling factor computed to match the
|
51 |
+
power of Gaussian noise in each band is taken to
|
52 |
+
that power, i.e. `1.` means full correction of the energy
|
53 |
+
in each band, and values less than `1` means only partial
|
54 |
+
correction. Can be used to balance the relative importance
|
55 |
+
of low vs. high freq in typical audio signals.
|
56 |
+
"""
|
57 |
+
def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
|
58 |
+
num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
|
59 |
+
super().__init__()
|
60 |
+
self.n_bands = n_bands
|
61 |
+
self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
|
62 |
+
self.num_samples = num_samples
|
63 |
+
self.power_std = power_std
|
64 |
+
if isinstance(power_std, list):
|
65 |
+
assert len(power_std) == n_bands
|
66 |
+
power_std = torch.tensor(power_std)
|
67 |
+
self.register_buffer('counts', torch.zeros(1))
|
68 |
+
self.register_buffer('sum_x', torch.zeros(n_bands))
|
69 |
+
self.register_buffer('sum_x2', torch.zeros(n_bands))
|
70 |
+
self.register_buffer('sum_target_x2', torch.zeros(n_bands))
|
71 |
+
self.counts: torch.Tensor
|
72 |
+
self.sum_x: torch.Tensor
|
73 |
+
self.sum_x2: torch.Tensor
|
74 |
+
self.sum_target_x2: torch.Tensor
|
75 |
+
|
76 |
+
@property
|
77 |
+
def mean(self):
|
78 |
+
mean = self.sum_x / self.counts
|
79 |
+
return mean
|
80 |
+
|
81 |
+
@property
|
82 |
+
def std(self):
|
83 |
+
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
84 |
+
return std
|
85 |
+
|
86 |
+
@property
|
87 |
+
def target_std(self):
|
88 |
+
target_std = self.sum_target_x2 / self.counts
|
89 |
+
return target_std
|
90 |
+
|
91 |
+
def project_sample(self, x: torch.Tensor):
|
92 |
+
assert x.dim() == 3
|
93 |
+
bands = self.split_bands(x)
|
94 |
+
if self.counts.item() < self.num_samples:
|
95 |
+
ref_bands = self.split_bands(torch.randn_like(x))
|
96 |
+
self.counts += len(x)
|
97 |
+
self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
|
98 |
+
self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
99 |
+
self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
100 |
+
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
101 |
+
bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
|
102 |
+
return bands.sum(dim=0)
|
103 |
+
|
104 |
+
def return_sample(self, x: torch.Tensor):
|
105 |
+
assert x.dim() == 3
|
106 |
+
bands = self.split_bands(x)
|
107 |
+
rescale = (self.std / self.target_std) ** self.power_std
|
108 |
+
bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
|
109 |
+
return bands.sum(dim=0)
|
110 |
+
|
111 |
+
|
112 |
+
class NoiseSchedule:
|
113 |
+
"""Noise schedule for diffusion.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
beta_t0 (float): Variance of the first diffusion step.
|
117 |
+
beta_t1 (float): Variance of the last diffusion step.
|
118 |
+
beta_exp (float): Power schedule exponent
|
119 |
+
num_steps (int): Number of diffusion step.
|
120 |
+
variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
|
121 |
+
clip (float): clipping value for the denoising steps
|
122 |
+
rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
|
123 |
+
repartition (str): shape of the schedule only power schedule is supported
|
124 |
+
sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
|
125 |
+
noise_scale (float): Scaling factor for the noise
|
126 |
+
"""
|
127 |
+
def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
|
128 |
+
clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
|
129 |
+
repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
|
130 |
+
sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
|
131 |
+
|
132 |
+
self.beta_t0 = beta_t0
|
133 |
+
self.beta_t1 = beta_t1
|
134 |
+
self.variance = variance
|
135 |
+
self.num_steps = num_steps
|
136 |
+
self.clip = clip
|
137 |
+
self.sample_processor = sample_processor
|
138 |
+
self.rescale = rescale
|
139 |
+
self.n_bands = n_bands
|
140 |
+
self.noise_scale = noise_scale
|
141 |
+
assert n_bands is None
|
142 |
+
if repartition == "power":
|
143 |
+
self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
|
144 |
+
device=device, dtype=torch.float) ** beta_exp
|
145 |
+
else:
|
146 |
+
raise RuntimeError('Not implemented')
|
147 |
+
self.rng = random.Random(1234)
|
148 |
+
|
149 |
+
def get_beta(self, step: tp.Union[int, torch.Tensor]):
|
150 |
+
if self.n_bands is None:
|
151 |
+
return self.betas[step]
|
152 |
+
else:
|
153 |
+
return self.betas[:, step] # [n_bands, len(step)]
|
154 |
+
|
155 |
+
def get_initial_noise(self, x: torch.Tensor):
|
156 |
+
if self.n_bands is None:
|
157 |
+
return torch.randn_like(x)
|
158 |
+
return torch.randn((x.size(0), self.n_bands, x.size(2)))
|
159 |
+
|
160 |
+
def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
|
161 |
+
"""Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
|
162 |
+
if step is None:
|
163 |
+
return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
|
164 |
+
if type(step) is int:
|
165 |
+
return (1 - self.betas[:step + 1]).prod()
|
166 |
+
else:
|
167 |
+
return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
|
168 |
+
|
169 |
+
def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
|
170 |
+
"""Create a noisy data item for diffusion model training:
|
171 |
+
|
172 |
+
Args:
|
173 |
+
x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
|
174 |
+
tensor_step (bool): If tensor_step = false, only one step t is sample,
|
175 |
+
the whole batch is diffused to the same step and t is int.
|
176 |
+
If tensor_step = true, t is a tensor of size (x.size(0),)
|
177 |
+
every element of the batch is diffused to a independently sampled.
|
178 |
+
"""
|
179 |
+
step: tp.Union[int, torch.Tensor]
|
180 |
+
if tensor_step:
|
181 |
+
bs = x.size(0)
|
182 |
+
step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
|
183 |
+
else:
|
184 |
+
step = self.rng.randrange(self.num_steps)
|
185 |
+
alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
|
186 |
+
|
187 |
+
x = self.sample_processor.project_sample(x)
|
188 |
+
noise = torch.randn_like(x)
|
189 |
+
noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
|
190 |
+
return TrainingItem(noisy, noise, step)
|
191 |
+
|
192 |
+
def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
|
193 |
+
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
194 |
+
"""Full ddpm reverse process.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
model (nn.Module): Diffusion model.
|
198 |
+
initial (tensor): Initial Noise.
|
199 |
+
condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
|
200 |
+
return_list (bool): Whether to return the whole process or only the sampled point.
|
201 |
+
"""
|
202 |
+
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
203 |
+
current = initial
|
204 |
+
iterates = [initial]
|
205 |
+
for step in range(self.num_steps)[::-1]:
|
206 |
+
with torch.no_grad():
|
207 |
+
estimate = model(current, step, condition=condition).sample
|
208 |
+
alpha = 1 - self.betas[step]
|
209 |
+
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
210 |
+
previous_alpha_bar = self.get_alpha_bar(step=step - 1)
|
211 |
+
if step == 0:
|
212 |
+
sigma2 = 0
|
213 |
+
elif self.variance == 'beta':
|
214 |
+
sigma2 = 1 - alpha
|
215 |
+
elif self.variance == 'beta_tilde':
|
216 |
+
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
217 |
+
elif self.variance == 'none':
|
218 |
+
sigma2 = 0
|
219 |
+
else:
|
220 |
+
raise ValueError(f'Invalid variance type {self.variance}')
|
221 |
+
|
222 |
+
if sigma2 > 0:
|
223 |
+
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
224 |
+
if self.clip:
|
225 |
+
previous = previous.clamp(-self.clip, self.clip)
|
226 |
+
current = previous
|
227 |
+
alpha_bar = previous_alpha_bar
|
228 |
+
if step == 0:
|
229 |
+
previous *= self.rescale
|
230 |
+
if return_list:
|
231 |
+
iterates.append(previous.cpu())
|
232 |
+
|
233 |
+
if return_list:
|
234 |
+
return iterates
|
235 |
+
else:
|
236 |
+
return self.sample_processor.return_sample(previous)
|
237 |
+
|
238 |
+
def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
|
239 |
+
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
240 |
+
"""Reverse process that only goes through Markov chain states in step_list."""
|
241 |
+
if step_list is None:
|
242 |
+
step_list = list(range(1000))[::-50] + [0]
|
243 |
+
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
244 |
+
alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
|
245 |
+
betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
|
246 |
+
current = initial * self.noise_scale
|
247 |
+
iterates = [current]
|
248 |
+
for idx, step in enumerate(step_list[:-1]):
|
249 |
+
with torch.no_grad():
|
250 |
+
estimate = model(current, step, condition=condition).sample * self.noise_scale
|
251 |
+
alpha = 1 - betas_subsampled[-1 - idx]
|
252 |
+
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
253 |
+
previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
|
254 |
+
if step == step_list[-2]:
|
255 |
+
sigma2 = 0
|
256 |
+
previous_alpha_bar = torch.tensor(1.0)
|
257 |
+
else:
|
258 |
+
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
259 |
+
if sigma2 > 0:
|
260 |
+
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
261 |
+
if self.clip:
|
262 |
+
previous = previous.clamp(-self.clip, self.clip)
|
263 |
+
current = previous
|
264 |
+
alpha_bar = previous_alpha_bar
|
265 |
+
if step == 0:
|
266 |
+
previous *= self.rescale
|
267 |
+
if return_list:
|
268 |
+
iterates.append(previous.cpu())
|
269 |
+
if return_list:
|
270 |
+
return iterates
|
271 |
+
else:
|
272 |
+
return self.sample_processor.return_sample(previous)
|
audiocraft/modules/lstm.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class StreamableLSTM(nn.Module):
|
11 |
+
"""LSTM without worrying about the hidden state, nor the layout of the data.
|
12 |
+
Expects input as convolutional layout.
|
13 |
+
"""
|
14 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
15 |
+
super().__init__()
|
16 |
+
self.skip = skip
|
17 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
x = x.permute(2, 0, 1)
|
21 |
+
y, _ = self.lstm(x)
|
22 |
+
if self.skip:
|
23 |
+
y = y + x
|
24 |
+
y = y.permute(1, 2, 0)
|
25 |
+
return y
|