PyTorch
Zeyue7 commited on
Commit
0f7df35
·
1 Parent(s): 309599c

VidMuse-cvpr

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +92 -1
  2. audiocraft/__init__.py +26 -0
  3. audiocraft/adversarial/__init__.py +22 -0
  4. audiocraft/adversarial/discriminators/__init__.py +10 -0
  5. audiocraft/adversarial/discriminators/base.py +34 -0
  6. audiocraft/adversarial/discriminators/mpd.py +106 -0
  7. audiocraft/adversarial/discriminators/msd.py +126 -0
  8. audiocraft/adversarial/discriminators/msstftd.py +134 -0
  9. audiocraft/adversarial/losses.py +228 -0
  10. audiocraft/data/__init__.py +10 -0
  11. audiocraft/data/audio.py +231 -0
  12. audiocraft/data/audio_dataset.py +694 -0
  13. audiocraft/data/audio_utils.py +176 -0
  14. audiocraft/data/info_audio_dataset.py +111 -0
  15. audiocraft/data/music_dataset.py +307 -0
  16. audiocraft/data/sound_dataset.py +330 -0
  17. audiocraft/data/video.py +83 -0
  18. audiocraft/data/zip.py +76 -0
  19. audiocraft/environment.py +176 -0
  20. audiocraft/losses/__init__.py +21 -0
  21. audiocraft/losses/balancer.py +136 -0
  22. audiocraft/losses/sisnr.py +97 -0
  23. audiocraft/losses/specloss.py +149 -0
  24. audiocraft/losses/stftloss.py +207 -0
  25. audiocraft/metrics/__init__.py +14 -0
  26. audiocraft/metrics/chroma_cosinesim.py +72 -0
  27. audiocraft/metrics/clap_consistency.py +84 -0
  28. audiocraft/metrics/fad.py +329 -0
  29. audiocraft/metrics/kld.py +220 -0
  30. audiocraft/metrics/rvm.py +110 -0
  31. audiocraft/metrics/visqol.py +216 -0
  32. audiocraft/models/__init__.py +18 -0
  33. audiocraft/models/audiogen.py +267 -0
  34. audiocraft/models/builders.py +268 -0
  35. audiocraft/models/encodec.py +580 -0
  36. audiocraft/models/lm.py +685 -0
  37. audiocraft/models/lm_back.py +698 -0
  38. audiocraft/models/loaders.py +149 -0
  39. audiocraft/models/multibanddiffusion.py +196 -0
  40. audiocraft/models/transformer_module.py +177 -0
  41. audiocraft/models/unet.py +214 -0
  42. audiocraft/models/vidmuse.py +425 -0
  43. audiocraft/modules/__init__.py +22 -0
  44. audiocraft/modules/activations.py +96 -0
  45. audiocraft/modules/chroma.py +66 -0
  46. audiocraft/modules/codebooks_patterns.py +544 -0
  47. audiocraft/modules/conditioners.py +1357 -0
  48. audiocraft/modules/conv.py +243 -0
  49. audiocraft/modules/diffusion_schedule.py +272 -0
  50. audiocraft/modules/lstm.py +25 -0
README.md CHANGED
@@ -1,3 +1,94 @@
1
  ---
2
- license: cc-by-nc-4.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: cc-by-4.0
3
  ---
4
+
5
+ # VidMuse
6
+
7
+ ## VidMuse: A Simple Video-to-Music Generation Framework with Long-Short-Term Modeling
8
+
9
+ [TL;DR]: VidMuse is a framework for generating high-fidelity music aligned with video content, utilizing Long-Short-Term modeling, and has been accepted to CVPR 2025.
10
+
11
+ ### Links
12
+ - **[Paper](https://arxiv.org/pdf/2406.04321)**: Explore the research behind VidMuse.
13
+ - **[Project](https://vidmuse.github.io/)**: Visit the official project page for more information and updates.
14
+ - **[Dataset](https://huggingface.co/datasets/HKUSTAudio/VidMuse-Dataset)**: Download the dataset used in the paper.
15
+
16
+ ## Clone the repository
17
+ ```bash
18
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/HKUSTAudio/VidMuse
19
+ cd VidMuse
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ 1. First install the [`VidMuse` library](https://github.com/ZeyueT/VidMuse)
25
+ ```
26
+ conda create -n VidMuse python=3.9
27
+ conda activate VidMuse
28
+ pip install git+https://github.com/ZeyueT/VidMuse.git
29
+ ```
30
+
31
+ 2. Install ffmpeg:
32
+ Install ffmpeg:
33
+ ```bash
34
+ sudo apt-get install ffmpeg
35
+ # Or if you are using Anaconda or Miniconda
36
+ conda install "ffmpeg<5" -c conda-forge
37
+ ```
38
+
39
+
40
+ 3. Run the following Python code:
41
+
42
+
43
+ ```py
44
+ from video_processor import VideoProcessor, merge_video_audio
45
+ from audiocraft.models import VidMuse
46
+ import scipy
47
+
48
+ # Path to the video
49
+ video_path = 'sample.mp4'
50
+ # Initialize the video processor
51
+ processor = VideoProcessor()
52
+ # Process the video to obtain tensors and duration
53
+ local_video_tensor, global_video_tensor, duration = processor.process(video_path)
54
+
55
+ progress = True
56
+ USE_DIFFUSION = False
57
+
58
+ # Load the pre-trained VidMuse model
59
+ MODEL = VidMuse.get_pretrained('HKUSTAudio/VidMuse')
60
+ # Set generation parameters for the model based on video duration
61
+ MODEL.set_generation_params(duration=duration)
62
+
63
+ try:
64
+ # Generate outputs using the model
65
+ outputs = MODEL.generate([local_video_tensor, global_video_tensor], progress=progress, return_tokens=USE_DIFFUSION)
66
+ except RuntimeError as e:
67
+ print(e)
68
+
69
+ # Detach outputs from the computation graph and convert to CPU float tensor
70
+ outputs = outputs.detach().cpu().float()
71
+
72
+
73
+ sampling_rate = 32000
74
+ output_wav_path = "vidmuse_sample.wav"
75
+ # Write the output audio data to a WAV file
76
+ scipy.io.wavfile.write(output_wav_path, rate=sampling_rate, data=outputs[0, 0].numpy())
77
+
78
+ output_video_path = "vidmuse_sample.mp4"
79
+ # Merge the original video with the generated music
80
+ merge_video_audio(video_path, output_wav_path, output_video_path)
81
+ ```
82
+
83
+
84
+ ## Citation
85
+ If you find our work useful, please consider citing:
86
+
87
+ ```
88
+ @article{tian2024vidmuse,
89
+ title={Vidmuse: A simple video-to-music generation framework with long-short-term modeling},
90
+ author={Tian, Zeyue and Liu, Zhaoyang and Yuan, Ruibin and Pan, Jiahao and Liu, Qifeng and Tan, Xu and Chen, Qifeng and Xue, Wei and Guo, Yike},
91
+ journal={arXiv preprint arXiv:2406.04321},
92
+ year={2024}
93
+ }
94
+ ```
audiocraft/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ AudioCraft is a general framework for training audio generative models.
8
+ At the moment we provide the training code for:
9
+
10
+ - [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
11
+ text-to-music and melody+text autoregressive generative model.
12
+ For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
13
+ `audiocraft.models.musicgen.MusicGen`.
14
+ - [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
15
+ text-to-general-audio generative model.
16
+ - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
17
+ neural audio codec which provides an excellent tokenizer for autoregressive language models.
18
+ See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
19
+ - [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
20
+ improves the perceived quality and reduces the artifacts coming from adversarial decoders.
21
+ """
22
+
23
+ # flake8: noqa
24
+ from . import data, modules, models
25
+
26
+ __version__ = '1.2.0a1'
audiocraft/adversarial/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Adversarial losses and discriminator architectures."""
7
+
8
+ # flake8: noqa
9
+ from .discriminators import (
10
+ MultiPeriodDiscriminator,
11
+ MultiScaleDiscriminator,
12
+ MultiScaleSTFTDiscriminator
13
+ )
14
+ from .losses import (
15
+ AdversarialLoss,
16
+ AdvLossType,
17
+ get_adv_criterion,
18
+ get_fake_criterion,
19
+ get_real_criterion,
20
+ FeatLossType,
21
+ FeatureMatchingLoss
22
+ )
audiocraft/adversarial/discriminators/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .mpd import MultiPeriodDiscriminator
9
+ from .msd import MultiScaleDiscriminator
10
+ from .msstftd import MultiScaleSTFTDiscriminator
audiocraft/adversarial/discriminators/base.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ import typing as tp
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ FeatureMapType = tp.List[torch.Tensor]
15
+ LogitsType = torch.Tensor
16
+ MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
17
+
18
+
19
+ class MultiDiscriminator(ABC, nn.Module):
20
+ """Base implementation for discriminators composed of sub-discriminators acting at different scales.
21
+ """
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @abstractmethod
26
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
27
+ ...
28
+
29
+ @property
30
+ @abstractmethod
31
+ def num_discriminators(self) -> int:
32
+ """Number of discriminators.
33
+ """
34
+ ...
audiocraft/adversarial/discriminators/mpd.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from ...modules import NormConv2d
14
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
+
16
+
17
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ class PeriodDiscriminator(nn.Module):
22
+ """Period sub-discriminator.
23
+
24
+ Args:
25
+ period (int): Period between samples of audio.
26
+ in_channels (int): Number of input channels.
27
+ out_channels (int): Number of output channels.
28
+ n_layers (int): Number of convolutional layers.
29
+ kernel_sizes (list of int): Kernel sizes for convolutions.
30
+ stride (int): Stride for convolutions.
31
+ filters (int): Initial number of filters in convolutions.
32
+ filters_scale (int): Multiplier of number of filters as we increase depth.
33
+ max_filters (int): Maximum number of filters.
34
+ norm (str): Normalization method.
35
+ activation (str): Activation function.
36
+ activation_params (dict): Parameters to provide to the activation function.
37
+ """
38
+ def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
39
+ n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
40
+ filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
41
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
+ activation_params: dict = {'negative_slope': 0.2}):
43
+ super().__init__()
44
+ self.period = period
45
+ self.n_layers = n_layers
46
+ self.activation = getattr(torch.nn, activation)(**activation_params)
47
+ self.convs = nn.ModuleList()
48
+ in_chs = in_channels
49
+ for i in range(self.n_layers):
50
+ out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
51
+ eff_stride = 1 if i == self.n_layers - 1 else stride
52
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
53
+ padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
54
+ in_chs = out_chs
55
+ self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
56
+ padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
57
+
58
+ def forward(self, x: torch.Tensor):
59
+ fmap = []
60
+ # 1d to 2d
61
+ b, c, t = x.shape
62
+ if t % self.period != 0: # pad first
63
+ n_pad = self.period - (t % self.period)
64
+ x = F.pad(x, (0, n_pad), 'reflect')
65
+ t = t + n_pad
66
+ x = x.view(b, c, t // self.period, self.period)
67
+
68
+ for conv in self.convs:
69
+ x = conv(x)
70
+ x = self.activation(x)
71
+ fmap.append(x)
72
+ x = self.conv_post(x)
73
+ fmap.append(x)
74
+ # x = torch.flatten(x, 1, -1)
75
+
76
+ return x, fmap
77
+
78
+
79
+ class MultiPeriodDiscriminator(MultiDiscriminator):
80
+ """Multi-Period (MPD) Discriminator.
81
+
82
+ Args:
83
+ in_channels (int): Number of input channels.
84
+ out_channels (int): Number of output channels.
85
+ periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
86
+ **kwargs: Additional args for `PeriodDiscriminator`
87
+ """
88
+ def __init__(self, in_channels: int = 1, out_channels: int = 1,
89
+ periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
90
+ super().__init__()
91
+ self.discriminators = nn.ModuleList([
92
+ PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
93
+ ])
94
+
95
+ @property
96
+ def num_discriminators(self):
97
+ return len(self.discriminators)
98
+
99
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
100
+ logits = []
101
+ fmaps = []
102
+ for disc in self.discriminators:
103
+ logit, fmap = disc(x)
104
+ logits.append(logit)
105
+ fmaps.append(fmap)
106
+ return logits, fmaps
audiocraft/adversarial/discriminators/msd.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from ...modules import NormConv1d
14
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
+
16
+
17
+ class ScaleDiscriminator(nn.Module):
18
+ """Waveform sub-discriminator.
19
+
20
+ Args:
21
+ in_channels (int): Number of input channels.
22
+ out_channels (int): Number of output channels.
23
+ kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
24
+ filters (int): Number of initial filters for convolutions.
25
+ max_filters (int): Maximum number of filters.
26
+ downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
27
+ inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
28
+ groups (Sequence[int] or None): Groups for inner convolutions.
29
+ strides (Sequence[int] or None): Strides for inner convolutions.
30
+ paddings (Sequence[int] or None): Paddings for inner convolutions.
31
+ norm (str): Normalization method.
32
+ activation (str): Activation function.
33
+ activation_params (dict): Parameters to provide to the activation function.
34
+ pad (str): Padding for initial convolution.
35
+ pad_params (dict): Parameters to provide to the padding module.
36
+ """
37
+ def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
38
+ filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
39
+ inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
40
+ strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
41
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
+ activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
43
+ pad_params: dict = {}):
44
+ super().__init__()
45
+ assert len(kernel_sizes) == 2
46
+ assert kernel_sizes[0] % 2 == 1
47
+ assert kernel_sizes[1] % 2 == 1
48
+ assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
49
+ assert (groups is None or len(groups) == len(downsample_scales))
50
+ assert (strides is None or len(strides) == len(downsample_scales))
51
+ assert (paddings is None or len(paddings) == len(downsample_scales))
52
+ self.activation = getattr(torch.nn, activation)(**activation_params)
53
+ self.convs = nn.ModuleList()
54
+ self.convs.append(
55
+ nn.Sequential(
56
+ getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
57
+ NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
58
+ )
59
+ )
60
+
61
+ in_chs = filters
62
+ for i, downsample_scale in enumerate(downsample_scales):
63
+ out_chs = min(in_chs * downsample_scale, max_filters)
64
+ default_kernel_size = downsample_scale * 10 + 1
65
+ default_stride = downsample_scale
66
+ default_padding = (default_kernel_size - 1) // 2
67
+ default_groups = in_chs // 4
68
+ self.convs.append(
69
+ NormConv1d(in_chs, out_chs,
70
+ kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
71
+ stride=strides[i] if strides else default_stride,
72
+ groups=groups[i] if groups else default_groups,
73
+ padding=paddings[i] if paddings else default_padding,
74
+ norm=norm))
75
+ in_chs = out_chs
76
+
77
+ out_chs = min(in_chs * 2, max_filters)
78
+ self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
79
+ padding=(kernel_sizes[0] - 1) // 2, norm=norm))
80
+ self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
81
+ padding=(kernel_sizes[1] - 1) // 2, norm=norm)
82
+
83
+ def forward(self, x: torch.Tensor):
84
+ fmap = []
85
+ for layer in self.convs:
86
+ x = layer(x)
87
+ x = self.activation(x)
88
+ fmap.append(x)
89
+ x = self.conv_post(x)
90
+ fmap.append(x)
91
+ # x = torch.flatten(x, 1, -1)
92
+ return x, fmap
93
+
94
+
95
+ class MultiScaleDiscriminator(MultiDiscriminator):
96
+ """Multi-Scale (MSD) Discriminator,
97
+
98
+ Args:
99
+ in_channels (int): Number of input channels.
100
+ out_channels (int): Number of output channels.
101
+ downsample_factor (int): Downsampling factor between the different scales.
102
+ scale_norms (Sequence[str]): Normalization for each sub-discriminator.
103
+ **kwargs: Additional args for ScaleDiscriminator.
104
+ """
105
+ def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
106
+ scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
107
+ super().__init__()
108
+ self.discriminators = nn.ModuleList([
109
+ ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
110
+ ])
111
+ self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
112
+
113
+ @property
114
+ def num_discriminators(self):
115
+ return len(self.discriminators)
116
+
117
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
118
+ logits = []
119
+ fmaps = []
120
+ for i, disc in enumerate(self.discriminators):
121
+ if i != 0:
122
+ self.downsample(x)
123
+ logit, fmap = disc(x)
124
+ logits.append(logit)
125
+ fmaps.append(fmap)
126
+ return logits, fmaps
audiocraft/adversarial/discriminators/msstftd.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import torchaudio
10
+ import torch
11
+ from torch import nn
12
+ from einops import rearrange
13
+
14
+ from ...modules import NormConv2d
15
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
16
+
17
+
18
+ def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
19
+ return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
20
+
21
+
22
+ class DiscriminatorSTFT(nn.Module):
23
+ """STFT sub-discriminator.
24
+
25
+ Args:
26
+ filters (int): Number of filters in convolutions.
27
+ in_channels (int): Number of input channels.
28
+ out_channels (int): Number of output channels.
29
+ n_fft (int): Size of FFT for each scale.
30
+ hop_length (int): Length of hop between STFT windows for each scale.
31
+ kernel_size (tuple of int): Inner Conv2d kernel sizes.
32
+ stride (tuple of int): Inner Conv2d strides.
33
+ dilations (list of int): Inner Conv2d dilation on the time dimension.
34
+ win_length (int): Window size for each scale.
35
+ normalized (bool): Whether to normalize by magnitude after stft.
36
+ norm (str): Normalization method.
37
+ activation (str): Activation function.
38
+ activation_params (dict): Parameters to provide to the activation function.
39
+ growth (int): Growth factor for the filters.
40
+ """
41
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
42
+ n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
43
+ filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
44
+ stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
45
+ activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
46
+ super().__init__()
47
+ assert len(kernel_size) == 2
48
+ assert len(stride) == 2
49
+ self.filters = filters
50
+ self.in_channels = in_channels
51
+ self.out_channels = out_channels
52
+ self.n_fft = n_fft
53
+ self.hop_length = hop_length
54
+ self.win_length = win_length
55
+ self.normalized = normalized
56
+ self.activation = getattr(torch.nn, activation)(**activation_params)
57
+ self.spec_transform = torchaudio.transforms.Spectrogram(
58
+ n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
59
+ normalized=self.normalized, center=False, pad_mode=None, power=None)
60
+ spec_channels = 2 * self.in_channels
61
+ self.convs = nn.ModuleList()
62
+ self.convs.append(
63
+ NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
64
+ )
65
+ in_chs = min(filters_scale * self.filters, max_filters)
66
+ for i, dilation in enumerate(dilations):
67
+ out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
68
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
69
+ dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
70
+ norm=norm))
71
+ in_chs = out_chs
72
+ out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
73
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
74
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
75
+ norm=norm))
76
+ self.conv_post = NormConv2d(out_chs, self.out_channels,
77
+ kernel_size=(kernel_size[0], kernel_size[0]),
78
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
79
+ norm=norm)
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ fmap = []
83
+ z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
84
+ z = torch.cat([z.real, z.imag], dim=1)
85
+ z = rearrange(z, 'b c w t -> b c t w')
86
+ for i, layer in enumerate(self.convs):
87
+ z = layer(z)
88
+ z = self.activation(z)
89
+ fmap.append(z)
90
+ z = self.conv_post(z)
91
+ return z, fmap
92
+
93
+
94
+ class MultiScaleSTFTDiscriminator(MultiDiscriminator):
95
+ """Multi-Scale STFT (MS-STFT) discriminator.
96
+
97
+ Args:
98
+ filters (int): Number of filters in convolutions.
99
+ in_channels (int): Number of input channels.
100
+ out_channels (int): Number of output channels.
101
+ sep_channels (bool): Separate channels to distinct samples for stereo support.
102
+ n_ffts (Sequence[int]): Size of FFT for each scale.
103
+ hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
104
+ win_lengths (Sequence[int]): Window size for each scale.
105
+ **kwargs: Additional args for STFTDiscriminator.
106
+ """
107
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
108
+ n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
109
+ win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
110
+ super().__init__()
111
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
112
+ self.sep_channels = sep_channels
113
+ self.discriminators = nn.ModuleList([
114
+ DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
115
+ n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
116
+ for i in range(len(n_ffts))
117
+ ])
118
+
119
+ @property
120
+ def num_discriminators(self):
121
+ return len(self.discriminators)
122
+
123
+ def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
124
+ B, C, T = x.shape
125
+ return x.view(-1, 1, T)
126
+
127
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
128
+ logits = []
129
+ fmaps = []
130
+ for disc in self.discriminators:
131
+ logit, fmap = disc(x)
132
+ logits.append(logit)
133
+ fmaps.append(fmap)
134
+ return logits, fmaps
audiocraft/adversarial/losses.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility module to handle adversarial losses without requiring to mess up the main training loop.
9
+ """
10
+
11
+ import typing as tp
12
+
13
+ import flashy
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
20
+
21
+
22
+ AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
23
+ FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
24
+
25
+
26
+ class AdversarialLoss(nn.Module):
27
+ """Adversary training wrapper.
28
+
29
+ Args:
30
+ adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
31
+ We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
32
+ where the first item is a list of logits and the second item is a list of feature maps.
33
+ optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
34
+ loss (AdvLossType): Loss function for generator training.
35
+ loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
36
+ loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
37
+ loss_feat (FeatLossType): Feature matching loss function for generator training.
38
+ normalize (bool): Whether to normalize by number of sub-discriminators.
39
+
40
+ Example of usage:
41
+ adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
42
+ for real in loader:
43
+ noise = torch.randn(...)
44
+ fake = model(noise)
45
+ adv_loss.train_adv(fake, real)
46
+ loss, _ = adv_loss(fake, real)
47
+ loss.backward()
48
+ """
49
+ def __init__(self,
50
+ adversary: nn.Module,
51
+ optimizer: torch.optim.Optimizer,
52
+ loss: AdvLossType,
53
+ loss_real: AdvLossType,
54
+ loss_fake: AdvLossType,
55
+ loss_feat: tp.Optional[FeatLossType] = None,
56
+ normalize: bool = True):
57
+ super().__init__()
58
+ self.adversary: nn.Module = adversary
59
+ flashy.distrib.broadcast_model(self.adversary)
60
+ self.optimizer = optimizer
61
+ self.loss = loss
62
+ self.loss_real = loss_real
63
+ self.loss_fake = loss_fake
64
+ self.loss_feat = loss_feat
65
+ self.normalize = normalize
66
+
67
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
68
+ # Add the optimizer state dict inside our own.
69
+ super()._save_to_state_dict(destination, prefix, keep_vars)
70
+ destination[prefix + 'optimizer'] = self.optimizer.state_dict()
71
+ return destination
72
+
73
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
74
+ # Load optimizer state.
75
+ self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
76
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
77
+
78
+ def get_adversary_pred(self, x):
79
+ """Run adversary model, validating expected output format."""
80
+ logits, fmaps = self.adversary(x)
81
+ assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
82
+ f'Expecting a list of tensors as logits but {type(logits)} found.'
83
+ assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
84
+ for fmap in fmaps:
85
+ assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
86
+ f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
87
+ return logits, fmaps
88
+
89
+ def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
90
+ """Train the adversary with the given fake and real example.
91
+
92
+ We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
93
+ The first item being the logits and second item being a list of feature maps for each sub-discriminator.
94
+
95
+ This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
96
+ and call the optimizer.
97
+ """
98
+ loss = torch.tensor(0., device=fake.device)
99
+ all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
100
+ all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
101
+ n_sub_adversaries = len(all_logits_fake_is_fake)
102
+ for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
103
+ loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
104
+
105
+ if self.normalize:
106
+ loss /= n_sub_adversaries
107
+
108
+ self.optimizer.zero_grad()
109
+ with flashy.distrib.eager_sync_model(self.adversary):
110
+ loss.backward()
111
+ self.optimizer.step()
112
+
113
+ return loss
114
+
115
+ def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
116
+ """Return the loss for the generator, i.e. trying to fool the adversary,
117
+ and feature matching loss if provided.
118
+ """
119
+ adv = torch.tensor(0., device=fake.device)
120
+ feat = torch.tensor(0., device=fake.device)
121
+ with flashy.utils.readonly(self.adversary):
122
+ all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
123
+ all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
124
+ n_sub_adversaries = len(all_logits_fake_is_fake)
125
+ for logit_fake_is_fake in all_logits_fake_is_fake:
126
+ adv += self.loss(logit_fake_is_fake)
127
+ if self.loss_feat:
128
+ for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
129
+ feat += self.loss_feat(fmap_fake, fmap_real)
130
+
131
+ if self.normalize:
132
+ adv /= n_sub_adversaries
133
+ feat /= n_sub_adversaries
134
+
135
+ return adv, feat
136
+
137
+
138
+ def get_adv_criterion(loss_type: str) -> tp.Callable:
139
+ assert loss_type in ADVERSARIAL_LOSSES
140
+ if loss_type == 'mse':
141
+ return mse_loss
142
+ elif loss_type == 'hinge':
143
+ return hinge_loss
144
+ elif loss_type == 'hinge2':
145
+ return hinge2_loss
146
+ raise ValueError('Unsupported loss')
147
+
148
+
149
+ def get_fake_criterion(loss_type: str) -> tp.Callable:
150
+ assert loss_type in ADVERSARIAL_LOSSES
151
+ if loss_type == 'mse':
152
+ return mse_fake_loss
153
+ elif loss_type in ['hinge', 'hinge2']:
154
+ return hinge_fake_loss
155
+ raise ValueError('Unsupported loss')
156
+
157
+
158
+ def get_real_criterion(loss_type: str) -> tp.Callable:
159
+ assert loss_type in ADVERSARIAL_LOSSES
160
+ if loss_type == 'mse':
161
+ return mse_real_loss
162
+ elif loss_type in ['hinge', 'hinge2']:
163
+ return hinge_real_loss
164
+ raise ValueError('Unsupported loss')
165
+
166
+
167
+ def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
168
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
169
+
170
+
171
+ def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
172
+ return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
173
+
174
+
175
+ def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
176
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
177
+
178
+
179
+ def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
180
+ return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
181
+
182
+
183
+ def mse_loss(x: torch.Tensor) -> torch.Tensor:
184
+ if x.numel() == 0:
185
+ return torch.tensor([0.0], device=x.device)
186
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
187
+
188
+
189
+ def hinge_loss(x: torch.Tensor) -> torch.Tensor:
190
+ if x.numel() == 0:
191
+ return torch.tensor([0.0], device=x.device)
192
+ return -x.mean()
193
+
194
+
195
+ def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
196
+ if x.numel() == 0:
197
+ return torch.tensor([0.0])
198
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
199
+
200
+
201
+ class FeatureMatchingLoss(nn.Module):
202
+ """Feature matching loss for adversarial training.
203
+
204
+ Args:
205
+ loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
206
+ normalize (bool): Whether to normalize the loss.
207
+ by number of feature maps.
208
+ """
209
+ def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
210
+ super().__init__()
211
+ self.loss = loss
212
+ self.normalize = normalize
213
+
214
+ def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
215
+ assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
216
+ feat_loss = torch.tensor(0., device=fmap_fake[0].device)
217
+ feat_scale = torch.tensor(0., device=fmap_fake[0].device)
218
+ n_fmaps = 0
219
+ for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
220
+ assert feat_fake.shape == feat_real.shape
221
+ n_fmaps += 1
222
+ feat_loss += self.loss(feat_fake, feat_real)
223
+ feat_scale += torch.mean(torch.abs(feat_real))
224
+
225
+ if self.normalize:
226
+ feat_loss /= n_fmaps
227
+
228
+ return feat_loss
audiocraft/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Audio loading and writing support. Datasets for raw audio
7
+ or also including some metadata."""
8
+
9
+ # flake8: noqa
10
+ from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
audiocraft/data/audio.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Audio IO methods are defined in this module (info, read, write),
9
+ We rely on av library for faster read when possible, otherwise on torchaudio.
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ import logging
15
+ import typing as tp
16
+
17
+ import numpy as np
18
+ import soundfile
19
+ import torch
20
+ from torch.nn import functional as F
21
+
22
+ import av
23
+ import subprocess as sp
24
+
25
+ from .audio_utils import f32_pcm, normalize_audio
26
+
27
+ _av_initialized = False
28
+
29
+
30
+ def _init_av():
31
+ global _av_initialized
32
+ if _av_initialized:
33
+ return
34
+ logger = logging.getLogger('libav.mp3')
35
+ logger.setLevel(logging.ERROR)
36
+ _av_initialized = True
37
+
38
+
39
+ @dataclass(frozen=True)
40
+ class AudioFileInfo:
41
+ sample_rate: int
42
+ duration: float
43
+ channels: int
44
+
45
+
46
+ def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
47
+ _init_av()
48
+ with av.open(str(filepath)) as af:
49
+ stream = af.streams.audio[0]
50
+ sample_rate = stream.codec_context.sample_rate
51
+ duration = float(stream.duration * stream.time_base)
52
+ channels = stream.channels
53
+ return AudioFileInfo(sample_rate, duration, channels)
54
+
55
+
56
+ def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
57
+ info = soundfile.info(filepath)
58
+ return AudioFileInfo(info.samplerate, info.duration, info.channels)
59
+
60
+
61
+ def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
62
+ # torchaudio no longer returns useful duration informations for some formats like mp3s.
63
+ filepath = Path(filepath)
64
+ if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
65
+ # ffmpeg has some weird issue with flac.
66
+ return _soundfile_info(filepath)
67
+ else:
68
+ return _av_info(filepath)
69
+
70
+
71
+ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
72
+ """FFMPEG-based audio file reading using PyAV bindings.
73
+ Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
74
+
75
+ Args:
76
+ filepath (str or Path): Path to audio file to read.
77
+ seek_time (float): Time at which to start reading in the file.
78
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
79
+ Returns:
80
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate
81
+ """
82
+ _init_av()
83
+ with av.open(str(filepath)) as af:
84
+ stream = af.streams.audio[0]
85
+ sr = stream.codec_context.sample_rate
86
+ num_frames = int(sr * duration) if duration >= 0 else -1
87
+ frame_offset = int(sr * seek_time)
88
+ # we need a small negative offset otherwise we get some edge artifact
89
+ # from the mp3 decoder.
90
+ af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
91
+ frames = []
92
+ length = 0
93
+ for frame in af.decode(streams=stream.index):
94
+ current_offset = int(frame.rate * frame.pts * frame.time_base)
95
+ strip = max(0, frame_offset - current_offset)
96
+ buf = torch.from_numpy(frame.to_ndarray())
97
+ if buf.shape[0] != stream.channels:
98
+ buf = buf.view(-1, stream.channels).t()
99
+ buf = buf[:, strip:]
100
+ frames.append(buf)
101
+ length += buf.shape[1]
102
+ if num_frames > 0 and length >= num_frames:
103
+ break
104
+ assert frames
105
+ # If the above assert fails, it is likely because we seeked past the end of file point,
106
+ # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
107
+ # This will need proper debugging, in due time.
108
+ wav = torch.cat(frames, dim=1)
109
+ assert wav.shape[0] == stream.channels
110
+ if num_frames > 0:
111
+ wav = wav[:, :num_frames]
112
+ return f32_pcm(wav), sr
113
+
114
+
115
+ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
116
+ duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
117
+ """Read audio by picking the most appropriate backend tool based on the audio format.
118
+
119
+ Args:
120
+ filepath (str or Path): Path to audio file to read.
121
+ seek_time (float): Time at which to start reading in the file.
122
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
123
+ pad (bool): Pad output audio if not reaching expected duration.
124
+ Returns:
125
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
126
+ """
127
+ fp = Path(filepath)
128
+ if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
129
+ # There is some bug with ffmpeg and reading flac
130
+ info = _soundfile_info(filepath)
131
+ frames = -1 if duration <= 0 else int(duration * info.sample_rate)
132
+ frame_offset = int(seek_time * info.sample_rate)
133
+ wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
134
+ assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
135
+ wav = torch.from_numpy(wav).t().contiguous()
136
+ if len(wav.shape) == 1:
137
+ wav = torch.unsqueeze(wav, 0)
138
+ else:
139
+ wav, sr = _av_read(filepath, seek_time, duration)
140
+
141
+ if pad and duration > 0:
142
+ expected_frames = int(duration * sr)
143
+ wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
144
+ return wav, sr
145
+
146
+
147
+ def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
148
+ # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
149
+ assert wav.dim() == 2, wav.shape
150
+ command = [
151
+ 'ffmpeg',
152
+ '-loglevel', 'error',
153
+ '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
154
+ '-i', '-'] + flags + [str(out_path)]
155
+ input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
156
+ sp.run(command, input=input_, check=True)
157
+
158
+
159
+ def audio_write(stem_name: tp.Union[str, Path],
160
+ wav: torch.Tensor, sample_rate: int,
161
+ format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
162
+ normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
163
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
164
+ loudness_compressor: bool = False,
165
+ log_clipping: bool = True, make_parent_dir: bool = True,
166
+ add_suffix: bool = True) -> Path:
167
+ """Convenience function for saving audio to disk. Returns the filename the audio was written to.
168
+
169
+ Args:
170
+ stem_name (str or Path): Filename without extension which will be added automatically.
171
+ wav (torch.Tensor): Audio data to save.
172
+ sample_rate (int): Sample rate of audio data.
173
+ format (str): Either "wav", "mp3", "ogg", or "flac".
174
+ mp3_rate (int): kbps when using mp3s.
175
+ ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
176
+ normalize (bool): if `True` (default), normalizes according to the prescribed
177
+ strategy (see after). If `False`, the strategy is only used in case clipping
178
+ would happen.
179
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
180
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
181
+ with extra headroom to avoid clipping. 'clip' just clips.
182
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
183
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
184
+ than the `peak_clip` one to avoid further clipping.
185
+ loudness_headroom_db (float): Target loudness for loudness normalization.
186
+ loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
187
+ when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
188
+ occurs despite strategy (only for 'rms').
189
+ make_parent_dir (bool): Make parent directory if it doesn't exist.
190
+ Returns:
191
+ Path: Path of the saved audio.
192
+ """
193
+ assert wav.dtype.is_floating_point, "wav is not floating point"
194
+ if wav.dim() == 1:
195
+ wav = wav[None]
196
+ elif wav.dim() > 2:
197
+ raise ValueError("Input wav should be at most 2 dimension.")
198
+ assert wav.isfinite().all()
199
+ wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
200
+ rms_headroom_db, loudness_headroom_db, loudness_compressor,
201
+ log_clipping=log_clipping, sample_rate=sample_rate,
202
+ stem_name=str(stem_name))
203
+ if format == 'mp3':
204
+ suffix = '.mp3'
205
+ flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
206
+ elif format == 'wav':
207
+ suffix = '.wav'
208
+ flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
209
+ elif format == 'ogg':
210
+ suffix = '.ogg'
211
+ flags = ['-f', 'ogg', '-c:a', 'libvorbis']
212
+ if ogg_rate is not None:
213
+ flags += ['-b:a', f'{ogg_rate}k']
214
+ elif format == 'flac':
215
+ suffix = '.flac'
216
+ flags = ['-f', 'flac']
217
+ else:
218
+ raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
219
+ if not add_suffix:
220
+ suffix = ''
221
+ path = Path(str(stem_name) + suffix)
222
+ if make_parent_dir:
223
+ path.parent.mkdir(exist_ok=True, parents=True)
224
+ try:
225
+ _piping_to_ffmpeg(path, wav, sample_rate, flags)
226
+ except Exception:
227
+ if path.exists():
228
+ # we do not want to leave half written files around.
229
+ path.unlink()
230
+ raise
231
+ return path
audiocraft/data/audio_dataset.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Audiocraft (https://github.com/facebookresearch/audiocraft)
2
+
3
+ """AudioDataset support. In order to handle a larger number of files
4
+ without having to scan again the folders, we precompute some metadata
5
+ (filename, sample rate, duration), and use that to efficiently sample audio segments.
6
+ """
7
+ import argparse
8
+ import copy
9
+ from concurrent.futures import ThreadPoolExecutor, Future
10
+ from dataclasses import dataclass, fields
11
+ from contextlib import ExitStack
12
+ from functools import lru_cache
13
+ import gzip
14
+ import json
15
+ import logging
16
+ import os
17
+ from pathlib import Path
18
+ import random
19
+ import sys
20
+ import typing as tp
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import numpy as np
25
+
26
+ from .audio import audio_read, audio_info
27
+ from .video import video_read_local, video_read_global
28
+ from .audio_utils import convert_audio
29
+ from .zip import PathInZip
30
+ import h5py
31
+ try:
32
+ import dora
33
+ except ImportError:
34
+ dora = None # type: ignore
35
+
36
+
37
+ @dataclass(order=True)
38
+ class BaseInfo:
39
+
40
+ @classmethod
41
+ def _dict2fields(cls, dictionary: dict):
42
+ return {
43
+ field.name: dictionary[field.name]
44
+ for field in fields(cls) if field.name in dictionary
45
+ }
46
+
47
+ @classmethod
48
+ def from_dict(cls, dictionary: dict):
49
+ _dictionary = cls._dict2fields(dictionary)
50
+ return cls(**_dictionary)
51
+
52
+ def to_dict(self):
53
+ return {
54
+ field.name: self.__getattribute__(field.name)
55
+ for field in fields(self)
56
+ }
57
+
58
+
59
+ @dataclass(order=True)
60
+ class AudioMeta(BaseInfo):
61
+ path: str
62
+ video_path: str
63
+ duration: float
64
+ sample_rate: int
65
+ amplitude: tp.Optional[float] = None
66
+ weight: tp.Optional[float] = None
67
+ # info_path is used to load additional information about the audio file that is stored in zip files.
68
+ info_path: tp.Optional[PathInZip] = None
69
+
70
+ @classmethod
71
+ def from_dict(cls, dictionary: dict):
72
+ # print(f'dictionary:{dictionary}')
73
+ # print(f'cls:{cls}')
74
+ base = cls._dict2fields(dictionary)
75
+ # print(f'base:{base}')
76
+ if 'info_path' in base and base['info_path'] is not None:
77
+ base['info_path'] = PathInZip(base['info_path'])
78
+ # print(f'base:{base}')
79
+ # exit()
80
+ return cls(**base)
81
+
82
+ def to_dict(self):
83
+ d = super().to_dict()
84
+ if d['info_path'] is not None:
85
+ d['info_path'] = str(d['info_path'])
86
+ return d
87
+
88
+
89
+ @dataclass(order=True)
90
+ class SegmentInfo(BaseInfo):
91
+ meta: AudioMeta
92
+ seek_time: float
93
+ # The following values are given once the audio is processed, e.g.
94
+ # at the target sample rate and target number of channels.
95
+ n_frames: int # actual number of frames without padding
96
+ total_frames: int # total number of frames, padding included
97
+ sample_rate: int # actual sample rate
98
+ channels: int # number of audio channels.
99
+
100
+
101
+ DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
102
+
103
+ logger = logging.getLogger(__name__)
104
+
105
+
106
+ def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
107
+ """AudioMeta from a path to an audio file.
108
+
109
+ Args:
110
+ file_path (str): Resolved path of valid audio file.
111
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
112
+ Returns:
113
+ AudioMeta: Audio file path and its metadata.
114
+ """
115
+ info = audio_info(file_path)
116
+ amplitude: tp.Optional[float] = None
117
+ if not minimal:
118
+ wav, sr = audio_read(file_path)
119
+ amplitude = wav.abs().max().item()
120
+ return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
121
+
122
+
123
+ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
124
+ """If Dora is available as a dependency, try to resolve potential relative paths
125
+ in list of AudioMeta. This method is expected to be used when loading meta from file.
126
+
127
+ Args:
128
+ m (AudioMeta): Audio meta to resolve.
129
+ fast (bool): If True, uses a really fast check for determining if a file
130
+ is already absolute or not. Only valid on Linux/Mac.
131
+ Returns:
132
+ AudioMeta: Audio meta with resolved path.
133
+ """
134
+ def is_abs(m):
135
+ if fast:
136
+ return str(m)[0] == '/'
137
+ else:
138
+ os.path.isabs(str(m))
139
+
140
+ if not dora:
141
+ return m
142
+
143
+ if not is_abs(m.path):
144
+ m.path = dora.git_save.to_absolute_path(m.path)
145
+ if m.info_path is not None and not is_abs(m.info_path.zip_path):
146
+ m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
147
+ return m
148
+
149
+
150
+ def find_audio_files(path: tp.Union[Path, str],
151
+ exts: tp.List[str] = DEFAULT_EXTS,
152
+ resolve: bool = True,
153
+ minimal: bool = True,
154
+ progress: bool = False,
155
+ workers: int = 0) -> tp.List[AudioMeta]:
156
+ """Build a list of AudioMeta from a given path,
157
+ collecting relevant audio files and fetching meta info.
158
+
159
+ Args:
160
+ path (str or Path): Path to folder containing audio files.
161
+ exts (list of str): List of file extensions to consider for audio files.
162
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
163
+ progress (bool): Whether to log progress on audio files collection.
164
+ workers (int): number of parallel workers, if 0, use only the current thread.
165
+ Returns:
166
+ list of AudioMeta: List of audio file path and its metadata.
167
+ """
168
+ audio_files = []
169
+ futures: tp.List[Future] = []
170
+ pool: tp.Optional[ThreadPoolExecutor] = None
171
+ with ExitStack() as stack:
172
+ if workers > 0:
173
+ pool = ThreadPoolExecutor(workers)
174
+ stack.enter_context(pool)
175
+
176
+ if progress:
177
+ print("Finding audio files...")
178
+ for root, folders, files in os.walk(path, followlinks=True):
179
+ for file in files:
180
+ full_path = Path(root) / file
181
+ if full_path.suffix.lower() in exts:
182
+ audio_files.append(full_path)
183
+ if pool is not None:
184
+ futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
185
+ if progress:
186
+ print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
187
+
188
+ if progress:
189
+ print("Getting audio metadata...")
190
+ meta: tp.List[AudioMeta] = []
191
+ for idx, file_path in enumerate(audio_files):
192
+ try:
193
+ if pool is None:
194
+ m = _get_audio_meta(str(file_path), minimal)
195
+ else:
196
+ m = futures[idx].result()
197
+ if resolve:
198
+ m = _resolve_audio_meta(m)
199
+ except Exception as err:
200
+ print("Error with", str(file_path), err, file=sys.stderr)
201
+ continue
202
+ meta.append(m)
203
+ if progress:
204
+ print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
205
+ meta.sort()
206
+ return meta
207
+
208
+
209
+ def load_audio_meta(path: tp.Union[str, Path],
210
+ resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
211
+ """Load list of AudioMeta from an optionally compressed json file.
212
+
213
+ Args:
214
+ path (str or Path): Path to JSON file.
215
+ resolve (bool): Whether to resolve the path from AudioMeta (default=True).
216
+ fast (bool): activates some tricks to make things faster.
217
+ Returns:
218
+ list of AudioMeta: List of audio file path and its total duration.
219
+ """
220
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
221
+ with open_fn(path, 'rb') as fp: # type: ignore
222
+ lines = fp.readlines()
223
+ meta = []
224
+ for line in lines:
225
+ d = json.loads(line)
226
+ # print(f'line:{d}')
227
+ m = AudioMeta.from_dict(d)
228
+ # print(f'm:{m}')
229
+
230
+ if resolve:
231
+ m = _resolve_audio_meta(m, fast=fast)
232
+ # print(f'm:{m}')
233
+ meta.append(m)
234
+ # exit()
235
+ # print(f'meta:{meta}')
236
+ # exit()
237
+ return meta
238
+
239
+
240
+ def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
241
+ """Save the audio metadata to the file pointer as json.
242
+
243
+ Args:
244
+ path (str or Path): Path to JSON file.
245
+ metadata (list of BaseAudioMeta): List of audio meta to save.
246
+ """
247
+ Path(path).parent.mkdir(exist_ok=True, parents=True)
248
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
249
+ with open_fn(path, 'wb') as fp: # type: ignore
250
+ for m in meta:
251
+ json_str = json.dumps(m.to_dict()) + '\n'
252
+ json_bytes = json_str.encode('utf-8')
253
+ fp.write(json_bytes)
254
+
255
+
256
+ class AudioDataset:
257
+ """Base audio dataset.
258
+
259
+ The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
260
+ and potentially additional information, by creating random segments from the list of audio
261
+ files referenced in the metadata and applying minimal data pre-processing such as resampling,
262
+ mixing of channels, padding, etc.
263
+
264
+ If no segment_duration value is provided, the AudioDataset will return the full wav for each
265
+ audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
266
+ duration, applying padding if required.
267
+
268
+ By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
269
+ allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
270
+ original audio meta.
271
+
272
+ Note that you can call `start_epoch(epoch)` in order to get
273
+ a deterministic "randomization" for `shuffle=True`.
274
+ For a given epoch and dataset index, this will always return the same extract.
275
+ You can get back some diversity by setting the `shuffle_seed` param.
276
+
277
+ Args:
278
+ meta (list of AudioMeta): List of audio files metadata.
279
+ segment_duration (float, optional): Optional segment duration of audio to load.
280
+ If not specified, the dataset will load the full audio segment from the file.
281
+ shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
282
+ sample_rate (int): Target sample rate of the loaded audio samples.
283
+ channels (int): Target number of channels of the loaded audio samples.
284
+ sample_on_duration (bool): Set to `True` to sample segments with probability
285
+ dependent on audio file duration. This is only used if `segment_duration` is provided.
286
+ sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
287
+ `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
288
+ of the file duration and file weight. This is only used if `segment_duration` is provided.
289
+ min_segment_ratio (float): Minimum segment ratio to use when the audio file
290
+ is shorter than the desired segment.
291
+ max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
292
+ return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
293
+ min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
294
+ audio shorter than this will be filtered out.
295
+ max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
296
+ audio longer than this will be filtered out.
297
+ shuffle_seed (int): can be used to further randomize
298
+ load_wav (bool): if False, skip loading the wav but returns a tensor of 0
299
+ with the expected segment_duration (which must be provided if load_wav is False).
300
+ permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
301
+ are False. Will ensure a permutation on files when going through the dataset.
302
+ In that case the epoch number must be provided in order for the model
303
+ to continue the permutation across epochs. In that case, it is assumed
304
+ that `num_samples = total_batch_size * num_updates_per_epoch`, with
305
+ `total_batch_size` the overall batch size accounting for all gpus.
306
+ """
307
+ def __init__(self,
308
+ meta: tp.List[AudioMeta],
309
+ segment_duration: tp.Optional[float] = None,
310
+ shuffle: bool = True,
311
+ num_samples: int = 10_000,
312
+ sample_rate: int = 48_000,
313
+ video_fps: int = 2,
314
+ video_overlap: int = 2,
315
+ if_add_gobal: bool = False,
316
+ global_mode: str = "average",
317
+ global_num_frames: int = 64,
318
+ global_feature_path: bool = False,
319
+ channels: int = 2,
320
+ pad: bool = True,
321
+ sample_on_duration: bool = True,
322
+ sample_on_weight: bool = True,
323
+ min_segment_ratio: float = 0.5,
324
+ max_read_retry: int = 10,
325
+ return_info: bool = False,
326
+ min_audio_duration: tp.Optional[float] = None,
327
+ max_audio_duration: tp.Optional[float] = None,
328
+ shuffle_seed: int = 0,
329
+ load_wav: bool = True,
330
+ permutation_on_files: bool = False,
331
+ ):
332
+ assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
333
+ assert segment_duration is None or segment_duration > 0
334
+ assert segment_duration is None or min_segment_ratio >= 0
335
+ self.segment_duration = segment_duration
336
+ self.min_segment_ratio = min_segment_ratio
337
+ self.max_audio_duration = max_audio_duration
338
+ self.min_audio_duration = min_audio_duration
339
+ if self.min_audio_duration is not None and self.max_audio_duration is not None:
340
+ assert self.min_audio_duration <= self.max_audio_duration
341
+ self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
342
+ assert len(self.meta) # Fail fast if all data has been filtered.
343
+ self.total_duration = sum(d.duration for d in self.meta)
344
+
345
+ if segment_duration is None:
346
+ num_samples = len(self.meta)
347
+ self.num_samples = num_samples
348
+ self.shuffle = shuffle
349
+ self.sample_rate = sample_rate
350
+ self.video_fps = video_fps
351
+ self.video_overlap = video_overlap
352
+ self.if_add_gobal = if_add_gobal
353
+ self.global_mode = global_mode
354
+ self.global_num_frames = global_num_frames
355
+ self.global_feature_path = global_feature_path
356
+ self.channels = channels
357
+ self.pad = pad
358
+ self.sample_on_weight = sample_on_weight
359
+ self.sample_on_duration = sample_on_duration
360
+ self.sampling_probabilities = self._get_sampling_probabilities()
361
+ self.max_read_retry = max_read_retry
362
+ self.return_info = return_info
363
+ self.shuffle_seed = shuffle_seed
364
+ self.current_epoch: tp.Optional[int] = None
365
+ self.load_wav = load_wav
366
+ if not load_wav:
367
+ assert segment_duration is not None
368
+ self.permutation_on_files = permutation_on_files
369
+ if permutation_on_files:
370
+ assert not self.sample_on_duration
371
+ assert not self.sample_on_weight
372
+ assert self.shuffle
373
+
374
+ def start_epoch(self, epoch: int):
375
+ self.current_epoch = epoch
376
+
377
+ def __len__(self):
378
+ return self.num_samples
379
+
380
+ def _get_sampling_probabilities(self, normalized: bool = True):
381
+ """Return the sampling probabilities for each file inside `self.meta`."""
382
+ scores: tp.List[float] = []
383
+ for file_meta in self.meta:
384
+ score = 1.
385
+ if self.sample_on_weight and file_meta.weight is not None:
386
+ score *= file_meta.weight
387
+ if self.sample_on_duration:
388
+ score *= file_meta.duration
389
+ scores.append(score)
390
+ probabilities = torch.tensor(scores)
391
+ if normalized:
392
+ probabilities /= probabilities.sum()
393
+ return probabilities
394
+
395
+ @staticmethod
396
+ @lru_cache(16)
397
+ def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
398
+ # Used to keep the most recent files permutation in memory implicitely.
399
+ # will work unless someone is using a lot of Datasets in parallel.
400
+ rng = torch.Generator()
401
+ rng.manual_seed(base_seed + permutation_index)
402
+ return torch.randperm(num_files, generator=rng)
403
+
404
+ def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
405
+ """Sample a given file from `self.meta`. Can be overridden in subclasses.
406
+ This is only called if `segment_duration` is not None.
407
+
408
+ You must use the provided random number generator `rng` for reproducibility.
409
+ You can further make use of the index accessed.
410
+ """
411
+ if self.permutation_on_files:
412
+ assert self.current_epoch is not None
413
+ total_index = self.current_epoch * len(self) + index
414
+ permutation_index = total_index // len(self.meta)
415
+ relative_index = total_index % len(self.meta)
416
+ permutation = AudioDataset._get_file_permutation(
417
+ len(self.meta), permutation_index, self.shuffle_seed)
418
+ file_index = permutation[relative_index]
419
+ return self.meta[file_index]
420
+
421
+ if not self.sample_on_weight and not self.sample_on_duration:
422
+ file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
423
+ else:
424
+ file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
425
+
426
+ return self.meta[file_index]
427
+
428
+ def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
429
+ # Override this method in subclass if needed.
430
+
431
+ if self.load_wav:
432
+ return audio_read(path, seek_time, duration, pad=False)
433
+ else:
434
+ assert self.segment_duration is not None
435
+ n_frames = int(self.sample_rate * self.segment_duration)
436
+ return torch.zeros(self.channels, n_frames), self.sample_rate
437
+
438
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
439
+ if self.segment_duration is None:
440
+ file_meta = self.meta[index]
441
+ out, sr = audio_read(file_meta.path)
442
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
443
+ n_frames = out.shape[-1]
444
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
445
+
446
+ if self.if_add_gobal:
447
+ # global_feature_path
448
+ if self.global_feature_path!='' and os.path.exists(self.global_feature_path):
449
+ ytb_id = file_meta.video_path.split('/')[-1][:11]
450
+ with h5py.File(f'{self.global_feature_path}/{ytb_id}.h5', 'r') as file:
451
+ data = file['global_video_array'][:]
452
+ global_video = np.array(data)
453
+ local_video = video_read_local(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration)
454
+ else:
455
+
456
+ local_video, global_video = video_read_global(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration, global_mode=self.global_mode, global_num_frames=self.global_num_frames)
457
+
458
+ else: # local only
459
+ video = video_read_local(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration)
460
+
461
+ segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
462
+ sample_rate=self.sample_rate, channels=out.shape[0])
463
+ else:
464
+ rng = torch.Generator()
465
+ if self.shuffle:
466
+ # We use index, plus extra randomness, either totally random if we don't know the epoch.
467
+ # otherwise we make use of the epoch number and optional shuffle_seed.
468
+ if self.current_epoch is None:
469
+ rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
470
+ else:
471
+ rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
472
+ else:
473
+ # We only use index
474
+ rng.manual_seed(index)
475
+
476
+ for retry in range(self.max_read_retry):
477
+ file_meta = self.sample_file(index, rng)
478
+ # We add some variance in the file position even if audio file is smaller than segment
479
+ # without ending up with empty segments
480
+
481
+ overlap = self.video_overlap
482
+ segment_duration_no_overlap = self.segment_duration - overlap
483
+ max_seek = max(0, file_meta.duration - segment_duration_no_overlap * self.min_segment_ratio)
484
+ max_value = max_seek
485
+ random_value = torch.rand(1, generator=rng).item() * max_value
486
+ base_seek_time = segment_duration_no_overlap * int(random_value // segment_duration_no_overlap)
487
+
488
+ seek_time = random.randint(base_seek_time, base_seek_time + overlap)
489
+ seek_time = min(max_seek, seek_time)
490
+
491
+ try:
492
+ out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
493
+
494
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
495
+ n_frames = out.shape[-1]
496
+ target_frames = int(self.segment_duration * self.sample_rate)
497
+
498
+ if self.if_add_gobal:
499
+ if self.global_feature_path!='' and os.path.exists(self.global_feature_path):
500
+ ytb_id = file_meta.video_path.split('/')[-1][:11]
501
+ with h5py.File(f'{self.global_feature_path}/{ytb_id}.h5', 'r') as file:
502
+ data = file['global_video_array'][:]
503
+ global_video = np.array(data)
504
+ indices = np.linspace(0, global_video.shape[1]-1, num=self.global_num_frames, endpoint=True).round().astype(int)
505
+ global_video = global_video[:,indices,:,:]
506
+ global_video = torch.from_numpy(global_video)
507
+ local_video = video_read_local(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration)
508
+ else:
509
+ local_video, global_video = video_read_global(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration, global_mode=self.global_mode, global_num_frames=self.global_num_frames)
510
+ assert global_video.shape[1]==self.global_num_frames
511
+
512
+ n_frames_video = local_video.shape[1]
513
+
514
+ else: # local only
515
+ video = video_read_local(file_meta.video_path, target_fps=self.video_fps, seek_time=seek_time, duration=self.segment_duration)
516
+ n_frames_video = video.shape[1]
517
+
518
+ target_frames_video = int(self.segment_duration * self.video_fps)
519
+
520
+ if self.pad:
521
+ out = F.pad(out, (0, target_frames - n_frames))
522
+
523
+ segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
524
+ sample_rate=self.sample_rate, channels=out.shape[0])
525
+ except Exception as exc:
526
+ logger.warning("Error opening file %s: %r", file_meta.path, exc)
527
+ if retry == self.max_read_retry - 1:
528
+ raise
529
+ else:
530
+ break
531
+ if self.if_add_gobal:
532
+ if self.return_info:
533
+ # Returns the wav and additional information on the wave segment
534
+ return out, [local_video, global_video], segment_info
535
+ else:
536
+ return out, [local_video, global_video]
537
+ else:
538
+ if self.return_info:
539
+ # Returns the wav and additional information on the wave segment
540
+ return out, [video], segment_info
541
+ else:
542
+ return out, [video]
543
+
544
+ def collater(self, samples):
545
+ """The collater function has to be provided to the dataloader
546
+ if AudioDataset has return_info=True in order to properly collate
547
+ the samples of a batch.
548
+ """
549
+ if self.segment_duration is None and len(samples) > 1:
550
+ assert self.pad, "Must allow padding when batching examples of different durations."
551
+
552
+ # In this case the audio reaching the collater is of variable length as segment_duration=None.
553
+ to_pad = self.segment_duration is None and self.pad
554
+ if to_pad:
555
+ max_len = max([wav.shape[-1] for wav, _ in samples])
556
+
557
+ def _pad_wav(wav):
558
+ return F.pad(wav, (0, max_len - wav.shape[-1]))
559
+
560
+ if self.return_info:
561
+ if len(samples) > 0:
562
+ assert len(samples[0]) == 3
563
+ assert isinstance(samples[0][0], torch.Tensor)
564
+ assert isinstance(samples[0][1], list)
565
+ assert isinstance(samples[0][2], SegmentInfo)
566
+
567
+
568
+ wavs = [wav for wav, _, _ in samples]
569
+ video_lists = [video_list for _, video_list, _ in samples]
570
+ segment_infos = [copy.deepcopy(info) for _, _, info in samples]
571
+ wav = torch.stack(wavs)
572
+
573
+ assert isinstance(video_lists[0],list)
574
+ if len(video_lists[0])==1:
575
+ videos=[video_list[0] for video_list in video_lists]
576
+ if to_pad:
577
+ # Each wav could be of a different duration as they are not segmented.
578
+ for i in range(len(samples)):
579
+ # Determines the total length of the signal with padding, so we update here as we pad.
580
+ segment_infos[i].total_frames = max_len
581
+ wavs[i] = _pad_wav(wavs[i])
582
+ video = torch.stack(videos)
583
+
584
+ return wav, [video], segment_infos
585
+
586
+ elif len(video_lists[0])==2:
587
+
588
+ local_videos=[video_list[0] for video_list in video_lists]
589
+ global_videos=[video_list[1] for video_list in video_lists]
590
+
591
+ if to_pad:
592
+ # Each wav could be of a different duration as they are not segmented.
593
+ for i in range(len(samples)):
594
+ # Determines the total length of the signal with padding, so we update here as we pad.
595
+ segment_infos[i].total_frames = max_len
596
+ wavs[i] = _pad_wav(wavs[i])
597
+ local_video = torch.stack(local_videos)
598
+ global_video = torch.stack(global_videos)
599
+
600
+ return wav, [local_video, global_video], segment_infos
601
+
602
+ else:
603
+ assert isinstance(samples[0], torch.Tensor)
604
+ if to_pad:
605
+ samples = [_pad_wav(s) for s in samples]
606
+ return torch.stack(samples)
607
+
608
+ def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
609
+ """Filters out audio files with audio durations that will not allow to sample examples from them."""
610
+ orig_len = len(meta)
611
+
612
+ # Filter data that is too short.
613
+ if self.min_audio_duration is not None:
614
+ meta = [m for m in meta if m.duration >= self.min_audio_duration]
615
+
616
+ # Filter data that is too long.
617
+ if self.max_audio_duration is not None:
618
+ meta = [m for m in meta if m.duration <= self.max_audio_duration]
619
+
620
+ filtered_len = len(meta)
621
+ removed_percentage = 100*(1-float(filtered_len)/orig_len)
622
+ msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
623
+ if removed_percentage < 10:
624
+ logging.debug(msg)
625
+ else:
626
+ logging.warning(msg)
627
+ return meta
628
+
629
+ @classmethod
630
+ def from_meta(cls, root: tp.Union[str, Path], **kwargs):
631
+ """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
632
+
633
+ Args:
634
+ root (str or Path): Path to root folder containing audio files.
635
+ kwargs: Additional keyword arguments for the AudioDataset.
636
+ """
637
+ root = Path(root)
638
+ if root.is_dir():
639
+ if (root / 'data.jsonl').exists():
640
+ root = root / 'data.jsonl'
641
+ elif (root / 'data.jsonl.gz').exists():
642
+ root = root / 'data.jsonl.gz'
643
+ else:
644
+ raise ValueError("Don't know where to read metadata from in the dir. "
645
+ "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
646
+ meta = load_audio_meta(root)
647
+
648
+ return cls(meta, **kwargs)
649
+
650
+ @classmethod
651
+ def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
652
+ exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
653
+ """Instantiate AudioDataset from a path containing (possibly nested) audio files.
654
+
655
+ Args:
656
+ root (str or Path): Path to root folder containing audio files.
657
+ minimal_meta (bool): Whether to only load minimal metadata or not.
658
+ exts (list of str): Extensions for audio files.
659
+ kwargs: Additional keyword arguments for the AudioDataset.
660
+ """
661
+ root = Path(root)
662
+ if root.is_file():
663
+ meta = load_audio_meta(root, resolve=True)
664
+ else:
665
+ meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
666
+ return cls(meta, **kwargs)
667
+
668
+
669
+ def main():
670
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
671
+ parser = argparse.ArgumentParser(
672
+ prog='audio_dataset',
673
+ description='Generate .jsonl files by scanning a folder.')
674
+ parser.add_argument('root', help='Root folder with all the audio files')
675
+ parser.add_argument('output_meta_file',
676
+ help='Output file to store the metadata, ')
677
+ parser.add_argument('--complete',
678
+ action='store_false', dest='minimal', default=True,
679
+ help='Retrieve all metadata, even the one that are expansive '
680
+ 'to compute (e.g. normalization).')
681
+ parser.add_argument('--resolve',
682
+ action='store_true', default=False,
683
+ help='Resolve the paths to be absolute and with no symlinks.')
684
+ parser.add_argument('--workers',
685
+ default=10, type=int,
686
+ help='Number of workers.')
687
+ args = parser.parse_args()
688
+ meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
689
+ resolve=args.resolve, minimal=args.minimal, workers=args.workers)
690
+ save_audio_meta(args.output_meta_file, meta)
691
+
692
+
693
+ if __name__ == '__main__':
694
+ main()
audiocraft/data/audio_utils.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Various utilities for audio convertion (pcm format, sample rate and channels),
7
+ and volume normalization."""
8
+ import sys
9
+ import typing as tp
10
+
11
+ import julius
12
+ import torch
13
+ import torchaudio
14
+
15
+
16
+ def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
17
+ """Convert audio to the given number of channels.
18
+
19
+ Args:
20
+ wav (torch.Tensor): Audio wave of shape [B, C, T].
21
+ channels (int): Expected number of channels as output.
22
+ Returns:
23
+ torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
24
+ """
25
+ *shape, src_channels, length = wav.shape
26
+ if src_channels == channels:
27
+ pass
28
+ elif channels == 1:
29
+ # Case 1:
30
+ # The caller asked 1-channel audio, and the stream has multiple
31
+ # channels, downmix all channels.
32
+ wav = wav.mean(dim=-2, keepdim=True)
33
+ elif src_channels == 1:
34
+ # Case 2:
35
+ # The caller asked for multiple channels, but the input file has
36
+ # a single channel, replicate the audio over all channels.
37
+ wav = wav.expand(*shape, channels, length)
38
+ elif src_channels >= channels:
39
+ # Case 3:
40
+ # The caller asked for multiple channels, and the input file has
41
+ # more channels than requested. In that case return the first channels.
42
+ wav = wav[..., :channels, :]
43
+ else:
44
+ # Case 4: What is a reasonable choice here?
45
+ raise ValueError('The audio file has less channels than requested but is not mono.')
46
+ return wav
47
+
48
+
49
+ def convert_audio(wav: torch.Tensor, from_rate: float,
50
+ to_rate: float, to_channels: int) -> torch.Tensor:
51
+ """Convert audio to new sample rate and number of audio channels."""
52
+ wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
53
+ wav = convert_audio_channels(wav, to_channels)
54
+ return wav
55
+
56
+
57
+ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
58
+ loudness_compressor: bool = False, energy_floor: float = 2e-3):
59
+ """Normalize an input signal to a user loudness in dB LKFS.
60
+ Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
61
+
62
+ Args:
63
+ wav (torch.Tensor): Input multichannel audio data.
64
+ sample_rate (int): Sample rate.
65
+ loudness_headroom_db (float): Target loudness of the output in dB LUFS.
66
+ loudness_compressor (bool): Uses tanh for soft clipping.
67
+ energy_floor (float): anything below that RMS level will not be rescaled.
68
+ Returns:
69
+ torch.Tensor: Loudness normalized output data.
70
+ """
71
+ energy = wav.pow(2).mean().sqrt().item()
72
+ if energy < energy_floor:
73
+ return wav
74
+ transform = torchaudio.transforms.Loudness(sample_rate)
75
+ input_loudness_db = transform(wav).item()
76
+ # calculate the gain needed to scale to the desired loudness level
77
+ delta_loudness = -loudness_headroom_db - input_loudness_db
78
+ gain = 10.0 ** (delta_loudness / 20.0)
79
+ output = gain * wav
80
+ if loudness_compressor:
81
+ output = torch.tanh(output)
82
+ assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
83
+ return output
84
+
85
+
86
+ def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
87
+ """Utility function to clip the audio with logging if specified."""
88
+ max_scale = wav.abs().max()
89
+ if log_clipping and max_scale > 1:
90
+ clamp_prob = (wav.abs() > 1).float().mean().item()
91
+ print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
92
+ clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
93
+ wav.clamp_(-1, 1)
94
+
95
+
96
+ def normalize_audio(wav: torch.Tensor, normalize: bool = True,
97
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
98
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
99
+ loudness_compressor: bool = False, log_clipping: bool = False,
100
+ sample_rate: tp.Optional[int] = None,
101
+ stem_name: tp.Optional[str] = None) -> torch.Tensor:
102
+ """Normalize the audio according to the prescribed strategy (see after).
103
+
104
+ Args:
105
+ wav (torch.Tensor): Audio data.
106
+ normalize (bool): if `True` (default), normalizes according to the prescribed
107
+ strategy (see after). If `False`, the strategy is only used in case clipping
108
+ would happen.
109
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
110
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
111
+ with extra headroom to avoid clipping. 'clip' just clips.
112
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
113
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
114
+ than the `peak_clip` one to avoid further clipping.
115
+ loudness_headroom_db (float): Target loudness for loudness normalization.
116
+ loudness_compressor (bool): If True, uses tanh based soft clipping.
117
+ log_clipping (bool): If True, basic logging on stderr when clipping still
118
+ occurs despite strategy (only for 'rms').
119
+ sample_rate (int): Sample rate for the audio data (required for loudness).
120
+ stem_name (str, optional): Stem name for clipping logging.
121
+ Returns:
122
+ torch.Tensor: Normalized audio.
123
+ """
124
+ scale_peak = 10 ** (-peak_clip_headroom_db / 20)
125
+ scale_rms = 10 ** (-rms_headroom_db / 20)
126
+ if strategy == 'peak':
127
+ rescaling = (scale_peak / wav.abs().max())
128
+ if normalize or rescaling < 1:
129
+ wav = wav * rescaling
130
+ elif strategy == 'clip':
131
+ wav = wav.clamp(-scale_peak, scale_peak)
132
+ elif strategy == 'rms':
133
+ mono = wav.mean(dim=0)
134
+ rescaling = scale_rms / mono.pow(2).mean().sqrt()
135
+ if normalize or rescaling < 1:
136
+ wav = wav * rescaling
137
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
138
+ elif strategy == 'loudness':
139
+ assert sample_rate is not None, "Loudness normalization requires sample rate."
140
+ wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
141
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
142
+ else:
143
+ assert wav.abs().max() < 1
144
+ assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
145
+ return wav
146
+
147
+
148
+ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
149
+ """Convert audio to float 32 bits PCM format.
150
+ """
151
+ if wav.dtype.is_floating_point:
152
+ return wav
153
+ elif wav.dtype == torch.int16:
154
+ return wav.float() / 2**15
155
+ elif wav.dtype == torch.int32:
156
+ return wav.float() / 2**31
157
+ raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
158
+
159
+
160
+ def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
161
+ """Convert audio to int 16 bits PCM format.
162
+
163
+ ..Warning:: There exist many formula for doing this conversion. None are perfect
164
+ due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
165
+ or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
166
+ it is possible that `i16_pcm(f32_pcm)) != Identity`.
167
+ """
168
+ if wav.dtype.is_floating_point:
169
+ assert wav.abs().max() <= 1
170
+ candidate = (wav * 2 ** 15).round()
171
+ if candidate.max() >= 2 ** 15: # clipping would occur
172
+ candidate = (wav * (2 ** 15 - 1)).round()
173
+ return candidate.short()
174
+ else:
175
+ assert wav.dtype == torch.int16
176
+ return wav
audiocraft/data/info_audio_dataset.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Base classes for the datasets that also provide non-audio metadata,
7
+ e.g. description, text transcription etc.
8
+ """
9
+ from dataclasses import dataclass
10
+ import logging
11
+ import math
12
+ import re
13
+ import typing as tp
14
+
15
+ import torch
16
+
17
+ from .audio_dataset import AudioDataset, AudioMeta
18
+ from ..environment import AudioCraftEnvironment
19
+ from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
26
+ """Monkey-patch meta to match cluster specificities."""
27
+ meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
28
+ if meta.info_path is not None:
29
+ meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
30
+ return meta
31
+
32
+
33
+ def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
34
+ """Monkey-patch all meta to match cluster specificities."""
35
+ return [_clusterify_meta(m) for m in meta]
36
+
37
+
38
+ @dataclass
39
+ class AudioInfo(SegmentWithAttributes):
40
+ """Dummy SegmentInfo with empty attributes.
41
+
42
+ The InfoAudioDataset is expected to return metadata that inherits
43
+ from SegmentWithAttributes class and can return conditioning attributes.
44
+
45
+ This basically guarantees all datasets will be compatible with current
46
+ solver that contain conditioners requiring this.
47
+ """
48
+ audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
49
+
50
+ def to_condition_attributes(self) -> ConditioningAttributes:
51
+ return ConditioningAttributes()
52
+
53
+
54
+ class InfoAudioDataset(AudioDataset):
55
+ """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
56
+
57
+ See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
58
+ """
59
+ def __init__(self, meta: tp.List[AudioMeta], **kwargs):
60
+ super().__init__(clusterify_all_meta(meta), **kwargs)
61
+
62
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
63
+ if not self.return_info:
64
+ wav = super().__getitem__(index)
65
+ assert isinstance(wav, torch.Tensor)
66
+ return wav
67
+ wav, video_list, meta = super().__getitem__(index)
68
+ assert isinstance(video_list, list)
69
+ return wav, video_list, AudioInfo(**meta.to_dict())
70
+
71
+
72
+ def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
73
+ """Preprocess a single keyword or possible a list of keywords."""
74
+ if isinstance(value, list):
75
+ return get_keyword_list(value)
76
+ else:
77
+ return get_keyword(value)
78
+
79
+
80
+ def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
81
+ """Preprocess a single keyword."""
82
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
83
+ return None
84
+ else:
85
+ return value.strip()
86
+
87
+
88
+ def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
89
+ """Preprocess a single keyword."""
90
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
91
+ return None
92
+ else:
93
+ return value.strip().lower()
94
+
95
+
96
+ def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
97
+ """Preprocess a list of keywords."""
98
+ if isinstance(values, str):
99
+ values = [v.strip() for v in re.split(r'[,\s]', values)]
100
+ elif isinstance(values, float) and math.isnan(values):
101
+ values = []
102
+ if not isinstance(values, list):
103
+ logger.debug(f"Unexpected keyword list {values}")
104
+ values = [str(values)]
105
+
106
+ kws = [get_keyword(v) for v in values]
107
+ kw_list = [k for k in kws if k is not None]
108
+ if len(kw_list) == 0:
109
+ return None
110
+ else:
111
+ return kw_list
audiocraft/data/music_dataset.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of music tracks with rich metadata.
7
+ """
8
+ from dataclasses import dataclass, field, fields, replace
9
+ import gzip
10
+ import json
11
+ import logging
12
+ from pathlib import Path
13
+ import random
14
+ import typing as tp
15
+
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ AudioInfo,
21
+ get_keyword_list,
22
+ get_keyword,
23
+ get_string
24
+ )
25
+ from ..modules.conditioners import (
26
+ ConditioningAttributes,
27
+ JointEmbedCondition,
28
+ WavCondition,
29
+ )
30
+ from ..utils.utils import warn_once
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class MusicInfo(AudioInfo):
38
+ """Segment info augmented with music metadata.
39
+ """
40
+ # music-specific metadata
41
+ title: tp.Optional[str] = None
42
+ artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
43
+ key: tp.Optional[str] = None
44
+ bpm: tp.Optional[float] = None
45
+ genre: tp.Optional[str] = None
46
+ moods: tp.Optional[list] = None
47
+ keywords: tp.Optional[list] = None
48
+ description: tp.Optional[str] = None
49
+ name: tp.Optional[str] = None
50
+ instrument: tp.Optional[str] = None
51
+ # original wav accompanying the metadata
52
+ self_wav: tp.Optional[WavCondition] = None
53
+ # dict mapping attributes names to tuple of wav, text and metadata
54
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
55
+
56
+ @property
57
+ def has_music_meta(self) -> bool:
58
+ return self.name is not None
59
+
60
+ def to_condition_attributes(self) -> ConditioningAttributes:
61
+ out = ConditioningAttributes()
62
+ for _field in fields(self):
63
+ key, value = _field.name, getattr(self, _field.name)
64
+ if key == 'self_wav':
65
+ out.wav[key] = value
66
+ elif key == 'joint_embed':
67
+ for embed_attribute, embed_cond in value.items():
68
+ out.joint_embed[embed_attribute] = embed_cond
69
+ else:
70
+ if isinstance(value, list):
71
+ value = ' '.join(value)
72
+ out.text[key] = value
73
+ return out
74
+
75
+ @staticmethod
76
+ def attribute_getter(attribute):
77
+ if attribute == 'bpm':
78
+ preprocess_func = get_bpm
79
+ elif attribute == 'key':
80
+ preprocess_func = get_musical_key
81
+ elif attribute in ['moods', 'keywords']:
82
+ preprocess_func = get_keyword_list
83
+ elif attribute in ['genre', 'name', 'instrument']:
84
+ preprocess_func = get_keyword
85
+ elif attribute in ['title', 'artist', 'description']:
86
+ preprocess_func = get_string
87
+ else:
88
+ preprocess_func = None
89
+ return preprocess_func
90
+
91
+ @classmethod
92
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
93
+ _dictionary: tp.Dict[str, tp.Any] = {}
94
+
95
+ # allow a subset of attributes to not be loaded from the dictionary
96
+ # these attributes may be populated later
97
+ post_init_attributes = ['self_wav', 'joint_embed']
98
+ optional_fields = ['keywords']
99
+
100
+ for _field in fields(cls):
101
+ if _field.name in post_init_attributes:
102
+ continue
103
+ elif _field.name not in dictionary:
104
+ if fields_required and _field.name not in optional_fields:
105
+ raise KeyError(f"Unexpected missing key: {_field.name}")
106
+ else:
107
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
108
+ value = dictionary[_field.name]
109
+ if preprocess_func:
110
+ value = preprocess_func(value)
111
+ _dictionary[_field.name] = value
112
+ return cls(**_dictionary)
113
+
114
+
115
+ def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
116
+ drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
117
+ """Augment MusicInfo description with additional metadata fields and potential dropout.
118
+ Additional textual attributes are added given probability 'merge_text_conditions_p' and
119
+ the original textual description is dropped from the augmented description given probability drop_desc_p.
120
+
121
+ Args:
122
+ music_info (MusicInfo): The music metadata to augment.
123
+ merge_text_p (float): Probability of merging additional metadata to the description.
124
+ If provided value is 0, then no merging is performed.
125
+ drop_desc_p (float): Probability of dropping the original description on text merge.
126
+ if provided value is 0, then no drop out is performed.
127
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
128
+ Returns:
129
+ MusicInfo: The MusicInfo with augmented textual description.
130
+ """
131
+ def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
132
+ valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
133
+ valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
134
+ keep_field = random.uniform(0, 1) < drop_other_p
135
+ return valid_field_name and valid_field_value and keep_field
136
+
137
+ def process_value(v: tp.Any) -> str:
138
+ if isinstance(v, (int, float, str)):
139
+ return str(v)
140
+ if isinstance(v, list):
141
+ return ", ".join(v)
142
+ else:
143
+ raise ValueError(f"Unknown type for text value! ({type(v), v})")
144
+
145
+ description = music_info.description
146
+
147
+ metadata_text = ""
148
+ if random.uniform(0, 1) < merge_text_p:
149
+ meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
150
+ for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
151
+ random.shuffle(meta_pairs)
152
+ metadata_text = ". ".join(meta_pairs)
153
+ description = description if not random.uniform(0, 1) < drop_desc_p else None
154
+ logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
155
+
156
+ if description is None:
157
+ description = metadata_text if len(metadata_text) > 1 else None
158
+ else:
159
+ description = ". ".join([description.rstrip('.'), metadata_text])
160
+ description = description.strip() if description else None
161
+
162
+ music_info = replace(music_info)
163
+ music_info.description = description
164
+ return music_info
165
+
166
+
167
+ class Paraphraser:
168
+ def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
169
+ self.paraphrase_p = paraphrase_p
170
+ open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
171
+ with open_fn(paraphrase_source, 'rb') as f: # type: ignore
172
+ self.paraphrase_source = json.loads(f.read())
173
+ logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
174
+
175
+ def sample_paraphrase(self, audio_path: str, description: str):
176
+ if random.random() >= self.paraphrase_p:
177
+ return description
178
+ info_path = Path(audio_path).with_suffix('.json')
179
+ if info_path not in self.paraphrase_source:
180
+ warn_once(logger, f"{info_path} not in paraphrase source!")
181
+ return description
182
+ new_desc = random.choice(self.paraphrase_source[info_path])
183
+ logger.debug(f"{description} -> {new_desc}")
184
+ return new_desc
185
+
186
+
187
+ class MusicDataset(InfoAudioDataset):
188
+ """Music dataset is an AudioDataset with music-related metadata.
189
+
190
+ Args:
191
+ info_fields_required (bool): Whether to enforce having required fields.
192
+ merge_text_p (float): Probability of merging additional metadata to the description.
193
+ drop_desc_p (float): Probability of dropping the original description on text merge.
194
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
195
+ joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
196
+ paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
197
+ paraphrases for the description. The json should be a dict with keys are the
198
+ original info path (e.g. track_path.json) and each value is a list of possible
199
+ paraphrased.
200
+ paraphrase_p (float): probability of taking a paraphrase.
201
+
202
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
203
+ """
204
+ def __init__(self, *args, info_fields_required: bool = True,
205
+ merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
206
+ joint_embed_attributes: tp.List[str] = [],
207
+ paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
208
+ **kwargs):
209
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
210
+ super().__init__(*args, **kwargs)
211
+ self.info_fields_required = info_fields_required
212
+ self.merge_text_p = merge_text_p
213
+ self.drop_desc_p = drop_desc_p
214
+ self.drop_other_p = drop_other_p
215
+ self.joint_embed_attributes = joint_embed_attributes
216
+ self.paraphraser = None
217
+ if paraphrase_source is not None:
218
+ self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
219
+
220
+ def __getitem__(self, index):
221
+ wav, video_list, info = super().__getitem__(index)
222
+ assert isinstance(video_list, list)
223
+
224
+
225
+ if len(video_list)==1:
226
+ video=video_list[0]
227
+ info_data = info.to_dict()
228
+ music_info_path = Path(info.meta.path).with_suffix('.json')
229
+ if Path(music_info_path).exists():
230
+ with open(music_info_path, 'r') as json_file:
231
+ music_data = json.load(json_file)
232
+ music_data.update(info_data)
233
+ music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
234
+ if self.paraphraser is not None:
235
+ music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
236
+ if self.merge_text_p:
237
+ music_info = augment_music_info_description(
238
+ music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
239
+ else:
240
+ music_info = MusicInfo.from_dict(info_data, fields_required=False)
241
+
242
+ music_info.self_wav = WavCondition(
243
+ wav=wav[None], length=torch.tensor([info.n_frames]),
244
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
245
+
246
+ for att in self.joint_embed_attributes:
247
+ att_value = getattr(music_info, att)
248
+ joint_embed_cond = JointEmbedCondition(
249
+ wav[None], [att_value], torch.tensor([info.n_frames]),
250
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
251
+ music_info.joint_embed[att] = joint_embed_cond
252
+
253
+ return wav, [video], music_info
254
+
255
+ elif len(video_list)==2:
256
+ local_video=video_list[0]
257
+ global_video=video_list[1]
258
+
259
+ info_data = info.to_dict()
260
+ music_info_path = Path(info.meta.path).with_suffix('.json')
261
+
262
+ if Path(music_info_path).exists():
263
+ with open(music_info_path, 'r') as json_file:
264
+ music_data = json.load(json_file)
265
+ music_data.update(info_data)
266
+ music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
267
+ if self.paraphraser is not None:
268
+ music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
269
+ if self.merge_text_p:
270
+ music_info = augment_music_info_description(
271
+ music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
272
+ else:
273
+ music_info = MusicInfo.from_dict(info_data, fields_required=False)
274
+
275
+ music_info.self_wav = WavCondition(
276
+ wav=wav[None], length=torch.tensor([info.n_frames]),
277
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
278
+
279
+ for att in self.joint_embed_attributes:
280
+ att_value = getattr(music_info, att)
281
+ joint_embed_cond = JointEmbedCondition(
282
+ wav[None], [att_value], torch.tensor([info.n_frames]),
283
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
284
+ music_info.joint_embed[att] = joint_embed_cond
285
+
286
+ return wav, [local_video, global_video], music_info
287
+
288
+
289
+ def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
290
+ """Preprocess key keywords, discarding them if there are multiple key defined."""
291
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
292
+ return None
293
+ elif ',' in value:
294
+ # For now, we discard when multiple keys are defined separated with comas
295
+ return None
296
+ else:
297
+ return value.strip().lower()
298
+
299
+
300
+ def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
301
+ """Preprocess to a float."""
302
+ if value is None:
303
+ return None
304
+ try:
305
+ return float(value)
306
+ except ValueError:
307
+ return None
audiocraft/data/sound_dataset.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of audio with a simple description.
7
+ """
8
+
9
+ from dataclasses import dataclass, fields, replace
10
+ import json
11
+ from pathlib import Path
12
+ import random
13
+ import typing as tp
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ get_keyword_or_keyword_list
21
+ )
22
+ from ..modules.conditioners import (
23
+ ConditioningAttributes,
24
+ SegmentWithAttributes,
25
+ WavCondition,
26
+ )
27
+
28
+
29
+ EPS = torch.finfo(torch.float32).eps
30
+ TARGET_LEVEL_LOWER = -35
31
+ TARGET_LEVEL_UPPER = -15
32
+
33
+
34
+ @dataclass
35
+ class SoundInfo(SegmentWithAttributes):
36
+ """Segment info augmented with Sound metadata.
37
+ """
38
+ description: tp.Optional[str] = None
39
+ self_wav: tp.Optional[torch.Tensor] = None
40
+
41
+ @property
42
+ def has_sound_meta(self) -> bool:
43
+ return self.description is not None
44
+
45
+ def to_condition_attributes(self) -> ConditioningAttributes:
46
+ out = ConditioningAttributes()
47
+
48
+ for _field in fields(self):
49
+ key, value = _field.name, getattr(self, _field.name)
50
+ if key == 'self_wav':
51
+ out.wav[key] = value
52
+ else:
53
+ out.text[key] = value
54
+ return out
55
+
56
+ @staticmethod
57
+ def attribute_getter(attribute):
58
+ if attribute == 'description':
59
+ preprocess_func = get_keyword_or_keyword_list
60
+ else:
61
+ preprocess_func = None
62
+ return preprocess_func
63
+
64
+ @classmethod
65
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
66
+ _dictionary: tp.Dict[str, tp.Any] = {}
67
+
68
+ # allow a subset of attributes to not be loaded from the dictionary
69
+ # these attributes may be populated later
70
+ post_init_attributes = ['self_wav']
71
+
72
+ for _field in fields(cls):
73
+ if _field.name in post_init_attributes:
74
+ continue
75
+ elif _field.name not in dictionary:
76
+ if fields_required:
77
+ raise KeyError(f"Unexpected missing key: {_field.name}")
78
+ else:
79
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
80
+ value = dictionary[_field.name]
81
+ if preprocess_func:
82
+ value = preprocess_func(value)
83
+ _dictionary[_field.name] = value
84
+ return cls(**_dictionary)
85
+
86
+
87
+ class SoundDataset(InfoAudioDataset):
88
+ """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
89
+
90
+ Args:
91
+ info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
92
+ external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
93
+ The metadata files contained in this folder are expected to match the stem of the audio file with
94
+ a json extension.
95
+ aug_p (float): Probability of performing audio mixing augmentation on the batch.
96
+ mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
97
+ mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
98
+ mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
99
+ mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
100
+ kwargs: Additional arguments for AudioDataset.
101
+
102
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
103
+ """
104
+ def __init__(
105
+ self,
106
+ *args,
107
+ info_fields_required: bool = True,
108
+ external_metadata_source: tp.Optional[str] = None,
109
+ aug_p: float = 0.,
110
+ mix_p: float = 0.,
111
+ mix_snr_low: int = -5,
112
+ mix_snr_high: int = 5,
113
+ mix_min_overlap: float = 0.5,
114
+ **kwargs
115
+ ):
116
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
117
+ super().__init__(*args, **kwargs)
118
+ self.info_fields_required = info_fields_required
119
+ self.external_metadata_source = external_metadata_source
120
+ self.aug_p = aug_p
121
+ self.mix_p = mix_p
122
+ if self.aug_p > 0:
123
+ assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
124
+ assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
125
+ self.mix_snr_low = mix_snr_low
126
+ self.mix_snr_high = mix_snr_high
127
+ self.mix_min_overlap = mix_min_overlap
128
+
129
+ def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
130
+ """Get path of JSON with metadata (description, etc.).
131
+ If there exists a JSON with the same name as 'path.name', then it will be used.
132
+ Else, such JSON will be searched for in an external json source folder if it exists.
133
+ """
134
+ info_path = Path(path).with_suffix('.json')
135
+ if Path(info_path).exists():
136
+ return info_path
137
+ elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
138
+ return Path(self.external_metadata_source) / info_path.name
139
+ else:
140
+ raise Exception(f"Unable to find a metadata JSON for path: {path}")
141
+
142
+ def __getitem__(self, index):
143
+ wav, info = super().__getitem__(index)
144
+ info_data = info.to_dict()
145
+ info_path = self._get_info_path(info.meta.path)
146
+ if Path(info_path).exists():
147
+ with open(info_path, 'r') as json_file:
148
+ sound_data = json.load(json_file)
149
+ sound_data.update(info_data)
150
+ sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
151
+ # if there are multiple descriptions, sample one randomly
152
+ if isinstance(sound_info.description, list):
153
+ sound_info.description = random.choice(sound_info.description)
154
+ else:
155
+ sound_info = SoundInfo.from_dict(info_data, fields_required=False)
156
+
157
+ sound_info.self_wav = WavCondition(
158
+ wav=wav[None], length=torch.tensor([info.n_frames]),
159
+ sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
160
+
161
+ return wav, sound_info
162
+
163
+ def collater(self, samples):
164
+ # when training, audio mixing is performed in the collate function
165
+ wav, sound_info = super().collater(samples) # SoundDataset always returns infos
166
+ if self.aug_p > 0:
167
+ wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
168
+ snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
169
+ min_overlap=self.mix_min_overlap)
170
+ return wav, sound_info
171
+
172
+
173
+ def rms_f(x: torch.Tensor) -> torch.Tensor:
174
+ return (x ** 2).mean(1).pow(0.5)
175
+
176
+
177
+ def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
178
+ """Normalize the signal to the target level."""
179
+ rms = rms_f(audio)
180
+ scalar = 10 ** (target_level / 20) / (rms + EPS)
181
+ audio = audio * scalar.unsqueeze(1)
182
+ return audio
183
+
184
+
185
+ def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
186
+ return (abs(audio) > clipping_threshold).any(1)
187
+
188
+
189
+ def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
190
+ start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
191
+ remainder = src.shape[1] - start
192
+ if dst.shape[1] > remainder:
193
+ src[:, start:] = src[:, start:] + dst[:, :remainder]
194
+ else:
195
+ src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
196
+ return src
197
+
198
+
199
+ def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
200
+ target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
201
+ """Function to mix clean speech and noise at various SNR levels.
202
+
203
+ Args:
204
+ clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
205
+ noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
206
+ snr (int): SNR level when mixing.
207
+ min_overlap (float): Minimum overlap between the two mixed sources.
208
+ target_level (int): Gain level in dB.
209
+ clipping_threshold (float): Threshold for clipping the audio.
210
+ Returns:
211
+ torch.Tensor: The mixed audio, of shape [B, T].
212
+ """
213
+ if clean.shape[1] > noise.shape[1]:
214
+ noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
215
+ else:
216
+ noise = noise[:, :clean.shape[1]]
217
+
218
+ # normalizing to -25 dB FS
219
+ clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
220
+ clean = normalize(clean, target_level)
221
+ rmsclean = rms_f(clean)
222
+
223
+ noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
224
+ noise = normalize(noise, target_level)
225
+ rmsnoise = rms_f(noise)
226
+
227
+ # set the noise level for a given SNR
228
+ noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
229
+ noisenewlevel = noise * noisescalar
230
+
231
+ # mix noise and clean speech
232
+ noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
233
+
234
+ # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
235
+ # there is a chance of clipping that might happen with very less probability, which is not a major issue.
236
+ noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
237
+ rmsnoisy = rms_f(noisyspeech)
238
+ scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
239
+ noisyspeech = noisyspeech * scalarnoisy
240
+ clean = clean * scalarnoisy
241
+ noisenewlevel = noisenewlevel * scalarnoisy
242
+
243
+ # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
244
+ clipped = is_clipped(noisyspeech)
245
+ if clipped.any():
246
+ noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
247
+ noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
248
+
249
+ return noisyspeech
250
+
251
+
252
+ def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
253
+ if snr_low == snr_high:
254
+ snr = snr_low
255
+ else:
256
+ snr = np.random.randint(snr_low, snr_high)
257
+ mix = snr_mixer(src, dst, snr, min_overlap)
258
+ return mix
259
+
260
+
261
+ def mix_text(src_text: str, dst_text: str):
262
+ """Mix text from different sources by concatenating them."""
263
+ if src_text == dst_text:
264
+ return src_text
265
+ return src_text + " " + dst_text
266
+
267
+
268
+ def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
269
+ snr_low: int, snr_high: int, min_overlap: float):
270
+ """Mix samples within a batch, summing the waveforms and concatenating the text infos.
271
+
272
+ Args:
273
+ wavs (torch.Tensor): Audio tensors of shape [B, C, T].
274
+ infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
275
+ aug_p (float): Augmentation probability.
276
+ mix_p (float): Proportion of items in the batch to mix (and merge) together.
277
+ snr_low (int): Lowerbound for sampling SNR.
278
+ snr_high (int): Upperbound for sampling SNR.
279
+ min_overlap (float): Minimum overlap between mixed samples.
280
+ Returns:
281
+ tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
282
+ and mixed SoundInfo for the given batch.
283
+ """
284
+ # no mixing to perform within the batch
285
+ if mix_p == 0:
286
+ return wavs, infos
287
+
288
+ if random.uniform(0, 1) < aug_p:
289
+ # perform all augmentations on waveforms as [B, T]
290
+ # randomly picking pairs of audio to mix
291
+ assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
292
+ wavs = wavs.mean(dim=1, keepdim=False)
293
+ B, T = wavs.shape
294
+ k = int(mix_p * B)
295
+ mixed_sources_idx = torch.randperm(B)[:k]
296
+ mixed_targets_idx = torch.randperm(B)[:k]
297
+ aug_wavs = snr_mix(
298
+ wavs[mixed_sources_idx],
299
+ wavs[mixed_targets_idx],
300
+ snr_low,
301
+ snr_high,
302
+ min_overlap,
303
+ )
304
+ # mixing textual descriptions in metadata
305
+ descriptions = [info.description for info in infos]
306
+ aug_infos = []
307
+ for i, j in zip(mixed_sources_idx, mixed_targets_idx):
308
+ text = mix_text(descriptions[i], descriptions[j])
309
+ m = replace(infos[i])
310
+ m.description = text
311
+ aug_infos.append(m)
312
+
313
+ # back to [B, C, T]
314
+ aug_wavs = aug_wavs.unsqueeze(1)
315
+ assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
316
+ assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
317
+ assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
318
+
319
+ return aug_wavs, aug_infos # [B, C, T]
320
+ else:
321
+ # randomly pick samples in the batch to match
322
+ # the batch size when performing audio mixing
323
+ B, C, T = wavs.shape
324
+ k = int(mix_p * B)
325
+ wav_idx = torch.randperm(B)[:k]
326
+ wavs = wavs[wav_idx]
327
+ infos = [infos[i] for i in wav_idx]
328
+ assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
329
+
330
+ return wavs, infos # [B, C, T]
audiocraft/data/video.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import decord
2
+ from decord import VideoReader
3
+ from decord import cpu
4
+ import torch
5
+ import math
6
+ import einops
7
+ import torchvision.transforms as transforms
8
+
9
+ def adjust_video_duration(video_tensor, duration, target_fps):
10
+ current_duration = video_tensor.shape[1]
11
+ target_duration = duration * target_fps
12
+
13
+ if current_duration > target_duration:
14
+ video_tensor = video_tensor[:, :target_duration]
15
+ elif current_duration < target_duration:
16
+ last_frame = video_tensor[:, -1:]
17
+ repeat_times = target_duration - current_duration
18
+ video_tensor = torch.cat((video_tensor, last_frame.repeat(1, repeat_times, 1, 1)), dim=1)
19
+
20
+ return video_tensor
21
+
22
+ def video_read_local(filepath, seek_time=0., duration=-1, target_fps=2):
23
+ vr = VideoReader(filepath, ctx=cpu(0))
24
+ fps = vr.get_avg_fps()
25
+
26
+ if duration > 0:
27
+ total_frames_to_read = target_fps * duration
28
+ frame_interval = int(math.ceil(fps / target_fps))
29
+ start_frame = int(seek_time * fps)
30
+ end_frame = start_frame + frame_interval * total_frames_to_read
31
+ frame_ids = list(range(start_frame, min(end_frame, len(vr)), frame_interval))
32
+ else:
33
+ frame_ids = list(range(0, len(vr), int(math.ceil(fps / target_fps))))
34
+
35
+ frames = vr.get_batch(frame_ids)
36
+ frames = torch.from_numpy(frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
37
+
38
+ resize_transform = transforms.Resize((224, 224))
39
+ frames = [resize_transform(frame) for frame in frames]
40
+ video_tensor = torch.stack(frames)
41
+ video_tensor = einops.rearrange(video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
42
+ video_tensor = adjust_video_duration(video_tensor, duration, target_fps)
43
+ assert video_tensor.shape[1] == duration * target_fps, f"the shape of video_tensor is {video_tensor.shape}"
44
+
45
+ return video_tensor
46
+
47
+
48
+ def video_read_global(filepath, seek_time=0., duration=-1, target_fps=2, global_mode='average', global_num_frames=32):
49
+ vr = VideoReader(filepath, ctx=cpu(0))
50
+ fps = vr.get_avg_fps()
51
+ frame_count = len(vr)
52
+
53
+ if duration > 0:
54
+ total_frames_to_read = target_fps * duration
55
+ frame_interval = int(math.ceil(fps / target_fps))
56
+ start_frame = int(seek_time * fps)
57
+ end_frame = start_frame + frame_interval * total_frames_to_read
58
+ frame_ids = list(range(start_frame, min(end_frame, frame_count), frame_interval))
59
+ else:
60
+ frame_ids = list(range(0, frame_count, int(math.ceil(fps / target_fps))))
61
+
62
+ local_frames = vr.get_batch(frame_ids)
63
+ local_frames = torch.from_numpy(local_frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
64
+
65
+ resize_transform = transforms.Resize((224, 224))
66
+ local_frames = [resize_transform(frame) for frame in local_frames]
67
+ local_video_tensor = torch.stack(local_frames)
68
+ local_video_tensor = einops.rearrange(local_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
69
+ local_video_tensor = adjust_video_duration(local_video_tensor, duration, target_fps)
70
+
71
+ if global_mode=='average':
72
+ global_frame_ids = torch.linspace(0, frame_count - 1, global_num_frames).long()
73
+
74
+ global_frames = vr.get_batch(global_frame_ids)
75
+ global_frames = torch.from_numpy(global_frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
76
+
77
+ global_frames = [resize_transform(frame) for frame in global_frames]
78
+ global_video_tensor = torch.stack(global_frames)
79
+ global_video_tensor = einops.rearrange(global_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
80
+
81
+ assert global_video_tensor.shape[1] == global_num_frames, f"the shape of global_video_tensor is {global_video_tensor.shape}"
82
+ return local_video_tensor, global_video_tensor
83
+
audiocraft/data/zip.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Utility for reading some info from inside a zip file.
7
+ """
8
+
9
+ import typing
10
+ import zipfile
11
+
12
+ from dataclasses import dataclass
13
+ from functools import lru_cache
14
+ from typing_extensions import Literal
15
+
16
+
17
+ DEFAULT_SIZE = 32
18
+ MODE = Literal['r', 'w', 'x', 'a']
19
+
20
+
21
+ @dataclass(order=True)
22
+ class PathInZip:
23
+ """Hold a path of file within a zip file.
24
+
25
+ Args:
26
+ path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
27
+ Let's assume there is a zip file /some/location/foo.zip
28
+ and inside of it is a json file located at /data/file1.json,
29
+ Then we expect path = "/some/location/foo.zip:/data/file1.json".
30
+ """
31
+
32
+ INFO_PATH_SEP = ':'
33
+ zip_path: str
34
+ file_path: str
35
+
36
+ def __init__(self, path: str) -> None:
37
+ split_path = path.split(self.INFO_PATH_SEP)
38
+ assert len(split_path) == 2
39
+ self.zip_path, self.file_path = split_path
40
+
41
+ @classmethod
42
+ def from_paths(cls, zip_path: str, file_path: str):
43
+ return cls(zip_path + cls.INFO_PATH_SEP + file_path)
44
+
45
+ def __str__(self) -> str:
46
+ return self.zip_path + self.INFO_PATH_SEP + self.file_path
47
+
48
+
49
+ def _open_zip(path: str, mode: MODE = 'r'):
50
+ return zipfile.ZipFile(path, mode)
51
+
52
+
53
+ _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
54
+
55
+
56
+ def set_zip_cache_size(max_size: int):
57
+ """Sets the maximal LRU caching for zip file opening.
58
+
59
+ Args:
60
+ max_size (int): the maximal LRU cache.
61
+ """
62
+ global _cached_open_zip
63
+ _cached_open_zip = lru_cache(max_size)(_open_zip)
64
+
65
+
66
+ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
67
+ """Opens a file stored inside a zip and returns a file-like object.
68
+
69
+ Args:
70
+ path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
71
+ mode (str): The mode in which to open the file with.
72
+ Returns:
73
+ A file-like object for PathInZip.
74
+ """
75
+ zf = _cached_open_zip(path_in_zip.zip_path)
76
+ return zf.open(path_in_zip.file_path)
audiocraft/environment.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Provides cluster and tools configuration across clusters (slurm, dora, utilities).
9
+ """
10
+
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+ import re
15
+ import typing as tp
16
+
17
+ import omegaconf
18
+
19
+ from .utils.cluster import _guess_cluster_type
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class AudioCraftEnvironment:
26
+ """Environment configuration for teams and clusters.
27
+
28
+ AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
29
+ or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
30
+ provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
31
+ allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
32
+ map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
33
+
34
+ The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
35
+ Use the following environment variables to specify the cluster, team or configuration:
36
+
37
+ AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
38
+ cannot be inferred automatically.
39
+ AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
40
+ If not set, configuration is read from config/teams.yaml.
41
+ AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
42
+ Cluster configuration are shared across teams to match compute allocation,
43
+ specify your cluster configuration in the configuration file under a key mapping
44
+ your team name.
45
+ """
46
+ _instance = None
47
+ DEFAULT_TEAM = "default"
48
+
49
+ def __init__(self) -> None:
50
+ """Loads configuration."""
51
+ self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
52
+ cluster_type = _guess_cluster_type()
53
+ cluster = os.getenv(
54
+ "AUDIOCRAFT_CLUSTER", cluster_type.value
55
+ )
56
+ logger.info("Detecting cluster type %s", cluster_type)
57
+
58
+ self.cluster: str = cluster
59
+
60
+ config_path = os.getenv(
61
+ "AUDIOCRAFT_CONFIG",
62
+ Path(__file__)
63
+ .parent.parent.joinpath("config/teams", self.team)
64
+ .with_suffix(".yaml"),
65
+ )
66
+ self.config = omegaconf.OmegaConf.load(config_path)
67
+ self._dataset_mappers = []
68
+ cluster_config = self._get_cluster_config()
69
+ if "dataset_mappers" in cluster_config:
70
+ for pattern, repl in cluster_config["dataset_mappers"].items():
71
+ regex = re.compile(pattern)
72
+ self._dataset_mappers.append((regex, repl))
73
+
74
+ def _get_cluster_config(self) -> omegaconf.DictConfig:
75
+ assert isinstance(self.config, omegaconf.DictConfig)
76
+ return self.config[self.cluster]
77
+
78
+ @classmethod
79
+ def instance(cls):
80
+ if cls._instance is None:
81
+ cls._instance = cls()
82
+ return cls._instance
83
+
84
+ @classmethod
85
+ def reset(cls):
86
+ """Clears the environment and forces a reload on next invocation."""
87
+ cls._instance = None
88
+
89
+ @classmethod
90
+ def get_team(cls) -> str:
91
+ """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
92
+ If not defined, defaults to "labs".
93
+ """
94
+ return cls.instance().team
95
+
96
+ @classmethod
97
+ def get_cluster(cls) -> str:
98
+ """Gets the detected cluster.
99
+ This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
100
+ """
101
+ return cls.instance().cluster
102
+
103
+ @classmethod
104
+ def get_dora_dir(cls) -> Path:
105
+ """Gets the path to the dora directory for the current team and cluster.
106
+ Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
107
+ """
108
+ cluster_config = cls.instance()._get_cluster_config()
109
+ dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
110
+ logger.warning(f"Dora directory: {dora_dir}")
111
+ return Path(dora_dir)
112
+
113
+ @classmethod
114
+ def get_reference_dir(cls) -> Path:
115
+ """Gets the path to the reference directory for the current team and cluster.
116
+ Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
117
+ """
118
+ cluster_config = cls.instance()._get_cluster_config()
119
+ return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
120
+
121
+ @classmethod
122
+ def get_slurm_exclude(cls) -> tp.Optional[str]:
123
+ """Get the list of nodes to exclude for that cluster."""
124
+ cluster_config = cls.instance()._get_cluster_config()
125
+ return cluster_config.get("slurm_exclude")
126
+
127
+ @classmethod
128
+ def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
129
+ """Gets the requested partitions for the current team and cluster as a comma-separated string.
130
+
131
+ Args:
132
+ partition_types (list[str], optional): partition types to retrieve. Values must be
133
+ from ['global', 'team']. If not provided, the global partition is returned.
134
+ """
135
+ if not partition_types:
136
+ partition_types = ["global"]
137
+
138
+ cluster_config = cls.instance()._get_cluster_config()
139
+ partitions = [
140
+ cluster_config["partitions"][partition_type]
141
+ for partition_type in partition_types
142
+ ]
143
+ return ",".join(partitions)
144
+
145
+ @classmethod
146
+ def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
147
+ """Converts reference placeholder in path with configured reference dir to resolve paths.
148
+
149
+ Args:
150
+ path (str or Path): Path to resolve.
151
+ Returns:
152
+ Path: Resolved path.
153
+ """
154
+ path = str(path)
155
+
156
+ if path.startswith("//reference"):
157
+ reference_dir = cls.get_reference_dir()
158
+ logger.warn(f"Reference directory: {reference_dir}")
159
+ assert (
160
+ reference_dir.exists() and reference_dir.is_dir()
161
+ ), f"Reference directory does not exist: {reference_dir}."
162
+ path = re.sub("^//reference", str(reference_dir), path)
163
+
164
+ return Path(path)
165
+
166
+ @classmethod
167
+ def apply_dataset_mappers(cls, path: str) -> str:
168
+ """Applies dataset mapping regex rules as defined in the configuration.
169
+ If no rules are defined, the path is returned as-is.
170
+ """
171
+ instance = cls.instance()
172
+
173
+ for pattern, repl in instance._dataset_mappers:
174
+ path = pattern.sub(repl, path)
175
+
176
+ return path
audiocraft/losses/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Loss related classes and functions. In particular the loss balancer from
7
+ EnCodec, and the usual spectral losses."""
8
+
9
+ # flake8: noqa
10
+ from .balancer import Balancer
11
+ from .sisnr import SISNR
12
+ from .stftloss import (
13
+ LogSTFTMagnitudeLoss,
14
+ MRSTFTLoss,
15
+ SpectralConvergenceLoss,
16
+ STFTLoss
17
+ )
18
+ from .specloss import (
19
+ MelSpectrogramL1Loss,
20
+ MultiScaleMelSpectrogramLoss,
21
+ )
audiocraft/losses/balancer.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import flashy
10
+ import torch
11
+ from torch import autograd
12
+
13
+
14
+ class Balancer:
15
+ """Loss balancer.
16
+
17
+ The loss balancer combines losses together to compute gradients for the backward.
18
+ Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
19
+ not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
20
+ `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
21
+ the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
22
+ going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
23
+ interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
24
+
25
+ Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
26
+ (with `avg` an exponential moving average over the updates),
27
+
28
+ G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
29
+
30
+ If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
31
+ standard sum of the partial gradients with the given weights.
32
+
33
+ A call to the backward method of the balancer will compute the the partial gradients,
34
+ combining all the losses and potentially rescaling the gradients,
35
+ which can help stabilize the training and reason about multiple losses with varying scales.
36
+ The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
37
+
38
+ Expected usage:
39
+
40
+ weights = {'loss_a': 1, 'loss_b': 4}
41
+ balancer = Balancer(weights, ...)
42
+ losses: dict = {}
43
+ losses['loss_a'] = compute_loss_a(x, y)
44
+ losses['loss_b'] = compute_loss_b(x, y)
45
+ if model.training():
46
+ effective_loss = balancer.backward(losses, x)
47
+
48
+ Args:
49
+ weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
50
+ from the backward method to match the weights keys to assign weight to each of the provided loss.
51
+ balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
52
+ overall gradient, rather than a constant multiplier.
53
+ total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
54
+ emay_decay (float): EMA decay for averaging the norms.
55
+ per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
56
+ when rescaling the gradients.
57
+ epsilon (float): Epsilon value for numerical stability.
58
+ monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
59
+ coming from each loss, when calling `backward()`.
60
+ """
61
+ def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
62
+ ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
63
+ monitor: bool = False):
64
+ self.weights = weights
65
+ self.per_batch_item = per_batch_item
66
+ self.total_norm = total_norm or 1.
67
+ self.averager = flashy.averager(ema_decay or 1.)
68
+ self.epsilon = epsilon
69
+ self.monitor = monitor
70
+ self.balance_grads = balance_grads
71
+ self._metrics: tp.Dict[str, tp.Any] = {}
72
+
73
+ @property
74
+ def metrics(self):
75
+ return self._metrics
76
+
77
+ def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
78
+ """Compute the backward and return the effective train loss, e.g. the loss obtained from
79
+ computing the effective weights. If `balance_grads` is True, the effective weights
80
+ are the one that needs to be applied to each gradient to respect the desired relative
81
+ scale of gradients coming from each loss.
82
+
83
+ Args:
84
+ losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
85
+ input (torch.Tensor): the input of the losses, typically the output of the model.
86
+ This should be the single point of dependence between the losses
87
+ and the model being trained.
88
+ """
89
+ norms = {}
90
+ grads = {}
91
+ for name, loss in losses.items():
92
+ # Compute partial derivative of the less with respect to the input.
93
+ grad, = autograd.grad(loss, [input], retain_graph=True)
94
+ if self.per_batch_item:
95
+ # We do not average the gradient over the batch dimension.
96
+ dims = tuple(range(1, grad.dim()))
97
+ norm = grad.norm(dim=dims, p=2).mean()
98
+ else:
99
+ norm = grad.norm(p=2)
100
+ norms[name] = norm
101
+ grads[name] = grad
102
+
103
+ count = 1
104
+ if self.per_batch_item:
105
+ count = len(grad)
106
+ # Average norms across workers. Theoretically we should average the
107
+ # squared norm, then take the sqrt, but it worked fine like that.
108
+ avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
109
+ # We approximate the total norm of the gradient as the sums of the norms.
110
+ # Obviously this can be very incorrect if all gradients are aligned, but it works fine.
111
+ total = sum(avg_norms.values())
112
+
113
+ self._metrics = {}
114
+ if self.monitor:
115
+ # Store the ratio of the total gradient represented by each loss.
116
+ for k, v in avg_norms.items():
117
+ self._metrics[f'ratio_{k}'] = v / total
118
+
119
+ total_weights = sum([self.weights[k] for k in avg_norms])
120
+ assert total_weights > 0.
121
+ desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
122
+
123
+ out_grad = torch.zeros_like(input)
124
+ effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
125
+ for name, avg_norm in avg_norms.items():
126
+ if self.balance_grads:
127
+ # g_balanced = g / avg(||g||) * total_norm * desired_ratio
128
+ scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
129
+ else:
130
+ # We just do regular weighted sum of the gradients.
131
+ scale = self.weights[name]
132
+ out_grad.add_(grads[name], alpha=scale)
133
+ effective_loss += scale * losses[name].detach()
134
+ # Send the computed partial derivative with respect to the output of the model to the model.
135
+ input.backward(out_grad)
136
+ return effective_loss
audiocraft/losses/sisnr.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+
14
+
15
+ def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
16
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
17
+ with K the kernel size, by extracting frames with the given stride.
18
+ This will pad the input so that `F = ceil(T / K)`.
19
+ see https://github.com/pytorch/pytorch/issues/60466
20
+ """
21
+ *shape, length = a.shape
22
+ n_frames = math.ceil(length / stride)
23
+ tgt_length = (n_frames - 1) * stride + kernel_size
24
+ a = F.pad(a, (0, tgt_length - length))
25
+ strides = list(a.stride())
26
+ assert strides[-1] == 1, "data should be contiguous"
27
+ strides = strides[:-1] + [stride, 1]
28
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
29
+
30
+
31
+ def _center(x: torch.Tensor) -> torch.Tensor:
32
+ return x - x.mean(-1, True)
33
+
34
+
35
+ def _norm2(x: torch.Tensor) -> torch.Tensor:
36
+ return x.pow(2).sum(-1, True)
37
+
38
+
39
+ class SISNR(nn.Module):
40
+ """SISNR loss.
41
+
42
+ Input should be [B, C, T], output is scalar.
43
+
44
+ ..Warning:: This function returns the opposite of the SI-SNR (e.g. `-1 * regular_SI_SNR`).
45
+ Consequently, lower scores are better in terms of reconstruction quality,
46
+ in particular, it should be negative if training goes well. This done this way so
47
+ that this module can also be used as a loss function for training model.
48
+
49
+ Args:
50
+ sample_rate (int): Sample rate.
51
+ segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
52
+ entire audio only.
53
+ overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
54
+ epsilon (float): Epsilon value for numerical stability.
55
+ """
56
+ def __init__(
57
+ self,
58
+ sample_rate: int = 16000,
59
+ segment: tp.Optional[float] = 20,
60
+ overlap: float = 0.5,
61
+ epsilon: float = torch.finfo(torch.float32).eps,
62
+ ):
63
+ super().__init__()
64
+ self.sample_rate = sample_rate
65
+ self.segment = segment
66
+ self.overlap = overlap
67
+ self.epsilon = epsilon
68
+
69
+ def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
70
+ B, C, T = ref_sig.shape
71
+ assert ref_sig.shape == out_sig.shape
72
+
73
+ if self.segment is None:
74
+ frame = T
75
+ stride = T
76
+ else:
77
+ frame = int(self.segment * self.sample_rate)
78
+ stride = int(frame * (1 - self.overlap))
79
+
80
+ epsilon = self.epsilon * frame # make epsilon prop to frame size.
81
+
82
+ gt = _unfold(ref_sig, frame, stride)
83
+ est = _unfold(out_sig, frame, stride)
84
+ if self.segment is None:
85
+ assert gt.shape[-1] == 1
86
+
87
+ gt = _center(gt)
88
+ est = _center(est)
89
+ dot = torch.einsum("bcft,bcft->bcf", gt, est)
90
+
91
+ proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
92
+ noise = est - proj
93
+
94
+ sisnr = 10 * (
95
+ torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
96
+ )
97
+ return -1 * sisnr[..., 0].mean()
audiocraft/losses/specloss.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ from torchaudio.transforms import MelSpectrogram
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from ..modules import pad_for_conv1d
16
+
17
+
18
+ class MelSpectrogramWrapper(nn.Module):
19
+ """Wrapper around MelSpectrogram torchaudio transform providing proper padding
20
+ and additional post-processing including log scaling.
21
+
22
+ Args:
23
+ n_mels (int): Number of mel bins.
24
+ n_fft (int): Number of fft.
25
+ hop_length (int): Hop size.
26
+ win_length (int): Window length.
27
+ n_mels (int): Number of mel bins.
28
+ sample_rate (int): Sample rate.
29
+ f_min (float or None): Minimum frequency.
30
+ f_max (float or None): Maximum frequency.
31
+ log (bool): Whether to scale with log.
32
+ normalized (bool): Whether to normalize the melspectrogram.
33
+ floor_level (float): Floor level based on human perception (default=1e-5).
34
+ """
35
+ def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
36
+ n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
37
+ log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
38
+ super().__init__()
39
+ self.n_fft = n_fft
40
+ hop_length = int(hop_length)
41
+ self.hop_length = hop_length
42
+ self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
43
+ win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
44
+ window_fn=torch.hann_window, center=False)
45
+ self.floor_level = floor_level
46
+ self.log = log
47
+
48
+ def forward(self, x):
49
+ p = int((self.n_fft - self.hop_length) // 2)
50
+ if len(x.shape) == 2:
51
+ x = x.unsqueeze(1)
52
+ x = F.pad(x, (p, p), "reflect")
53
+ # Make sure that all the frames are full.
54
+ # The combination of `pad_for_conv1d` and the above padding
55
+ # will make the output of size ceil(T / hop).
56
+ x = pad_for_conv1d(x, self.n_fft, self.hop_length)
57
+ self.mel_transform.to(x.device)
58
+ mel_spec = self.mel_transform(x)
59
+ B, C, freqs, frame = mel_spec.shape
60
+ if self.log:
61
+ mel_spec = torch.log10(self.floor_level + mel_spec)
62
+ return mel_spec.reshape(B, C * freqs, frame)
63
+
64
+
65
+ class MelSpectrogramL1Loss(torch.nn.Module):
66
+ """L1 Loss on MelSpectrogram.
67
+
68
+ Args:
69
+ sample_rate (int): Sample rate.
70
+ n_fft (int): Number of fft.
71
+ hop_length (int): Hop size.
72
+ win_length (int): Window length.
73
+ n_mels (int): Number of mel bins.
74
+ f_min (float or None): Minimum frequency.
75
+ f_max (float or None): Maximum frequency.
76
+ log (bool): Whether to scale with log.
77
+ normalized (bool): Whether to normalize the melspectrogram.
78
+ floor_level (float): Floor level value based on human perception (default=1e-5).
79
+ """
80
+ def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
81
+ n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
82
+ log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
83
+ super().__init__()
84
+ self.l1 = torch.nn.L1Loss()
85
+ self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
86
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
87
+ log=log, normalized=normalized, floor_level=floor_level)
88
+
89
+ def forward(self, x, y):
90
+ self.melspec.to(x.device)
91
+ s_x = self.melspec(x)
92
+ s_y = self.melspec(y)
93
+ return self.l1(s_x, s_y)
94
+
95
+
96
+ class MultiScaleMelSpectrogramLoss(nn.Module):
97
+ """Multi-Scale spectrogram loss (msspec).
98
+
99
+ Args:
100
+ sample_rate (int): Sample rate.
101
+ range_start (int): Power of 2 to use for the first scale.
102
+ range_stop (int): Power of 2 to use for the last scale.
103
+ n_mels (int): Number of mel bins.
104
+ f_min (float): Minimum frequency.
105
+ f_max (float or None): Maximum frequency.
106
+ normalized (bool): Whether to normalize the melspectrogram.
107
+ alphas (bool): Whether to use alphas as coefficients or not.
108
+ floor_level (float): Floor level value based on human perception (default=1e-5).
109
+ """
110
+ def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
111
+ n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
112
+ normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
113
+ super().__init__()
114
+ l1s = list()
115
+ l2s = list()
116
+ self.alphas = list()
117
+ self.total = 0
118
+ self.normalized = normalized
119
+ for i in range(range_start, range_end):
120
+ l1s.append(
121
+ MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
122
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
123
+ log=False, normalized=normalized, floor_level=floor_level))
124
+ l2s.append(
125
+ MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
126
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
127
+ log=True, normalized=normalized, floor_level=floor_level))
128
+ if alphas:
129
+ self.alphas.append(np.sqrt(2 ** i - 1))
130
+ else:
131
+ self.alphas.append(1)
132
+ self.total += self.alphas[-1] + 1
133
+
134
+ self.l1s = nn.ModuleList(l1s)
135
+ self.l2s = nn.ModuleList(l2s)
136
+
137
+ def forward(self, x, y):
138
+ loss = 0.0
139
+ self.l1s.to(x.device)
140
+ self.l2s.to(x.device)
141
+ for i in range(len(self.alphas)):
142
+ s_x_1 = self.l1s[i](x)
143
+ s_y_1 = self.l1s[i](y)
144
+ s_x_2 = self.l2s[i](x)
145
+ s_y_2 = self.l2s[i](y)
146
+ loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
147
+ if self.normalized:
148
+ loss = loss / self.total
149
+ return loss
audiocraft/losses/stftloss.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Adapted from MIT code under the original license
7
+ # Copyright 2019 Tomoki Hayashi
8
+ # MIT License (https://opensource.org/licenses/MIT)
9
+ import typing as tp
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+
16
+ # TODO: Replace with torchaudio.STFT?
17
+ def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
18
+ window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor:
19
+ """Perform STFT and convert to magnitude spectrogram.
20
+
21
+ Args:
22
+ x: Input signal tensor (B, C, T).
23
+ fft_size (int): FFT size.
24
+ hop_length (int): Hop size.
25
+ win_length (int): Window length.
26
+ window (torch.Tensor or None): Window function type.
27
+ normalized (bool): Whether to normalize the STFT or not.
28
+
29
+ Returns:
30
+ torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1).
31
+ """
32
+ B, C, T = x.shape
33
+ x_stft = torch.stft(
34
+ x.view(-1, T), fft_size, hop_length, win_length, window,
35
+ normalized=normalized, return_complex=True,
36
+ )
37
+ x_stft = x_stft.view(B, C, *x_stft.shape[1:])
38
+ real = x_stft.real
39
+ imag = x_stft.imag
40
+
41
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
42
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
43
+
44
+
45
+ class SpectralConvergenceLoss(nn.Module):
46
+ """Spectral convergence loss.
47
+ """
48
+ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
49
+ super().__init__()
50
+ self.epsilon = epsilon
51
+
52
+ def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
53
+ """Calculate forward propagation.
54
+
55
+ Args:
56
+ x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
57
+ y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
58
+ Returns:
59
+ torch.Tensor: Spectral convergence loss value.
60
+ """
61
+ return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
62
+
63
+
64
+ class LogSTFTMagnitudeLoss(nn.Module):
65
+ """Log STFT magnitude loss.
66
+
67
+ Args:
68
+ epsilon (float): Epsilon value for numerical stability.
69
+ """
70
+ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
71
+ super().__init__()
72
+ self.epsilon = epsilon
73
+
74
+ def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
75
+ """Calculate forward propagation.
76
+
77
+ Args:
78
+ x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
79
+ y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
80
+ Returns:
81
+ torch.Tensor: Log STFT magnitude loss value.
82
+ """
83
+ return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
84
+
85
+
86
+ class STFTLosses(nn.Module):
87
+ """STFT losses.
88
+
89
+ Args:
90
+ n_fft (int): Size of FFT.
91
+ hop_length (int): Hop length.
92
+ win_length (int): Window length.
93
+ window (str): Window function type.
94
+ normalized (bool): Whether to use normalized STFT or not.
95
+ epsilon (float): Epsilon for numerical stability.
96
+ """
97
+ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
98
+ window: str = "hann_window", normalized: bool = False,
99
+ epsilon: float = torch.finfo(torch.float32).eps):
100
+ super().__init__()
101
+ self.n_fft = n_fft
102
+ self.hop_length = hop_length
103
+ self.win_length = win_length
104
+ self.normalized = normalized
105
+ self.register_buffer("window", getattr(torch, window)(win_length))
106
+ self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
107
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
108
+
109
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
110
+ """Calculate forward propagation.
111
+
112
+ Args:
113
+ x (torch.Tensor): Predicted signal (B, T).
114
+ y (torch.Tensor): Groundtruth signal (B, T).
115
+ Returns:
116
+ torch.Tensor: Spectral convergence loss value.
117
+ torch.Tensor: Log STFT magnitude loss value.
118
+ """
119
+ x_mag = _stft(x, self.n_fft, self.hop_length,
120
+ self.win_length, self.window, self.normalized) # type: ignore
121
+ y_mag = _stft(y, self.n_fft, self.hop_length,
122
+ self.win_length, self.window, self.normalized) # type: ignore
123
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
124
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
125
+
126
+ return sc_loss, mag_loss
127
+
128
+
129
+ class STFTLoss(nn.Module):
130
+ """Single Resolution STFT loss.
131
+
132
+ Args:
133
+ n_fft (int): Nb of FFT.
134
+ hop_length (int): Hop length.
135
+ win_length (int): Window length.
136
+ window (str): Window function type.
137
+ normalized (bool): Whether to use normalized STFT or not.
138
+ epsilon (float): Epsilon for numerical stability.
139
+ factor_sc (float): Coefficient for the spectral loss.
140
+ factor_mag (float): Coefficient for the magnitude loss.
141
+ """
142
+ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
143
+ window: str = "hann_window", normalized: bool = False,
144
+ factor_sc: float = 0.1, factor_mag: float = 0.1,
145
+ epsilon: float = torch.finfo(torch.float32).eps):
146
+ super().__init__()
147
+ self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
148
+ self.factor_sc = factor_sc
149
+ self.factor_mag = factor_mag
150
+
151
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
152
+ """Calculate forward propagation.
153
+
154
+ Args:
155
+ x (torch.Tensor): Predicted signal (B, T).
156
+ y (torch.Tensor): Groundtruth signal (B, T).
157
+ Returns:
158
+ torch.Tensor: Single resolution STFT loss.
159
+ """
160
+ sc_loss, mag_loss = self.loss(x, y)
161
+ return self.factor_sc * sc_loss + self.factor_mag * mag_loss
162
+
163
+
164
+ class MRSTFTLoss(nn.Module):
165
+ """Multi resolution STFT loss.
166
+
167
+ Args:
168
+ n_ffts (Sequence[int]): Sequence of FFT sizes.
169
+ hop_lengths (Sequence[int]): Sequence of hop sizes.
170
+ win_lengths (Sequence[int]): Sequence of window lengths.
171
+ window (str): Window function type.
172
+ factor_sc (float): Coefficient for the spectral loss.
173
+ factor_mag (float): Coefficient for the magnitude loss.
174
+ normalized (bool): Whether to use normalized STFT or not.
175
+ epsilon (float): Epsilon for numerical stability.
176
+ """
177
+ def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
178
+ win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
179
+ factor_sc: float = 0.1, factor_mag: float = 0.1,
180
+ normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
181
+ super().__init__()
182
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
183
+ self.stft_losses = torch.nn.ModuleList()
184
+ for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
185
+ self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
186
+ self.factor_sc = factor_sc
187
+ self.factor_mag = factor_mag
188
+
189
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
190
+ """Calculate forward propagation.
191
+
192
+ Args:
193
+ x (torch.Tensor): Predicted signal (B, T).
194
+ y (torch.Tensor): Groundtruth signal (B, T).
195
+ Returns:
196
+ torch.Tensor: Multi resolution STFT loss.
197
+ """
198
+ sc_loss = torch.Tensor([0.0])
199
+ mag_loss = torch.Tensor([0.0])
200
+ for f in self.stft_losses:
201
+ sc_l, mag_l = f(x, y)
202
+ sc_loss += sc_l
203
+ mag_loss += mag_l
204
+ sc_loss /= len(self.stft_losses)
205
+ mag_loss /= len(self.stft_losses)
206
+
207
+ return self.factor_sc * sc_loss + self.factor_mag * mag_loss
audiocraft/metrics/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc.
7
+ """
8
+ # flake8: noqa
9
+ from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric
10
+ from .chroma_cosinesim import ChromaCosineSimilarityMetric
11
+ from .fad import FrechetAudioDistanceMetric
12
+ from .kld import KLDivergenceMetric, PasstKLDivergenceMetric
13
+ from .rvm import RelativeVolumeMel
14
+ from .visqol import ViSQOL
audiocraft/metrics/chroma_cosinesim.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torchmetrics
9
+
10
+ from ..data.audio_utils import convert_audio
11
+ from ..modules.chroma import ChromaExtractor
12
+
13
+
14
+ class ChromaCosineSimilarityMetric(torchmetrics.Metric):
15
+ """Chroma cosine similarity metric.
16
+
17
+ This metric extracts a chromagram for a reference waveform and
18
+ a generated waveform and compares each frame using the cosine similarity
19
+ function. The output is the mean cosine similarity.
20
+
21
+ Args:
22
+ sample_rate (int): Sample rate used by the chroma extractor.
23
+ n_chroma (int): Number of chroma used by the chroma extractor.
24
+ radix2_exp (int): Exponent for the chroma extractor.
25
+ argmax (bool): Whether the chroma extractor uses argmax.
26
+ eps (float): Epsilon for cosine similarity computation.
27
+ """
28
+ def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
29
+ super().__init__()
30
+ self.chroma_sample_rate = sample_rate
31
+ self.n_chroma = n_chroma
32
+ self.eps = eps
33
+ self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
34
+ radix2_exp=radix2_exp, argmax=argmax)
35
+ self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
36
+ self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
37
+
38
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
39
+ sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
40
+ """Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
41
+ if preds.size(0) == 0:
42
+ return
43
+
44
+ assert preds.shape == targets.shape, (
45
+ f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
46
+ assert preds.size(0) == sizes.size(0), (
47
+ f"Number of items in preds ({preds.shape}) mismatch ",
48
+ f"with sizes ({sizes.shape})")
49
+ assert preds.size(0) == sample_rates.size(0), (
50
+ f"Number of items in preds ({preds.shape}) mismatch ",
51
+ f"with sample_rates ({sample_rates.shape})")
52
+ assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
53
+
54
+ device = self.weight.device
55
+ preds, targets = preds.to(device), targets.to(device) # type: ignore
56
+ sample_rate = sample_rates[0].item()
57
+ preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
58
+ targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
59
+ gt_chroma = self.chroma_extractor(targets)
60
+ gen_chroma = self.chroma_extractor(preds)
61
+ chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
62
+ for i in range(len(gt_chroma)):
63
+ t = int(chroma_lens[i].item())
64
+ cosine_sim = torch.nn.functional.cosine_similarity(
65
+ gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
66
+ self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
67
+ self.weight += torch.tensor(t) # type: ignore
68
+
69
+ def compute(self) -> float:
70
+ """Computes the average cosine similarty across all generated/target chromagrams pairs."""
71
+ assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
72
+ return (self.cosine_sum / self.weight).item() # type: ignore
audiocraft/metrics/clap_consistency.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from pathlib import Path
8
+ import typing as tp
9
+
10
+ import torch
11
+ import torchmetrics
12
+ from transformers import RobertaTokenizer # type: ignore
13
+
14
+ from ..data.audio_utils import convert_audio
15
+ from ..environment import AudioCraftEnvironment
16
+ from ..utils.utils import load_clap_state_dict
17
+
18
+ try:
19
+ import laion_clap # type: ignore
20
+ except ImportError:
21
+ laion_clap = None
22
+
23
+
24
+ class TextConsistencyMetric(torchmetrics.Metric):
25
+ """Text consistency metric measuring consistency between audio and text pairs."""
26
+
27
+ def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
28
+ raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
29
+
30
+ def compute(self):
31
+ raise NotImplementedError("implement how to compute the final metric score.")
32
+
33
+
34
+ class CLAPTextConsistencyMetric(TextConsistencyMetric):
35
+ """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
36
+
37
+ This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
38
+ or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
39
+
40
+ As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
41
+ similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
42
+ well as the generated audio based on them, and define the MCC metric as the average cosine similarity
43
+ between these embeddings.
44
+
45
+ Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
46
+ """
47
+ def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
48
+ super().__init__()
49
+ if laion_clap is None:
50
+ raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
51
+ self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
52
+ self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
53
+ self._initialize_model(model_path, model_arch, enable_fusion)
54
+
55
+ def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
56
+ model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
57
+ self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
58
+ self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
59
+ self.model_sample_rate = 48_000
60
+ load_clap_state_dict(self.model, model_path)
61
+ self.model.eval()
62
+
63
+ def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
64
+ # we use the default params from CLAP module here as well
65
+ return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
66
+
67
+ def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
68
+ """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
69
+ assert audio.size(0) == len(text), "Number of audio and text samples should match"
70
+ assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
71
+ sample_rate = int(sample_rates[0].item())
72
+ # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
73
+ audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
74
+ audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
75
+ text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
76
+ # cosine similarity between the text and the audio embedding
77
+ cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
78
+ self.cosine_sum += cosine_sim.sum(dim=0)
79
+ self.weight += torch.tensor(cosine_sim.size(0))
80
+
81
+ def compute(self):
82
+ """Computes the average cosine similarty across all audio/text pairs."""
83
+ assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
84
+ return (self.cosine_sum / self.weight).item() # type: ignore
audiocraft/metrics/fad.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ import os
10
+ import subprocess
11
+ import tempfile
12
+ import typing as tp
13
+
14
+ from audiocraft.data.audio import audio_write
15
+ from audiocraft.data.audio_utils import convert_audio
16
+ import flashy
17
+ import torch
18
+ import torchmetrics
19
+
20
+ from ..environment import AudioCraftEnvironment
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ VGGISH_SAMPLE_RATE = 16_000
26
+ VGGISH_CHANNELS = 1
27
+
28
+
29
+ class FrechetAudioDistanceMetric(torchmetrics.Metric):
30
+ """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
31
+
32
+ From: D.C. Dowson & B.V. Landau The Fréchet distance between
33
+ multivariate normal distributions
34
+ https://doi.org/10.1016/0047-259X(82)90077-X
35
+ The Fréchet distance between two multivariate gaussians,
36
+ `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
37
+ d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
38
+ = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
39
+ - 2 * Tr(sqrt(sigma_x*sigma_y)))
40
+
41
+ To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
42
+ from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
43
+ We provide the below instructions as reference but we do not guarantee for further support
44
+ in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
45
+
46
+ We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
47
+
48
+ 1. Get the code and models following the repository instructions. We used the steps below:
49
+ git clone git@github.com:google-research/google-research.git
50
+ git clone git@github.com:tensorflow/models.git
51
+ mkdir google-research/tensorflow_models
52
+ touch google-research/tensorflow_models/__init__.py
53
+ cp -r models/research/audioset google-research/tensorflow_models/
54
+ touch google-research/tensorflow_models/audioset/__init__.py
55
+ echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
56
+ google-research/tensorflow_models/audioset/__init__.py
57
+ # we can now remove the tensorflow models repository
58
+ # rm -r models
59
+ cd google-research
60
+ Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
61
+ assumes it is placed in the AudioCraft reference dir.
62
+
63
+ Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
64
+ - Update xrange for range in:
65
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
66
+ - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
67
+ `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
68
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
69
+ - Update `import vggish_params as params` to `from . import vggish_params as params` in:
70
+ https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
71
+ - Add flag to provide a given batch size for running the AudioSet model in:
72
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
73
+ ```
74
+ flags.DEFINE_integer('batch_size', 64,
75
+ 'Number of samples in the batch for AudioSet model.')
76
+ ```
77
+ Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
78
+ `batch_size=FLAGS.batch_size` to the provided parameters.
79
+
80
+ 2. Follow instructions for the library installation and a valid TensorFlow installation
81
+ ```
82
+ # e.g. instructions from: https://www.tensorflow.org/install/pip
83
+ conda install -c conda-forge cudatoolkit=11.8.0
84
+ python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
85
+ mkdir -p $CONDA_PREFIX/etc/conda/activate.d
86
+ echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
87
+ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
88
+ echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
89
+ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
90
+ source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
91
+ # Verify install: on a machine with GPU device
92
+ python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
93
+ ```
94
+
95
+ Now install frechet_audio_distance required dependencies:
96
+ ```
97
+ # We assume we already have TensorFlow installed from the above steps
98
+ pip install apache-beam numpy scipy tf_slim
99
+ ```
100
+
101
+ Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
102
+ (you may want to specify --model_ckpt flag pointing to the model's path).
103
+
104
+ 3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
105
+ and Tensorflow library path from the above installation steps:
106
+ export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
107
+ export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
108
+
109
+ e.g. assuming we have installed everything in a dedicated conda env
110
+ with python 3.10 that is currently active:
111
+ export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
112
+ export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
113
+
114
+ Finally you may want to export the following variable:
115
+ export TF_FORCE_GPU_ALLOW_GROWTH=true
116
+ See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
117
+
118
+ You can save those environment variables in your training conda env, when currently active:
119
+ `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
120
+ e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
121
+ and the training conda env is named audiocraft:
122
+ ```
123
+ # activate training env
124
+ conda activate audiocraft
125
+ # get path to all envs
126
+ CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
127
+ # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
128
+ touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
129
+ echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
130
+ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
131
+ echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
132
+ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
133
+ # optionally:
134
+ echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
135
+ # you may need to reactivate the audiocraft env for this to take effect
136
+ ```
137
+
138
+ Args:
139
+ bin (Path or str): Path to installed frechet audio distance code.
140
+ model_path (Path or str): Path to Tensorflow checkpoint for the model
141
+ used to compute statistics over the embedding beams.
142
+ format (str): Audio format used to save files.
143
+ log_folder (Path or str, optional): Path where to write process logs.
144
+ """
145
+ def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
146
+ format: str = "wav", batch_size: tp.Optional[int] = None,
147
+ log_folder: tp.Optional[tp.Union[Path, str]] = None):
148
+ super().__init__()
149
+ self.model_sample_rate = VGGISH_SAMPLE_RATE
150
+ self.model_channels = VGGISH_CHANNELS
151
+ self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
152
+ assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
153
+ self.format = format
154
+ self.batch_size = batch_size
155
+ self.bin = bin
156
+ self.tf_env = {"PYTHONPATH": str(self.bin)}
157
+ self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
158
+ logger.info("Python exe for TF is %s", self.python_path)
159
+ if 'TF_LIBRARY_PATH' in os.environ:
160
+ self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
161
+ if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
162
+ self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
163
+ logger.info("Env for TF is %r", self.tf_env)
164
+ self.reset(log_folder)
165
+ self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
166
+
167
+ def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
168
+ """Reset torchmetrics.Metrics state."""
169
+ log_folder = Path(log_folder or tempfile.mkdtemp())
170
+ self.tmp_dir = log_folder / 'fad'
171
+ self.tmp_dir.mkdir(exist_ok=True)
172
+ self.samples_tests_dir = self.tmp_dir / 'tests'
173
+ self.samples_tests_dir.mkdir(exist_ok=True)
174
+ self.samples_background_dir = self.tmp_dir / 'background'
175
+ self.samples_background_dir.mkdir(exist_ok=True)
176
+ self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
177
+ self.manifest_background = self.tmp_dir / 'files_background.cvs'
178
+ self.stats_tests_dir = self.tmp_dir / 'stats_tests'
179
+ self.stats_background_dir = self.tmp_dir / 'stats_background'
180
+ self.counter = 0
181
+
182
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
183
+ sizes: torch.Tensor, sample_rates: torch.Tensor,
184
+ stems: tp.Optional[tp.List[str]] = None):
185
+ """Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
186
+ assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
187
+ num_samples = preds.shape[0]
188
+ assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
189
+ assert stems is None or num_samples == len(set(stems))
190
+ for i in range(num_samples):
191
+ self.total_files += 1 # type: ignore
192
+ self.counter += 1
193
+ wav_len = int(sizes[i].item())
194
+ sample_rate = int(sample_rates[i].item())
195
+ pred_wav = preds[i]
196
+ target_wav = targets[i]
197
+ pred_wav = pred_wav[..., :wav_len]
198
+ target_wav = target_wav[..., :wav_len]
199
+ stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
200
+ # dump audio files
201
+ try:
202
+ pred_wav = convert_audio(
203
+ pred_wav.unsqueeze(0), from_rate=sample_rate,
204
+ to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
205
+ audio_write(
206
+ self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
207
+ format=self.format, strategy="peak")
208
+ except Exception as e:
209
+ logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
210
+ try:
211
+ # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
212
+ # the original audio when writing it
213
+ target_wav = convert_audio(
214
+ target_wav.unsqueeze(0), from_rate=sample_rate,
215
+ to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
216
+ audio_write(
217
+ self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
218
+ format=self.format, strategy="peak")
219
+ except Exception as e:
220
+ logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
221
+
222
+ def _get_samples_name(self, is_background: bool):
223
+ return 'background' if is_background else 'tests'
224
+
225
+ def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
226
+ if is_background:
227
+ input_samples_dir = self.samples_background_dir
228
+ input_filename = self.manifest_background
229
+ stats_name = self.stats_background_dir
230
+ else:
231
+ input_samples_dir = self.samples_tests_dir
232
+ input_filename = self.manifest_tests
233
+ stats_name = self.stats_tests_dir
234
+ beams_name = self._get_samples_name(is_background)
235
+ log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
236
+
237
+ logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
238
+ with open(input_filename, "w") as fout:
239
+ for path in Path(input_samples_dir).glob(f"*.{self.format}"):
240
+ fout.write(f"{str(path)}\n")
241
+
242
+ cmd = [
243
+ self.python_path, "-m",
244
+ "frechet_audio_distance.create_embeddings_main",
245
+ "--model_ckpt", f"{self.model_path}",
246
+ "--input_files", f"{str(input_filename)}",
247
+ "--stats", f"{str(stats_name)}",
248
+ ]
249
+ if self.batch_size is not None:
250
+ cmd += ["--batch_size", str(self.batch_size)]
251
+ logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
252
+ env = os.environ
253
+ if gpu_index is not None:
254
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
255
+ process = subprocess.Popen(
256
+ cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
257
+ return process, log_file
258
+
259
+ def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
260
+ cmd = [
261
+ self.python_path, "-m", "frechet_audio_distance.compute_fad",
262
+ "--test_stats", f"{str(self.stats_tests_dir)}",
263
+ "--background_stats", f"{str(self.stats_background_dir)}",
264
+ ]
265
+ logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
266
+ env = os.environ
267
+ if gpu_index is not None:
268
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
269
+ result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
270
+ if result.returncode:
271
+ logger.error(
272
+ "Error with FAD computation from stats: \n %s \n %s",
273
+ result.stdout.decode(), result.stderr.decode()
274
+ )
275
+ raise RuntimeError("Error while executing FAD computation from stats")
276
+ try:
277
+ # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
278
+ fad_score = float(result.stdout[4:])
279
+ return fad_score
280
+ except Exception as e:
281
+ raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
282
+
283
+ def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
284
+ beams_name = self._get_samples_name(is_background)
285
+ if returncode:
286
+ with open(log_file, "r") as f:
287
+ error_log = f.read()
288
+ logger.error(error_log)
289
+ os._exit(1)
290
+ else:
291
+ logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
292
+
293
+ def _parallel_create_embedding_beams(self, num_of_gpus: int):
294
+ assert num_of_gpus > 0
295
+ logger.info("Creating embeddings beams in a parallel manner on different GPUs")
296
+ tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
297
+ bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
298
+ tests_beams_code = tests_beams_process.wait()
299
+ bg_beams_code = bg_beams_process.wait()
300
+ self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
301
+ self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
302
+
303
+ def _sequential_create_embedding_beams(self):
304
+ logger.info("Creating embeddings beams in a sequential manner")
305
+ tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
306
+ tests_beams_code = tests_beams_process.wait()
307
+ self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
308
+ bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
309
+ bg_beams_code = bg_beams_process.wait()
310
+ self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
311
+
312
+ @flashy.distrib.rank_zero_only
313
+ def _local_compute_frechet_audio_distance(self):
314
+ """Compute Frechet Audio Distance score calling TensorFlow API."""
315
+ num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
316
+ if num_of_gpus > 1:
317
+ self._parallel_create_embedding_beams(num_of_gpus)
318
+ else:
319
+ self._sequential_create_embedding_beams()
320
+ fad_score = self._compute_fad_score(gpu_index=0)
321
+ return fad_score
322
+
323
+ def compute(self) -> float:
324
+ """Compute metrics."""
325
+ assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore
326
+ fad_score = self._local_compute_frechet_audio_distance()
327
+ logger.warning(f"FAD score = {fad_score}")
328
+ fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
329
+ return fad_score
audiocraft/metrics/kld.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import contextlib
8
+ from functools import partial
9
+ import logging
10
+ import os
11
+ import typing as tp
12
+
13
+ import torch
14
+ import torchmetrics
15
+
16
+ from ..data.audio_utils import convert_audio
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class _patch_passt_stft:
23
+ """Decorator to patch torch.stft in PaSST."""
24
+ def __init__(self):
25
+ self.old_stft = torch.stft
26
+
27
+ def __enter__(self):
28
+ # return_complex is a mandatory parameter in latest torch versions
29
+ # torch is throwing RuntimeErrors when not set
30
+ torch.stft = partial(torch.stft, return_complex=False)
31
+
32
+ def __exit__(self, *exc):
33
+ torch.stft = self.old_stft
34
+
35
+
36
+ def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
37
+ """Computes the elementwise KL-Divergence loss between probability distributions
38
+ from generated samples and target samples.
39
+
40
+ Args:
41
+ pred_probs (torch.Tensor): Probabilities for each label obtained
42
+ from a classifier on generated audio. Expected shape is [B, num_classes].
43
+ target_probs (torch.Tensor): Probabilities for each label obtained
44
+ from a classifier on target audio. Expected shape is [B, num_classes].
45
+ epsilon (float): Epsilon value.
46
+ Returns:
47
+ kld (torch.Tensor): KLD loss between each generated sample and target pair.
48
+ """
49
+ kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
50
+ return kl_div.sum(-1)
51
+
52
+
53
+ class KLDivergenceMetric(torchmetrics.Metric):
54
+ """Base implementation for KL Divergence metric.
55
+
56
+ The KL divergence is measured between probability distributions
57
+ of class predictions returned by a pre-trained audio classification model.
58
+ When the KL-divergence is low, the generated audio is expected to
59
+ have similar acoustic characteristics as the reference audio,
60
+ according to the classifier.
61
+ """
62
+ def __init__(self):
63
+ super().__init__()
64
+ self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
65
+ self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
66
+ self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
67
+ self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
68
+
69
+ def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
70
+ sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
71
+ """Get model output given provided input tensor.
72
+
73
+ Args:
74
+ x (torch.Tensor): Input audio tensor of shape [B, C, T].
75
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
76
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
77
+ Returns:
78
+ probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
79
+ """
80
+ raise NotImplementedError("implement method to extract label distributions from the model.")
81
+
82
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
83
+ sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
84
+ """Calculates running KL-Divergence loss between batches of audio
85
+ preds (generated) and target (ground-truth)
86
+ Args:
87
+ preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
88
+ targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
89
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
90
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
91
+ """
92
+ assert preds.shape == targets.shape
93
+ assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
94
+ preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
95
+ targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
96
+ if preds_probs is not None and targets_probs is not None:
97
+ assert preds_probs.shape == targets_probs.shape
98
+ kld_scores = kl_divergence(preds_probs, targets_probs)
99
+ assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
100
+ self.kld_pq_sum += torch.sum(kld_scores)
101
+ kld_qp_scores = kl_divergence(targets_probs, preds_probs)
102
+ self.kld_qp_sum += torch.sum(kld_qp_scores)
103
+ self.weight += torch.tensor(kld_scores.size(0))
104
+
105
+ def compute(self) -> dict:
106
+ """Computes KL-Divergence across all evaluated pred/target pairs."""
107
+ weight: float = float(self.weight.item()) # type: ignore
108
+ assert weight > 0, "Unable to compute with total number of comparisons <= 0"
109
+ logger.info(f"Computing KL divergence on a total of {weight} samples")
110
+ kld_pq = self.kld_pq_sum.item() / weight # type: ignore
111
+ kld_qp = self.kld_qp_sum.item() / weight # type: ignore
112
+ kld_both = kld_pq + kld_qp
113
+ return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
114
+
115
+
116
+ class PasstKLDivergenceMetric(KLDivergenceMetric):
117
+ """KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
118
+
119
+ From: PaSST: Efficient Training of Audio Transformers with Patchout
120
+ Paper: https://arxiv.org/abs/2110.05069
121
+ Implementation: https://github.com/kkoutini/PaSST
122
+
123
+ Follow instructions from the github repo:
124
+ ```
125
+ pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
126
+ ```
127
+
128
+ Args:
129
+ pretrained_length (float, optional): Audio duration used for the pretrained model.
130
+ """
131
+ def __init__(self, pretrained_length: tp.Optional[float] = None):
132
+ super().__init__()
133
+ self._initialize_model(pretrained_length)
134
+
135
+ def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
136
+ """Initialize underlying PaSST audio classifier."""
137
+ model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
138
+ self.min_input_frames = min_frames
139
+ self.max_input_frames = max_frames
140
+ self.model_sample_rate = sr
141
+ self.model = model
142
+ self.model.eval()
143
+ self.model.to(self.device)
144
+
145
+ def _load_base_model(self, pretrained_length: tp.Optional[float]):
146
+ """Load pretrained model from PaSST."""
147
+ try:
148
+ if pretrained_length == 30:
149
+ from hear21passt.base30sec import get_basic_model # type: ignore
150
+ max_duration = 30
151
+ elif pretrained_length == 20:
152
+ from hear21passt.base20sec import get_basic_model # type: ignore
153
+ max_duration = 20
154
+ else:
155
+ from hear21passt.base import get_basic_model # type: ignore
156
+ # Original PASST was trained on AudioSet with 10s-long audio samples
157
+ max_duration = 10
158
+ min_duration = 0.15
159
+ min_duration = 0.15
160
+ except ModuleNotFoundError:
161
+ raise ModuleNotFoundError(
162
+ "Please install hear21passt to compute KL divergence: ",
163
+ "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
164
+ )
165
+ model_sample_rate = 32_000
166
+ max_input_frames = int(max_duration * model_sample_rate)
167
+ min_input_frames = int(min_duration * model_sample_rate)
168
+ with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
169
+ model = get_basic_model(mode='logits')
170
+ return model, model_sample_rate, max_input_frames, min_input_frames
171
+
172
+ def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
173
+ """Process audio to feed to the pretrained model."""
174
+ wav = wav.unsqueeze(0)
175
+ wav = wav[..., :wav_len]
176
+ wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
177
+ wav = wav.squeeze(0)
178
+ # we don't pad but return a list of audio segments as this otherwise affects the KLD computation
179
+ segments = torch.split(wav, self.max_input_frames, dim=-1)
180
+ valid_segments = []
181
+ for s in segments:
182
+ # ignoring too small segments that are breaking the model inference
183
+ if s.size(-1) > self.min_input_frames:
184
+ valid_segments.append(s)
185
+ return [s[None] for s in valid_segments]
186
+
187
+ def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
188
+ """Run the pretrained model and get the predictions."""
189
+ assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
190
+ wav = wav.mean(dim=1)
191
+ # PaSST is printing a lot of garbage that we are not interested in
192
+ with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
193
+ with torch.no_grad(), _patch_passt_stft():
194
+ logits = self.model(wav.to(self.device))
195
+ probs = torch.softmax(logits, dim=-1)
196
+ return probs
197
+
198
+ def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
199
+ sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
200
+ """Get model output given provided input tensor.
201
+
202
+ Args:
203
+ x (torch.Tensor): Input audio tensor of shape [B, C, T].
204
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
205
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
206
+ Returns:
207
+ probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
208
+ """
209
+ all_probs: tp.List[torch.Tensor] = []
210
+ for i, wav in enumerate(x):
211
+ sample_rate = int(sample_rates[i].item())
212
+ wav_len = int(sizes[i].item())
213
+ wav_segments = self._process_audio(wav, sample_rate, wav_len)
214
+ for segment in wav_segments:
215
+ probs = self._get_model_preds(segment).mean(dim=0)
216
+ all_probs.append(probs)
217
+ if len(all_probs) > 0:
218
+ return torch.stack(all_probs, dim=0)
219
+ else:
220
+ return None
audiocraft/metrics/rvm.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+ import torch
9
+ from torch import nn
10
+ import torchaudio
11
+
12
+
13
+ def db_to_scale(volume: tp.Union[float, torch.Tensor]):
14
+ return 10 ** (volume / 20)
15
+
16
+
17
+ def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
18
+ min_scale = db_to_scale(min_volume)
19
+ return 20 * torch.log10(scale.clamp(min=min_scale))
20
+
21
+
22
+ class RelativeVolumeMel(nn.Module):
23
+ """Relative volume melspectrogram measure.
24
+
25
+ Computes a measure of distance over two mel spectrogram that is interpretable in terms
26
+ of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
27
+ first renormalize both by the ground truth of `x_ref`.
28
+
29
+ ..Warning:: This class returns the volume of the distortion at the spectrogram level,
30
+ e.g. low negative values reflects lower distortion levels. For a SNR (like reported
31
+ in the MultiBandDiffusion paper), just take `-rvm`.
32
+
33
+ Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
34
+ relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
35
+ clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
36
+ with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
37
+ Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
38
+ average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
39
+ good (for a neural network output, although sound engineers typically aim for much lower attenuations).
40
+ Similarly, anything above +30 dB would just be completely missing the target, and there is no point
41
+ in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
42
+ in line with what neural nets currently can achieve.
43
+
44
+ For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
45
+ the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
46
+
47
+ The metric can be aggregated over a given frequency band in order have different insights for
48
+ different region of the spectrum. `num_aggregated_bands` controls the number of bands.
49
+
50
+ ..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
51
+ is numerically stable when computing its gradient. We thus advise against using it as a training loss.
52
+
53
+ Args:
54
+ sample_rate (int): Sample rate of the input audio.
55
+ n_mels (int): Number of mel bands to use.
56
+ n_fft (int): Number of frequency bins for the STFT.
57
+ hop_length (int): Hop length of the STFT and the mel-spectrogram.
58
+ min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
59
+ the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
60
+ max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
61
+ max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
62
+ to that amount, to avoid rescaling near silence. Given in dB.
63
+ min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
64
+ bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
65
+ and anything below that will be considered equally.
66
+ num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
67
+ For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
68
+ """
69
+ def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
70
+ hop_length: int = 128, min_relative_volume: float = -25,
71
+ max_relative_volume: float = 25, max_initial_gain: float = 25,
72
+ min_activity_volume: float = -25,
73
+ num_aggregated_bands: int = 4) -> None:
74
+ super().__init__()
75
+ self.melspec = torchaudio.transforms.MelSpectrogram(
76
+ n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
77
+ normalized=True, sample_rate=sample_rate, power=2)
78
+ self.min_relative_volume = min_relative_volume
79
+ self.max_relative_volume = max_relative_volume
80
+ self.max_initial_gain = max_initial_gain
81
+ self.min_activity_volume = min_activity_volume
82
+ self.num_aggregated_bands = num_aggregated_bands
83
+
84
+ def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
85
+ """Compute RVM metric between estimate and reference samples.
86
+
87
+ Args:
88
+ estimate (torch.Tensor): Estimate sample.
89
+ ground_truth (torch.Tensor): Reference sample.
90
+
91
+ Returns:
92
+ dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
93
+ for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
94
+ """
95
+ min_scale = db_to_scale(-self.max_initial_gain)
96
+ std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
97
+ z_gt = self.melspec(ground_truth / std).sqrt()
98
+ z_est = self.melspec(estimate / std).sqrt()
99
+
100
+ delta = z_gt - z_est
101
+ ref_db = scale_to_db(z_gt, self.min_activity_volume)
102
+ delta_db = scale_to_db(delta.abs(), min_volume=-120)
103
+ relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
104
+ dims = list(range(relative_db.dim()))
105
+ dims.remove(dims[-2])
106
+ losses_per_band = relative_db.mean(dim=dims)
107
+ aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
108
+ metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
109
+ metrics['rvm'] = losses_per_band.mean()
110
+ return metrics
audiocraft/metrics/visqol.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import csv
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ import tempfile
12
+ import typing as tp
13
+ import subprocess
14
+ import shutil
15
+
16
+ import torch
17
+ import torchaudio
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ViSQOL:
23
+ """ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
24
+
25
+ To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
26
+ instructions available in the open source repository: https://github.com/google/visqol
27
+
28
+ ViSQOL is capable of running in two modes:
29
+
30
+ Audio Mode:
31
+ When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
32
+ Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
33
+ Audio mode uses support vector regression, with the maximum range at ~4.75.
34
+
35
+ Speech Mode:
36
+ When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
37
+ Input should be resampled to 16kHz.
38
+ As part of the speech mode processing, a root mean square implementation for voice activity detection
39
+ is performed on the reference signal to determine what parts of the signal have voice activity and
40
+ should therefore be included in the comparison. The signal is normalized before performing the voice
41
+ activity detection.
42
+ Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
43
+ Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
44
+
45
+ For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
46
+
47
+ Args:
48
+ visqol_bin (str): Path to the ViSQOL binary.
49
+ mode (str): ViSQOL computation mode, expecting "audio" or "speech".
50
+ model (str): Name of the model to use for similarity to quality model.
51
+ debug (bool): Whether to also get debug metrics from ViSQOL or not.
52
+ """
53
+ SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
54
+ ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
55
+
56
+ def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
57
+ model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
58
+ assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
59
+ self.visqol_bin = str(bin)
60
+ self.visqol_mode = mode
61
+ self.target_sr = self._get_target_sr(self.visqol_mode)
62
+ self.model = model
63
+ self.debug = debug
64
+ assert Path(self.visqol_model).exists(), \
65
+ f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
66
+
67
+ def _get_target_sr(self, mode: str) -> int:
68
+ # returns target sampling rate for the corresponding ViSQOL mode.
69
+ if mode not in ViSQOL.SAMPLE_RATES_MODES:
70
+ raise ValueError(
71
+ f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
72
+ )
73
+ return ViSQOL.SAMPLE_RATES_MODES[mode]
74
+
75
+ def _prepare_files(
76
+ self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
77
+ ):
78
+ # prepare files for ViSQOL evaluation.
79
+ assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
80
+ assert len(ref_sig) == len(deg_sig), (
81
+ "Expects same number of ref and degraded inputs",
82
+ f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
83
+ )
84
+ # resample audio if needed
85
+ if sr != target_sr:
86
+ transform = torchaudio.transforms.Resample(sr, target_sr)
87
+ pad = int(0.5 * target_sr)
88
+ rs_ref = []
89
+ rs_deg = []
90
+ for i in range(len(ref_sig)):
91
+ rs_ref_i = transform(ref_sig[i])
92
+ rs_deg_i = transform(deg_sig[i])
93
+ if pad_with_silence:
94
+ rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
95
+ rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
96
+ rs_ref.append(rs_ref_i)
97
+ rs_deg.append(rs_deg_i)
98
+ ref_sig = torch.stack(rs_ref)
99
+ deg_sig = torch.stack(rs_deg)
100
+ # save audio chunks to tmp dir and create csv
101
+ tmp_dir = Path(tempfile.mkdtemp())
102
+ try:
103
+ tmp_input_csv_path = tmp_dir / "input.csv"
104
+ tmp_results_csv_path = tmp_dir / "results.csv"
105
+ tmp_debug_json_path = tmp_dir / "debug.json"
106
+ with open(tmp_input_csv_path, "w") as csv_file:
107
+ csv_writer = csv.writer(csv_file)
108
+ csv_writer.writerow(["reference", "degraded"])
109
+ for i in range(len(ref_sig)):
110
+ tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
111
+ tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
112
+ torchaudio.save(
113
+ tmp_ref_filename,
114
+ torch.clamp(ref_sig[i], min=-0.99, max=0.99),
115
+ sample_rate=target_sr,
116
+ bits_per_sample=16,
117
+ encoding="PCM_S"
118
+ )
119
+ torchaudio.save(
120
+ tmp_deg_filename,
121
+ torch.clamp(deg_sig[i], min=-0.99, max=0.99),
122
+ sample_rate=target_sr,
123
+ bits_per_sample=16,
124
+ encoding="PCM_S"
125
+ )
126
+ csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
127
+ return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
128
+ except Exception as e:
129
+ logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
130
+ return tmp_dir, None, None, None
131
+
132
+ def _flush_files(self, tmp_dir: tp.Union[Path, str]):
133
+ # flush tmp files used to compute ViSQOL.
134
+ shutil.rmtree(str(tmp_dir))
135
+
136
+ def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
137
+ # collect results for each evaluated pair and return averaged moslqo score.
138
+ with open(results_csv_path, "r") as csv_file:
139
+ reader = csv.DictReader(csv_file)
140
+ moslqo_scores = [float(row["moslqo"]) for row in reader]
141
+ if len(moslqo_scores) > 0:
142
+ return sum(moslqo_scores) / len(moslqo_scores)
143
+ else:
144
+ return 0.0
145
+
146
+ def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
147
+ # collect debug data for the visqol inference.
148
+ with open(debug_json_path, "r") as f:
149
+ data = json.load(f)
150
+ return data
151
+
152
+ @property
153
+ def visqol_model(self):
154
+ return f'{self.visqol_bin}/model/{self.model}'
155
+
156
+ def _run_visqol(
157
+ self,
158
+ input_csv_path: tp.Union[Path, str],
159
+ results_csv_path: tp.Union[Path, str],
160
+ debug_csv_path: tp.Optional[tp.Union[Path, str]],
161
+ ):
162
+ input_csv_path = str(input_csv_path)
163
+ results_csv_path = str(results_csv_path)
164
+ debug_csv_path = str(debug_csv_path)
165
+ cmd = [
166
+ f'{self.visqol_bin}/bazel-bin/visqol',
167
+ '--batch_input_csv', f'{input_csv_path}',
168
+ '--results_csv', f'{results_csv_path}'
169
+ ]
170
+ if debug_csv_path is not None:
171
+ cmd += ['--output_debug', f'{debug_csv_path}']
172
+ if self.visqol_mode == "speech":
173
+ cmd += ['--use_speech_mode']
174
+ cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
175
+ result = subprocess.run(cmd, capture_output=True)
176
+ if result.returncode:
177
+ logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
178
+ raise RuntimeError("Error while executing visqol")
179
+ result.check_returncode()
180
+
181
+ def __call__(
182
+ self,
183
+ ref_sig: torch.Tensor,
184
+ deg_sig: torch.Tensor,
185
+ sr: int,
186
+ pad_with_silence: bool = False,
187
+ ):
188
+ """Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
189
+ Args:
190
+ ref_sig (torch.Tensor): Reference signals as [B, C, T].
191
+ deg_sig (torch.Tensor): Degraded signals as [B, C, T].
192
+ sr (int): Sample rate of the two audio signals.
193
+ pad_with_silence (bool): Whether to pad the file with silences as recommended
194
+ in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
195
+ Returns:
196
+ float: The ViSQOL score or mean score for the batch.
197
+ """
198
+ logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
199
+ tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
200
+ ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
201
+ )
202
+ try:
203
+ if input_csv and results_csv:
204
+ self._run_visqol(
205
+ input_csv,
206
+ results_csv,
207
+ debug_json if self.debug else None,
208
+ )
209
+ mosqol = self._collect_moslqo_score(results_csv)
210
+ return mosqol
211
+ else:
212
+ raise RuntimeError("Something unexpected happened when running VISQOL!")
213
+ except Exception as e:
214
+ logger.error("Exception occurred when running ViSQOL: %s", e)
215
+ finally:
216
+ self._flush_files(tmp_dir)
audiocraft/models/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
8
+ """
9
+ # flake8: noqa
10
+ from . import builders, loaders
11
+ from .encodec import (
12
+ CompressionModel, EncodecModel, DAC,
13
+ HFEncodecModel, HFEncodecCompressionModel)
14
+ from .audiogen import AudioGen
15
+ from .lm import LMModel
16
+ from .multibanddiffusion import MultiBandDiffusion
17
+ from .vidmuse import VidMuse
18
+ from .unet import DiffusionUnet
audiocraft/models/audiogen.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Main model for using AudioGen. This will combine all the required components
9
+ and provide easy access to the generation API.
10
+ """
11
+
12
+ import typing as tp
13
+
14
+ import torch
15
+
16
+ from .encodec import CompressionModel
17
+ from .lm import LMModel
18
+ from .builders import get_debug_compression_model, get_debug_lm_model
19
+ from .loaders import load_compression_model, load_lm_model
20
+ from ..data.audio_utils import convert_audio
21
+ from ..modules.conditioners import ConditioningAttributes
22
+ from ..utils.autocast import TorchAutocast
23
+
24
+
25
+ class AudioGen:
26
+ """AudioGen main model with convenient generation API.
27
+
28
+ Args:
29
+ name (str): name of the model.
30
+ compression_model (CompressionModel): Compression model
31
+ used to map audio to invertible discrete representations.
32
+ lm (LMModel): Language model over discrete representations.
33
+ max_duration (float, optional): maximum duration the model can produce,
34
+ otherwise, inferred from the training params.
35
+ """
36
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
37
+ max_duration: tp.Optional[float] = None):
38
+ self.name = name
39
+ self.compression_model = compression_model
40
+ self.lm = lm
41
+ # Just to be safe, let's put everything in eval mode.
42
+ self.compression_model.eval()
43
+ self.lm.eval()
44
+
45
+ if max_duration is None:
46
+ if hasattr(lm, 'cfg'):
47
+ max_duration = lm.cfg.dataset.segment_duration # type: ignore
48
+ else:
49
+ raise ValueError("You must provide max_duration when building directly AudioGen")
50
+ assert max_duration is not None
51
+ self.max_duration: float = max_duration
52
+ self.device = next(iter(lm.parameters())).device
53
+ self.generation_params: dict = {}
54
+ self.set_generation_params(duration=5) # 5 seconds by default
55
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
56
+ if self.device.type == 'cpu':
57
+ self.autocast = TorchAutocast(enabled=False)
58
+ else:
59
+ self.autocast = TorchAutocast(
60
+ enabled=True, device_type=self.device.type, dtype=torch.float16)
61
+
62
+ @property
63
+ def frame_rate(self) -> float:
64
+ """Roughly the number of AR steps per seconds."""
65
+ return self.compression_model.frame_rate
66
+
67
+ @property
68
+ def sample_rate(self) -> int:
69
+ """Sample rate of the generated audio."""
70
+ return self.compression_model.sample_rate
71
+
72
+ @property
73
+ def audio_channels(self) -> int:
74
+ """Audio channels of the generated audio."""
75
+ return self.compression_model.channels
76
+
77
+ @staticmethod
78
+ def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
79
+ """Return pretrained model, we provide a single model for now:
80
+ - facebook/audiogen-medium (1.5B), text to sound,
81
+ # see: https://huggingface.co/facebook/audiogen-medium
82
+ """
83
+ if device is None:
84
+ if torch.cuda.device_count():
85
+ device = 'cuda'
86
+ else:
87
+ device = 'cpu'
88
+
89
+ if name == 'debug':
90
+ # used only for unit tests
91
+ compression_model = get_debug_compression_model(device, sample_rate=16000)
92
+ lm = get_debug_lm_model(device)
93
+ return AudioGen(name, compression_model, lm, max_duration=10)
94
+
95
+ compression_model = load_compression_model(name, device=device)
96
+ lm = load_lm_model(name, device=device)
97
+ assert 'self_wav' not in lm.condition_provider.conditioners, \
98
+ "AudioGen do not support waveform conditioning for now"
99
+ return AudioGen(name, compression_model, lm)
100
+
101
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
102
+ top_p: float = 0.0, temperature: float = 1.0,
103
+ duration: float = 10.0, cfg_coef: float = 3.0,
104
+ two_step_cfg: bool = False, extend_stride: float = 2):
105
+ """Set the generation parameters for AudioGen.
106
+
107
+ Args:
108
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
109
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
110
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
111
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
112
+ duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
113
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
114
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
115
+ instead of batching together the two. This has some impact on how things
116
+ are padded but seems to have little impact in practice.
117
+ extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
118
+ should we extend the audio each time. Larger values will mean less context is
119
+ preserved, and shorter value will require extra computations.
120
+ """
121
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
122
+ self.extend_stride = extend_stride
123
+ self.duration = duration
124
+ self.generation_params = {
125
+ 'use_sampling': use_sampling,
126
+ 'temp': temperature,
127
+ 'top_k': top_k,
128
+ 'top_p': top_p,
129
+ 'cfg_coef': cfg_coef,
130
+ 'two_step_cfg': two_step_cfg,
131
+ }
132
+
133
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
134
+ """Override the default progress callback."""
135
+ self._progress_callback = progress_callback
136
+
137
+ def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
138
+ """Generate samples conditioned on text.
139
+
140
+ Args:
141
+ descriptions (list of str): A list of strings used as text conditioning.
142
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
143
+ """
144
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
145
+ assert prompt_tokens is None
146
+ return self._generate_tokens(attributes, prompt_tokens, progress)
147
+
148
+ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
149
+ descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
150
+ progress: bool = False) -> torch.Tensor:
151
+ """Generate samples conditioned on audio prompts.
152
+
153
+ Args:
154
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
155
+ Prompt should be [B, C, T], or [C, T] if only one sample is generated.
156
+ prompt_sample_rate (int): Sampling rate of the given audio waveforms.
157
+ descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
158
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
159
+ """
160
+ if prompt.dim() == 2:
161
+ prompt = prompt[None]
162
+ if prompt.dim() != 3:
163
+ raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
164
+ prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
165
+ if descriptions is None:
166
+ descriptions = [None] * len(prompt)
167
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
168
+ assert prompt_tokens is not None
169
+ return self._generate_tokens(attributes, prompt_tokens, progress)
170
+
171
+ @torch.no_grad()
172
+ def _prepare_tokens_and_attributes(
173
+ self,
174
+ descriptions: tp.Sequence[tp.Optional[str]],
175
+ prompt: tp.Optional[torch.Tensor],
176
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
177
+ """Prepare model inputs.
178
+
179
+ Args:
180
+ descriptions (list of str): A list of strings used as text conditioning.
181
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
182
+ """
183
+ attributes = [
184
+ ConditioningAttributes(text={'description': description})
185
+ for description in descriptions]
186
+
187
+ if prompt is not None:
188
+ if descriptions is not None:
189
+ assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
190
+ prompt = prompt.to(self.device)
191
+ prompt_tokens, scale = self.compression_model.encode(prompt)
192
+ assert scale is None
193
+ else:
194
+ prompt_tokens = None
195
+ return attributes, prompt_tokens
196
+
197
+ def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
198
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
199
+ """Generate discrete audio tokens given audio prompt and/or conditions.
200
+
201
+ Args:
202
+ attributes (list of ConditioningAttributes): Conditions used for generation (here text).
203
+ prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
204
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
205
+ Returns:
206
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
207
+ """
208
+ total_gen_len = int(self.duration * self.frame_rate)
209
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
210
+ current_gen_offset: int = 0
211
+
212
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
213
+ generated_tokens += current_gen_offset
214
+ if self._progress_callback is not None:
215
+ # Note that total_gen_len might be quite wrong depending on the
216
+ # codebook pattern used, but with delay it is almost accurate.
217
+ self._progress_callback(generated_tokens, total_gen_len)
218
+ else:
219
+ print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
220
+
221
+ if prompt_tokens is not None:
222
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
223
+ "Prompt is longer than audio to generate"
224
+
225
+ callback = None
226
+ if progress:
227
+ callback = _progress_callback
228
+
229
+ if self.duration <= self.max_duration:
230
+ # generate by sampling from LM, simple case.
231
+ with self.autocast:
232
+ gen_tokens = self.lm.generate(
233
+ prompt_tokens, attributes,
234
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
235
+
236
+ else:
237
+ all_tokens = []
238
+ if prompt_tokens is None:
239
+ prompt_length = 0
240
+ else:
241
+ all_tokens.append(prompt_tokens)
242
+ prompt_length = prompt_tokens.shape[-1]
243
+
244
+ stride_tokens = int(self.frame_rate * self.extend_stride)
245
+ while current_gen_offset + prompt_length < total_gen_len:
246
+ time_offset = current_gen_offset / self.frame_rate
247
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
248
+ max_gen_len = int(chunk_duration * self.frame_rate)
249
+ with self.autocast:
250
+ gen_tokens = self.lm.generate(
251
+ prompt_tokens, attributes,
252
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
253
+ if prompt_tokens is None:
254
+ all_tokens.append(gen_tokens)
255
+ else:
256
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
257
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
258
+ prompt_length = prompt_tokens.shape[-1]
259
+ current_gen_offset += stride_tokens
260
+
261
+ gen_tokens = torch.cat(all_tokens, dim=-1)
262
+
263
+ # generate audio
264
+ assert gen_tokens.dim() == 3
265
+ with torch.no_grad():
266
+ gen_audio = self.compression_model.decode(gen_tokens, None)
267
+ return gen_audio
audiocraft/models/builders.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ All the functions to build the relevant models and modules
9
+ from the Hydra config.
10
+ """
11
+
12
+ import typing as tp
13
+
14
+ import audiocraft
15
+ import omegaconf
16
+ import torch
17
+
18
+ from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
19
+ from .lm import LMModel
20
+ from ..modules.codebooks_patterns import (
21
+ CodebooksPatternProvider,
22
+ DelayedPatternProvider,
23
+ MusicLMPattern,
24
+ ParallelPatternProvider,
25
+ UnrolledPatternProvider,
26
+ CoarseFirstPattern,
27
+ )
28
+ from ..modules.conditioners import (
29
+ BaseConditioner,
30
+ ChromaStemConditioner,
31
+ CLAPEmbeddingConditioner,
32
+ ConditionFuser,
33
+ ConditioningProvider,
34
+ LUTConditioner,
35
+ T5Conditioner,
36
+ )
37
+ # T5Conditioner
38
+ from .unet import DiffusionUnet
39
+ from .. import quantization as qt
40
+ from ..utils.utils import dict_from_config
41
+ from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
42
+ from omegaconf import OmegaConf
43
+
44
+ def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
45
+ klass = {
46
+ 'no_quant': qt.DummyQuantizer,
47
+ 'rvq': qt.ResidualVectorQuantizer
48
+ }[quantizer]
49
+ kwargs = dict_from_config(getattr(cfg, quantizer))
50
+ if quantizer != 'no_quant':
51
+ kwargs['dimension'] = dimension
52
+ return klass(**kwargs)
53
+
54
+
55
+ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
56
+ if encoder_name == 'seanet':
57
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
58
+ encoder_override_kwargs = kwargs.pop('encoder')
59
+ decoder_override_kwargs = kwargs.pop('decoder')
60
+ encoder_kwargs = {**kwargs, **encoder_override_kwargs}
61
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
62
+ encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
63
+ decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
64
+ return encoder, decoder
65
+ else:
66
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
67
+
68
+
69
+ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
70
+ """Instantiate a compression model."""
71
+ # cfg=eval(cfg)
72
+ cfg = OmegaConf.create(cfg)
73
+
74
+ if cfg.compression_model == 'encodec':
75
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
76
+ encoder_name = kwargs.pop('autoencoder')
77
+ quantizer_name = kwargs.pop('quantizer')
78
+ encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
79
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
80
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
81
+ renormalize = kwargs.pop('renormalize', False)
82
+ # deprecated params
83
+ kwargs.pop('renorm', None)
84
+ return EncodecModel(encoder, decoder, quantizer,
85
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
86
+ else:
87
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
88
+
89
+
90
+ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
91
+ """Instantiate a transformer LM."""
92
+
93
+ if cfg.lm_model == 'transformer_lm':
94
+ kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
95
+ n_q = kwargs['n_q']
96
+ q_modeling = kwargs.pop('q_modeling', None)
97
+ codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
98
+ attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
99
+ cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
100
+ cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
101
+ fuser = get_condition_fuser(cfg)
102
+ condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
103
+ if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
104
+ kwargs['cross_attention'] = True
105
+ if codebooks_pattern_cfg.modeling is None:
106
+ assert q_modeling is not None, \
107
+ "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
108
+ codebooks_pattern_cfg = omegaconf.OmegaConf.create(
109
+ {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
110
+ )
111
+
112
+ pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
113
+ return LMModel(
114
+ pattern_provider=pattern_provider,
115
+ condition_provider=condition_provider,
116
+
117
+ visual_encoder=cfg.video.visual_encoder,
118
+ if_add_gobal=cfg.video.add_global.if_add_gobal,
119
+
120
+ fuser=fuser,
121
+ cfg_dropout=cfg_prob,
122
+ cfg_coef=cfg_coef,
123
+ attribute_dropout=attribute_dropout,
124
+ dtype=getattr(torch, cfg.dtype),
125
+ device=cfg.device,
126
+ **kwargs
127
+ ).to(cfg.device)
128
+ else:
129
+ raise KeyError(f"Unexpected LM model {cfg.lm_model}")
130
+
131
+
132
+ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
133
+ """Instantiate a conditioning model."""
134
+ device = cfg.device
135
+ duration = cfg.dataset.segment_duration
136
+ cfg = getattr(cfg, 'conditioners')
137
+ dict_cfg = {} if cfg is None else dict_from_config(cfg)
138
+ conditioners: tp.Dict[str, BaseConditioner] = {}
139
+ condition_provider_args = dict_cfg.pop('args', {})
140
+ condition_provider_args.pop('merge_text_conditions_p', None)
141
+ condition_provider_args.pop('drop_desc_p', None)
142
+
143
+ for cond, cond_cfg in dict_cfg.items():
144
+ model_type = cond_cfg['model']
145
+ model_args = cond_cfg[model_type]
146
+ if model_type == 't5':
147
+ conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
148
+ elif model_type == 'lut':
149
+ conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
150
+ elif model_type == 'chroma_stem':
151
+ conditioners[str(cond)] = ChromaStemConditioner(
152
+ output_dim=output_dim,
153
+ duration=duration,
154
+ device=device,
155
+ **model_args
156
+ )
157
+ elif model_type == 'clap':
158
+ conditioners[str(cond)] = CLAPEmbeddingConditioner(
159
+ output_dim=output_dim,
160
+ device=device,
161
+ **model_args
162
+ )
163
+ else:
164
+ raise ValueError(f"Unrecognized conditioning model: {model_type}")
165
+ conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
166
+ return conditioner
167
+
168
+
169
+ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
170
+ """Instantiate a condition fuser object."""
171
+ fuser_cfg = getattr(cfg, 'fuser')
172
+ fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
173
+ fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
174
+ kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
175
+ fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
176
+ return fuser
177
+
178
+
179
+ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
180
+ """Instantiate a codebooks pattern provider object."""
181
+ pattern_providers = {
182
+ 'parallel': ParallelPatternProvider,
183
+ 'delay': DelayedPatternProvider,
184
+ 'unroll': UnrolledPatternProvider,
185
+ 'coarse_first': CoarseFirstPattern,
186
+ 'musiclm': MusicLMPattern,
187
+ }
188
+ name = cfg.modeling
189
+ kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
190
+ klass = pattern_providers[name]
191
+ return klass(n_q, **kwargs)
192
+
193
+
194
+ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
195
+ """Instantiate a debug compression model to be used for unit tests."""
196
+ assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
197
+ model_ratios = {
198
+ 16000: [10, 8, 8], # 25 Hz at 16kHz
199
+ 32000: [10, 8, 16] # 25 Hz at 32kHz
200
+ }
201
+ ratios: tp.List[int] = model_ratios[sample_rate]
202
+ frame_rate = 25
203
+ seanet_kwargs: dict = {
204
+ 'n_filters': 4,
205
+ 'n_residual_layers': 1,
206
+ 'dimension': 32,
207
+ 'ratios': ratios,
208
+ }
209
+ encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
210
+ decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
211
+ quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
212
+ init_x = torch.randn(8, 32, 128)
213
+ quantizer(init_x, 1) # initialize kmeans etc.
214
+ compression_model = EncodecModel(
215
+ encoder, decoder, quantizer,
216
+ frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
217
+ return compression_model.eval()
218
+
219
+
220
+ def get_diffusion_model(cfg: omegaconf.DictConfig):
221
+ # TODO Find a way to infer the channels from dset
222
+ channels = cfg.channels
223
+ num_steps = cfg.schedule.num_steps
224
+ return DiffusionUnet(
225
+ chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
226
+
227
+
228
+ def get_processor(cfg, sample_rate: int = 24000):
229
+ sample_processor = SampleProcessor()
230
+ if cfg.use:
231
+ kw = dict(cfg)
232
+ kw.pop('use')
233
+ kw.pop('name')
234
+ if cfg.name == "multi_band_processor":
235
+ sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
236
+ return sample_processor
237
+
238
+
239
+ def get_debug_lm_model(device='cpu'):
240
+ """Instantiate a debug LM to be used for unit tests."""
241
+ pattern = DelayedPatternProvider(n_q=4)
242
+ dim = 16
243
+ providers = {
244
+ 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
245
+ }
246
+ condition_provider = ConditioningProvider(providers)
247
+ fuser = ConditionFuser(
248
+ {'cross': ['description'], 'prepend': [],
249
+ 'sum': [], 'input_interpolate': []})
250
+ lm = LMModel(
251
+ pattern, condition_provider, fuser,
252
+ n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
253
+ cross_attention=True, causal=True)
254
+ return lm.to(device).eval()
255
+
256
+
257
+ def get_wrapped_compression_model(
258
+ compression_model: CompressionModel,
259
+ cfg: omegaconf.DictConfig) -> CompressionModel:
260
+ if hasattr(cfg, 'interleave_stereo_codebooks'):
261
+ if cfg.interleave_stereo_codebooks.use:
262
+ kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
263
+ kwargs.pop('use')
264
+ compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
265
+ if hasattr(cfg, 'compression_model_n_q'):
266
+ if cfg.compression_model_n_q is not None:
267
+ compression_model.set_num_codebooks(cfg.compression_model_n_q)
268
+ return compression_model
audiocraft/models/encodec.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Compression models or wrapper around existing models.
7
+ Also defines the main interface that a model must follow to be usable as an audio tokenizer.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ import logging
12
+ import math
13
+ from pathlib import Path
14
+ import typing as tp
15
+
16
+ from einops import rearrange
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+ from transformers import EncodecModel as HFEncodecModel
21
+ from .. import quantization as qt
22
+ # from semanticodec import SemantiCodec
23
+
24
+ logger = logging.getLogger()
25
+
26
+
27
+ class CompressionModel(ABC, nn.Module):
28
+ """Base API for all compression models that aim at being used as audio tokenizers
29
+ with a language model.
30
+ """
31
+
32
+ @abstractmethod
33
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
34
+ ...
35
+
36
+ @abstractmethod
37
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
38
+ """See `EncodecModel.encode`."""
39
+ ...
40
+
41
+ @abstractmethod
42
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
43
+ """See `EncodecModel.decode`."""
44
+ ...
45
+
46
+ @abstractmethod
47
+ def decode_latent(self, codes: torch.Tensor):
48
+ """Decode from the discrete codes to continuous latent space."""
49
+ ...
50
+
51
+ @property
52
+ @abstractmethod
53
+ def channels(self) -> int:
54
+ ...
55
+
56
+ @property
57
+ @abstractmethod
58
+ def frame_rate(self) -> float:
59
+ ...
60
+
61
+ @property
62
+ @abstractmethod
63
+ def sample_rate(self) -> int:
64
+ ...
65
+
66
+ @property
67
+ @abstractmethod
68
+ def cardinality(self) -> int:
69
+ ...
70
+
71
+ @property
72
+ @abstractmethod
73
+ def num_codebooks(self) -> int:
74
+ ...
75
+
76
+ @property
77
+ @abstractmethod
78
+ def total_codebooks(self) -> int:
79
+ ...
80
+
81
+ @abstractmethod
82
+ def set_num_codebooks(self, n: int):
83
+ """Set the active number of codebooks used by the quantizer."""
84
+ ...
85
+
86
+ @staticmethod
87
+ def get_pretrained(
88
+ name: str, device: tp.Union[torch.device, str] = 'cpu'
89
+ ) -> 'CompressionModel':
90
+ """Instantiate a CompressionModel from a given pretrained model.
91
+
92
+ Args:
93
+ name (Path or str): name of the pretrained model. See after.
94
+ device (torch.device or str): Device on which the model is loaded.
95
+
96
+ Pretrained models:
97
+ - dac_44khz (https://github.com/descriptinc/descript-audio-codec)
98
+ - dac_24khz (same)
99
+ - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
100
+ - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
101
+ - your own model on Hugging Face. Export instructions to come...
102
+ """
103
+
104
+ from . import builders, loaders
105
+ model: CompressionModel
106
+ if name in ['dac_44khz', 'dac_24khz']:
107
+ model_type = name.split('_')[1]
108
+ logger.info("Getting pretrained compression model from DAC %s", model_type)
109
+ model = DAC(model_type)
110
+
111
+ elif name in ['semantic_16khz']:
112
+ model_type = name.split('_')[1]
113
+ logger.info("Getting pretrained compression model from Semantic Codec %s", model_type)
114
+ model = Semantic_Codec(model_type)
115
+
116
+ elif name in ['debug_compression_model']:
117
+ logger.info("Getting pretrained compression model for debug")
118
+ model = builders.get_debug_compression_model()
119
+ elif Path(name).exists():
120
+ # We assume here if the path exists that it is in fact an AC checkpoint
121
+ # that was exported using `audiocraft.utils.export` functions.
122
+ model = loaders.load_compression_model(name, device=device)
123
+ else:
124
+ logger.info("Getting pretrained compression model from HF %s", name)
125
+ hf_model = HFEncodecModel.from_pretrained(name)
126
+ model = HFEncodecCompressionModel(hf_model).to(device)
127
+ return model.to(device).eval()
128
+
129
+
130
+ class EncodecModel(CompressionModel):
131
+ """Encodec model operating on the raw waveform.
132
+
133
+ Args:
134
+ encoder (nn.Module): Encoder network.
135
+ decoder (nn.Module): Decoder network.
136
+ quantizer (qt.BaseQuantizer): Quantizer network.
137
+ frame_rate (int): Frame rate for the latent representation.
138
+ sample_rate (int): Audio sample rate.
139
+ channels (int): Number of audio channels.
140
+ causal (bool): Whether to use a causal version of the model.
141
+ renormalize (bool): Whether to renormalize the audio before running the model.
142
+ """
143
+ # we need assignment to override the property in the abstract class,
144
+ # I couldn't find a better way...
145
+ frame_rate: float = 0
146
+ sample_rate: int = 0
147
+ channels: int = 0
148
+
149
+ def __init__(self,
150
+ encoder: nn.Module,
151
+ decoder: nn.Module,
152
+ quantizer: qt.BaseQuantizer,
153
+ frame_rate: int,
154
+ sample_rate: int,
155
+ channels: int,
156
+ causal: bool = False,
157
+ renormalize: bool = False):
158
+ super().__init__()
159
+ self.encoder = encoder
160
+ self.decoder = decoder
161
+ self.quantizer = quantizer
162
+ self.frame_rate = frame_rate
163
+ self.sample_rate = sample_rate
164
+ self.channels = channels
165
+ self.renormalize = renormalize
166
+ self.causal = causal
167
+ if self.causal:
168
+ # we force disabling here to avoid handling linear overlap of segments
169
+ # as supported in original EnCodec codebase.
170
+ assert not self.renormalize, 'Causal model does not support renormalize'
171
+
172
+ @property
173
+ def total_codebooks(self):
174
+ """Total number of quantizer codebooks available."""
175
+ return self.quantizer.total_codebooks
176
+
177
+ @property
178
+ def num_codebooks(self):
179
+ """Active number of codebooks used by the quantizer."""
180
+ return self.quantizer.num_codebooks
181
+
182
+ def set_num_codebooks(self, n: int):
183
+ """Set the active number of codebooks used by the quantizer."""
184
+ self.quantizer.set_num_codebooks(n)
185
+
186
+ @property
187
+ def cardinality(self):
188
+ """Cardinality of each codebook."""
189
+ return self.quantizer.bins
190
+
191
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
192
+ scale: tp.Optional[torch.Tensor]
193
+ if self.renormalize:
194
+ mono = x.mean(dim=1, keepdim=True)
195
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
196
+ scale = 1e-8 + volume
197
+ x = x / scale
198
+ scale = scale.view(-1, 1)
199
+ else:
200
+ scale = None
201
+ return x, scale
202
+
203
+ def postprocess(self,
204
+ x: torch.Tensor,
205
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
206
+ if scale is not None:
207
+ assert self.renormalize
208
+ x = x * scale.view(-1, 1, 1)
209
+ return x
210
+
211
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
212
+ assert x.dim() == 3
213
+ length = x.shape[-1]
214
+ x, scale = self.preprocess(x)
215
+
216
+ emb = self.encoder(x)
217
+ q_res = self.quantizer(emb, self.frame_rate)
218
+ out = self.decoder(q_res.x)
219
+
220
+ # remove extra padding added by the encoder and decoder
221
+ assert out.shape[-1] >= length, (out.shape[-1], length)
222
+ out = out[..., :length]
223
+
224
+ q_res.x = self.postprocess(out, scale)
225
+
226
+ return q_res
227
+
228
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
229
+ """Encode the given input tensor to quantized representation along with scale parameter.
230
+
231
+ Args:
232
+ x (torch.Tensor): Float tensor of shape [B, C, T]
233
+
234
+ Returns:
235
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
236
+ codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
237
+ scale: a float tensor containing the scale for audio renormalization.
238
+ """
239
+ assert x.dim() == 3
240
+ x, scale = self.preprocess(x)
241
+ emb = self.encoder(x)
242
+ codes = self.quantizer.encode(emb)
243
+ return codes, scale
244
+
245
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
246
+ """Decode the given codes to a reconstructed representation, using the scale to perform
247
+ audio denormalization if needed.
248
+
249
+ Args:
250
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
251
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
252
+
253
+ Returns:
254
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
255
+ """
256
+ emb = self.decode_latent(codes)
257
+ out = self.decoder(emb)
258
+ out = self.postprocess(out, scale)
259
+ # out contains extra padding added by the encoder and decoder
260
+ return out
261
+
262
+ def decode_latent(self, codes: torch.Tensor):
263
+ """Decode from the discrete codes to continuous latent space."""
264
+ return self.quantizer.decode(codes)
265
+
266
+
267
+ class DAC(CompressionModel):
268
+ def __init__(self, model_type: str = "44khz"):
269
+ super().__init__()
270
+ try:
271
+ import dac.utils
272
+ except ImportError:
273
+ raise RuntimeError("Could not import dac, make sure it is installed, "
274
+ "please run `pip install descript-audio-codec`")
275
+ self.model = dac.utils.load_model(model_type=model_type)
276
+ self.n_quantizers = self.total_codebooks
277
+ self.model.eval()
278
+
279
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
280
+ # We don't support training with this.
281
+ raise NotImplementedError("Forward and training with DAC not supported.")
282
+
283
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
284
+ codes = self.model.encode(x, self.n_quantizers)[1] # [4(B), 9(self.n_quantizers), 430(86*T)], x: [4(B), 1, 220500(44100*T)]
285
+
286
+ return codes[:, :self.n_quantizers], None
287
+
288
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
289
+ assert scale is None
290
+ z_q = self.decode_latent(codes)
291
+ return self.model.decode(z_q)
292
+
293
+ def decode_latent(self, codes: torch.Tensor):
294
+ """Decode from the discrete codes to continuous latent space."""
295
+ return self.model.quantizer.from_codes(codes)[0]
296
+
297
+ @property
298
+ def channels(self) -> int:
299
+ return 1
300
+
301
+ @property
302
+ def frame_rate(self) -> float:
303
+ return self.model.sample_rate / self.model.hop_length
304
+
305
+ @property
306
+ def sample_rate(self) -> int:
307
+ return self.model.sample_rate
308
+
309
+ @property
310
+ def cardinality(self) -> int:
311
+ return self.model.codebook_size
312
+
313
+ @property
314
+ def num_codebooks(self) -> int:
315
+ return self.n_quantizers
316
+
317
+ @property
318
+ def total_codebooks(self) -> int:
319
+ return self.model.n_codebooks
320
+
321
+ def set_num_codebooks(self, n: int):
322
+ """Set the active number of codebooks used by the quantizer.
323
+ """
324
+ assert n >= 1
325
+ assert n <= self.total_codebooks
326
+ self.n_quantizers = n
327
+
328
+
329
+
330
+
331
+ class Semantic_Codec(CompressionModel):
332
+ def __init__(self, model_type: str = "16khz"):
333
+ super().__init__()
334
+ try:
335
+ from semanticodec import SemantiCodec
336
+ except ImportError:
337
+ raise RuntimeError("Could not import semanticcodec, make sure it is installed, "
338
+ "please run `pip install git+https://github.com/haoheliu/SemantiCodec-inference.git`")
339
+ self.model = SemantiCodec(token_rate=100, semantic_vocab_size=16384)
340
+ # self.n_quantizers = self.total_codebooks
341
+ self.n_quantizers = 2
342
+ self.model.sample_rate = 16000
343
+ self.model.cardinality = 16384
344
+ self.model.frame_rate = 50
345
+ self.model.eval()
346
+
347
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
348
+ # We don't support training with this.
349
+ raise NotImplementedError("Forward and training with DAC not supported.")
350
+
351
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
352
+ codes = self.model.encode(x)
353
+ # codes = self.model.encode(x, self.n_quantizers)[1]
354
+ return codes[:, :self.n_quantizers], None
355
+
356
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
357
+ assert scale is None
358
+ z_q = self.decode_latent(codes)
359
+ return self.model.decode(z_q)
360
+
361
+ def decode_latent(self, codes: torch.Tensor):
362
+ """Decode from the discrete codes to continuous latent space."""
363
+ return self.model.quantizer.from_codes(codes)[0]
364
+
365
+ @property
366
+ def channels(self) -> int:
367
+ return 1
368
+
369
+ @property
370
+ def frame_rate(self) -> float:
371
+ return self.model.frame_rate
372
+ # return self.model.sample_rate / self.model.hop_length
373
+
374
+ @property
375
+ def sample_rate(self) -> int:
376
+ return self.model.sample_rate
377
+
378
+ @property
379
+ def cardinality(self) -> int:
380
+ return self.model.cardinality
381
+
382
+ @property
383
+ def num_codebooks(self) -> int:
384
+ return self.n_quantizers
385
+
386
+ @property
387
+ def total_codebooks(self) -> int:
388
+ return self.model.n_codebooks
389
+
390
+ def set_num_codebooks(self, n: int):
391
+ """Set the active number of codebooks used by the quantizer.
392
+ """
393
+ assert n >= 1
394
+ assert n <= self.total_codebooks
395
+ self.n_quantizers = n
396
+
397
+ class HFEncodecCompressionModel(CompressionModel):
398
+ """Wrapper around HuggingFace Encodec.
399
+ """
400
+ def __init__(self, model: HFEncodecModel):
401
+ super().__init__()
402
+ self.model = model
403
+ bws = self.model.config.target_bandwidths
404
+ num_codebooks = [
405
+ bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
406
+ for bw in bws
407
+ ]
408
+ deltas = [nc - int(nc) for nc in num_codebooks]
409
+ # Checking we didn't do some bad maths and we indeed have integers!
410
+ assert all(deltas) <= 1e-3, deltas
411
+ self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
412
+ self.set_num_codebooks(max(self.possible_num_codebooks))
413
+
414
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
415
+ # We don't support training with this.
416
+ raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
417
+
418
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
419
+ bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
420
+ bandwidth = self.model.config.target_bandwidths[bandwidth_index]
421
+ res = self.model.encode(x, None, bandwidth)
422
+ assert len(res[0]) == 1
423
+ assert len(res[1]) == 1
424
+ return res[0][0], res[1][0]
425
+
426
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
427
+ if scale is None:
428
+ scales = [None] # type: ignore
429
+ else:
430
+ scales = scale # type: ignore
431
+ res = self.model.decode(codes[None], scales)
432
+ return res[0]
433
+
434
+ def decode_latent(self, codes: torch.Tensor):
435
+ """Decode from the discrete codes to continuous latent space."""
436
+ return self.model.quantizer.decode(codes.transpose(0, 1))
437
+
438
+ @property
439
+ def channels(self) -> int:
440
+ return self.model.config.audio_channels
441
+
442
+ @property
443
+ def frame_rate(self) -> float:
444
+ hop_length = int(np.prod(self.model.config.upsampling_ratios))
445
+ return self.sample_rate / hop_length
446
+
447
+ @property
448
+ def sample_rate(self) -> int:
449
+ return self.model.config.sampling_rate
450
+
451
+ @property
452
+ def cardinality(self) -> int:
453
+ return self.model.config.codebook_size
454
+
455
+ @property
456
+ def num_codebooks(self) -> int:
457
+ return self._num_codebooks
458
+
459
+ @property
460
+ def total_codebooks(self) -> int:
461
+ return max(self.possible_num_codebooks)
462
+
463
+ def set_num_codebooks(self, n: int):
464
+ """Set the active number of codebooks used by the quantizer.
465
+ """
466
+ if n not in self.possible_num_codebooks:
467
+ raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
468
+ self._num_codebooks = n
469
+
470
+
471
+ class InterleaveStereoCompressionModel(CompressionModel):
472
+ """Wraps a CompressionModel to support stereo inputs. The wrapped model
473
+ will be applied independently to the left and right channels, and both codebooks
474
+ will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
475
+ channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
476
+ `per_timestep`.
477
+
478
+ Args:
479
+ model (CompressionModel): Compression model to wrap.
480
+ per_timestep (bool): Whether to interleave on the timestep dimension
481
+ or on the codebooks dimension.
482
+ """
483
+ def __init__(self, model: CompressionModel, per_timestep: bool = False):
484
+ super().__init__()
485
+ self.model = model
486
+ self.per_timestep = per_timestep
487
+ assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
488
+
489
+ @property
490
+ def total_codebooks(self):
491
+ return self.model.total_codebooks
492
+
493
+ @property
494
+ def num_codebooks(self):
495
+ """Active number of codebooks used by the quantizer.
496
+
497
+ ..Warning:: this reports the number of codebooks after the interleaving
498
+ of the codebooks!
499
+ """
500
+ return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
501
+
502
+ def set_num_codebooks(self, n: int):
503
+ """Set the active number of codebooks used by the quantizer.
504
+
505
+ ..Warning:: this sets the number of codebooks before the interleaving!
506
+ """
507
+ self.model.set_num_codebooks(n)
508
+
509
+ @property
510
+ def num_virtual_steps(self) -> float:
511
+ """Return the number of virtual steps, e.g. one real step
512
+ will be split into that many steps.
513
+ """
514
+ return 2 if self.per_timestep else 1
515
+
516
+ @property
517
+ def frame_rate(self) -> float:
518
+ return self.model.frame_rate * self.num_virtual_steps
519
+
520
+ @property
521
+ def sample_rate(self) -> int:
522
+ return self.model.sample_rate
523
+
524
+ @property
525
+ def channels(self) -> int:
526
+ return 2
527
+
528
+ @property
529
+ def cardinality(self):
530
+ """Cardinality of each codebook.
531
+ """
532
+ return self.model.cardinality
533
+
534
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
535
+ raise NotImplementedError("Not supported, use encode and decode.")
536
+
537
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
538
+ B, C, T = x.shape
539
+ assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
540
+
541
+ indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
542
+ indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
543
+ indices = torch.stack([indices_c0, indices_c1], dim=0)
544
+ scales: tp.Optional[torch.Tensor] = None
545
+ if scales_c0 is not None and scales_c1 is not None:
546
+ scales = torch.stack([scales_c0, scales_c1], dim=1)
547
+
548
+ if self.per_timestep:
549
+ indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
550
+ else:
551
+ indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
552
+
553
+ return (indices, scales)
554
+
555
+ def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
556
+ if self.per_timestep:
557
+ codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
558
+ else:
559
+ codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
560
+ return codes[0], codes[1]
561
+
562
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
563
+ B, K, T = codes.shape
564
+ assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
565
+ assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
566
+
567
+ scale_c0, scale_c1 = None, None
568
+ if scale is not None:
569
+ assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
570
+ scale_c0 = scale[0, ...]
571
+ scale_c1 = scale[1, ...]
572
+
573
+ codes_c0, codes_c1 = self.get_left_right_codes(codes)
574
+ audio_c0 = self.model.decode(codes_c0, scale_c0)
575
+ audio_c1 = self.model.decode(codes_c1, scale_c1)
576
+ return torch.cat([audio_c0, audio_c1], dim=1)
577
+
578
+ def decode_latent(self, codes: torch.Tensor):
579
+ """Decode from the discrete codes to continuous latent space."""
580
+ raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
audiocraft/models/lm.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Audiocraft (https://github.com/facebookresearch/audiocraft)
2
+
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ import logging
6
+ import math
7
+ import typing as tp
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from ..utils import utils
13
+ from ..modules.streaming import StreamingModule, State
14
+ from ..modules.transformer import StreamingTransformer, create_norm_fn
15
+
16
+ import time
17
+ from ..modules.conditioners import (
18
+ ConditionFuser,
19
+ ClassifierFreeGuidanceDropout,
20
+ AttributeDropout,
21
+ ConditioningProvider,
22
+ ConditioningAttributes,
23
+ ConditionType,
24
+ )
25
+ from ..modules.codebooks_patterns import CodebooksPatternProvider
26
+ from ..modules.activations import get_activation_fn
27
+ import warnings
28
+ warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms._transforms_video")
29
+ import torch.nn.init as init
30
+ import os
31
+
32
+ import logging
33
+ import random
34
+ import sys
35
+ import einops
36
+ from .transformer_module import Attention, PreNorm, FeedForward
37
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection, VideoMAEModel
38
+
39
+ logger = logging.getLogger(__name__)
40
+ ConditionTensors = tp.Dict[str, ConditionType]
41
+ CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
42
+
43
+ def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
44
+ """LM layer initialization.
45
+ Inspired from xlformers: https://github.com/fairinternal/xlformers
46
+
47
+ Args:
48
+ method (str): Method name for init function. Valid options are:
49
+ 'gaussian', 'uniform'.
50
+ input_dim (int): Input dimension of the initialized module.
51
+ init_depth (int, optional): Optional init depth value used to rescale
52
+ the standard deviation if defined.
53
+ """
54
+ # Compute std
55
+ std = 1 / math.sqrt(input_dim)
56
+ # Rescale with depth
57
+ if init_depth is not None:
58
+ std = std / math.sqrt(2 * init_depth)
59
+
60
+ if method == 'gaussian':
61
+ return partial(
62
+ torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
63
+ )
64
+ elif method == 'uniform':
65
+ bound = math.sqrt(3) * std # ensure the standard deviation is std
66
+ return partial(torch.nn.init.uniform_, a=-bound, b=bound)
67
+ else:
68
+ raise ValueError("Unsupported layer initialization method")
69
+
70
+
71
+ def init_layer(m: nn.Module,
72
+ method: str,
73
+ init_depth: tp.Optional[int] = None,
74
+ zero_bias_init: bool = False):
75
+ """Wrapper around `get_init_fn for proper initialization of LM modules.
76
+
77
+ Args:
78
+ m (nn.Module): Module to initialize.
79
+ method (str): Method name for the init function.
80
+ init_depth (int, optional): Optional init depth value used to rescale
81
+ the standard deviation if defined.
82
+ zero_bias_init (bool): Whether to initialize the bias to 0 or not.
83
+ """
84
+ if isinstance(m, nn.Linear):
85
+ init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
86
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
87
+ weight = m.weight.float()
88
+ init_fn(weight)
89
+ m.weight.data[:] = weight.half()
90
+ else:
91
+ init_fn(m.weight)
92
+ if zero_bias_init and m.bias is not None:
93
+ nn.init.constant_(m.bias, 0)
94
+ elif isinstance(m, nn.Embedding):
95
+ init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
96
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
97
+ weight = m.weight.float()
98
+ init_fn(weight)
99
+ m.weight.data[:] = weight.half()
100
+ else:
101
+ init_fn(m.weight)
102
+
103
+
104
+ class ScaledEmbedding(nn.Embedding):
105
+ """Boost learning rate for embeddings (with scale).
106
+ """
107
+ def __init__(self, *args, lr=None, **kwargs):
108
+ super().__init__(*args, **kwargs)
109
+ self.lr = lr
110
+
111
+ def make_optim_group(self):
112
+ group = {"params": list(self.parameters())}
113
+ if self.lr is not None:
114
+ group["lr"] = self.lr
115
+ return group
116
+
117
+
118
+ @dataclass
119
+ class LMOutput:
120
+ # The logits are already re-aligned with the input codes
121
+ # hence no extra shift is required, e.g. when computing CE
122
+ logits: torch.Tensor # [B, K, T, card]
123
+ mask: torch.Tensor # [B, K, T]
124
+
125
+
126
+ class Transformer(nn.Module):
127
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
128
+ super().__init__()
129
+ self.layers = nn.ModuleList([])
130
+ self.norm = nn.LayerNorm(dim)
131
+ for _ in range(depth):
132
+ self.layers.append(nn.ModuleList([
133
+ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
134
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
135
+ ]))
136
+
137
+ def forward(self, x):
138
+ for attn, ff in self.layers:
139
+ x = attn(x) + x
140
+ x = ff(x) + x
141
+ return self.norm(x)
142
+
143
+
144
+ class MultiHeadCrossAttention(nn.Module):
145
+ def __init__(self, x1, num_heads):
146
+ super().__init__()
147
+ self.num_heads = num_heads
148
+ self.depth = x1 // num_heads
149
+
150
+ self.query = nn.Linear(x1, x1)
151
+ self.key = nn.Linear(x1, x1)
152
+ self.value = nn.Linear(x1, x1)
153
+
154
+ self.final_linear = nn.Linear(x1, x1)
155
+
156
+ self.norm1 = nn.LayerNorm(x1)
157
+ self.norm2 = nn.LayerNorm(x1)
158
+
159
+ init.constant_(self.final_linear.weight, 0)
160
+ if self.final_linear.bias is not None:
161
+ init.constant_(self.final_linear.bias, 0)
162
+
163
+ def split_heads(self, x, batch_size):
164
+ x = x.view(batch_size, -1, self.num_heads, self.depth)
165
+ return x.permute(0, 2, 1, 3)
166
+
167
+ def forward(self, tensor_A, tensor_B):
168
+ batch_size = tensor_A.size(0)
169
+
170
+ Q = self.split_heads(self.query(tensor_A), batch_size)
171
+ K = self.split_heads(self.key(tensor_B), batch_size)
172
+ V = self.split_heads(self.value(tensor_B), batch_size)
173
+
174
+ attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.depth ** 0.5)
175
+ attention_scores = torch.softmax(attention_scores, dim=-1)
176
+
177
+ attention_output = torch.matmul(attention_scores, V)
178
+ attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
179
+
180
+ output = attention_output.view(batch_size, -1, self.num_heads * self.depth)
181
+
182
+ output = self.norm1(output + tensor_A)
183
+ output = self.norm2(self.final_linear(output) + output)
184
+ return output
185
+
186
+
187
+ def evenly_sample_or_duplicate_frames(video_tensor, target_frames=32):
188
+ num_frames = video_tensor.size(0)
189
+ if target_frames <= num_frames:
190
+ indices = torch.linspace(0, num_frames - 1, steps=target_frames).long()
191
+ return video_tensor[indices]
192
+ else:
193
+ scale_factor = target_frames / num_frames
194
+ repeated_indices = (torch.arange(target_frames) / scale_factor).long()
195
+ return video_tensor[repeated_indices]
196
+
197
+ class LMModel(StreamingModule):
198
+ """Transformer-based language model on multiple streams of codes.
199
+
200
+ Args:
201
+ pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
202
+ condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
203
+ fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
204
+ n_q (int): Number of parallel streams to model.
205
+ card (int): Cardinality, vocabulary size.
206
+ dim (int): Dimension of the transformer encoder.
207
+ num_heads (int): Number of heads for the transformer encoder.
208
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
209
+ norm (str): Normalization method.
210
+ norm_first (bool): Use pre-norm instead of post-norm.
211
+ emb_lr (float, optional): Embedding-specific learning rate.
212
+ bias_proj (bool): Use bias for output projections.
213
+ weight_init (str, optional): Method for weight initialization.
214
+ depthwise_init (str, optional): Method for depthwise weight initialization.
215
+ zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
216
+ cfg_dropout (float): Classifier-free guidance dropout.
217
+ cfg_coef (float): Classifier-free guidance coefficient.
218
+ attribute_dropout (dict): Attribute dropout probabilities.
219
+ two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
220
+ **kwargs: Additional parameters for the transformer encoder.
221
+ """
222
+
223
+ def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
224
+ visual_encoder,
225
+ if_add_gobal,
226
+ fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
227
+ hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
228
+ emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
229
+ weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
230
+ zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
231
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
232
+ depth=2,
233
+ temporal_dim=768,
234
+ dim_head=64,
235
+ **kwargs):
236
+ super().__init__()
237
+ self.cfg_coef = cfg_coef
238
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
239
+ self.att_dropout = AttributeDropout(p=attribute_dropout)
240
+ self.condition_provider = condition_provider
241
+ self.visual_encoder = visual_encoder
242
+ self.if_add_gobal = if_add_gobal
243
+ self.temporal_dim = temporal_dim
244
+
245
+ self.fuser = fuser
246
+ self.card = card
247
+ embed_dim = self.card + 1
248
+ self.n_q = n_q
249
+ self.dim = dim
250
+ self.pattern_provider = pattern_provider
251
+ self.two_step_cfg = two_step_cfg
252
+ self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
253
+ if 'activation' in kwargs:
254
+ kwargs['activation'] = get_activation_fn(kwargs['activation'])
255
+ self.transformer = StreamingTransformer(
256
+ d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
257
+ norm=norm, norm_first=norm_first, **kwargs)
258
+
259
+
260
+ self.out_norm: tp.Optional[nn.Module] = None
261
+ if norm_first:
262
+ self.out_norm = create_norm_fn(norm, dim)
263
+ self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
264
+ self._init_weights(weight_init, depthwise_init, zero_bias_init)
265
+ self._fsdp: tp.Optional[nn.Module]
266
+ self.__dict__['_fsdp'] = None
267
+
268
+ if self.visual_encoder == 'clip':
269
+ self.visual_encoder_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
270
+ self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
271
+
272
+ else:
273
+ print(f'the encoder now is:{self.visual_encoder}')
274
+ print(f'please input the right video encoder.')
275
+ exit()
276
+
277
+ if self.visual_encoder == 'clip':
278
+ temporal_dim = 768
279
+ self.local_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim))
280
+ self.visual_encoder_model = self.visual_encoder_model.eval()
281
+ for param in self.visual_encoder_model.parameters():
282
+ param.requires_grad = False
283
+
284
+ self.local_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) # [768, 4, 16, 64, 768*4]
285
+
286
+ if self.if_add_gobal:
287
+ if self.visual_encoder == 'clip':
288
+ self.global_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim))
289
+
290
+ self.global_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) # [768, 4, 16, 64, 768*4]
291
+
292
+ cross_attention_num_heads = 3 # MultiHeadCrossAttention
293
+ self.multi_head_cross_attention = MultiHeadCrossAttention(temporal_dim, cross_attention_num_heads)
294
+
295
+ self.visual_feature_proj = nn.Linear(temporal_dim, dim)
296
+
297
+
298
+ def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
299
+ """Initialization of the transformer module weights.
300
+
301
+ Args:
302
+ weight_init (str, optional): Weight initialization strategy. See `get_init_fn for valid options.
303
+ depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
304
+ 'current' where the depth corresponds to the current layer index or 'global' where the total number
305
+ of layer is used as depth. If not set, no depthwise initialization strategy is used.
306
+ zero_bias_init (bool): Whether to initialize bias to zero or not.
307
+ """
308
+ assert depthwise_init is None or depthwise_init in ['current', 'global']
309
+ assert depthwise_init is None or weight_init is not None, \
310
+ "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
311
+ assert not zero_bias_init or weight_init is not None, \
312
+ "If 'zero_bias_init', a 'weight_init' method should be provided"
313
+
314
+ if weight_init is None:
315
+ return
316
+
317
+ for emb_layer in self.emb:
318
+ init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
319
+
320
+ for layer_idx, tr_layer in enumerate(self.transformer.layers):
321
+ depth = None
322
+ if depthwise_init == 'current':
323
+ depth = layer_idx + 1
324
+ elif depthwise_init == 'global':
325
+ depth = len(self.transformer.layers)
326
+ init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
327
+ tr_layer.apply(init_fn)
328
+
329
+ for linear in self.linears:
330
+ init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
331
+
332
+
333
+ @property
334
+ def special_token_id(self) -> int:
335
+ return self.card
336
+
337
+ @property
338
+ def num_codebooks(self) -> int:
339
+ return self.n_q
340
+
341
+ def compute_video_emb(self, video_tensor_list: tp.List, device: str) -> torch.Tensor:
342
+ assert isinstance(video_tensor_list, list)
343
+ assert self.if_add_gobal
344
+ assert len(video_tensor_list) == 2
345
+
346
+ [local_video_tensor, global_video_tensor] = video_tensor_list
347
+ local_image = local_video_tensor.to(dtype=torch.float32)
348
+ global_image = global_video_tensor.to(dtype=torch.float32)
349
+
350
+ local_batch_size, _, local_time_length, _, _ = local_image.size()
351
+ local_image = einops.rearrange(local_image, 'b c t h w -> (b t) c h w')
352
+
353
+ global_batch_size, _, global_time_length, _, _ = global_image.size()
354
+ global_image = einops.rearrange(global_image, 'b c t h w -> (b t) c h w')
355
+
356
+ local_temporal_transformer = self.local_temporal_transformer
357
+ global_temporal_transformer = self.global_temporal_transformer
358
+
359
+ local_video_inputs = self.processor(images=local_image.float(), return_tensors="pt")
360
+ local_pixel_values = local_video_inputs['pixel_values'].to(device)
361
+
362
+ global_video_inputs = self.processor(images=global_image.float(), return_tensors="pt")
363
+ global_pixel_values = global_video_inputs['pixel_values'].to(device)
364
+
365
+ if self.visual_encoder == 'clip':
366
+ with torch.no_grad():
367
+ local_video_hidden = self.visual_encoder_model(pixel_values=local_pixel_values).last_hidden_state
368
+ local_video_hidden += self.local_pos_embedding
369
+ local_video_hidden = local_temporal_transformer(local_video_hidden)
370
+ local_video_hidden = einops.rearrange(
371
+ local_video_hidden, '(b t) q h -> b (t q) h',
372
+ b=local_batch_size, t=local_time_length
373
+ )
374
+
375
+ with torch.no_grad():
376
+ global_video_hidden = self.visual_encoder_model(pixel_values=global_pixel_values).last_hidden_state
377
+ global_video_hidden += self.global_pos_embedding
378
+ global_video_hidden = global_temporal_transformer(global_video_hidden)
379
+ global_video_hidden = einops.rearrange(
380
+ global_video_hidden, '(b t) q h -> b (t q) h',
381
+ b=global_batch_size, t=global_time_length
382
+ )
383
+
384
+ video_hidden = self.multi_head_cross_attention(local_video_hidden, global_video_hidden)
385
+ video_emb = self.visual_feature_proj(video_hidden)
386
+
387
+ return video_emb
388
+
389
+
390
+ def forward(self, sequence: torch.Tensor,
391
+ conditions: tp.List[ConditioningAttributes],
392
+ video_tensor_list: tp.List,
393
+ precomputed_video_emb: tp.Optional[torch.Tensor] = None # 新增参数
394
+ ) -> torch.Tensor:
395
+
396
+ B, K, S = sequence.shape
397
+ assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
398
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
399
+ self.device = input_.device
400
+ assert self.device != "cpu"
401
+
402
+ if precomputed_video_emb is None:
403
+ video_emb = self.compute_video_emb(video_tensor_list, device=self.device)
404
+ else:
405
+ video_emb = precomputed_video_emb
406
+
407
+ out = self.transformer(input_, cross_attention_src=video_emb)
408
+ if self.out_norm:
409
+ out = self.out_norm(out)
410
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)
411
+
412
+ if len(self.fuser.fuse2cond['prepend']) > 0:
413
+ logits = logits[:, :, -S:]
414
+ return logits # [B, K, S, card]
415
+
416
+
417
+ def compute_predictions(
418
+ self, codes: torch.Tensor,
419
+ conditions: tp.List[ConditioningAttributes],
420
+ condition_tensors_list: tp.List) -> LMOutput:
421
+ """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
422
+ forward using the specified codes interleaving pattern.
423
+
424
+ Args:
425
+ codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
426
+ K the number of codebooks and T the number of timesteps.
427
+ conditions (list of ConditioningAttributes): Conditions to use when modeling
428
+ the given codes. Note that when evaluating multiple time with the same conditioning
429
+ you should pre-compute those and pass them as condition_tensors.
430
+ condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
431
+ tensors, see conditions.
432
+ Returns:
433
+ LMOutput: Language model outputs
434
+ logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
435
+ i.e. the first item corresponds to logits to predict the first code, meaning that
436
+ no additional shifting of codes and logits is required.
437
+ mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
438
+ Given the specified interleaving strategies, parts of the logits and codes should
439
+ not be considered as valid predictions because of invalid context.
440
+ """
441
+ B, K, T = codes.shape
442
+ codes = codes.contiguous()
443
+
444
+ assert isinstance(condition_tensors_list, list)
445
+ pattern = self.pattern_provider.get_pattern(T)
446
+ sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
447
+ codes, self.special_token_id, keep_only_valid_steps=True
448
+ )
449
+
450
+ model = self if self._fsdp is None else self._fsdp
451
+ logits = model(sequence_codes, conditions, condition_tensors_list) # [B, K, S, card]
452
+
453
+
454
+ logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
455
+ logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
456
+ logits, float('nan'), keep_only_valid_steps=True
457
+ )
458
+ logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
459
+ logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
460
+ return LMOutput(logits, logits_mask)
461
+
462
+
463
+ def _sample_next_token(
464
+ self,
465
+ sequence: torch.Tensor,
466
+ cfg_conditions_list: tp.List,
467
+ unconditional_state: State,
468
+ use_sampling: bool = False,
469
+ temp: float = 1.0,
470
+ top_k: int = 0,
471
+ top_p: float = 0.0,
472
+ cfg_coef: tp.Optional[float] = None,
473
+ two_step_cfg: tp.Optional[bool] = None,
474
+ precomputed_video_emb: tp.Optional[torch.Tensor] = None # 新增参数
475
+ ) -> torch.Tensor:
476
+ """Sample next token from the model given a sequence and a set of conditions. The model supports
477
+ multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
478
+
479
+ Args:
480
+ sequence (torch.Tensor): Current sequence of shape [B, K, S]
481
+ with K corresponding to the number of codebooks and S the number of sequence steps.
482
+ S = 1 in streaming mode, except for the first step that contains a bigger prompt.
483
+ condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
484
+ should be twice the batch size, being the concatenation of the conditions + null conditions.
485
+ use_sampling (bool): Whether to use a sampling strategy or not.
486
+ temp (float): Sampling temperature.
487
+ top_k (int): K for "top-k" sampling.
488
+ top_p (float): P for "top-p" sampling.
489
+ cfg_coef (float, optional): classifier free guidance coefficient.
490
+ Returns:
491
+ next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
492
+ """
493
+ B = sequence.shape[0]
494
+ cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
495
+ model = self if self._fsdp is None else self._fsdp
496
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
497
+
498
+ assert isinstance(cfg_conditions_list, list)
499
+ assert len(cfg_conditions_list) == 2
500
+ local_cfg_conditions = cfg_conditions_list[0]
501
+ global_cfg_conditions = cfg_conditions_list[1]
502
+
503
+ if two_step_cfg and local_cfg_conditions != {}:
504
+ assert isinstance(local_cfg_conditions, tuple), type(local_cfg_conditions)
505
+ local_condition_tensors, local_null_condition_tensors = local_cfg_conditions
506
+ global_condition_tensors, global_null_condition_tensors = global_cfg_conditions
507
+ cond_logits = model(sequence, conditions=[], condition_tensors=[local_condition_tensors, global_condition_tensors])
508
+
509
+ state = self.get_streaming_state()
510
+ self.set_streaming_state(unconditional_state)
511
+ uncond_logits = model(sequence, conditions=[], condition_tensors=[local_null_condition_tensors, global_null_condition_tensors])
512
+ unconditional_state.update(self.get_streaming_state())
513
+ self.set_streaming_state(state)
514
+ logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
515
+ else:
516
+ local_condition_tensors = cfg_conditions_list[0].to(sequence.device)
517
+ global_condition_tensors = cfg_conditions_list[1].to(sequence.device)
518
+ sequence = torch.cat([sequence, sequence], dim=0)
519
+
520
+ if precomputed_video_emb is None:
521
+ video_emb = self.compute_video_emb([cfg_conditions_list[0], cfg_conditions_list[1]], device=sequence.device)
522
+ else:
523
+ video_emb = precomputed_video_emb
524
+
525
+ all_logits = model(
526
+ sequence,
527
+ conditions=[],
528
+ video_tensor_list=[],
529
+ precomputed_video_emb=video_emb
530
+ )
531
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
532
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
533
+
534
+ logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
535
+ logits = logits[..., -1] # [B x K x card]
536
+
537
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
538
+ if use_sampling and temp > 0.0:
539
+ probs = torch.softmax(logits / temp, dim=-1)
540
+ if top_p > 0.0:
541
+ next_token = utils.sample_top_p(probs, p=top_p)
542
+ elif top_k > 0:
543
+ next_token = utils.sample_top_k(probs, k=top_k)
544
+ else:
545
+ next_token = utils.multinomial(probs, num_samples=1)
546
+ else:
547
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
548
+ return next_token
549
+
550
+
551
+ @torch.no_grad()
552
+ def generate(self,
553
+ prompt: tp.Optional[torch.Tensor] = None,
554
+ conditions_list: tp.List = [],
555
+ num_samples: tp.Optional[int] = None,
556
+ max_gen_len: int = 256,
557
+ use_sampling: bool = True,
558
+ temp: float = 1.0,
559
+ top_k: int = 250,
560
+ top_p: float = 0.0,
561
+ cfg_coef: tp.Optional[float] = None,
562
+ two_step_cfg: tp.Optional[bool] = None,
563
+ remove_prompts: bool = False,
564
+ check: bool = False,
565
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
566
+ """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
567
+ be perform in a greedy fashion or using sampling with top K and top P strategies.
568
+
569
+ Args:
570
+ prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
571
+ conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
572
+ num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
573
+ max_gen_len (int): Maximum generation length.
574
+ use_sampling (bool): Whether to use a sampling strategy or not.
575
+ temp (float): Sampling temperature.
576
+ top_k (int): K for "top-k" sampling.
577
+ top_p (float): P for "top-p" sampling.
578
+ cfg_coeff (float, optional): Classifier-free guidance coefficient.
579
+ two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
580
+ remove_prompts (bool): Whether to remove prompts from generation or not.
581
+ check (bool): Whether to apply further checks on generated sequence.
582
+ callback (Callback, optional): Callback function to report generation progress.
583
+ Returns:
584
+ torch.Tensor: Generated tokens.
585
+ """
586
+ assert not self.training, "generation shouldn't be used in training mode."
587
+ first_param = next(iter(self.parameters()))
588
+ device = first_param.device
589
+ assert isinstance(conditions_list, list)
590
+
591
+ assert len(conditions_list) == 2
592
+ local_conditions = conditions_list[0]
593
+ global_conditions = conditions_list[1]
594
+ # Checking all input shapes are consistent.
595
+ possible_num_samples = []
596
+ if num_samples is not None:
597
+ possible_num_samples.append(num_samples)
598
+ elif prompt is not None:
599
+ possible_num_samples.append(prompt.shape[0])
600
+ elif local_conditions is not None:
601
+ possible_num_samples.append(len(local_conditions))
602
+ else:
603
+ possible_num_samples.append(1)
604
+
605
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
606
+ num_samples = possible_num_samples[0]
607
+
608
+ local_cfg_conditions: CFGConditions
609
+ global_cfg_conditions: CFGConditions
610
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
611
+ local_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(local_conditions)
612
+ local_cfg_conditions = torch.cat((local_conditions, local_null_conditions), dim=0)
613
+ global_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(global_conditions)
614
+ global_cfg_conditions = torch.cat((global_conditions, global_null_conditions), dim=0)
615
+
616
+ if prompt is None:
617
+ assert num_samples > 0
618
+ prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
619
+
620
+ B, K, T = prompt.shape
621
+ start_offset = T
622
+ assert start_offset < max_gen_len
623
+
624
+ pattern = self.pattern_provider.get_pattern(max_gen_len)
625
+ # this token is used as default value for codes that are not generated yet
626
+ unknown_token = -1
627
+
628
+
629
+ gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
630
+ gen_codes[..., :start_offset] = prompt
631
+ gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
632
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
633
+ assert start_offset_sequence is not None
634
+
635
+ video_emb = self.compute_video_emb([local_cfg_conditions, global_cfg_conditions], device=device)
636
+
637
+ with self.streaming():
638
+ unconditional_state = self.get_streaming_state()
639
+ prev_offset = 0
640
+ gen_sequence_len = gen_sequence.shape[-1]
641
+
642
+ for offset in range(start_offset_sequence, gen_sequence_len):
643
+ curr_sequence = gen_sequence[..., prev_offset:offset]
644
+ curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
645
+ if check:
646
+ assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
647
+ assert not (curr_sequence == unknown_token).any()
648
+ next_token = self._sample_next_token(
649
+ curr_sequence,
650
+ [local_cfg_conditions, global_cfg_conditions],
651
+ unconditional_state,
652
+ use_sampling,
653
+ temp,
654
+ top_k,
655
+ top_p,
656
+ cfg_coef=cfg_coef,
657
+ two_step_cfg=two_step_cfg,
658
+ precomputed_video_emb=video_emb #
659
+ )
660
+ valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
661
+ next_token[~valid_mask] = self.special_token_id
662
+ gen_sequence[..., offset:offset+1] = torch.where(
663
+ gen_sequence[..., offset:offset+1] == unknown_token,
664
+ next_token,
665
+ gen_sequence[..., offset:offset+1]
666
+ )
667
+ prev_offset = offset
668
+ if callback is not None:
669
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
670
+
671
+ unconditional_state.clear()
672
+ assert not (gen_sequence == unknown_token).any()
673
+ assert (
674
+ gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
675
+ ).all()
676
+ out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
677
+
678
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
679
+ assert (out_mask[..., :max_gen_len] == 1).all()
680
+
681
+ out_start_offset = start_offset if remove_prompts else 0
682
+ out_codes = out_codes[..., out_start_offset:max_gen_len]
683
+
684
+ assert (out_codes >= 0).all() and (out_codes <= self.card).all()
685
+ return out_codes
audiocraft/models/lm_back.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Audiocraft (https://github.com/facebookresearch/audiocraft)
2
+
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ import logging
6
+ import math
7
+ import typing as tp
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from ..utils import utils
13
+ from ..modules.streaming import StreamingModule, State
14
+ from ..modules.transformer import StreamingTransformer, create_norm_fn
15
+
16
+ import time
17
+ from ..modules.conditioners import (
18
+ ConditionFuser,
19
+ ClassifierFreeGuidanceDropout,
20
+ AttributeDropout,
21
+ ConditioningProvider,
22
+ ConditioningAttributes,
23
+ ConditionType,
24
+ )
25
+ from ..modules.codebooks_patterns import CodebooksPatternProvider
26
+ from ..modules.activations import get_activation_fn
27
+ import warnings
28
+ warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms._transforms_video")
29
+ import torch.nn.init as init
30
+ import os
31
+
32
+ import logging
33
+ import random
34
+ import sys
35
+ import einops
36
+ from .transformer_module import Attention, PreNorm, FeedForward
37
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection, VideoMAEModel
38
+
39
+
40
+ logger = logging.getLogger(__name__)
41
+ ConditionTensors = tp.Dict[str, ConditionType]
42
+ CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
43
+
44
+
45
+ def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
46
+ """LM layer initialization.
47
+ Inspired from xlformers: https://github.com/fairinternal/xlformers
48
+
49
+ Args:
50
+ method (str): Method name for init function. Valid options are:
51
+ 'gaussian', 'uniform'.
52
+ input_dim (int): Input dimension of the initialized module.
53
+ init_depth (int, optional): Optional init depth value used to rescale
54
+ the standard deviation if defined.
55
+ """
56
+ # Compute std
57
+ std = 1 / math.sqrt(input_dim)
58
+ # Rescale with depth
59
+ if init_depth is not None:
60
+ std = std / math.sqrt(2 * init_depth)
61
+
62
+ if method == 'gaussian':
63
+ return partial(
64
+ torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
65
+ )
66
+ elif method == 'uniform':
67
+ bound = math.sqrt(3) * std # ensure the standard deviation is `std`
68
+ return partial(torch.nn.init.uniform_, a=-bound, b=bound)
69
+ else:
70
+ raise ValueError("Unsupported layer initialization method")
71
+
72
+
73
+ def init_layer(m: nn.Module,
74
+ method: str,
75
+ init_depth: tp.Optional[int] = None,
76
+ zero_bias_init: bool = False):
77
+ """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
78
+
79
+ Args:
80
+ m (nn.Module): Module to initialize.
81
+ method (str): Method name for the init function.
82
+ init_depth (int, optional): Optional init depth value used to rescale
83
+ the standard deviation if defined.
84
+ zero_bias_init (bool): Whether to initialize the bias to 0 or not.
85
+ """
86
+ if isinstance(m, nn.Linear):
87
+ init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
88
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
89
+ weight = m.weight.float()
90
+ init_fn(weight)
91
+ m.weight.data[:] = weight.half()
92
+ else:
93
+ init_fn(m.weight)
94
+ if zero_bias_init and m.bias is not None:
95
+ nn.init.constant_(m.bias, 0)
96
+ elif isinstance(m, nn.Embedding):
97
+ init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
98
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
99
+ weight = m.weight.float()
100
+ init_fn(weight)
101
+ m.weight.data[:] = weight.half()
102
+ else:
103
+ init_fn(m.weight)
104
+
105
+
106
+ class ScaledEmbedding(nn.Embedding):
107
+ """Boost learning rate for embeddings (with `scale`).
108
+ """
109
+ def __init__(self, *args, lr=None, **kwargs):
110
+ super().__init__(*args, **kwargs)
111
+ self.lr = lr
112
+
113
+ def make_optim_group(self):
114
+ group = {"params": list(self.parameters())}
115
+ if self.lr is not None:
116
+ group["lr"] = self.lr
117
+ return group
118
+
119
+
120
+ @dataclass
121
+ class LMOutput:
122
+ # The logits are already re-aligned with the input codes
123
+ # hence no extra shift is required, e.g. when computing CE
124
+ logits: torch.Tensor # [B, K, T, card]
125
+ mask: torch.Tensor # [B, K, T]
126
+
127
+ class Transformer(nn.Module):
128
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
129
+ super().__init__()
130
+ self.layers = nn.ModuleList([])
131
+ self.norm = nn.LayerNorm(dim)
132
+ for _ in range(depth):
133
+ self.layers.append(nn.ModuleList([
134
+ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
135
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
136
+ ]))
137
+
138
+ def forward(self, x):
139
+ for attn, ff in self.layers:
140
+ x = attn(x) + x
141
+ x = ff(x) + x
142
+ return self.norm(x)
143
+
144
+ class MultiHeadCrossAttention(nn.Module):
145
+ def __init__(self, x1, num_heads):
146
+ super().__init__()
147
+ self.num_heads = num_heads
148
+ self.depth = x1 // num_heads
149
+
150
+ self.query = nn.Linear(x1, x1)
151
+ self.key = nn.Linear(x1, x1)
152
+ self.value = nn.Linear(x1, x1)
153
+
154
+ self.final_linear = nn.Linear(x1, x1)
155
+
156
+ self.norm1 = nn.LayerNorm(x1)
157
+ self.norm2 = nn.LayerNorm(x1)
158
+
159
+ init.constant_(self.final_linear.weight, 0)
160
+ if self.final_linear.bias is not None:
161
+ init.constant_(self.final_linear.bias, 0)
162
+
163
+ def split_heads(self, x, batch_size):
164
+ x = x.view(batch_size, -1, self.num_heads, self.depth)
165
+ return x.permute(0, 2, 1, 3)
166
+
167
+ def forward(self, tensor_A, tensor_B):
168
+ batch_size = tensor_A.size(0)
169
+
170
+ Q = self.split_heads(self.query(tensor_A), batch_size)
171
+ K = self.split_heads(self.key(tensor_B), batch_size)
172
+ V = self.split_heads(self.value(tensor_B), batch_size)
173
+
174
+ attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.depth ** 0.5)
175
+ attention_scores = torch.softmax(attention_scores, dim=-1)
176
+
177
+ attention_output = torch.matmul(attention_scores, V)
178
+ attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
179
+
180
+ output = attention_output.view(batch_size, -1, self.num_heads * self.depth)
181
+
182
+ output = self.norm1(output + tensor_A)
183
+ output = self.norm2(self.final_linear(output) + output)
184
+ return output
185
+
186
+
187
+ def evenly_sample_or_duplicate_frames(video_tensor, target_frames=32):
188
+ num_frames = video_tensor.size(0)
189
+ if target_frames <= num_frames:
190
+ indices = torch.linspace(0, num_frames - 1, steps=target_frames).long()
191
+ return video_tensor[indices]
192
+ else:
193
+ scale_factor = target_frames / num_frames
194
+ repeated_indices = (torch.arange(target_frames) / scale_factor).long()
195
+ return video_tensor[repeated_indices]
196
+
197
+ class LMModel(StreamingModule):
198
+ """Transformer-based language model on multiple streams of codes.
199
+
200
+ Args:
201
+ pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
202
+ condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
203
+ fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
204
+ n_q (int): Number of parallel streams to model.
205
+ card (int): Cardinality, vocabulary size.
206
+ dim (int): Dimension of the transformer encoder.
207
+ num_heads (int): Number of heads for the transformer encoder.
208
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
209
+ norm (str): Normalization method.
210
+ norm_first (bool): Use pre-norm instead of post-norm.
211
+ emb_lr (float, optional): Embedding-specific learning rate.
212
+ bias_proj (bool): Use bias for output projections.
213
+ weight_init (str, optional): Method for weight initialization.
214
+ depthwise_init (str, optional): Method for depthwise weight initialization.
215
+ zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
216
+ cfg_dropout (float): Classifier-free guidance dropout.
217
+ cfg_coef (float): Classifier-free guidance coefficient.
218
+ attribute_dropout (dict): Attribute dropout probabilities.
219
+ two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
220
+ **kwargs: Additional parameters for the transformer encoder.
221
+ """
222
+
223
+ def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
224
+ visual_encoder,
225
+ if_add_gobal,
226
+ fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
227
+ hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
228
+ emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
229
+ weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
230
+ zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
231
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
232
+ depth=2,
233
+ temporal_dim=768,
234
+ dim_head=64,
235
+ **kwargs):
236
+ super().__init__()
237
+ self.cfg_coef = cfg_coef
238
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
239
+ self.att_dropout = AttributeDropout(p=attribute_dropout)
240
+ self.condition_provider = condition_provider
241
+ self.visual_encoder=visual_encoder
242
+ self.if_add_gobal=if_add_gobal
243
+ self.temporal_dim=temporal_dim
244
+
245
+ self.fuser = fuser
246
+ self.card = card
247
+ embed_dim = self.card + 1
248
+ self.n_q = n_q
249
+ self.dim = dim
250
+ self.pattern_provider = pattern_provider
251
+ self.two_step_cfg = two_step_cfg
252
+ self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
253
+ if 'activation' in kwargs:
254
+ kwargs['activation'] = get_activation_fn(kwargs['activation'])
255
+ self.transformer = StreamingTransformer(
256
+ d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
257
+ norm=norm, norm_first=norm_first, **kwargs)
258
+
259
+
260
+ self.out_norm: tp.Optional[nn.Module] = None
261
+ if norm_first:
262
+ self.out_norm = create_norm_fn(norm, dim)
263
+ self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
264
+ self._init_weights(weight_init, depthwise_init, zero_bias_init)
265
+ self._fsdp: tp.Optional[nn.Module]
266
+ self.__dict__['_fsdp'] = None
267
+
268
+ if self.visual_encoder=='clip':
269
+ self.visual_encoder_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
270
+ self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
271
+
272
+ else:
273
+ print(f'the encoder now is:{self.visual_encoder}')
274
+ print(f'please input the right video encoder.')
275
+ exit()
276
+
277
+ if self.visual_encoder=='clip':
278
+ temporal_dim=768
279
+ self.local_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim))
280
+ self.visual_encoder_model = self.visual_encoder_model.eval()
281
+ for param in self.visual_encoder_model.parameters():
282
+ param.requires_grad = False
283
+
284
+ self.local_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) # [768, 4, 16, 64, 768*4]
285
+
286
+ if self.if_add_gobal:
287
+ if self.visual_encoder=='clip':
288
+ self.global_pos_embedding = nn.Parameter(torch.randn(1, 50, temporal_dim))
289
+
290
+ self.global_temporal_transformer = Transformer(temporal_dim, depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) # [768, 4, 16, 64, 768*4]
291
+
292
+ cross_attention_num_heads = 3 # MultiHeadCrossAttention
293
+ self.multi_head_cross_attention = MultiHeadCrossAttention(temporal_dim, cross_attention_num_heads)
294
+
295
+ self.visual_feature_proj = nn.Linear(temporal_dim, dim)
296
+
297
+
298
+ def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
299
+ """Initialization of the transformer module weights.
300
+
301
+ Args:
302
+ weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
303
+ depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
304
+ 'current' where the depth corresponds to the current layer index or 'global' where the total number
305
+ of layer is used as depth. If not set, no depthwise initialization strategy is used.
306
+ zero_bias_init (bool): Whether to initialize bias to zero or not.
307
+ """
308
+ assert depthwise_init is None or depthwise_init in ['current', 'global']
309
+ assert depthwise_init is None or weight_init is not None, \
310
+ "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
311
+ assert not zero_bias_init or weight_init is not None, \
312
+ "If 'zero_bias_init', a 'weight_init' method should be provided"
313
+
314
+ if weight_init is None:
315
+ return
316
+
317
+ for emb_layer in self.emb:
318
+ init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
319
+
320
+ for layer_idx, tr_layer in enumerate(self.transformer.layers):
321
+ depth = None
322
+ if depthwise_init == 'current':
323
+ depth = layer_idx + 1
324
+ elif depthwise_init == 'global':
325
+ depth = len(self.transformer.layers)
326
+ init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
327
+ tr_layer.apply(init_fn)
328
+
329
+ for linear in self.linears:
330
+ init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
331
+
332
+ @property
333
+ def special_token_id(self) -> int:
334
+ return self.card
335
+
336
+ @property
337
+ def num_codebooks(self) -> int:
338
+ return self.n_q
339
+
340
+ def forward(self, sequence: torch.Tensor,
341
+ conditions: tp.List[ConditioningAttributes],
342
+ video_tensor_list: tp.List) -> torch.Tensor:
343
+ """Apply language model on sequence and conditions.
344
+ Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
345
+ S the sequence steps, return the logits with shape [B, card, K, S].
346
+
347
+ Args:
348
+ indices (torch.Tensor): Indices of the codes to model.
349
+ conditions (list of ConditioningAttributes): Conditions to use when modeling
350
+ the given codes. Note that when evaluating multiple time with the same conditioning
351
+ you should pre-compute those and pass them as `condition_tensors`.
352
+ # condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
353
+ # tensors, see `conditions`.
354
+ video_tensor (torch.Tensor): Indices of the video features [b c t h w].
355
+ Returns:
356
+ torch.Tensor: Logits.
357
+ """
358
+
359
+ B, K, S = sequence.shape
360
+ assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
361
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
362
+ self.device = input_.device
363
+ assert self.device != "cpu"
364
+
365
+ if self.visual_encoder=='clip':
366
+ visual_encoder_model = self.visual_encoder_model
367
+ processor = self.processor
368
+
369
+ assert isinstance(video_tensor_list, list)
370
+
371
+ assert self.if_add_gobal
372
+ assert len(video_tensor_list)==2
373
+
374
+ [local_video_tensor, global_video_tensor] = video_tensor_list
375
+ local_image = local_video_tensor.to(dtype=torch.float32)
376
+ global_image = global_video_tensor.to(dtype=torch.float32)
377
+
378
+ local_batch_size,_,local_time_length,_,_ = local_image.size()
379
+ local_image = einops.rearrange(local_image, 'b c t h w -> (b t) c h w')
380
+
381
+ global_batch_size,_,global_time_length,_,_ = global_image.size()
382
+ global_image = einops.rearrange(global_image, 'b c t h w -> (b t) c h w')
383
+
384
+ local_temporal_transformer = self.local_temporal_transformer
385
+ global_temporal_transformer = self.global_temporal_transformer
386
+
387
+ local_video_inputs = processor(images=local_image.float(), return_tensors="pt")
388
+ local_pixel_values = local_video_inputs['pixel_values'].to(self.device)
389
+
390
+ global_video_inputs = processor(images=global_image.float(), return_tensors="pt")
391
+ global_pixel_values = global_video_inputs['pixel_values'].to(self.device)
392
+
393
+ if self.visual_encoder=='clip':
394
+ with torch.no_grad():
395
+ local_video_hidden = visual_encoder_model(pixel_values=local_pixel_values).last_hidden_state
396
+ local_video_hidden += self.local_pos_embedding
397
+ local_video_hidden = local_temporal_transformer(local_video_hidden)
398
+ local_video_hidden = einops.rearrange(local_video_hidden, '(b t) q h -> b (t q) h',b=local_batch_size,t=local_time_length)
399
+ with torch.no_grad():
400
+ global_video_hidden = visual_encoder_model(pixel_values=global_pixel_values).last_hidden_state
401
+ global_video_hidden += self.global_pos_embedding
402
+ global_video_hidden = global_temporal_transformer(global_video_hidden)
403
+ global_video_hidden = einops.rearrange(global_video_hidden, '(b t) q h -> b (t q) h',b=global_batch_size,t=global_time_length)
404
+
405
+ assert local_batch_size==global_batch_size
406
+ video_hidden = self.multi_head_cross_attention(local_video_hidden, global_video_hidden)
407
+ video_emb = self.visual_feature_proj(video_hidden)
408
+
409
+ out = self.transformer(input_, cross_attention_src=video_emb)
410
+ if self.out_norm:
411
+ out = self.out_norm(out)
412
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)
413
+
414
+ # remove the prefix from the model outputs
415
+ if len(self.fuser.fuse2cond['prepend']) > 0:
416
+ logits = logits[:, :, -S:]
417
+ return logits # [B, K, S, card]
418
+
419
+
420
+ def compute_predictions(
421
+ self, codes: torch.Tensor,
422
+ conditions: tp.List[ConditioningAttributes],
423
+ condition_tensors_list: tp.List) -> LMOutput:
424
+ """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
425
+ forward using the specified codes interleaving pattern.
426
+
427
+ Args:
428
+ codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
429
+ K the number of codebooks and T the number of timesteps.
430
+ conditions (list of ConditioningAttributes): conditionings to use when modeling
431
+ the given codes. Note that when evaluating multiple time with the same conditioning
432
+ you should pre-compute those and pass them as `condition_tensors`.
433
+ condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
434
+ tensors, see `conditions`.
435
+ Returns:
436
+ LMOutput: Language model outputs
437
+ logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
438
+ i.e. the first item corresponds to logits to predict the first code, meaning that
439
+ no additional shifting of codes and logits is required.
440
+ mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
441
+ Given the specified interleaving strategies, parts of the logits and codes should
442
+ not be considered as valid predictions because of invalid context.
443
+ """
444
+ B, K, T = codes.shape
445
+ codes = codes.contiguous()
446
+
447
+ assert isinstance(condition_tensors_list,list)
448
+ # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
449
+ pattern = self.pattern_provider.get_pattern(T)
450
+ sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
451
+ codes, self.special_token_id, keep_only_valid_steps=True
452
+ )
453
+
454
+ model = self if self._fsdp is None else self._fsdp
455
+ logits = model(sequence_codes, conditions, condition_tensors_list) # [B, K, S, card]
456
+
457
+ # apply model on pattern sequence
458
+ # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
459
+ # and provide the corresponding mask over invalid positions of tokens
460
+ logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
461
+ # note: we use nans as special token to make it obvious if we feed unexpected logits
462
+ logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
463
+ logits, float('nan'), keep_only_valid_steps=True
464
+ )
465
+ logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
466
+ logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
467
+ return LMOutput(logits, logits_mask)
468
+
469
+ def _sample_next_token(self,
470
+ sequence: torch.Tensor,
471
+ cfg_conditions_list: tp.List,
472
+ unconditional_state: State,
473
+ use_sampling: bool = False,
474
+ temp: float = 1.0,
475
+ top_k: int = 0,
476
+ top_p: float = 0.0,
477
+ cfg_coef: tp.Optional[float] = None,
478
+ two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
479
+ """Sample next token from the model given a sequence and a set of conditions. The model supports
480
+ multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
481
+
482
+ Args:
483
+ sequence (torch.Tensor): Current sequence of shape [B, K, S]
484
+ with K corresponding to the number of codebooks and S the number of sequence steps.
485
+ S = 1 in streaming mode, except for the first step that contains a bigger prompt.
486
+ condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
487
+ should be twice the batch size, being the concatenation of the conditions + null conditions.
488
+ use_sampling (bool): Whether to use a sampling strategy or not.
489
+ temp (float): Sampling temperature.
490
+ top_k (int): K for "top-k" sampling.
491
+ top_p (float): P for "top-p" sampling.
492
+ cfg_coef (float, optional): classifier free guidance coefficient
493
+ Returns:
494
+ next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
495
+ """
496
+ B = sequence.shape[0]
497
+ cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
498
+ model = self if self._fsdp is None else self._fsdp
499
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
500
+
501
+ assert isinstance(cfg_conditions_list,list)
502
+
503
+ assert len(cfg_conditions_list)==2
504
+ local_cfg_conditions=cfg_conditions_list[0]
505
+ global_cfg_conditions=cfg_conditions_list[1]
506
+ if two_step_cfg and local_cfg_conditions != {}:
507
+ assert isinstance(local_cfg_conditions, tuple), type(local_cfg_conditions)
508
+ local_condition_tensors, local_null_condition_tensors = local_cfg_conditions
509
+ global_condition_tensors, global_null_condition_tensors = global_cfg_conditions
510
+ cond_logits = model(sequence, conditions=[], condition_tensors=[local_condition_tensors, global_condition_tensors])
511
+
512
+ state = self.get_streaming_state()
513
+ self.set_streaming_state(unconditional_state)
514
+ uncond_logits = model(sequence, conditions=[], condition_tensors=[local_null_condition_tensors, global_null_condition_tensors])
515
+ unconditional_state.update(self.get_streaming_state())
516
+ self.set_streaming_state(state)
517
+ logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
518
+ else:
519
+ local_condition_tensors = local_cfg_conditions
520
+ sequence = torch.cat([sequence, sequence], dim=0)
521
+ local_condition_tensors = local_condition_tensors.to(sequence.device)
522
+
523
+ global_condition_tensors = global_cfg_conditions
524
+ global_condition_tensors = global_condition_tensors.to(sequence.device)
525
+
526
+ all_logits = model(
527
+ sequence,
528
+ conditions=[], video_tensor_list=[local_condition_tensors, global_condition_tensors])
529
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
530
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
531
+
532
+ logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
533
+ logits = logits[..., -1] # [B x K x card]
534
+
535
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
536
+ if use_sampling and temp > 0.0:
537
+ probs = torch.softmax(logits / temp, dim=-1)
538
+ if top_p > 0.0:
539
+ next_token = utils.sample_top_p(probs, p=top_p)
540
+ elif top_k > 0:
541
+ next_token = utils.sample_top_k(probs, k=top_k)
542
+ else:
543
+ next_token = utils.multinomial(probs, num_samples=1)
544
+ else:
545
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
546
+ return next_token
547
+
548
+ @torch.no_grad()
549
+ def generate(self,
550
+ prompt: tp.Optional[torch.Tensor] = None,
551
+ conditions_list: tp.List = [],
552
+ num_samples: tp.Optional[int] = None,
553
+ max_gen_len: int = 256,
554
+ use_sampling: bool = True,
555
+ temp: float = 1.0,
556
+ top_k: int = 250,
557
+ top_p: float = 0.0,
558
+ cfg_coef: tp.Optional[float] = None,
559
+ two_step_cfg: tp.Optional[bool] = None,
560
+ remove_prompts: bool = False,
561
+ check: bool = False,
562
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
563
+ """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
564
+ be perform in a greedy fashion or using sampling with top K and top P strategies.
565
+
566
+ Args:
567
+ prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
568
+ conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
569
+ num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
570
+ max_gen_len (int): Maximum generation length.
571
+ use_sampling (bool): Whether to use a sampling strategy or not.
572
+ temp (float): Sampling temperature.
573
+ top_k (int): K for "top-k" sampling.
574
+ top_p (float): P for "top-p" sampling.
575
+ cfg_coeff (float, optional): Classifier-free guidance coefficient.
576
+ two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
577
+ remove_prompts (bool): Whether to remove prompts from generation or not.
578
+ check (bool): Whether to apply further checks on generated sequence.
579
+ callback (Callback, optional): Callback function to report generation progress.
580
+ Returns:
581
+ torch.Tensor: Generated tokens.
582
+ """
583
+ assert not self.training, "generation shouldn't be used in training mode."
584
+ first_param = next(iter(self.parameters()))
585
+ device = first_param.device
586
+ assert isinstance(conditions_list,list)
587
+
588
+ assert len(conditions_list)==2
589
+ local_conditions=conditions_list[0]
590
+ global_conditions=conditions_list[1]
591
+ # Checking all input shapes are consistent.
592
+ possible_num_samples = []
593
+ if num_samples is not None:
594
+ possible_num_samples.append(num_samples)
595
+ elif prompt is not None:
596
+ possible_num_samples.append(prompt.shape[0])
597
+ elif local_conditions is not None:
598
+ possible_num_samples.append(len(local_conditions))
599
+ else:
600
+ possible_num_samples.append(1)
601
+
602
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
603
+ num_samples = possible_num_samples[0]
604
+
605
+ # below we create set of local_conditions: one conditional and one unconditional
606
+ # to do that we merge the regular condition together with the null condition
607
+ # we then do 1 forward pass instead of 2.
608
+ # the reason for that is two-fold:
609
+ # 1. it is about x2 faster than doing 2 forward passes
610
+ # 2. avoid the streaming API treating the 2 passes as part of different time steps
611
+ # We also support doing two different passes, in particular to ensure that
612
+ # the padding structure is exactly the same between train and test.
613
+ # With a batch size of 1, this can be slower though.
614
+ local_cfg_conditions: CFGConditions
615
+ global_cfg_conditions: CFGConditions
616
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
617
+ local_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(local_conditions)
618
+ local_cfg_conditions = torch.cat((local_conditions,local_null_conditions), dim=0)
619
+
620
+ global_null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(global_conditions)
621
+ global_cfg_conditions = torch.cat((global_conditions,global_null_conditions), dim=0)
622
+
623
+ if prompt is None:
624
+ assert num_samples > 0
625
+ prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
626
+
627
+ B, K, T = prompt.shape
628
+ start_offset = T
629
+ assert start_offset < max_gen_len
630
+
631
+ pattern = self.pattern_provider.get_pattern(max_gen_len)
632
+ # this token is used as default value for codes that are not generated yet
633
+ unknown_token = -1
634
+
635
+ # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
636
+ gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
637
+ # filling the gen_codes with the prompt if needed
638
+ gen_codes[..., :start_offset] = prompt
639
+ # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
640
+ gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
641
+ # retrieve the start_offset in the sequence:
642
+ # it is the first sequence step that contains the `start_offset` timestep
643
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
644
+ assert start_offset_sequence is not None
645
+
646
+ with self.streaming():
647
+ unconditional_state = self.get_streaming_state()
648
+ prev_offset = 0
649
+ gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
650
+
651
+ for offset in range(start_offset_sequence, gen_sequence_len):
652
+ # get current sequence (note that the streaming API is providing the caching over previous offsets)
653
+ curr_sequence = gen_sequence[..., prev_offset:offset]
654
+ curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
655
+ if check:
656
+ # check coherence between mask and sequence
657
+ assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
658
+ # should never happen as gen_sequence is filled progressively
659
+ assert not (curr_sequence == unknown_token).any()
660
+ # sample next token from the model, next token shape is [B, K, 1]
661
+ next_token = self._sample_next_token(
662
+ curr_sequence, [local_cfg_conditions, global_cfg_conditions], unconditional_state, use_sampling, temp, top_k, top_p,
663
+ cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
664
+ # ensure the tokens that should be masked are properly set to special_token_id
665
+ # as the model never output special_token_id
666
+ valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
667
+ next_token[~valid_mask] = self.special_token_id
668
+ # ensure we don't overwrite prompt tokens, we only write over unknown tokens
669
+ # (then mask tokens should be left as is as well, which is correct)
670
+ gen_sequence[..., offset:offset+1] = torch.where(
671
+ gen_sequence[..., offset:offset+1] == unknown_token,
672
+ next_token, gen_sequence[..., offset:offset+1]
673
+ )
674
+ prev_offset = offset
675
+ if callback is not None:
676
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
677
+
678
+ unconditional_state.clear()
679
+ # ensure sequence has been entirely filled
680
+ assert not (gen_sequence == unknown_token).any()
681
+ # ensure gen_sequence pattern and mask are matching
682
+ # which means the gen_sequence is valid according to the pattern
683
+ assert (
684
+ gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
685
+ ).all()
686
+ # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
687
+ out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
688
+
689
+ # sanity checks over the returned codes and corresponding masks
690
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
691
+ assert (out_mask[..., :max_gen_len] == 1).all()
692
+
693
+ out_start_offset = start_offset if remove_prompts else 0
694
+ out_codes = out_codes[..., out_start_offset:max_gen_len]
695
+
696
+ # ensure the returned codes are all valid
697
+ assert (out_codes >= 0).all() and (out_codes <= self.card).all()
698
+ return out_codes
audiocraft/models/loaders.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility functions to load from the checkpoints.
9
+ Each checkpoint is a torch.saved dict with the following keys:
10
+ - 'xp.cfg': the hydra config as dumped during training. This should be used
11
+ to rebuild the object using the audiocraft.models.builders functions,
12
+ - 'model_best_state': a readily loadable best state for the model, including
13
+ the conditioner. The model obtained from `xp.cfg` should be compatible
14
+ with this state dict. In the case of a LM, the encodec model would not be
15
+ bundled along but instead provided separately.
16
+
17
+ Those functions also support loading from a remote location with the Torch Hub API.
18
+ They also support overriding some parameters, in particular the device and dtype
19
+ of the returned model.
20
+ """
21
+
22
+ from pathlib import Path
23
+ from huggingface_hub import hf_hub_download
24
+ import typing as tp
25
+ import os
26
+
27
+ from omegaconf import OmegaConf, DictConfig
28
+ import torch
29
+
30
+ import audiocraft
31
+ from . import builders
32
+ from .encodec import CompressionModel
33
+
34
+
35
+ def get_audiocraft_cache_dir() -> tp.Optional[str]:
36
+ return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
37
+
38
+
39
+ def _get_state_dict(
40
+ file_or_url_or_id: tp.Union[Path, str],
41
+ filename: tp.Optional[str] = None,
42
+ device='cpu',
43
+ cache_dir: tp.Optional[str] = None,
44
+ ):
45
+ if cache_dir is None:
46
+ cache_dir = get_audiocraft_cache_dir()
47
+ # Return the state dict either from a file or url
48
+ file_or_url_or_id = str(file_or_url_or_id)
49
+ assert isinstance(file_or_url_or_id, str)
50
+
51
+ if os.path.isfile(file_or_url_or_id):
52
+ return torch.load(file_or_url_or_id, map_location=device)
53
+
54
+ if os.path.isdir(file_or_url_or_id):
55
+ file = f"{file_or_url_or_id}/{filename}"
56
+ return torch.load(file, map_location=device)
57
+
58
+ elif file_or_url_or_id.startswith('https://'):
59
+ return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
60
+
61
+ else:
62
+ assert filename is not None, "filename needs to be defined if using HF checkpoints"
63
+
64
+ file = hf_hub_download(
65
+ repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
66
+ library_name="audiocraft", library_version=audiocraft.__version__)
67
+ return torch.load(file, map_location=device)
68
+
69
+
70
+ def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
71
+ return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
72
+
73
+
74
+ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
75
+ pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
76
+ if 'pretrained' in pkg:
77
+ return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
78
+ cfg = OmegaConf.create(pkg['xp.cfg'])
79
+ cfg.device = str(device)
80
+ model = builders.get_compression_model(cfg)
81
+ model.load_state_dict(pkg['best_state'])
82
+ model.eval()
83
+ return model
84
+
85
+
86
+ def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
87
+ return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
88
+
89
+
90
+ def _delete_param(cfg: DictConfig, full_name: str):
91
+ parts = full_name.split('.')
92
+ for part in parts[:-1]:
93
+ if part in cfg:
94
+ cfg = cfg[part]
95
+ else:
96
+ return
97
+ OmegaConf.set_struct(cfg, False)
98
+ if parts[-1] in cfg:
99
+ del cfg[parts[-1]]
100
+ OmegaConf.set_struct(cfg, True)
101
+
102
+
103
+ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
104
+ pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
105
+ cfg = OmegaConf.create(pkg['xp.cfg'])
106
+ cfg.device = str(device)
107
+ if cfg.device == 'cpu':
108
+ cfg.dtype = 'float32'
109
+ else:
110
+ cfg.dtype = 'float16'
111
+ _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
112
+ _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
113
+ _delete_param(cfg, 'conditioners.args.drop_desc_p')
114
+ model = builders.get_lm_model(cfg)
115
+ model.load_state_dict(pkg['best_state'])
116
+ model.eval()
117
+ model.cfg = cfg
118
+ return model
119
+
120
+
121
+ def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
122
+ filename: tp.Optional[str] = None,
123
+ cache_dir: tp.Optional[str] = None):
124
+ return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
125
+
126
+
127
+ def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
128
+ device='cpu',
129
+ filename: tp.Optional[str] = None,
130
+ cache_dir: tp.Optional[str] = None):
131
+ pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
132
+ models = []
133
+ processors = []
134
+ cfgs = []
135
+ sample_rate = pkg['sample_rate']
136
+ for i in range(pkg['n_bands']):
137
+ cfg = pkg[i]['cfg']
138
+ model = builders.get_diffusion_model(cfg)
139
+ model_dict = pkg[i]['model_state']
140
+ model.load_state_dict(model_dict)
141
+ model.to(device)
142
+ processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
143
+ processor_dict = pkg[i]['processor_state']
144
+ processor.load_state_dict(processor_dict)
145
+ processor.to(device)
146
+ models.append(model)
147
+ processors.append(processor)
148
+ cfgs.append(cfg)
149
+ return models, processors, cfgs
audiocraft/models/multibanddiffusion.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Multi Band Diffusion models as described in
9
+ "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
10
+ (paper link).
11
+ """
12
+
13
+ import typing as tp
14
+
15
+ import torch
16
+ import julius
17
+
18
+ from .unet import DiffusionUnet
19
+ from ..modules.diffusion_schedule import NoiseSchedule
20
+ from .encodec import CompressionModel
21
+ from ..solvers.compression import CompressionSolver
22
+ from .loaders import load_compression_model, load_diffusion_models
23
+
24
+
25
+ class DiffusionProcess:
26
+ """Sampling for a diffusion Model.
27
+
28
+ Args:
29
+ model (DiffusionUnet): Diffusion U-Net model.
30
+ noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
31
+ """
32
+ def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
33
+ """
34
+ """
35
+ self.model = model
36
+ self.schedule = noise_schedule
37
+
38
+ def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
39
+ step_list: tp.Optional[tp.List[int]] = None):
40
+ """Perform one diffusion process to generate one of the bands.
41
+
42
+ Args:
43
+ condition (tensor): The embeddings form the compression model.
44
+ initial_noise (tensor): The initial noise to start the process/
45
+ """
46
+ return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
47
+ condition=condition)
48
+
49
+
50
+ class MultiBandDiffusion:
51
+ """Sample from multiple diffusion models.
52
+
53
+ Args:
54
+ DPs (list of DiffusionProcess): Diffusion processes.
55
+ codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
56
+ """
57
+ def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
58
+ self.DPs = DPs
59
+ self.codec_model = codec_model
60
+ self.device = next(self.codec_model.parameters()).device
61
+
62
+ @property
63
+ def sample_rate(self) -> int:
64
+ return self.codec_model.sample_rate
65
+
66
+ @staticmethod
67
+ def get_mbd_musicgen(device=None):
68
+ """Load our diffusion models trained for MusicGen."""
69
+ if device is None:
70
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
71
+ path = 'facebook/multiband-diffusion'
72
+ filename = 'mbd_musicgen_32khz.th'
73
+ name = 'facebook/musicgen-small'
74
+ codec_model = load_compression_model(name, device=device)
75
+ models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
76
+ DPs = []
77
+ for i in range(len(models)):
78
+ schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
79
+ DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
80
+ return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
81
+
82
+ @staticmethod
83
+ def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
84
+ device: tp.Optional[tp.Union[torch.device, str]] = None,
85
+ n_q: tp.Optional[int] = None):
86
+ """Get the pretrained Models for MultibandDiffusion.
87
+
88
+ Args:
89
+ bw (float): Bandwidth of the compression model.
90
+ pretrained (bool): Whether to use / download if necessary the models.
91
+ device (torch.device or str, optional): Device on which the models are loaded.
92
+ n_q (int, optional): Number of quantizers to use within the compression model.
93
+ """
94
+ if device is None:
95
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
96
+ assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
97
+ if n_q is not None:
98
+ assert n_q in [2, 4, 8]
99
+ assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
100
+ f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
101
+ n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
102
+ codec_model = CompressionSolver.model_from_checkpoint(
103
+ '//pretrained/facebook/encodec_24khz', device=device)
104
+ codec_model.set_num_codebooks(n_q)
105
+ codec_model = codec_model.to(device)
106
+ path = 'facebook/multiband-diffusion'
107
+ filename = f'mbd_comp_{n_q}.pt'
108
+ models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
109
+ DPs = []
110
+ for i in range(len(models)):
111
+ schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
112
+ DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
113
+ return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
114
+
115
+ return MultiBandDiffusion(DPs, codec_model)
116
+
117
+ @torch.no_grad()
118
+ def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
119
+ """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.
120
+ Args:
121
+ wav (torch.Tensor): The audio that we want to extract the conditioning from
122
+ sample_rate (int): sample rate of the audio"""
123
+ if sample_rate != self.sample_rate:
124
+ wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
125
+ codes, scale = self.codec_model.encode(wav)
126
+ assert scale is None, "Scaled compression models not supported."
127
+ emb = self.get_emb(codes)
128
+ return emb
129
+
130
+ @torch.no_grad()
131
+ def get_emb(self, codes: torch.Tensor):
132
+ """Get latent representation from the discrete codes
133
+ Argrs:
134
+ codes (torch.Tensor): discrete tokens"""
135
+ emb = self.codec_model.decode_latent(codes)
136
+ return emb
137
+
138
+ def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
139
+ step_list: tp.Optional[tp.List[int]] = None):
140
+ """Generate Wavform audio from the latent embeddings of the compression model
141
+ Args:
142
+ emb (torch.Tensor): Conditioning embeddinds
143
+ size (none torch.Size): size of the output
144
+ if None this is computed from the typical upsampling of the model
145
+ step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step.
146
+ """
147
+ if size is None:
148
+ upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
149
+ size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
150
+ assert size[0] == emb.size(0)
151
+ out = torch.zeros(size).to(self.device)
152
+ for DP in self.DPs:
153
+ out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
154
+ return out
155
+
156
+ def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
157
+ """match the eq to the encodec output by matching the standard deviation of some frequency bands
158
+ Args:
159
+ wav (torch.Tensor): audio to equalize
160
+ ref (torch.Tensor):refenrence audio from which we match the spectrogram.
161
+ n_bands (int): number of bands of the eq
162
+ strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching.
163
+ """
164
+ split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
165
+ bands = split(wav)
166
+ bands_ref = split(ref)
167
+ out = torch.zeros_like(ref)
168
+ for i in range(n_bands):
169
+ out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
170
+ return out
171
+
172
+ def regenerate(self, wav: torch.Tensor, sample_rate: int):
173
+ """Regenerate a wavform through compression and diffusion regeneration.
174
+ Args:
175
+ wav (torch.Tensor): Original 'ground truth' audio
176
+ sample_rate (int): sample rate of the input (and output) wav
177
+ """
178
+ if sample_rate != self.codec_model.sample_rate:
179
+ wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
180
+ emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
181
+ size = wav.size()
182
+ out = self.generate(emb, size=size)
183
+ if sample_rate != self.codec_model.sample_rate:
184
+ out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
185
+ return out
186
+
187
+ def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
188
+ """Generate Waveform audio with diffusion from the discrete codes.
189
+ Args:
190
+ tokens (torch.Tensor): discrete codes
191
+ n_bands (int): bands for the eq matching.
192
+ """
193
+ wav_encodec = self.codec_model.decode(tokens)
194
+ condition = self.get_emb(tokens)
195
+ wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
196
+ return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
audiocraft/models/transformer_module.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange, repeat
6
+ from einops.layers.torch import Rearrange
7
+
8
+ class Residual(nn.Module):
9
+ def __init__(self, fn):
10
+ super().__init__()
11
+ self.fn = fn
12
+ def forward(self, x, **kwargs):
13
+ return self.fn(x, **kwargs) + x
14
+
15
+ class PreNorm(nn.Module):
16
+ def __init__(self, dim, fn):
17
+ super().__init__()
18
+ self.norm = nn.LayerNorm(dim)
19
+ self.fn = fn
20
+ def forward(self, x, **kwargs):
21
+ return self.fn(self.norm(x), **kwargs)
22
+
23
+ class FeedForward(nn.Module):
24
+ def __init__(self, dim, hidden_dim, dropout = 0.):
25
+ super().__init__()
26
+ self.net = nn.Sequential(
27
+ nn.Linear(dim, hidden_dim),
28
+ nn.GELU(),
29
+ nn.Dropout(dropout),
30
+ nn.Linear(hidden_dim, dim),
31
+ nn.Dropout(dropout)
32
+ )
33
+ def forward(self, x):
34
+ return self.net(x)
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
38
+ super().__init__()
39
+ inner_dim = dim_head * heads
40
+ project_out = not (heads == 1 and dim_head == dim)
41
+
42
+ self.heads = heads
43
+ self.scale = dim_head ** -0.5
44
+
45
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
46
+
47
+ self.to_out = nn.Sequential(
48
+ nn.Linear(inner_dim, dim),
49
+ nn.Dropout(dropout)
50
+ ) if project_out else nn.Identity()
51
+
52
+ def forward(self, x):
53
+ b, n, _, h = *x.shape, self.heads
54
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
55
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
56
+
57
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
58
+
59
+ attn = dots.softmax(dim=-1)
60
+
61
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
62
+ out = rearrange(out, 'b h n d -> b n (h d)')
63
+ out = self.to_out(out)
64
+ return out
65
+
66
+
67
+ class ReAttention(nn.Module):
68
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
69
+ super().__init__()
70
+ inner_dim = dim_head * heads
71
+ self.heads = heads
72
+ self.scale = dim_head ** -0.5
73
+
74
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
75
+
76
+ self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
77
+
78
+ self.reattn_norm = nn.Sequential(
79
+ Rearrange('b h i j -> b i j h'),
80
+ nn.LayerNorm(heads),
81
+ Rearrange('b i j h -> b h i j')
82
+ )
83
+
84
+ self.to_out = nn.Sequential(
85
+ nn.Linear(inner_dim, dim),
86
+ nn.Dropout(dropout)
87
+ )
88
+
89
+ def forward(self, x):
90
+ b, n, _, h = *x.shape, self.heads
91
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
92
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
93
+
94
+ # attention
95
+
96
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
97
+ attn = dots.softmax(dim=-1)
98
+
99
+ # re-attention
100
+
101
+ attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
102
+ attn = self.reattn_norm(attn)
103
+
104
+ # aggregate and out
105
+
106
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
107
+ out = rearrange(out, 'b h n d -> b n (h d)')
108
+ out = self.to_out(out)
109
+ return out
110
+
111
+ class LeFF(nn.Module):
112
+
113
+ def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
114
+ super().__init__()
115
+
116
+ scale_dim = dim*scale
117
+ self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
118
+ Rearrange('b n c -> b c n'),
119
+ nn.BatchNorm1d(scale_dim),
120
+ nn.GELU(),
121
+ Rearrange('b c (h w) -> b c h w', h=14, w=14)
122
+ )
123
+
124
+ self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
125
+ nn.BatchNorm2d(scale_dim),
126
+ nn.GELU(),
127
+ Rearrange('b c h w -> b (h w) c', h=14, w=14)
128
+ )
129
+
130
+ self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
131
+ Rearrange('b n c -> b c n'),
132
+ nn.BatchNorm1d(dim),
133
+ nn.GELU(),
134
+ Rearrange('b c n -> b n c')
135
+ )
136
+
137
+ def forward(self, x):
138
+ x = self.up_proj(x)
139
+ x = self.depth_conv(x)
140
+ x = self.down_proj(x)
141
+ return x
142
+
143
+
144
+ class LCAttention(nn.Module):
145
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
146
+ super().__init__()
147
+ inner_dim = dim_head * heads
148
+ project_out = not (heads == 1 and dim_head == dim)
149
+
150
+ self.heads = heads
151
+ self.scale = dim_head ** -0.5
152
+
153
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
154
+
155
+ self.to_out = nn.Sequential(
156
+ nn.Linear(inner_dim, dim),
157
+ nn.Dropout(dropout)
158
+ ) if project_out else nn.Identity()
159
+
160
+ def forward(self, x):
161
+ b, n, _, h = *x.shape, self.heads
162
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
163
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
164
+ q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query
165
+
166
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
167
+
168
+ attn = dots.softmax(dim=-1)
169
+
170
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
171
+ out = rearrange(out, 'b h n d -> b n (h d)')
172
+ out = self.to_out(out)
173
+ return out
174
+
175
+
176
+
177
+
audiocraft/models/unet.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Pytorch Unet Module used for diffusion.
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ import typing as tp
13
+
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+ from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
18
+
19
+
20
+ @dataclass
21
+ class Output:
22
+ sample: torch.Tensor
23
+
24
+
25
+ def get_model(cfg, channels: int, side: int, num_steps: int):
26
+ if cfg.model == 'unet':
27
+ return DiffusionUnet(
28
+ chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
29
+ else:
30
+ raise RuntimeError('Not Implemented')
31
+
32
+
33
+ class ResBlock(nn.Module):
34
+ def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
35
+ dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
36
+ dropout: float = 0.):
37
+ super().__init__()
38
+ stride = 1
39
+ padding = dilation * (kernel - stride) // 2
40
+ Conv = nn.Conv1d
41
+ Drop = nn.Dropout1d
42
+ self.norm1 = nn.GroupNorm(norm_groups, channels)
43
+ self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
44
+ self.activation1 = activation()
45
+ self.dropout1 = Drop(dropout)
46
+
47
+ self.norm2 = nn.GroupNorm(norm_groups, channels)
48
+ self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
49
+ self.activation2 = activation()
50
+ self.dropout2 = Drop(dropout)
51
+
52
+ def forward(self, x):
53
+ h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
54
+ h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
55
+ return x + h
56
+
57
+
58
+ class DecoderLayer(nn.Module):
59
+ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
60
+ norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
61
+ dropout: float = 0.):
62
+ super().__init__()
63
+ padding = (kernel - stride) // 2
64
+ self.res_blocks = nn.Sequential(
65
+ *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
66
+ for idx in range(res_blocks)])
67
+ self.norm = nn.GroupNorm(norm_groups, chin)
68
+ ConvTr = nn.ConvTranspose1d
69
+ self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
70
+ self.activation = activation()
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ x = self.res_blocks(x)
74
+ x = self.norm(x)
75
+ x = self.activation(x)
76
+ x = self.convtr(x)
77
+ return x
78
+
79
+
80
+ class EncoderLayer(nn.Module):
81
+ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
82
+ norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
83
+ dropout: float = 0.):
84
+ super().__init__()
85
+ padding = (kernel - stride) // 2
86
+ Conv = nn.Conv1d
87
+ self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
88
+ self.norm = nn.GroupNorm(norm_groups, chout)
89
+ self.activation = activation()
90
+ self.res_blocks = nn.Sequential(
91
+ *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
92
+ for idx in range(res_blocks)])
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ B, C, T = x.shape
96
+ stride, = self.conv.stride
97
+ pad = (stride - (T % stride)) % stride
98
+ x = F.pad(x, (0, pad))
99
+
100
+ x = self.conv(x)
101
+ x = self.norm(x)
102
+ x = self.activation(x)
103
+ x = self.res_blocks(x)
104
+ return x
105
+
106
+
107
+ class BLSTM(nn.Module):
108
+ """BiLSTM with same hidden units as input dim.
109
+ """
110
+ def __init__(self, dim, layers=2):
111
+ super().__init__()
112
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
113
+ self.linear = nn.Linear(2 * dim, dim)
114
+
115
+ def forward(self, x):
116
+ x = x.permute(2, 0, 1)
117
+ x = self.lstm(x)[0]
118
+ x = self.linear(x)
119
+ x = x.permute(1, 2, 0)
120
+ return x
121
+
122
+
123
+ class DiffusionUnet(nn.Module):
124
+ def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
125
+ max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
126
+ bilstm: bool = False, transformer: bool = False,
127
+ codec_dim: tp.Optional[int] = None, **kwargs):
128
+ super().__init__()
129
+ self.encoders = nn.ModuleList()
130
+ self.decoders = nn.ModuleList()
131
+ self.embeddings: tp.Optional[nn.ModuleList] = None
132
+ self.embedding = nn.Embedding(num_steps, hidden)
133
+ if emb_all_layers:
134
+ self.embeddings = nn.ModuleList()
135
+ self.condition_embedding: tp.Optional[nn.Module] = None
136
+ for d in range(depth):
137
+ encoder = EncoderLayer(chin, hidden, **kwargs)
138
+ decoder = DecoderLayer(hidden, chin, **kwargs)
139
+ self.encoders.append(encoder)
140
+ self.decoders.insert(0, decoder)
141
+ if emb_all_layers and d > 0:
142
+ assert self.embeddings is not None
143
+ self.embeddings.append(nn.Embedding(num_steps, hidden))
144
+ chin = hidden
145
+ hidden = min(int(chin * growth), max_channels)
146
+ self.bilstm: tp.Optional[nn.Module]
147
+ if bilstm:
148
+ self.bilstm = BLSTM(chin)
149
+ else:
150
+ self.bilstm = None
151
+ self.use_transformer = transformer
152
+ self.cross_attention = False
153
+ if transformer:
154
+ self.cross_attention = cross_attention
155
+ self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
156
+ cross_attention=cross_attention)
157
+
158
+ self.use_codec = False
159
+ if codec_dim is not None:
160
+ self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
161
+ self.use_codec = True
162
+
163
+ def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
164
+ skips = []
165
+ bs = x.size(0)
166
+ z = x
167
+ view_args = [1]
168
+ if type(step) is torch.Tensor:
169
+ step_tensor = step
170
+ else:
171
+ step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
172
+
173
+ for idx, encoder in enumerate(self.encoders):
174
+ z = encoder(z)
175
+ if idx == 0:
176
+ z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
177
+ elif self.embeddings is not None:
178
+ z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
179
+
180
+ skips.append(z)
181
+
182
+ if self.use_codec: # insert condition in the bottleneck
183
+ assert condition is not None, "Model defined for conditionnal generation"
184
+ condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
185
+ assert condition_emb.size(-1) <= 2 * z.size(-1), \
186
+ f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
187
+ if not self.cross_attention:
188
+
189
+ condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
190
+ assert z.size() == condition_emb.size()
191
+ z += condition_emb
192
+ cross_attention_src = None
193
+ else:
194
+ cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
195
+ B, T, C = cross_attention_src.shape
196
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
197
+ pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
198
+ cross_attention_src = cross_attention_src + pos_emb
199
+ if self.use_transformer:
200
+ z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
201
+ else:
202
+ if self.bilstm is None:
203
+ z = torch.zeros_like(z)
204
+ else:
205
+ z = self.bilstm(z)
206
+
207
+ for decoder in self.decoders:
208
+ s = skips.pop(-1)
209
+ z = z[:, :, :s.shape[2]]
210
+ z = z + s
211
+ z = decoder(z)
212
+
213
+ z = z[:, :, :x.shape[2]]
214
+ return Output(z)
audiocraft/models/vidmuse.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Audiocraft (https://github.com/facebookresearch/audiocraft)
2
+
3
+ import typing as tp
4
+ import warnings
5
+
6
+ import omegaconf
7
+ import torch
8
+
9
+ from .encodec import CompressionModel
10
+ from .lm import LMModel
11
+ from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
12
+ from .loaders import load_compression_model, load_lm_model
13
+ from ..data.audio_utils import convert_audio
14
+ from ..modules.conditioners import ConditioningAttributes, WavCondition
15
+ from ..utils.autocast import TorchAutocast
16
+
17
+ MelodyList = tp.List[tp.Optional[torch.Tensor]]
18
+ MelodyType = tp.Union[torch.Tensor, MelodyList]
19
+
20
+ # backward compatible names mapping
21
+ _HF_MODEL_CHECKPOINTS_MAP = {
22
+ "small": "facebook/musicgen-small",
23
+ "medium": "facebook/musicgen-medium",
24
+ "large": "facebook/musicgen-large",
25
+ "melody": "facebook/musicgen-melody",
26
+ }
27
+
28
+
29
+ class VidMuse:
30
+ """VidMuse main model with convenient generation API.
31
+
32
+ Args:
33
+ name (str): name of the model.
34
+ compression_model (CompressionModel): Compression model
35
+ used to map audio to invertible discrete representations.
36
+ lm (LMModel): Language model over discrete representations.
37
+ max_duration (float, optional): maximum duration the model can produce,
38
+ otherwise, inferred from the training params.
39
+ """
40
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
41
+ max_duration: tp.Optional[float] = None):
42
+ self.name = name
43
+ self.compression_model = compression_model
44
+ self.lm = lm
45
+ self.cfg: tp.Optional[omegaconf.DictConfig] = None
46
+ # Just to be safe, let's put everything in eval mode.
47
+ self.compression_model.eval()
48
+ self.lm.eval()
49
+
50
+ if hasattr(lm, 'cfg'):
51
+ cfg = lm.cfg
52
+ assert isinstance(cfg, omegaconf.DictConfig)
53
+ self.cfg = cfg
54
+
55
+ if self.cfg is not None:
56
+ self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
57
+
58
+ if max_duration is None:
59
+ if self.cfg is not None:
60
+ max_duration = lm.cfg.dataset.segment_duration # type: ignore
61
+ else:
62
+ raise ValueError("You must provide max_duration when building directly MusicGen")
63
+ assert max_duration is not None
64
+ self.max_duration: float = max_duration
65
+ self.device = next(iter(lm.parameters())).device
66
+
67
+ self.generation_params: dict = {}
68
+ self.set_generation_params(duration=15) # 15 seconds by default
69
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
70
+ if self.device.type == 'cpu':
71
+ self.autocast = TorchAutocast(enabled=False)
72
+ else:
73
+ self.autocast = TorchAutocast(
74
+ enabled=True, device_type=self.device.type, dtype=torch.float16)
75
+
76
+ @property
77
+ def frame_rate(self) -> float:
78
+ """Roughly the number of AR steps per seconds."""
79
+ return self.compression_model.frame_rate
80
+
81
+ @property
82
+ def sample_rate(self) -> int:
83
+ """Sample rate of the generated audio."""
84
+ return self.compression_model.sample_rate
85
+
86
+ @property
87
+ def audio_channels(self) -> int:
88
+ """Audio channels of the generated audio."""
89
+ return self.compression_model.channels
90
+
91
+ @staticmethod
92
+ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
93
+ """Return pretrained model, we provide four models:
94
+ - facebook/musicgen-small (300M), text to music,
95
+ # see: https://huggingface.co/facebook/musicgen-small
96
+ - facebook/musicgen-medium (1.5B), text to music,
97
+ # see: https://huggingface.co/facebook/musicgen-medium
98
+ - facebook/musicgen-melody (1.5B) text to music and text+melody to music,
99
+ # see: https://huggingface.co/facebook/musicgen-melody
100
+ - facebook/musicgen-large (3.3B), text to music,
101
+ # see: https://huggingface.co/facebook/musicgen-large
102
+ """
103
+ if device is None:
104
+ if torch.cuda.device_count():
105
+ device = 'cuda'
106
+ else:
107
+ device = 'cpu'
108
+
109
+ if name == 'debug':
110
+ # used only for unit tests
111
+ compression_model = get_debug_compression_model(device)
112
+ lm = get_debug_lm_model(device)
113
+ return VidMuse(name, compression_model, lm, max_duration=30)
114
+
115
+ if name in _HF_MODEL_CHECKPOINTS_MAP:
116
+ # warnings.warn(
117
+ # "MusicGen pretrained model relying on deprecated checkpoint mapping. " +
118
+ # f"Please use full pre-trained id instead: facebook/musicgen-{name}")
119
+ name = _HF_MODEL_CHECKPOINTS_MAP[name]
120
+
121
+ lm = load_lm_model(name, device=device)
122
+ compression_model = load_compression_model(name, device=device)
123
+ if 'self_wav' in lm.condition_provider.conditioners:
124
+ lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
125
+ lm.condition_provider.conditioners['self_wav']._use_masking = False
126
+ return VidMuse(name, compression_model, lm, max_duration=30)
127
+
128
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
129
+ top_p: float = 0.0, temperature: float = 1.0,
130
+ duration: float = 30.0, cfg_coef: float = 3.0,
131
+ two_step_cfg: bool = False, extend_stride: float = 29.5):
132
+ """Set the generation parameters for VidMuse.
133
+
134
+ Args:
135
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
136
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
137
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
138
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
139
+ duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
140
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
141
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
142
+ instead of batching together the two. This has some impact on how things
143
+ are padded but seems to have little impact in practice.
144
+ extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
145
+ should we extend the audio each time. Larger values will mean less context is
146
+ preserved, and shorter value will require extra computations.
147
+ """
148
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
149
+ self.extend_stride = extend_stride
150
+ self.duration = duration
151
+ self.generation_params = {
152
+ 'use_sampling': use_sampling,
153
+ 'temp': temperature,
154
+ 'top_k': top_k,
155
+ 'top_p': top_p,
156
+ 'cfg_coef': cfg_coef,
157
+ 'two_step_cfg': two_step_cfg,
158
+ }
159
+
160
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
161
+ """Override the default progress callback."""
162
+ self._progress_callback = progress_callback
163
+
164
+ def generate_unconditional(self, num_samples: int, progress: bool = False,
165
+ return_tokens: bool = False) -> tp.Union[torch.Tensor,
166
+ tp.Tuple[torch.Tensor, torch.Tensor]]:
167
+ """Generate samples in an unconditional manner.
168
+
169
+ Args:
170
+ num_samples (int): Number of samples to be generated.
171
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
172
+ """
173
+ descriptions: tp.List[tp.Optional[torch.Tensor]] = [None] * num_samples
174
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
175
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
176
+ if return_tokens:
177
+ return self.generate_audio(tokens), tokens
178
+ return self.generate_audio(tokens)
179
+
180
+ def generate(self, descriptions_list: tp.List, progress: bool = False, return_tokens: bool = False) \
181
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
182
+ """Generate samples conditioned on text.
183
+
184
+ Args:
185
+ descriptions (list of str): A list of strings used as text conditioning.
186
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
187
+ """
188
+
189
+ assert isinstance(descriptions_list,list)
190
+ assert len(descriptions_list)<=2
191
+
192
+ assert len(descriptions_list)==2
193
+ local_descriptions=[descriptions_list[0]]
194
+ global_descriptions=[descriptions_list[1]]
195
+
196
+ local_attributes = torch.stack(local_descriptions)
197
+ global_attributes = torch.stack(global_descriptions)
198
+
199
+ prompt_tokens = None
200
+ assert prompt_tokens is None
201
+
202
+ assert len(descriptions_list)==2
203
+ tokens = self._generate_tokens([local_attributes, global_attributes], prompt_tokens, progress)
204
+
205
+ if return_tokens:
206
+ return self.generate_audio(tokens), tokens
207
+ return self.generate_audio(tokens)
208
+
209
+ def generate_with_chroma(self, descriptions: tp.List[torch.Tensor], melody_wavs: MelodyType,
210
+ melody_sample_rate: int, progress: bool = False,
211
+ return_tokens: bool = False) -> tp.Union[torch.Tensor,
212
+ tp.Tuple[torch.Tensor, torch.Tensor]]:
213
+ """Generate samples conditioned on text and melody.
214
+
215
+ Args:
216
+ descriptions (list of str): A list of strings used as text conditioning.
217
+ melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
218
+ melody conditioning. Should have shape [B, C, T] with B matching the description length,
219
+ C=1 or 2. It can be [C, T] if there is a single description. It can also be
220
+ a list of [C, T] tensors.
221
+ melody_sample_rate: (int): Sample rate of the melody waveforms.
222
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
223
+ """
224
+ if isinstance(melody_wavs, torch.Tensor):
225
+ if melody_wavs.dim() == 2:
226
+ melody_wavs = melody_wavs[None]
227
+ if melody_wavs.dim() != 3:
228
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
229
+ melody_wavs = list(melody_wavs)
230
+ else:
231
+ for melody in melody_wavs:
232
+ if melody is not None:
233
+ assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
234
+
235
+ melody_wavs = [
236
+ convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
237
+ if wav is not None else None
238
+ for wav in melody_wavs]
239
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
240
+ melody_wavs=melody_wavs)
241
+ assert prompt_tokens is None
242
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
243
+ if return_tokens:
244
+ return self.generate_audio(tokens), tokens
245
+ return self.generate_audio(tokens)
246
+
247
+ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
248
+ descriptions: tp.Optional[tp.List[tp.Optional[torch.Tensor]]] = None,
249
+ progress: bool = False, return_tokens: bool = False) \
250
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
251
+ """Generate samples conditioned on audio prompts.
252
+
253
+ Args:
254
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
255
+ Prompt should be [B, C, T], or [C, T] if only one sample is generated.
256
+ prompt_sample_rate (int): Sampling rate of the given audio waveforms.
257
+ descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
258
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
259
+ """
260
+ if prompt.dim() == 2:
261
+ prompt = prompt[None]
262
+ if prompt.dim() != 3:
263
+ raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
264
+ prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
265
+ if descriptions is None:
266
+ descriptions = [None] * len(prompt)
267
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
268
+ assert prompt_tokens is not None
269
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
270
+ if return_tokens:
271
+ return self.generate_audio(tokens), tokens
272
+ return self.generate_audio(tokens)
273
+
274
+ @torch.no_grad()
275
+ def _prepare_tokens_and_attributes(
276
+ self,
277
+ descriptions: tp.Sequence[tp.Optional[str]],
278
+ prompt: tp.Optional[torch.Tensor],
279
+ melody_wavs: tp.Optional[MelodyList] = None,
280
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
281
+ """Prepare model inputs.
282
+
283
+ Args:
284
+ descriptions (list of str): A list of strings used as text conditioning.
285
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
286
+ melody_wavs (torch.Tensor, optional): A batch of waveforms
287
+ used as melody conditioning. Defaults to None.
288
+ """
289
+ attributes = [
290
+ ConditioningAttributes(text={'description': description})
291
+ for description in descriptions]
292
+
293
+ if melody_wavs is None:
294
+ for attr in attributes:
295
+ attr.wav['self_wav'] = WavCondition(
296
+ torch.zeros((1, 1, 1), device=self.device),
297
+ torch.tensor([0], device=self.device),
298
+ sample_rate=[self.sample_rate],
299
+ path=[None])
300
+ else:
301
+ if 'self_wav' not in self.lm.condition_provider.conditioners:
302
+ raise RuntimeError("This model doesn't support melody conditioning. "
303
+ "Use the `melody` model.")
304
+ assert len(melody_wavs) == len(descriptions), \
305
+ f"number of melody wavs must match number of descriptions! " \
306
+ f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
307
+ for attr, melody in zip(attributes, melody_wavs):
308
+ if melody is None:
309
+ attr.wav['self_wav'] = WavCondition(
310
+ torch.zeros((1, 1, 1), device=self.device),
311
+ torch.tensor([0], device=self.device),
312
+ sample_rate=[self.sample_rate],
313
+ path=[None])
314
+ else:
315
+ attr.wav['self_wav'] = WavCondition(
316
+ melody[None].to(device=self.device),
317
+ torch.tensor([melody.shape[-1]], device=self.device),
318
+ sample_rate=[self.sample_rate],
319
+ path=[None],
320
+ )
321
+
322
+ if prompt is not None:
323
+ if descriptions is not None:
324
+ assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
325
+ prompt = prompt.to(self.device)
326
+ prompt_tokens, scale = self.compression_model.encode(prompt)
327
+ assert scale is None
328
+ else:
329
+ prompt_tokens = None
330
+ return attributes, prompt_tokens
331
+
332
+ def _generate_tokens(self, attributes: tp.List,
333
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
334
+ """Generate discrete audio tokens given audio prompt and/or conditions.
335
+
336
+ Args:
337
+ attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
338
+ prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
339
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
340
+ Returns:
341
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
342
+ """
343
+ self.max_duration = 30
344
+
345
+ total_gen_len = int(self.duration * self.frame_rate)
346
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
347
+ current_gen_offset: int = 0
348
+
349
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
350
+ generated_tokens += current_gen_offset
351
+ if self._progress_callback is not None:
352
+ # Note that total_gen_len might be quite wrong depending on the
353
+ # codebook pattern used, but with delay it is almost accurate.
354
+ self._progress_callback(generated_tokens, total_gen_len)
355
+ else:
356
+ print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
357
+
358
+ if prompt_tokens is not None:
359
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
360
+ "Prompt is longer than audio to generate"
361
+
362
+ callback = None
363
+ if progress:
364
+ callback = _progress_callback
365
+
366
+ if self.duration <= self.max_duration:
367
+ # generate by sampling from LM, simple case.
368
+ with self.autocast:
369
+ gen_tokens = self.lm.generate(
370
+ prompt_tokens, attributes,
371
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
372
+
373
+ else:
374
+ # now this gets a bit messier, we need to handle prompts,
375
+ # melody conditioning etc.
376
+
377
+ assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
378
+ assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
379
+ # ref_wavs = [attr.wav['self_wav'] for attr in attributes]
380
+ all_tokens = []
381
+ if prompt_tokens is None: # None
382
+ prompt_length = 0
383
+ else:
384
+ all_tokens.append(prompt_tokens)
385
+ prompt_length = prompt_tokens.shape[-1]
386
+
387
+ stride_tokens = int(self.frame_rate * self.extend_stride) # max_duration - overlap_duration
388
+
389
+ self.fps = 2
390
+ stride_video_frames = int(self.fps * self.extend_stride)
391
+
392
+ while current_gen_offset + prompt_length < total_gen_len:
393
+ time_offset = current_gen_offset / self.frame_rate
394
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
395
+ max_gen_len = int(chunk_duration * self.frame_rate)
396
+
397
+ with self.autocast:
398
+ assert len(attributes)==2
399
+ # import pdb; pdb.set_trace()
400
+ gen_tokens = self.lm.generate(
401
+ prompt_tokens, [attributes[0][:,:,:int(chunk_duration*self.fps),:,:], attributes[1]],
402
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
403
+
404
+ if prompt_tokens is None:
405
+ all_tokens.append(gen_tokens)
406
+ else:
407
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
408
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
409
+ prompt_length = prompt_tokens.shape[-1]
410
+
411
+ if attributes[0].shape[2]-stride_video_frames < self.max_duration*self.fps:
412
+ attributes[0]=attributes[0][:,:,-self.max_duration*self.fps:,:,:]
413
+ else:
414
+ attributes[0]=attributes[0][:,:,stride_video_frames:,:,:]
415
+ current_gen_offset += stride_tokens
416
+
417
+ gen_tokens = torch.cat(all_tokens, dim=-1)
418
+ return gen_tokens
419
+
420
+ def generate_audio(self, gen_tokens: torch.Tensor):
421
+ """Generate Audio from tokens"""
422
+ assert gen_tokens.dim() == 3
423
+ with torch.no_grad():
424
+ gen_audio = self.compression_model.decode(gen_tokens, None)
425
+ return gen_audio
audiocraft/modules/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Modules used for building the models."""
7
+
8
+ # flake8: noqa
9
+ from .conv import (
10
+ NormConv1d,
11
+ NormConv2d,
12
+ NormConvTranspose1d,
13
+ NormConvTranspose2d,
14
+ StreamableConv1d,
15
+ StreamableConvTranspose1d,
16
+ pad_for_conv1d,
17
+ pad1d,
18
+ unpad1d,
19
+ )
20
+ from .lstm import StreamableLSTM
21
+ from .seanet import SEANetEncoder, SEANetDecoder
22
+ from .transformer import StreamingTransformer
audiocraft/modules/activations.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+ from typing import Union, Callable
11
+
12
+
13
+ class CustomGLU(nn.Module):
14
+ """Custom Gated Linear Unit activation.
15
+ Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
16
+ of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
17
+ function (i.e. sigmoid, swish, etc.).
18
+
19
+ Args:
20
+ activation (nn.Module): The custom activation to apply in the Gated Linear Unit
21
+ dim (int): the dimension on which to split the input. Default: -1
22
+
23
+ Shape:
24
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
25
+ dimensions
26
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
27
+
28
+ Examples::
29
+ >>> m = CustomGLU(nn.Sigmoid())
30
+ >>> input = torch.randn(4, 2)
31
+ >>> output = m(input)
32
+ """
33
+ def __init__(self, activation: nn.Module, dim: int = -1):
34
+ super(CustomGLU, self).__init__()
35
+ self.dim = dim
36
+ self.activation = activation
37
+
38
+ def forward(self, x: Tensor):
39
+ assert x.shape[self.dim] % 2 == 0 # M = N / 2
40
+ a, b = torch.chunk(x, 2, dim=self.dim)
41
+ return a * self.activation(b)
42
+
43
+
44
+ class SwiGLU(CustomGLU):
45
+ """SiLU Gated Linear Unit activation.
46
+ Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
47
+ the first half of the input matrices, :math:`b` is the second half.
48
+
49
+ Args:
50
+ dim (int): the dimension on which to split the input. Default: -1
51
+ """
52
+ def __init__(self, dim: int = -1):
53
+ super(SwiGLU, self).__init__(nn.SiLU(), dim)
54
+
55
+
56
+ class GeGLU(CustomGLU):
57
+ """GeLU Gated Linear Unit activation.
58
+ Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
59
+ the first half of the input matrices, :math:`b` is the second half.
60
+
61
+ Args:
62
+ dim (int): the dimension on which to split the input. Default: -1
63
+ """
64
+ def __init__(self, dim: int = -1):
65
+ super(GeGLU, self).__init__(nn.GELU(), dim)
66
+
67
+
68
+ class ReGLU(CustomGLU):
69
+ """ReLU Gated Linear Unit activation.
70
+ Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
71
+ the first half of the input matrices, :math:`b` is the second half.
72
+
73
+ Args:
74
+ dim (int): the dimension on which to split the input. Default: -1
75
+ """
76
+ def __init__(self, dim: int = -1):
77
+ super(ReGLU, self).__init__(nn.ReLU(), dim)
78
+
79
+
80
+ def get_activation_fn(
81
+ activation: Union[str, Callable[[Tensor], Tensor]]
82
+ ) -> Union[str, Callable[[Tensor], Tensor]]:
83
+ """Helper function to map an activation string to the activation class.
84
+ If the supplied activation is not a string that is recognized, the activation is passed back.
85
+
86
+ Args:
87
+ activation (str, or Callable[[Tensor], Tensor]): Activation to check
88
+ """
89
+ if isinstance(activation, str):
90
+ if activation == "reglu":
91
+ return ReGLU()
92
+ elif activation == "geglu":
93
+ return GeGLU()
94
+ elif activation == "swiglu":
95
+ return SwiGLU()
96
+ return activation
audiocraft/modules/chroma.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import typing as tp
7
+
8
+ from einops import rearrange
9
+ from librosa import filters
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+ import torchaudio
14
+
15
+
16
+ class ChromaExtractor(nn.Module):
17
+ """Chroma extraction and quantization.
18
+
19
+ Args:
20
+ sample_rate (int): Sample rate for the chroma extraction.
21
+ n_chroma (int): Number of chroma bins for the chroma extraction.
22
+ radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
23
+ nfft (int, optional): Number of FFT.
24
+ winlen (int, optional): Window length.
25
+ winhop (int, optional): Window hop size.
26
+ argmax (bool, optional): Whether to use argmax. Defaults to False.
27
+ norm (float, optional): Norm for chroma normalization. Defaults to inf.
28
+ """
29
+ def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
30
+ winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
31
+ norm: float = torch.inf):
32
+ super().__init__()
33
+ self.winlen = winlen or 2 ** radix2_exp
34
+ self.nfft = nfft or self.winlen
35
+ self.winhop = winhop or (self.winlen // 4)
36
+ self.sample_rate = sample_rate
37
+ self.n_chroma = n_chroma
38
+ self.norm = norm
39
+ self.argmax = argmax
40
+ self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
41
+ n_chroma=self.n_chroma)), persistent=False)
42
+ self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
43
+ hop_length=self.winhop, power=2, center=True,
44
+ pad=0, normalized=True)
45
+
46
+ def forward(self, wav: torch.Tensor) -> torch.Tensor:
47
+ T = wav.shape[-1]
48
+ # in case we are getting a wav that was dropped out (nullified)
49
+ # from the conditioner, make sure wav length is no less that nfft
50
+ if T < self.nfft:
51
+ pad = self.nfft - T
52
+ r = 0 if pad % 2 == 0 else 1
53
+ wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
54
+ assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
55
+
56
+ spec = self.spec(wav).squeeze(1)
57
+ raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
58
+ norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
59
+ norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
60
+
61
+ if self.argmax:
62
+ idx = norm_chroma.argmax(-1, keepdim=True)
63
+ norm_chroma[:] = 0
64
+ norm_chroma.scatter_(dim=-1, index=idx, value=1)
65
+
66
+ return norm_chroma
audiocraft/modules/codebooks_patterns.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import namedtuple
8
+ from dataclasses import dataclass
9
+ from functools import lru_cache
10
+ import logging
11
+ import typing as tp
12
+
13
+ from abc import ABC, abstractmethod
14
+ import torch
15
+
16
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
17
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class Pattern:
23
+ """Base implementation of a pattern over a sequence with multiple codebooks.
24
+
25
+ The codebook pattern consists in a layout, defining for each sequence step
26
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
27
+ The first item of the pattern is always an empty list in order to properly insert a special token
28
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
29
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
30
+
31
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
32
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
33
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
34
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
35
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
36
+ is returned along with a mask indicating valid tokens.
37
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
38
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
39
+ to fill and specify invalid positions if needed.
40
+ See the dedicated methods for more details.
41
+ """
42
+ # Pattern layout, for each sequence step, we have a list of coordinates
43
+ # corresponding to the original codebook timestep and position.
44
+ # The first list is always an empty list in order to properly insert
45
+ # a special token to start with.
46
+ layout: PatternLayout
47
+ timesteps: int
48
+ n_q: int
49
+
50
+ def __post_init__(self):
51
+ assert len(self.layout) > 0
52
+ assert self.layout[0] == []
53
+ self._validate_layout()
54
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
55
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
56
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
57
+
58
+ def _validate_layout(self):
59
+ """Runs checks on the layout to ensure a valid pattern is defined.
60
+ A pattern is considered invalid if:
61
+ - Multiple timesteps for a same codebook are defined in the same sequence step
62
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
63
+ (this would mean that we have future timesteps before past timesteps).
64
+ """
65
+ q_timesteps = {q: 0 for q in range(self.n_q)}
66
+ for s, seq_coords in enumerate(self.layout):
67
+ if len(seq_coords) > 0:
68
+ qs = set()
69
+ for coord in seq_coords:
70
+ qs.add(coord.q)
71
+ last_q_timestep = q_timesteps[coord.q]
72
+ assert coord.t >= last_q_timestep, \
73
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
74
+ q_timesteps[coord.q] = coord.t
75
+ # each sequence step contains at max 1 coordinate per codebook
76
+ assert len(qs) == len(seq_coords), \
77
+ f"Multiple entries for a same codebook are found at step {s}"
78
+
79
+ @property
80
+ def num_sequence_steps(self):
81
+ return len(self.layout) - 1
82
+
83
+ @property
84
+ def max_delay(self):
85
+ max_t_in_seq_coords = 0
86
+ for seq_coords in self.layout[1:]:
87
+ for coords in seq_coords:
88
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
89
+ return max_t_in_seq_coords - self.timesteps
90
+
91
+ @property
92
+ def valid_layout(self):
93
+ valid_step = len(self.layout) - self.max_delay
94
+ return self.layout[:valid_step]
95
+
96
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
97
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
98
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
99
+ and the actual codebook coordinates.
100
+ """
101
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
102
+ if q is not None:
103
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
104
+ coords = []
105
+ for s, seq_codes in enumerate(self.layout):
106
+ for code in seq_codes:
107
+ if code.t == t and (q is None or code.q == q):
108
+ coords.append((s, code))
109
+ return coords
110
+
111
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
112
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
113
+
114
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
115
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
116
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
117
+
118
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
119
+ device: tp.Union[torch.device, str] = 'cpu'):
120
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
121
+
122
+ Args:
123
+ timesteps (int): Maximum number of timesteps steps to consider.
124
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
125
+ device (torch.device or str): Device for created tensors.
126
+ Returns:
127
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
128
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
129
+ """
130
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
131
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
132
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
133
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
134
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
135
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
136
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
137
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
138
+ # fill indexes with last sequence step value that will correspond to our special token
139
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
140
+ # which will correspond to the index: n_q * timesteps
141
+ indexes[:] = n_q * timesteps
142
+ # iterate over the pattern and fill scattered indexes and mask
143
+ for s, sequence_coords in enumerate(ref_layout):
144
+ for coords in sequence_coords:
145
+ if coords.t < timesteps:
146
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
147
+ mask[coords.q, s] = 1
148
+ indexes = torch.from_numpy(indexes).to(device)
149
+ mask = torch.from_numpy(mask).to(device)
150
+ return indexes, mask
151
+
152
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
153
+ """Build sequence corresponding to the pattern from the input tensor z.
154
+ The sequence is built using up to sequence_steps if specified, and non-pattern
155
+ coordinates are filled with the special token.
156
+
157
+ Args:
158
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
159
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
160
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
161
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
162
+ Returns:
163
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
164
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
165
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
166
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
167
+ """
168
+ B, K, T = z.shape
169
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
170
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
171
+ )
172
+ z = z.view(B, -1)
173
+ # we append the special token as the last index of our flattened z tensor
174
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
175
+ values = z[:, indexes.view(-1)]
176
+ values = values.view(B, K, indexes.shape[-1])
177
+ return values, indexes, mask
178
+
179
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
180
+ keep_only_valid_steps: bool = False,
181
+ is_model_output: bool = False,
182
+ device: tp.Union[torch.device, str] = 'cpu'):
183
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
184
+ from interleaving pattern.
185
+
186
+ Args:
187
+ sequence_steps (int): Sequence steps.
188
+ n_q (int): Number of codebooks.
189
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
190
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
191
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
192
+ device (torch.device or str): Device for created tensors.
193
+ Returns:
194
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
195
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
196
+ """
197
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
198
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
199
+ timesteps = self.timesteps
200
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
201
+ assert sequence_steps <= len(ref_layout), \
202
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
203
+
204
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
205
+ if is_model_output:
206
+ ref_layout = ref_layout[1:]
207
+
208
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
209
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
210
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
211
+ # fill indexes with last sequence step value that will correspond to our special token
212
+ indexes[:] = n_q * sequence_steps
213
+ for s, sequence_codes in enumerate(ref_layout):
214
+ if s < sequence_steps:
215
+ for code in sequence_codes:
216
+ if code.t < timesteps:
217
+ indexes[code.q, code.t] = s + code.q * sequence_steps
218
+ mask[code.q, code.t] = 1
219
+ indexes = torch.from_numpy(indexes).to(device)
220
+ mask = torch.from_numpy(mask).to(device)
221
+ return indexes, mask
222
+
223
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
224
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
225
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
226
+ are filled with the special token.
227
+
228
+ Args:
229
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
230
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
231
+ Returns:
232
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
233
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
234
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
235
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
236
+ """
237
+ B, K, S = s.shape
238
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
239
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
240
+ )
241
+ s = s.view(B, -1)
242
+ # we append the special token as the last index of our flattened z tensor
243
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
244
+ values = s[:, indexes.view(-1)]
245
+ values = values.view(B, K, indexes.shape[-1])
246
+ return values, indexes, mask
247
+
248
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
249
+ """Revert model logits obtained on a sequence built from the pattern
250
+ back to a tensor matching the original sequence.
251
+
252
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
253
+ 1. It is designed to work with the extra cardinality dimension
254
+ 2. We return the logits for the first sequence item that matches the special_token and
255
+ which matching target in the original sequence is the first item of the sequence,
256
+ while we skip the last logits as there is no matching target
257
+ """
258
+ B, card, K, S = logits.shape
259
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
260
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
261
+ )
262
+ logits = logits.reshape(B, card, -1)
263
+ # we append the special token as the last index of our flattened z tensor
264
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
265
+ values = logits[:, :, indexes.view(-1)]
266
+ values = values.view(B, card, K, indexes.shape[-1])
267
+ return values, indexes, mask
268
+
269
+
270
+ class CodebooksPatternProvider(ABC):
271
+ """Abstraction around providing pattern for interleaving codebooks.
272
+
273
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
274
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
275
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
276
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
277
+ can be used to construct a new sequence from the original codes respecting the specified
278
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
279
+ being a tuple with the original timestep and codebook to build the new sequence.
280
+ Note that all patterns must start with an empty list that is then used to insert a first
281
+ sequence step of special tokens in the newly generated sequence.
282
+
283
+ Args:
284
+ n_q (int): number of codebooks.
285
+ cached (bool): if True, patterns for a given length are cached. In general
286
+ that should be true for efficiency reason to avoid synchronization points.
287
+ """
288
+ def __init__(self, n_q: int, cached: bool = True):
289
+ assert n_q > 0
290
+ self.n_q = n_q
291
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
292
+
293
+ @abstractmethod
294
+ def get_pattern(self, timesteps: int) -> Pattern:
295
+ """Builds pattern with specific interleaving between codebooks.
296
+
297
+ Args:
298
+ timesteps (int): Total number of timesteps.
299
+ """
300
+ raise NotImplementedError()
301
+
302
+
303
+ class DelayedPatternProvider(CodebooksPatternProvider):
304
+ """Provider for delayed pattern across delayed codebooks.
305
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
306
+ from different timesteps.
307
+
308
+ Example:
309
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
310
+ [[1, 2, 3, 4],
311
+ [1, 2, 3, 4],
312
+ [1, 2, 3, 4]]
313
+ The resulting sequence obtained from the returned pattern is:
314
+ [[S, 1, 2, 3, 4],
315
+ [S, S, 1, 2, 3],
316
+ [S, S, S, 1, 2]]
317
+ (with S being a special token)
318
+
319
+ Args:
320
+ n_q (int): Number of codebooks.
321
+ delays (list of int, optional): Delay for each of the codebooks.
322
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
323
+ flatten_first (int): Flatten the first N timesteps.
324
+ empty_initial (int): Prepend with N empty list of coordinates.
325
+ """
326
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
327
+ flatten_first: int = 0, empty_initial: int = 0):
328
+ super().__init__(n_q)
329
+ if delays is None:
330
+ delays = list(range(n_q))
331
+ self.delays = delays
332
+ self.flatten_first = flatten_first
333
+ self.empty_initial = empty_initial
334
+ assert len(self.delays) == self.n_q
335
+ assert sorted(self.delays) == self.delays
336
+
337
+ def get_pattern(self, timesteps: int) -> Pattern:
338
+ out: PatternLayout = [[]]
339
+ max_delay = max(self.delays)
340
+ if self.empty_initial:
341
+ out += [[] for _ in range(self.empty_initial)]
342
+ if self.flatten_first:
343
+ for t in range(min(timesteps, self.flatten_first)):
344
+ for q in range(self.n_q):
345
+ out.append([LayoutCoord(t, q)])
346
+ for t in range(self.flatten_first, timesteps + max_delay):
347
+ v = []
348
+ for q, delay in enumerate(self.delays):
349
+ t_for_q = t - delay
350
+ if t_for_q >= self.flatten_first:
351
+ v.append(LayoutCoord(t_for_q, q))
352
+ out.append(v)
353
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
354
+
355
+
356
+ class ParallelPatternProvider(DelayedPatternProvider):
357
+ """Provider for parallel pattern across codebooks.
358
+ This pattern provider is a special case of the delayed pattern with actually no delay,
359
+ hence delays=repeat(0, n_q).
360
+
361
+ Args:
362
+ n_q (int): Number of codebooks.
363
+ """
364
+ def __init__(self, n_q: int):
365
+ super().__init__(n_q, [0] * n_q)
366
+
367
+
368
+ class UnrolledPatternProvider(CodebooksPatternProvider):
369
+ """Provider for unrolling codebooks pattern.
370
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
371
+ while also specifying a given delay between the flattened codebooks representation, allowing to
372
+ unroll the codebooks in the sequence.
373
+
374
+ Example:
375
+ 1. Flattening of the codebooks.
376
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
377
+ taking n_q = 3 and timesteps = 4:
378
+ [[1, 2, 3, 4],
379
+ [1, 2, 3, 4],
380
+ [1, 2, 3, 4]]
381
+ will result into:
382
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
383
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
384
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
385
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
386
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
387
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
388
+ [[1, 2, 3, 4],
389
+ [1, 2, 3, 4],
390
+ [1, 2, 3, 4]]
391
+ will result into:
392
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
393
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
394
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
395
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
396
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
397
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
398
+ and delays = [0, 3, 3]:
399
+ [[1, 2, 3, 4],
400
+ [1, 2, 3, 4],
401
+ [1, 2, 3, 4]]
402
+ will result into:
403
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
404
+ [S, S, S, 1, S, 2, S, 3, S, 4],
405
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
406
+
407
+ Args:
408
+ n_q (int): Number of codebooks.
409
+ flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
410
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
411
+ have n_q extra steps for each timestep.
412
+ delays (list of int, optional): Delay for each of the codebooks. If not defined,
413
+ no delay is added and therefore will default to [0] * ``n_q``.
414
+ Note that two codebooks that will be flattened to the same inner step
415
+ should have the same delay, otherwise the pattern is considered as invalid.
416
+ """
417
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
418
+
419
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
420
+ delays: tp.Optional[tp.List[int]] = None):
421
+ super().__init__(n_q)
422
+ if flattening is None:
423
+ flattening = list(range(n_q))
424
+ if delays is None:
425
+ delays = [0] * n_q
426
+ assert len(flattening) == n_q
427
+ assert len(delays) == n_q
428
+ assert sorted(flattening) == flattening
429
+ assert sorted(delays) == delays
430
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
431
+ self.max_delay = max(delays)
432
+
433
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
434
+ """Build a flattened codebooks representation as a dictionary of inner step
435
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
436
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
437
+ """
438
+ flattened_codebooks: dict = {}
439
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
440
+ if inner_step not in flattened_codebooks:
441
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
442
+ else:
443
+ flat_codebook = flattened_codebooks[inner_step]
444
+ assert flat_codebook.delay == delay, (
445
+ "Delay and flattening between codebooks is inconsistent: ",
446
+ "two codebooks flattened to the same position should have the same delay."
447
+ )
448
+ flat_codebook.codebooks.append(q)
449
+ flattened_codebooks[inner_step] = flat_codebook
450
+ return flattened_codebooks
451
+
452
+ @property
453
+ def _num_inner_steps(self):
454
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
455
+ """
456
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
457
+
458
+ def num_virtual_steps(self, timesteps: int) -> int:
459
+ return timesteps * self._num_inner_steps + 1
460
+
461
+ def get_pattern(self, timesteps: int) -> Pattern:
462
+ """Builds pattern for delay across codebooks.
463
+
464
+ Args:
465
+ timesteps (int): Total number of timesteps.
466
+ """
467
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
468
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
469
+ indexed_out: list = [(-1, [])]
470
+ max_timesteps = timesteps + self.max_delay
471
+ for t in range(max_timesteps):
472
+ # for each timestep, we unroll the flattened codebooks,
473
+ # emitting the sequence step with the corresponding delay
474
+ for step in range(self._num_inner_steps):
475
+ if step in self._flattened_codebooks:
476
+ # we have codebooks at this virtual step to emit
477
+ step_codebooks = self._flattened_codebooks[step]
478
+ t_for_q = t + step_codebooks.delay
479
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
480
+ if t_for_q < max_timesteps and t < max_timesteps:
481
+ indexed_out.append((t_for_q, coords))
482
+ else:
483
+ # there is no codebook in this virtual step so we emit an empty list
484
+ indexed_out.append((t, []))
485
+ out = [coords for _, coords in sorted(indexed_out)]
486
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
487
+
488
+
489
+ class CoarseFirstPattern(CodebooksPatternProvider):
490
+ """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
491
+ potentially with delays.
492
+
493
+ ..Warning:: You must always generate the full training duration at test time, for instance,
494
+ 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
495
+ location. This is due to the non causality of the remaining codebooks with respect to
496
+ the first ones.
497
+
498
+ Args:
499
+ n_q (int): Number of codebooks.
500
+ delays (list of int, optional): Delay for each of the codebooks.
501
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
502
+ """
503
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
504
+ super().__init__(n_q)
505
+ if delays is None:
506
+ delays = [0] * (n_q - 1)
507
+ self.delays = delays
508
+ assert len(self.delays) == self.n_q - 1
509
+ assert sorted(self.delays) == self.delays
510
+
511
+ def get_pattern(self, timesteps: int) -> Pattern:
512
+ out: PatternLayout = [[]]
513
+ for t in range(timesteps):
514
+ out.append([LayoutCoord(t, 0)])
515
+ max_delay = max(self.delays)
516
+ for t in range(timesteps + max_delay):
517
+ v = []
518
+ for q, delay in enumerate(self.delays):
519
+ t_for_q = t - delay
520
+ if t_for_q >= 0:
521
+ v.append(LayoutCoord(t_for_q, q + 1))
522
+ out.append(v)
523
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
524
+
525
+
526
+ class MusicLMPattern(CodebooksPatternProvider):
527
+ """Almost MusicLM style pattern. This is equivalent to full flattening
528
+ but in a different order.
529
+
530
+ Args:
531
+ n_q (int): Number of codebooks.
532
+ group_by (int): Number of codebooks to group together.
533
+ """
534
+ def __init__(self, n_q: int, group_by: int = 2):
535
+ super().__init__(n_q)
536
+ self.group_by = group_by
537
+
538
+ def get_pattern(self, timesteps: int) -> Pattern:
539
+ out: PatternLayout = [[]]
540
+ for offset in range(0, self.n_q, self.group_by):
541
+ for t in range(timesteps):
542
+ for q in range(offset, offset + self.group_by):
543
+ out.append([LayoutCoord(t, q)])
544
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
audiocraft/modules/conditioners.py ADDED
@@ -0,0 +1,1357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from copy import deepcopy
9
+ from dataclasses import dataclass, field
10
+ from itertools import chain
11
+ import logging
12
+ import math
13
+ from pathlib import Path
14
+ import random
15
+ import re
16
+ import typing as tp
17
+ import warnings
18
+
19
+ import einops
20
+ from num2words import num2words
21
+ import spacy
22
+ from transformers import RobertaTokenizer # type: ignore
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ from torch.nn.utils.rnn import pad_sequence
27
+
28
+ from .chroma import ChromaExtractor
29
+ from .streaming import StreamingModule
30
+ from .transformer import create_sin_embedding
31
+ from ..data.audio import audio_read
32
+ from ..data.audio_dataset import SegmentInfo
33
+ from ..data.audio_utils import convert_audio
34
+ from ..environment import AudioCraftEnvironment
35
+ from ..quantization import ResidualVectorQuantizer
36
+ from ..utils.autocast import TorchAutocast
37
+ from ..utils.cache import EmbeddingCache
38
+ from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
39
+
40
+
41
+
42
+ logger = logging.getLogger(__name__)
43
+ TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
44
+ ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
45
+
46
+
47
+ class WavCondition(tp.NamedTuple):
48
+ wav: torch.Tensor
49
+ length: torch.Tensor
50
+ sample_rate: tp.List[int]
51
+ path: tp.List[tp.Optional[str]] = []
52
+ seek_time: tp.List[tp.Optional[float]] = []
53
+
54
+
55
+ class JointEmbedCondition(tp.NamedTuple):
56
+ wav: torch.Tensor
57
+ text: tp.List[tp.Optional[str]]
58
+ length: torch.Tensor
59
+ sample_rate: tp.List[int]
60
+ path: tp.List[tp.Optional[str]] = []
61
+ seek_time: tp.List[tp.Optional[float]] = []
62
+
63
+
64
+ @dataclass
65
+ class ConditioningAttributes:
66
+ text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
67
+ wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
68
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
69
+
70
+ def __getitem__(self, item):
71
+ return getattr(self, item)
72
+
73
+ @property
74
+ def text_attributes(self):
75
+ return self.text.keys()
76
+
77
+ @property
78
+ def wav_attributes(self):
79
+ return self.wav.keys()
80
+
81
+ @property
82
+ def joint_embed_attributes(self):
83
+ return self.joint_embed.keys()
84
+
85
+ @property
86
+ def attributes(self):
87
+ return {
88
+ "text": self.text_attributes,
89
+ "wav": self.wav_attributes,
90
+ "joint_embed": self.joint_embed_attributes,
91
+ }
92
+
93
+ def to_flat_dict(self):
94
+ return {
95
+ **{f"text.{k}": v for k, v in self.text.items()},
96
+ **{f"wav.{k}": v for k, v in self.wav.items()},
97
+ **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
98
+ }
99
+
100
+ @classmethod
101
+ def from_flat_dict(cls, x):
102
+ out = cls()
103
+ for k, v in x.items():
104
+ kind, att = k.split(".")
105
+ out[kind][att] = v
106
+ return out
107
+
108
+
109
+ class SegmentWithAttributes(SegmentInfo):
110
+ """Base class for all dataclasses that are used for conditioning.
111
+ All child classes should implement `to_condition_attributes` that converts
112
+ the existing attributes to a dataclass of type ConditioningAttributes.
113
+ """
114
+ def to_condition_attributes(self) -> ConditioningAttributes:
115
+ raise NotImplementedError()
116
+
117
+
118
+ def nullify_condition(condition: ConditionType, dim: int = 1):
119
+ """Transform an input condition to a null condition.
120
+ The way it is done by converting it to a single zero vector similarly
121
+ to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
122
+
123
+ Args:
124
+ condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
125
+ dim (int): The dimension that will be truncated (should be the time dimension)
126
+ WARNING!: dim should not be the batch dimension!
127
+ Returns:
128
+ ConditionType: A tuple of null condition and mask
129
+ """
130
+ assert dim != 0, "dim cannot be the batch dimension!"
131
+ assert isinstance(condition, tuple) and \
132
+ isinstance(condition[0], torch.Tensor) and \
133
+ isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
134
+ cond, mask = condition
135
+ B = cond.shape[0]
136
+ last_dim = cond.dim() - 1
137
+ out = cond.transpose(dim, last_dim)
138
+ out = 0. * out[..., :1]
139
+ out = out.transpose(dim, last_dim)
140
+ mask = torch.zeros((B, 1), device=out.device).int()
141
+ assert cond.dim() == out.dim()
142
+ return out, mask
143
+
144
+
145
+ def nullify_wav(cond: WavCondition) -> WavCondition:
146
+ """Transform a WavCondition to a nullified WavCondition.
147
+ It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
148
+
149
+ Args:
150
+ cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
151
+ Returns:
152
+ WavCondition: Nullified wav condition.
153
+ """
154
+ null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
155
+ return WavCondition(
156
+ wav=null_wav,
157
+ length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
158
+ sample_rate=cond.sample_rate,
159
+ path=[None] * cond.wav.shape[0],
160
+ seek_time=[None] * cond.wav.shape[0],
161
+ )
162
+
163
+
164
+ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
165
+ """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
166
+ and replacing metadata by dummy attributes.
167
+
168
+ Args:
169
+ cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
170
+ """
171
+ null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
172
+ return JointEmbedCondition(
173
+ wav=null_wav, text=[None] * len(embed.text),
174
+ length=torch.LongTensor([0]).to(embed.wav.device),
175
+ sample_rate=embed.sample_rate,
176
+ path=[None] * embed.wav.shape[0],
177
+ seek_time=[0] * embed.wav.shape[0],
178
+ )
179
+
180
+
181
+ class Tokenizer:
182
+ """Base tokenizer implementation
183
+ (in case we want to introduce more advances tokenizers in the future).
184
+ """
185
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
186
+ raise NotImplementedError()
187
+
188
+
189
+ class WhiteSpaceTokenizer(Tokenizer):
190
+ """This tokenizer should be used for natural language descriptions.
191
+ For example:
192
+ ["he didn't, know he's going home.", 'shorter sentence'] =>
193
+ [[78, 62, 31, 4, 78, 25, 19, 34],
194
+ [59, 77, 0, 0, 0, 0, 0, 0]]
195
+ """
196
+ PUNCTUATION = "?:!.,;"
197
+
198
+ def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
199
+ lemma: bool = True, stopwords: bool = True) -> None:
200
+ self.n_bins = n_bins
201
+ self.pad_idx = pad_idx
202
+ self.lemma = lemma
203
+ self.stopwords = stopwords
204
+ try:
205
+ self.nlp = spacy.load(language)
206
+ except IOError:
207
+ spacy.cli.download(language) # type: ignore
208
+ self.nlp = spacy.load(language)
209
+
210
+ @tp.no_type_check
211
+ def __call__(self, texts: tp.List[tp.Optional[str]],
212
+ return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
213
+ """Take a list of strings and convert them to a tensor of indices.
214
+
215
+ Args:
216
+ texts (list[str]): List of strings.
217
+ return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
218
+ Returns:
219
+ tuple[torch.Tensor, torch.Tensor]:
220
+ - Indices of words in the LUT.
221
+ - And a mask indicating where the padding tokens are
222
+ """
223
+ output, lengths = [], []
224
+ texts = deepcopy(texts)
225
+ for i, text in enumerate(texts):
226
+ # if current sample doesn't have a certain attribute, replace with pad token
227
+ if text is None:
228
+ output.append(torch.Tensor([self.pad_idx]))
229
+ lengths.append(0)
230
+ continue
231
+
232
+ # convert numbers to words
233
+ text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
234
+ # normalize text
235
+ text = self.nlp(text) # type: ignore
236
+ # remove stopwords
237
+ if self.stopwords:
238
+ text = [w for w in text if not w.is_stop] # type: ignore
239
+ # remove punctuation
240
+ text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
241
+ # lemmatize if needed
242
+ text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
243
+
244
+ texts[i] = " ".join(text)
245
+ lengths.append(len(text))
246
+ # convert to tensor
247
+ tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
248
+ output.append(tokens)
249
+
250
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
251
+ padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
252
+ if return_text:
253
+ return padded_output, mask, texts # type: ignore
254
+ return padded_output, mask
255
+
256
+
257
+ class NoopTokenizer(Tokenizer):
258
+ """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
259
+ The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
260
+ strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
261
+ split it to ["Jeff", "Buckley"] and return an index per word.
262
+
263
+ For example:
264
+ ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
265
+ ["Metal", "Rock", "Classical"] => [0, 223, 51]
266
+ """
267
+ def __init__(self, n_bins: int, pad_idx: int = 0):
268
+ self.n_bins = n_bins
269
+ self.pad_idx = pad_idx
270
+
271
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
272
+ output, lengths = [], []
273
+ for text in texts:
274
+ # if current sample doesn't have a certain attribute, replace with pad token
275
+ if text is None:
276
+ output.append(self.pad_idx)
277
+ lengths.append(0)
278
+ else:
279
+ output.append(hash_trick(text, self.n_bins))
280
+ lengths.append(1)
281
+
282
+ tokens = torch.LongTensor(output).unsqueeze(1)
283
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
284
+ return tokens, mask
285
+
286
+
287
+ class BaseConditioner(nn.Module):
288
+ """Base model for all conditioner modules.
289
+ We allow the output dim to be different than the hidden dim for two reasons:
290
+ 1) keep our LUTs small when the vocab is large;
291
+ 2) make all condition dims consistent.
292
+
293
+ Args:
294
+ dim (int): Hidden dim of the model.
295
+ output_dim (int): Output dim of the conditioner.
296
+ """
297
+ def __init__(self, dim: int, output_dim: int):
298
+ super().__init__()
299
+ self.dim = dim
300
+ self.output_dim = output_dim
301
+ self.output_proj = nn.Linear(dim, output_dim)
302
+
303
+ def tokenize(self, *args, **kwargs) -> tp.Any:
304
+ """Should be any part of the processing that will lead to a synchronization
305
+ point, e.g. BPE tokenization with transfer to the GPU.
306
+
307
+ The returned value will be saved and return later when calling forward().
308
+ """
309
+ raise NotImplementedError()
310
+
311
+ def forward(self, inputs: tp.Any) -> ConditionType:
312
+ """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
313
+ Outputs a ConditionType, after the input data was embedded as a dense vector.
314
+
315
+ Returns:
316
+ ConditionType:
317
+ - A tensor of size [B, T, D] where B is the batch size, T is the length of the
318
+ output embedding and D is the dimension of the embedding.
319
+ - And a mask indicating where the padding tokens.
320
+ """
321
+ raise NotImplementedError()
322
+
323
+
324
+ class TextConditioner(BaseConditioner):
325
+ ...
326
+ class VideoConditioner():
327
+ ...
328
+
329
+ class LUTConditioner(TextConditioner):
330
+ """Lookup table TextConditioner.
331
+
332
+ Args:
333
+ n_bins (int): Number of bins.
334
+ dim (int): Hidden dim of the model (text-encoder/LUT).
335
+ output_dim (int): Output dim of the conditioner.
336
+ tokenizer (str): Name of the tokenizer.
337
+ pad_idx (int, optional): Index for padding token. Defaults to 0.
338
+ """
339
+ def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
340
+ super().__init__(dim, output_dim)
341
+ self.embed = nn.Embedding(n_bins, dim)
342
+ self.tokenizer: Tokenizer
343
+ if tokenizer == 'whitespace':
344
+ self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
345
+ elif tokenizer == 'noop':
346
+ self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
347
+ else:
348
+ raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
349
+
350
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
351
+ device = self.embed.weight.device
352
+ tokens, mask = self.tokenizer(x)
353
+ tokens, mask = tokens.to(device), mask.to(device)
354
+ return tokens, mask
355
+
356
+ def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
357
+ tokens, mask = inputs
358
+ embeds = self.embed(tokens)
359
+ embeds = self.output_proj(embeds)
360
+ embeds = (embeds * mask.unsqueeze(-1))
361
+ return embeds, mask
362
+
363
+
364
+ class T5Conditioner(TextConditioner):
365
+ MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
366
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
367
+ "google/flan-t5-xl", "google/flan-t5-xxl"]
368
+ MODELS_DIMS = {
369
+ "t5-small": 512,
370
+ "t5-base": 768,
371
+ "t5-large": 1024,
372
+ "t5-3b": 1024,
373
+ "t5-11b": 1024,
374
+ "google/flan-t5-small": 512,
375
+ "google/flan-t5-base": 768,
376
+ "google/flan-t5-large": 1024,
377
+ "google/flan-t5-3b": 1024,
378
+ "google/flan-t5-11b": 1024,
379
+ }
380
+ def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
381
+ autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
382
+ normalize_text: bool = False):
383
+ assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
384
+ super().__init__(self.MODELS_DIMS[name], output_dim)
385
+
386
+
387
+ def forward():
388
+ embeds = torch.zeros([1,1,1])
389
+ mask = torch.zeros([1,1,1])
390
+ return embeds, mask
391
+
392
+
393
+ class WaveformConditioner(BaseConditioner):
394
+ """Base class for all conditioners that take a waveform as input.
395
+ Classes that inherit must implement `_get_wav_embedding` that outputs
396
+ a continuous tensor, and `_downsampling_factor` that returns the down-sampling
397
+ factor of the embedding model.
398
+
399
+ Args:
400
+ dim (int): The internal representation dimension.
401
+ output_dim (int): Output dimension.
402
+ device (tp.Union[torch.device, str]): Device.
403
+ """
404
+ def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
405
+ super().__init__(dim, output_dim)
406
+ self.device = device
407
+ # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
408
+ self._use_masking = True
409
+
410
+ def tokenize(self, x: WavCondition) -> WavCondition:
411
+ wav, length, sample_rate, path, seek_time = x
412
+ assert length is not None
413
+ return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
414
+
415
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
416
+ """Gets as input a WavCondition and returns a dense embedding."""
417
+ raise NotImplementedError()
418
+
419
+ def _downsampling_factor(self):
420
+ """Returns the downsampling factor of the embedding model."""
421
+ raise NotImplementedError()
422
+
423
+ def forward(self, x: WavCondition) -> ConditionType:
424
+ """Extract condition embedding and mask from a waveform and its metadata.
425
+ Args:
426
+ x (WavCondition): Waveform condition containing raw waveform and metadata.
427
+ Returns:
428
+ ConditionType: a dense vector representing the conditioning along with its mask
429
+ """
430
+ wav, lengths, *_ = x
431
+ with torch.no_grad():
432
+ embeds = self._get_wav_embedding(x)
433
+ embeds = embeds.to(self.output_proj.weight)
434
+ embeds = self.output_proj(embeds)
435
+
436
+ if lengths is not None and self._use_masking:
437
+ lengths = lengths / self._downsampling_factor()
438
+ mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
439
+ else:
440
+ mask = torch.ones_like(embeds[..., 0])
441
+ embeds = (embeds * mask.unsqueeze(-1))
442
+ return embeds, mask
443
+
444
+
445
+ class ChromaStemConditioner(WaveformConditioner):
446
+ """Chroma conditioner based on stems.
447
+ The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
448
+ the drums and bass often dominate the chroma leading to the chroma features
449
+ not containing information about the melody.
450
+
451
+ Args:
452
+ output_dim (int): Output dimension for the conditioner.
453
+ sample_rate (int): Sample rate for the chroma extractor.
454
+ n_chroma (int): Number of chroma bins for the chroma extractor.
455
+ radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
456
+ duration (int): duration used during training. This is later used for correct padding
457
+ in case we are using chroma as prefix.
458
+ match_len_on_eval (bool, optional): if True then all chromas are padded to the training
459
+ duration. Defaults to False.
460
+ eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
461
+ conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
462
+ Defaults to None.
463
+ n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
464
+ device (tp.Union[torch.device, str], optional): Device for the conditioner.
465
+ **kwargs: Additional parameters for the chroma extractor.
466
+ """
467
+ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
468
+ duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
469
+ n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
470
+ device: tp.Union[torch.device, str] = 'cpu', **kwargs):
471
+ from demucs import pretrained
472
+ super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
473
+ self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
474
+ self.sample_rate = sample_rate
475
+ self.match_len_on_eval = match_len_on_eval
476
+ if match_len_on_eval:
477
+ self._use_masking = False
478
+ self.duration = duration
479
+ self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
480
+ stem_sources: list = self.demucs.sources # type: ignore
481
+ self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
482
+ self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
483
+ radix2_exp=radix2_exp, **kwargs).to(device)
484
+ self.chroma_len = self._get_chroma_len()
485
+ self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
486
+ self.cache = None
487
+ if cache_path is not None:
488
+ self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
489
+ compute_embed_fn=self._get_full_chroma_for_cache,
490
+ extract_embed_fn=self._extract_chroma_chunk)
491
+
492
+ def _downsampling_factor(self) -> int:
493
+ return self.chroma.winhop
494
+
495
+ def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
496
+ """Load pre-defined waveforms from a json.
497
+ These waveforms will be used for chroma extraction during evaluation.
498
+ This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
499
+ """
500
+ if path is None:
501
+ return None
502
+
503
+ logger.info(f"Loading evaluation wavs from {path}")
504
+ from audiocraft.data.audio_dataset import AudioDataset
505
+ # print(f'self.video_fps:{self.video_fps}')
506
+ # print(f'self.video_overlap:{self.video_overlap}')
507
+ # exit()
508
+ dataset: AudioDataset = AudioDataset.from_meta(
509
+ path, segment_duration=self.duration, min_audio_duration=self.duration,
510
+ sample_rate=self.sample_rate, video_fps=self.video_fps, video_overlap=self.video_overlap, channels=1)
511
+
512
+ if len(dataset) > 0:
513
+ eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
514
+ logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
515
+ return eval_wavs
516
+ else:
517
+ raise ValueError("Could not find evaluation wavs, check lengths of wavs")
518
+
519
+ def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
520
+ self.eval_wavs = eval_wavs
521
+
522
+ def has_eval_wavs(self) -> bool:
523
+ return self.eval_wavs is not None
524
+
525
+ def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
526
+ """Sample wavs from a predefined list."""
527
+ assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
528
+ total_eval_wavs = len(self.eval_wavs)
529
+ out = self.eval_wavs
530
+ if num_samples > total_eval_wavs:
531
+ out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
532
+ return out[torch.randperm(len(out))][:num_samples]
533
+
534
+ def _get_chroma_len(self) -> int:
535
+ """Get length of chroma during training."""
536
+ dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
537
+ dummy_chr = self.chroma(dummy_wav)
538
+ return dummy_chr.shape[1]
539
+
540
+ @torch.no_grad()
541
+ def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
542
+ """Get parts of the wav that holds the melody, extracting the main stems from the wav."""
543
+ from demucs.apply import apply_model
544
+ from demucs.audio import convert_audio
545
+ with self.autocast:
546
+ wav = convert_audio(
547
+ wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
548
+ stems = apply_model(self.demucs, wav, device=self.device)
549
+ stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
550
+ mix_wav = stems.sum(1) # merge extracted stems to single waveform
551
+ mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
552
+ return mix_wav
553
+
554
+ @torch.no_grad()
555
+ def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
556
+ """Extract chroma features from the waveform."""
557
+ with self.autocast:
558
+ return self.chroma(wav)
559
+
560
+ @torch.no_grad()
561
+ def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
562
+ """Compute wav embedding, applying stem and chroma extraction."""
563
+ # avoid 0-size tensors when we are working with null conds
564
+ if wav.shape[-1] == 1:
565
+ return self._extract_chroma(wav)
566
+ stems = self._get_stemmed_wav(wav, sample_rate)
567
+ chroma = self._extract_chroma(stems)
568
+ return chroma
569
+
570
+ @torch.no_grad()
571
+ def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
572
+ """Extract chroma from the whole audio waveform at the given path."""
573
+ wav, sr = audio_read(path)
574
+ wav = wav[None].to(self.device)
575
+ wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
576
+ chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
577
+ return chroma
578
+
579
+ def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
580
+ """Extract a chunk of chroma from the full chroma derived from the full waveform."""
581
+ wav_length = x.wav.shape[-1]
582
+ seek_time = x.seek_time[idx]
583
+ assert seek_time is not None, (
584
+ "WavCondition seek_time is required "
585
+ "when extracting chroma chunks from pre-computed chroma.")
586
+ full_chroma = full_chroma.float()
587
+ frame_rate = self.sample_rate / self._downsampling_factor()
588
+ target_length = int(frame_rate * wav_length / self.sample_rate)
589
+ index = int(frame_rate * seek_time)
590
+ out = full_chroma[index: index + target_length]
591
+ out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
592
+ return out.to(self.device)
593
+
594
+ @torch.no_grad()
595
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
596
+ """Get the wav embedding from the WavCondition.
597
+ The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
598
+ or will rely on the embedding cache to load the pre-computed embedding if relevant.
599
+ """
600
+ sampled_wav: tp.Optional[torch.Tensor] = None
601
+ if not self.training and self.eval_wavs is not None:
602
+ warn_once(logger, "Using precomputed evaluation wavs!")
603
+ sampled_wav = self._sample_eval_wavs(len(x.wav))
604
+
605
+ no_undefined_paths = all(p is not None for p in x.path)
606
+ no_nullified_cond = x.wav.shape[-1] > 1
607
+ if sampled_wav is not None:
608
+ chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
609
+ elif self.cache is not None and no_undefined_paths and no_nullified_cond:
610
+ paths = [Path(p) for p in x.path if p is not None]
611
+ chroma = self.cache.get_embed_from_cache(paths, x)
612
+ else:
613
+ assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
614
+ chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
615
+
616
+ if self.match_len_on_eval:
617
+ B, T, C = chroma.shape
618
+ if T > self.chroma_len:
619
+ chroma = chroma[:, :self.chroma_len]
620
+ logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
621
+ elif T < self.chroma_len:
622
+ n_repeat = int(math.ceil(self.chroma_len / T))
623
+ chroma = chroma.repeat(1, n_repeat, 1)
624
+ chroma = chroma[:, :self.chroma_len]
625
+ logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
626
+
627
+ return chroma
628
+
629
+ def tokenize(self, x: WavCondition) -> WavCondition:
630
+ """Apply WavConditioner tokenization and populate cache if needed."""
631
+ x = super().tokenize(x)
632
+ no_undefined_paths = all(p is not None for p in x.path)
633
+ if self.cache is not None and no_undefined_paths:
634
+ paths = [Path(p) for p in x.path if p is not None]
635
+ self.cache.populate_embed_cache(paths, x)
636
+ return x
637
+
638
+
639
+ class JointEmbeddingConditioner(BaseConditioner):
640
+ """Joint embedding conditioning supporting both audio or text conditioning.
641
+
642
+ Args:
643
+ dim (int): Dimension.
644
+ output_dim (int): Output dimension.
645
+ device (str): Device.
646
+ attribute (str): Attribute used by the conditioner.
647
+ autocast_dtype (str): Autocast for the conditioner.
648
+ quantize (bool): Whether to quantize the CLAP embedding.
649
+ n_q (int): Number of residual quantizers (used if quantize is true).
650
+ bins (int): Quantizers' codebooks size (used if quantize is true).
651
+ kwargs: Additional parameters for residual vector quantizer.
652
+ """
653
+ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
654
+ autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
655
+ n_q: int = 12, bins: int = 1024, **kwargs):
656
+ super().__init__(dim=dim, output_dim=output_dim)
657
+ self.device = device
658
+ self.attribute = attribute
659
+ if autocast_dtype is None or device == 'cpu':
660
+ self.autocast = TorchAutocast(enabled=False)
661
+ logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
662
+ else:
663
+ dtype = getattr(torch, autocast_dtype)
664
+ assert isinstance(dtype, torch.dtype)
665
+ logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
666
+ self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
667
+ # residual vector quantizer to discretize the conditioned embedding
668
+ self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
669
+ if quantize:
670
+ self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
671
+
672
+ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
673
+ """Get joint embedding in latent space from the inputs.
674
+
675
+ Returns:
676
+ tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
677
+ and corresponding empty indexes.
678
+ """
679
+ raise NotImplementedError()
680
+
681
+ def forward(self, x: JointEmbedCondition) -> ConditionType:
682
+ with self.autocast:
683
+ embed, empty_idx = self._get_embed(x)
684
+ if self.quantizer is not None:
685
+ embed = embed.view(-1, self.dim, 1)
686
+ q_res = self.quantizer(embed, frame_rate=1)
687
+ out_embed = q_res.x.view(-1, self.dim)
688
+ else:
689
+ out_embed = embed
690
+ out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
691
+ mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
692
+ mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
693
+ out_embed = (out_embed * mask.unsqueeze(-1))
694
+ return out_embed, mask
695
+
696
+ def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
697
+ return x
698
+
699
+
700
+ class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
701
+ """Joint Embedding conditioner based on pre-trained CLAP model.
702
+
703
+ This CLAP-based conditioner supports a caching mechanism
704
+ over the computed embeddings for faster training.
705
+
706
+ Args:
707
+ dim (int): Dimension.
708
+ output_dim (int): Output dimension.
709
+ device (str): Device.
710
+ attribute (str): Attribute used by the conditioner.
711
+ quantize (bool): Whether to quantize the CLAP embedding.
712
+ n_q (int): Number of residual quantizers (used if quantize is true).
713
+ bins (int): Quantizers' codebooks size (used if quantize is true).
714
+ checkpoint (str): Path to CLAP checkpoint.
715
+ model_arch (str): CLAP model architecture.
716
+ enable_fusion (bool): Enable fusion for CLAP model.
717
+ sample_rate (int): Sample rate used by CLAP model.
718
+ max_audio_length (float): Maximum audio length for CLAP model.
719
+ audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
720
+ normalize (bool): Whether to normalize the CLAP embedding.
721
+ text_p (float): Probability of using text representation instead of audio at train time.
722
+ batch_size (Optional[int]): Batch size for CLAP embedding computation.
723
+ autocast_dtype (str): Autocast for the conditioner.
724
+ cache_path (Optional[str]): Path for pre-computed embeddings caching.
725
+ kwargs: Additional parameters for residual vector quantizer.
726
+ """
727
+ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
728
+ quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
729
+ enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
730
+ normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
731
+ autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
732
+ try:
733
+ import laion_clap # type: ignore
734
+ except ImportError:
735
+ raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
736
+ # warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
737
+ # "Please retrain all models.")
738
+ checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
739
+ clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
740
+ clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
741
+ load_clap_state_dict(clap_model, checkpoint)
742
+ clap_model.eval()
743
+ clap_model.to(device)
744
+ super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
745
+ autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
746
+ **kwargs)
747
+ self.checkpoint = checkpoint
748
+ self.enable_fusion = enable_fusion
749
+ self.model_arch = model_arch
750
+ self.clap: laion_clap.CLAP_Module
751
+ self.clap_tokenize: RobertaTokenizer
752
+ self.clap_sample_rate = sample_rate
753
+ self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
754
+ self.clap_stride = int(self.clap_sample_rate * audio_stride)
755
+ self.batch_size = batch_size or 1
756
+ self.normalize = normalize
757
+ self.text_p = text_p
758
+ self.__dict__['clap_tokenize'] = clap_tokenize
759
+ self.__dict__['clap'] = clap_model
760
+ self.wav_cache, self.text_cache = None, None
761
+ if cache_path is not None:
762
+ self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
763
+ compute_embed_fn=self._get_wav_embedding_for_cache,
764
+ extract_embed_fn=self._extract_wav_embedding_chunk)
765
+ self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
766
+ compute_embed_fn=self._get_text_embedding_for_cache)
767
+
768
+ def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
769
+ # we use the default params from CLAP module here as well
770
+ return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
771
+
772
+ def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
773
+ """Compute text embedding from CLAP model on a given a batch of text.
774
+
775
+ Args:
776
+ text (list[str]): List of text for the batch, with B items.
777
+ Returns:
778
+ torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
779
+ """
780
+ with torch.no_grad():
781
+ embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
782
+ return embed.view(embed.size(0), 1, embed.size(-1))
783
+
784
+ def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
785
+ x: JointEmbedCondition, idx: int) -> torch.Tensor:
786
+ """Get text embedding function for the cache."""
787
+ text = x.text[idx]
788
+ text = text if text is not None else ""
789
+ return self._compute_text_embedding([text])[0]
790
+
791
+ def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
792
+ """Preprocess wav to expected format by CLAP model.
793
+
794
+ Args:
795
+ wav (torch.Tensor): Audio wav, of shape [B, C, T].
796
+ length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
797
+ sample_rates (list[int]): Sample rates for each sample in the batch
798
+ Returns:
799
+ torch.Tensor: Audio wav of shape [B, T].
800
+ """
801
+ assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
802
+ if sample_rates is not None:
803
+ _wav = []
804
+ for i, audio in enumerate(wav):
805
+ sr = sample_rates[i]
806
+ audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
807
+ _wav.append(audio)
808
+ wav = torch.stack(_wav, dim=0)
809
+ wav = wav.mean(dim=1)
810
+ return wav
811
+
812
+ def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
813
+ sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
814
+ """Compute audio wave embedding from CLAP model.
815
+
816
+ Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
817
+ we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
818
+ average the resulting embeddings.
819
+
820
+ Args:
821
+ wav (torch.Tensor): Audio wav, of shape [B, C, T].
822
+ length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
823
+ sample_rates (list[int]): Sample rates for each sample in the batch.
824
+ reduce_mean (bool): Whether to get the average tensor.
825
+ Returns:
826
+ torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
827
+ """
828
+ with torch.no_grad():
829
+ wav = self._preprocess_wav(wav, length, sample_rates)
830
+ B, T = wav.shape
831
+ if T >= self.clap_max_frames:
832
+ wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
833
+ else:
834
+ wav = wav.view(-1, 1, T) # [B, F, T] with F=1
835
+ wav = einops.rearrange(wav, 'b f t -> (b f) t')
836
+ embed_list = []
837
+ for i in range(0, wav.size(0), self.batch_size):
838
+ _wav = wav[i:i+self.batch_size, ...]
839
+ _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
840
+ embed_list.append(_embed)
841
+ embed = torch.cat(embed_list, dim=0)
842
+ embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
843
+ if reduce_mean:
844
+ embed = embed.mean(dim=1, keepdim=True)
845
+ return embed # [B, F, D] with F=1 if reduce_mean is True
846
+
847
+ def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
848
+ x: JointEmbedCondition, idx: int) -> torch.Tensor:
849
+ """Compute audio wave embedding for the cache.
850
+ The embedding is computed on a given audio read from file.
851
+
852
+ Args:
853
+ path (str or Path): Path to the full audio file.
854
+ Returns:
855
+ torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
856
+ """
857
+ wav, sr = audio_read(path) # [C, T]
858
+ wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
859
+ wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
860
+ embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
861
+ return embed.squeeze(0) # [F, D]
862
+
863
+ def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
864
+ """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
865
+
866
+ Args:
867
+ full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
868
+ x (JointEmbedCondition): Joint embedding condition for the full batch.
869
+ idx (int): Index considered for the given embedding to extract.
870
+ Returns:
871
+ torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
872
+ """
873
+ sample_rate = x.sample_rate[idx]
874
+ seek_time = x.seek_time[idx]
875
+ seek_time = 0. if seek_time is None else seek_time
876
+ clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
877
+ end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
878
+ start_offset = int(seek_time * sample_rate // clap_stride)
879
+ end_offset = int(end_seek_time * sample_rate // clap_stride)
880
+ wav_embed = full_embed[start_offset:end_offset, ...]
881
+ wav_embed = wav_embed.mean(dim=0, keepdim=True)
882
+ return wav_embed.to(self.device) # [F, D]
883
+
884
+ def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
885
+ """Get CLAP embedding from a batch of text descriptions."""
886
+ no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
887
+ if self.text_cache is not None and no_nullified_cond:
888
+ assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
889
+ paths = [Path(p) for p in x.path if p is not None]
890
+ embed = self.text_cache.get_embed_from_cache(paths, x)
891
+ else:
892
+ text = [xi if xi is not None else "" for xi in x.text]
893
+ embed = self._compute_text_embedding(text)
894
+ if self.normalize:
895
+ embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
896
+ return embed
897
+
898
+ def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
899
+ """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
900
+ no_undefined_paths = all(p is not None for p in x.path)
901
+ no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
902
+ if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
903
+ paths = [Path(p) for p in x.path if p is not None]
904
+ embed = self.wav_cache.get_embed_from_cache(paths, x)
905
+ else:
906
+ embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
907
+ if self.normalize:
908
+ embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
909
+ return embed
910
+
911
+ def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
912
+ # Trying to limit as much as possible sync points when the cache is warm.
913
+ no_undefined_paths = all(p is not None for p in x.path)
914
+ if self.wav_cache is not None and no_undefined_paths:
915
+ assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
916
+ paths = [Path(p) for p in x.path if p is not None]
917
+ self.wav_cache.populate_embed_cache(paths, x)
918
+ if self.text_cache is not None and no_undefined_paths:
919
+ assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
920
+ paths = [Path(p) for p in x.path if p is not None]
921
+ self.text_cache.populate_embed_cache(paths, x)
922
+ return x
923
+
924
+ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
925
+ """Extract shared latent representation from either the wav or the text using CLAP."""
926
+ # decide whether to use text embedding at train time or not
927
+ use_text_embed = random.random() < self.text_p
928
+ if self.training and not use_text_embed:
929
+ embed = self._get_wav_embedding(x)
930
+ empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
931
+ else:
932
+ embed = self._get_text_embedding(x)
933
+ empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
934
+ return embed, empty_idx
935
+
936
+
937
+ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
938
+ """Utility function for nullifying an attribute inside an ConditioningAttributes object.
939
+ If the condition is of type "wav", then nullify it using `nullify_condition` function.
940
+ If the condition is of any other type, set its value to None.
941
+ Works in-place.
942
+ """
943
+ if condition_type not in ['text', 'wav', 'joint_embed']:
944
+ raise ValueError(
945
+ "dropout_condition got an unexpected condition type!"
946
+ f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
947
+ )
948
+
949
+ if condition not in getattr(sample, condition_type):
950
+ raise ValueError(
951
+ "dropout_condition received an unexpected condition!"
952
+ f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
953
+ f" but got '{condition}' of type '{condition_type}'!"
954
+ )
955
+
956
+ if condition_type == 'wav':
957
+ wav_cond = sample.wav[condition]
958
+ sample.wav[condition] = nullify_wav(wav_cond)
959
+ elif condition_type == 'joint_embed':
960
+ embed = sample.joint_embed[condition]
961
+ sample.joint_embed[condition] = nullify_joint_embed(embed)
962
+ else:
963
+ sample.text[condition] = None
964
+
965
+ return sample
966
+
967
+
968
+ class DropoutModule(nn.Module):
969
+ """Base module for all dropout modules."""
970
+ def __init__(self, seed: int = 1234):
971
+ super().__init__()
972
+ self.rng = torch.Generator()
973
+ self.rng.manual_seed(seed)
974
+
975
+
976
+ class AttributeDropout(DropoutModule):
977
+ """Dropout with a given probability per attribute.
978
+ This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
979
+ to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
980
+ This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
981
+ must also be dropped.
982
+
983
+ Args:
984
+ p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
985
+ ...
986
+ "genre": 0.1,
987
+ "artist": 0.5,
988
+ "wav": 0.25,
989
+ ...
990
+ active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
991
+ seed (int, optional): Random seed.
992
+ """
993
+ def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
994
+ super().__init__(seed=seed)
995
+ self.active_on_eval = active_on_eval
996
+ # construct dict that return the values from p otherwise 0
997
+ self.p = {}
998
+ for condition_type, probs in p.items():
999
+ self.p[condition_type] = defaultdict(lambda: 0, probs)
1000
+
1001
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
1002
+ """
1003
+ Args:
1004
+ samples (list[ConditioningAttributes]): List of conditions.
1005
+ Returns:
1006
+ list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
1007
+ """
1008
+ if not self.training and not self.active_on_eval:
1009
+ return samples
1010
+
1011
+ samples = deepcopy(samples)
1012
+ for condition_type, ps in self.p.items(): # for condition types [text, wav]
1013
+ for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
1014
+ if torch.rand(1, generator=self.rng).item() < p:
1015
+ for sample in samples:
1016
+ dropout_condition(sample, condition_type, condition)
1017
+ return samples
1018
+
1019
+ def __repr__(self):
1020
+ return f"AttributeDropout({dict(self.p)})"
1021
+
1022
+ import torch
1023
+ from torch.nn import Module
1024
+ from copy import deepcopy
1025
+
1026
+ class ClassifierFreeGuidanceDropout(Module):
1027
+ """Classifier Free Guidance dropout for tensor inputs.
1028
+ All elements in the tensor are dropped with the same probability.
1029
+
1030
+ Args:
1031
+ p (float): Probability to apply dropout during training.
1032
+ seed (int): Random seed.
1033
+ """
1034
+ def __init__(self, p: float, seed: int = 1234):
1035
+ super().__init__()
1036
+ self.p = p
1037
+ torch.manual_seed(seed) # Set the seed for reproducibility
1038
+
1039
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
1040
+ """
1041
+ Args:
1042
+ tensor (torch.Tensor): Input tensor.
1043
+ Returns:
1044
+ torch.Tensor: Tensor after applying dropout.
1045
+ """
1046
+ if not self.training or self.p <= 0:
1047
+ return tensor
1048
+
1049
+ # Create a dropout mask with the same size as the input tensor
1050
+ mask = torch.rand(tensor.shape) > self.p
1051
+ mask = mask.to(tensor.device) # Move mask to the same device as the input tensor
1052
+
1053
+ # Apply dropout mask
1054
+ dropped_tensor = tensor * mask.float()
1055
+
1056
+ return dropped_tensor
1057
+
1058
+ def __repr__(self):
1059
+ return f"ClassifierFreeGuidanceDropout(p={self.p})"
1060
+
1061
+
1062
+ class ConditioningProvider(nn.Module):
1063
+ """Prepare and provide conditions given all the supported conditioners.
1064
+
1065
+ Args:
1066
+ conditioners (dict): Dictionary of conditioners.
1067
+ device (torch.device or str, optional): Device for conditioners and output condition types.
1068
+ """
1069
+ def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
1070
+ super().__init__()
1071
+ self.device = device
1072
+ self.conditioners = nn.ModuleDict(conditioners)
1073
+
1074
+ @property
1075
+ def joint_embed_conditions(self):
1076
+ return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
1077
+
1078
+ @property
1079
+ def has_joint_embed_conditions(self):
1080
+ return len(self.joint_embed_conditions) > 0
1081
+
1082
+ @property
1083
+ def text_conditions(self):
1084
+ return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
1085
+
1086
+ @property
1087
+ def wav_conditions(self):
1088
+ return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
1089
+
1090
+ @property
1091
+ def has_wav_condition(self):
1092
+ return len(self.wav_conditions) > 0
1093
+
1094
+ def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
1095
+ """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
1096
+ This should be called before starting any real GPU work to avoid synchronization points.
1097
+ This will return a dict matching conditioner names to their arbitrary tokenized representations.
1098
+
1099
+ Args:
1100
+ inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
1101
+ text and wav conditions.
1102
+ """
1103
+ assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
1104
+ "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
1105
+ f" but types were {set([type(x) for x in inputs])}"
1106
+ )
1107
+
1108
+ output = {}
1109
+ text = self._collate_text(inputs)
1110
+ wavs = self._collate_wavs(inputs)
1111
+ joint_embeds = self._collate_joint_embeds(inputs)
1112
+
1113
+ assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
1114
+ f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
1115
+ f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
1116
+ )
1117
+
1118
+ for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
1119
+ output[attribute] = self.conditioners[attribute].tokenize(batch)
1120
+ return output
1121
+
1122
+ def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
1123
+ """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
1124
+ The output is for example:
1125
+ {
1126
+ "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
1127
+ "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
1128
+ ...
1129
+ }
1130
+
1131
+ Args:
1132
+ tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
1133
+ """
1134
+ output = {}
1135
+ for attribute, inputs in tokenized.items():
1136
+ condition, mask = self.conditioners[attribute](inputs)
1137
+ output[attribute] = (condition, mask)
1138
+ return output
1139
+
1140
+ def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
1141
+ """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
1142
+ are the attributes and the values are the aggregated input per attribute.
1143
+ For example:
1144
+ Input:
1145
+ [
1146
+ ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
1147
+ ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
1148
+ ]
1149
+ Output:
1150
+ {
1151
+ "genre": ["Rock", "Hip-hop"],
1152
+ "description": ["A rock song with a guitar solo", "A hip-hop verse"]
1153
+ }
1154
+
1155
+ Args:
1156
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
1157
+ Returns:
1158
+ dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
1159
+ """
1160
+ out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
1161
+ texts = [x.text for x in samples]
1162
+ for text in texts:
1163
+ for condition in self.text_conditions:
1164
+ out[condition].append(text[condition])
1165
+ return out
1166
+
1167
+ def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
1168
+ """Generate a dict where the keys are attributes by which we fetch similar wavs,
1169
+ and the values are Tensors of wavs according to said attributes.
1170
+
1171
+ *Note*: by the time the samples reach this function, each sample should have some waveform
1172
+ inside the "wav" attribute. It should be either:
1173
+ 1. A real waveform
1174
+ 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
1175
+ 3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
1176
+
1177
+ Args:
1178
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
1179
+ Returns:
1180
+ dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
1181
+ """
1182
+ wavs = defaultdict(list)
1183
+ lengths = defaultdict(list)
1184
+ sample_rates = defaultdict(list)
1185
+ paths = defaultdict(list)
1186
+ seek_times = defaultdict(list)
1187
+ out: tp.Dict[str, WavCondition] = {}
1188
+
1189
+ for sample in samples:
1190
+ for attribute in self.wav_conditions:
1191
+ wav, length, sample_rate, path, seek_time = sample.wav[attribute]
1192
+ assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
1193
+ assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
1194
+ # mono-channel conditioning
1195
+ wav = wav.mean(1, keepdim=True) # [1, 1, T]
1196
+ wavs[attribute].append(wav.flatten()) # [T]
1197
+ lengths[attribute].append(length)
1198
+ sample_rates[attribute].extend(sample_rate)
1199
+ paths[attribute].extend(path)
1200
+ seek_times[attribute].extend(seek_time)
1201
+
1202
+ # stack all wavs to a single tensor
1203
+ for attribute in self.wav_conditions:
1204
+ stacked_wav, _ = collate(wavs[attribute], dim=0)
1205
+ out[attribute] = WavCondition(
1206
+ stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
1207
+ paths[attribute], seek_times[attribute])
1208
+
1209
+ return out
1210
+
1211
+ def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
1212
+ """Generate a dict where the keys are attributes by which we compute joint embeddings,
1213
+ and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
1214
+
1215
+ Args:
1216
+ samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
1217
+ Returns:
1218
+ A dictionary mapping an attribute name to joint embeddings.
1219
+ """
1220
+ texts = defaultdict(list)
1221
+ wavs = defaultdict(list)
1222
+ lengths = defaultdict(list)
1223
+ sample_rates = defaultdict(list)
1224
+ paths = defaultdict(list)
1225
+ seek_times = defaultdict(list)
1226
+ channels: int = 0
1227
+
1228
+ out = {}
1229
+ for sample in samples:
1230
+ for attribute in self.joint_embed_conditions:
1231
+ wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
1232
+ assert wav.dim() == 3
1233
+ if channels == 0:
1234
+ channels = wav.size(1)
1235
+ else:
1236
+ assert channels == wav.size(1), "not all audio has same number of channels in batch"
1237
+ assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
1238
+ wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
1239
+ wavs[attribute].append(wav)
1240
+ texts[attribute].extend(text)
1241
+ lengths[attribute].append(length)
1242
+ sample_rates[attribute].extend(sample_rate)
1243
+ paths[attribute].extend(path)
1244
+ seek_times[attribute].extend(seek_time)
1245
+
1246
+ for attribute in self.joint_embed_conditions:
1247
+ stacked_texts = texts[attribute]
1248
+ stacked_paths = paths[attribute]
1249
+ stacked_seek_times = seek_times[attribute]
1250
+ stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
1251
+ stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
1252
+ stacked_sample_rates = sample_rates[attribute]
1253
+ stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
1254
+ assert stacked_lengths.size(0) == stacked_wavs.size(0)
1255
+ assert len(stacked_sample_rates) == stacked_wavs.size(0)
1256
+ assert len(stacked_texts) == stacked_wavs.size(0)
1257
+ out[attribute] = JointEmbedCondition(
1258
+ text=stacked_texts, wav=stacked_wavs,
1259
+ length=stacked_lengths, sample_rate=stacked_sample_rates,
1260
+ path=stacked_paths, seek_time=stacked_seek_times)
1261
+
1262
+ return out
1263
+
1264
+
1265
+ class ConditionFuser(StreamingModule):
1266
+ """Condition fuser handles the logic to combine the different conditions
1267
+ to the actual model input.
1268
+
1269
+ Args:
1270
+ fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
1271
+ each condition. For example:
1272
+ {
1273
+ "prepend": ["description"],
1274
+ "sum": ["genre", "bpm"],
1275
+ "cross": ["description"],
1276
+ }
1277
+ cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
1278
+ cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
1279
+ """
1280
+ FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
1281
+
1282
+ def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
1283
+ cross_attention_pos_emb_scale: float = 1.0):
1284
+ super().__init__()
1285
+ assert all(
1286
+ [k in self.FUSING_METHODS for k in fuse2cond.keys()]
1287
+ ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
1288
+ self.cross_attention_pos_emb = cross_attention_pos_emb
1289
+ self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
1290
+ self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
1291
+ self.cond2fuse: tp.Dict[str, str] = {}
1292
+ for fuse_method, conditions in fuse2cond.items():
1293
+ for condition in conditions:
1294
+ self.cond2fuse[condition] = fuse_method
1295
+
1296
+ def forward(
1297
+ self,
1298
+ input: torch.Tensor,
1299
+ conditions: tp.Dict[str, ConditionType]
1300
+ ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1301
+ """Fuse the conditions to the provided model input.
1302
+
1303
+ Args:
1304
+ input (torch.Tensor): Transformer input.
1305
+ conditions (dict[str, ConditionType]): Dict of conditions.
1306
+ Returns:
1307
+ tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
1308
+ after the conditions have been fused. The second output tensor is the tensor
1309
+ used for cross-attention or None if no cross attention inputs exist.
1310
+ """
1311
+ B, T, _ = input.shape
1312
+
1313
+ if 'offsets' in self._streaming_state:
1314
+ first_step = False
1315
+ offsets = self._streaming_state['offsets']
1316
+ else:
1317
+ first_step = True
1318
+ offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
1319
+
1320
+ assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
1321
+ f"given conditions contain unknown attributes for fuser, " \
1322
+ f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
1323
+ cross_attention_output = None
1324
+ # print(f'-------------conditioners.py line1386----------condition_tensors:{conditions}')
1325
+ # print(f'-------------conditioners.py line1387----------condition_tensors.items:{conditions.items}')
1326
+ # exit()
1327
+ for cond_type, (cond, cond_mask) in conditions.items():
1328
+ op = self.cond2fuse[cond_type]
1329
+ if op == 'sum':
1330
+ input += cond
1331
+ elif op == 'input_interpolate':
1332
+ cond = einops.rearrange(cond, "b t d -> b d t")
1333
+ cond = F.interpolate(cond, size=input.shape[1])
1334
+ input += einops.rearrange(cond, "b d t -> b t d")
1335
+ elif op == 'prepend':
1336
+ if first_step:
1337
+ input = torch.cat([cond, input], dim=1)
1338
+ elif op == 'cross':
1339
+ if cross_attention_output is not None:
1340
+ cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
1341
+ else:
1342
+ cross_attention_output = cond
1343
+ else:
1344
+ raise ValueError(f"unknown op ({op})")
1345
+
1346
+ if self.cross_attention_pos_emb and cross_attention_output is not None:
1347
+ positions = torch.arange(
1348
+ cross_attention_output.shape[1],
1349
+ device=cross_attention_output.device
1350
+ ).view(1, -1, 1)
1351
+ pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
1352
+ cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
1353
+
1354
+ if self._is_streaming:
1355
+ self._streaming_state['offsets'] = offsets + T
1356
+
1357
+ return input, cross_attention_output
audiocraft/modules/conv.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+ import warnings
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from torch.nn.utils import spectral_norm, weight_norm
15
+
16
+
17
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
18
+ 'time_group_norm'])
19
+
20
+
21
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
22
+ assert norm in CONV_NORMALIZATIONS
23
+ if norm == 'weight_norm':
24
+ return weight_norm(module)
25
+ elif norm == 'spectral_norm':
26
+ return spectral_norm(module)
27
+ else:
28
+ # We already check was in CONV_NORMALIZATION, so any other choice
29
+ # doesn't need reparametrization.
30
+ return module
31
+
32
+
33
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
34
+ """Return the proper normalization module. If causal is True, this will ensure the returned
35
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
36
+ """
37
+ assert norm in CONV_NORMALIZATIONS
38
+ if norm == 'time_group_norm':
39
+ if causal:
40
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
41
+ assert isinstance(module, nn.modules.conv._ConvNd)
42
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
43
+ else:
44
+ return nn.Identity()
45
+
46
+
47
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
48
+ padding_total: int = 0) -> int:
49
+ """See `pad_for_conv1d`."""
50
+ length = x.shape[-1]
51
+ n_frames = (length - kernel_size + padding_total) / stride + 1
52
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
53
+ return ideal_length - length
54
+
55
+
56
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
57
+ """Pad for a convolution to make sure that the last window is full.
58
+ Extra padding is added at the end. This is required to ensure that we can rebuild
59
+ an output of the same length, as otherwise, even with padding, some time steps
60
+ might get removed.
61
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
62
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
63
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
64
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
65
+ 1 2 3 4 # once you removed padding, we are missing one time step !
66
+ """
67
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
68
+ return F.pad(x, (0, extra_padding))
69
+
70
+
71
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
72
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
73
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
74
+ """
75
+ length = x.shape[-1]
76
+ padding_left, padding_right = paddings
77
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
78
+ if mode == 'reflect':
79
+ max_pad = max(padding_left, padding_right)
80
+ extra_pad = 0
81
+ if length <= max_pad:
82
+ extra_pad = max_pad - length + 1
83
+ x = F.pad(x, (0, extra_pad))
84
+ padded = F.pad(x, paddings, mode, value)
85
+ end = padded.shape[-1] - extra_pad
86
+ return padded[..., :end]
87
+ else:
88
+ return F.pad(x, paddings, mode, value)
89
+
90
+
91
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
92
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
93
+ padding_left, padding_right = paddings
94
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
95
+ assert (padding_left + padding_right) <= x.shape[-1]
96
+ end = x.shape[-1] - padding_right
97
+ return x[..., padding_left: end]
98
+
99
+
100
+ class NormConv1d(nn.Module):
101
+ """Wrapper around Conv1d and normalization applied to this conv
102
+ to provide a uniform interface across normalization approaches.
103
+ """
104
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
105
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
106
+ super().__init__()
107
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
108
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
109
+ self.norm_type = norm
110
+
111
+ def forward(self, x):
112
+ x = self.conv(x)
113
+ x = self.norm(x)
114
+ return x
115
+
116
+
117
+ class NormConv2d(nn.Module):
118
+ """Wrapper around Conv2d and normalization applied to this conv
119
+ to provide a uniform interface across normalization approaches.
120
+ """
121
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
122
+ super().__init__()
123
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
124
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
125
+ self.norm_type = norm
126
+
127
+ def forward(self, x):
128
+ x = self.conv(x)
129
+ x = self.norm(x)
130
+ return x
131
+
132
+
133
+ class NormConvTranspose1d(nn.Module):
134
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
135
+ to provide a uniform interface across normalization approaches.
136
+ """
137
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
138
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
139
+ super().__init__()
140
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
141
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
142
+ self.norm_type = norm
143
+
144
+ def forward(self, x):
145
+ x = self.convtr(x)
146
+ x = self.norm(x)
147
+ return x
148
+
149
+
150
+ class NormConvTranspose2d(nn.Module):
151
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
152
+ to provide a uniform interface across normalization approaches.
153
+ """
154
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
155
+ super().__init__()
156
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
157
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
158
+
159
+ def forward(self, x):
160
+ x = self.convtr(x)
161
+ x = self.norm(x)
162
+ return x
163
+
164
+
165
+ class StreamableConv1d(nn.Module):
166
+ """Conv1d with some builtin handling of asymmetric or causal padding
167
+ and normalization.
168
+ """
169
+ def __init__(self, in_channels: int, out_channels: int,
170
+ kernel_size: int, stride: int = 1, dilation: int = 1,
171
+ groups: int = 1, bias: bool = True, causal: bool = False,
172
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
173
+ pad_mode: str = 'reflect'):
174
+ super().__init__()
175
+ # warn user on unusual setup between dilation and stride
176
+ if stride > 1 and dilation > 1:
177
+ warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
178
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
179
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
180
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
181
+ norm=norm, norm_kwargs=norm_kwargs)
182
+ self.causal = causal
183
+ self.pad_mode = pad_mode
184
+
185
+ def forward(self, x):
186
+ B, C, T = x.shape
187
+ kernel_size = self.conv.conv.kernel_size[0]
188
+ stride = self.conv.conv.stride[0]
189
+ dilation = self.conv.conv.dilation[0]
190
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
191
+ padding_total = kernel_size - stride
192
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
193
+ if self.causal:
194
+ # Left padding for causal
195
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
196
+ else:
197
+ # Asymmetric padding required for odd strides
198
+ padding_right = padding_total // 2
199
+ padding_left = padding_total - padding_right
200
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
201
+ return self.conv(x)
202
+
203
+
204
+ class StreamableConvTranspose1d(nn.Module):
205
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
206
+ and normalization.
207
+ """
208
+ def __init__(self, in_channels: int, out_channels: int,
209
+ kernel_size: int, stride: int = 1, causal: bool = False,
210
+ norm: str = 'none', trim_right_ratio: float = 1.,
211
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
212
+ super().__init__()
213
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
214
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
215
+ self.causal = causal
216
+ self.trim_right_ratio = trim_right_ratio
217
+ assert self.causal or self.trim_right_ratio == 1., \
218
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
219
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
220
+
221
+ def forward(self, x):
222
+ kernel_size = self.convtr.convtr.kernel_size[0]
223
+ stride = self.convtr.convtr.stride[0]
224
+ padding_total = kernel_size - stride
225
+
226
+ y = self.convtr(x)
227
+
228
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
229
+ # removed at the very end, when keeping only the right length for the output,
230
+ # as removing it here would require also passing the length at the matching layer
231
+ # in the encoder.
232
+ if self.causal:
233
+ # Trim the padding on the right according to the specified ratio
234
+ # if trim_right_ratio = 1.0, trim everything from right
235
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
236
+ padding_left = padding_total - padding_right
237
+ y = unpad1d(y, (padding_left, padding_right))
238
+ else:
239
+ # Asymmetric padding required for odd strides
240
+ padding_right = padding_total // 2
241
+ padding_left = padding_total - padding_right
242
+ y = unpad1d(y, (padding_left, padding_right))
243
+ return y
audiocraft/modules/diffusion_schedule.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
9
+ """
10
+
11
+ from collections import namedtuple
12
+ import random
13
+ import typing as tp
14
+ import julius
15
+ import torch
16
+
17
+ TrainingItem = namedtuple("TrainingItem", "noisy noise step")
18
+
19
+
20
+ def betas_from_alpha_bar(alpha_bar):
21
+ alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
22
+ return 1 - alphas
23
+
24
+
25
+ class SampleProcessor(torch.nn.Module):
26
+ def project_sample(self, x: torch.Tensor):
27
+ """Project the original sample to the 'space' where the diffusion will happen."""
28
+ return x
29
+
30
+ def return_sample(self, z: torch.Tensor):
31
+ """Project back from diffusion space to the actual sample space."""
32
+ return z
33
+
34
+
35
+ class MultiBandProcessor(SampleProcessor):
36
+ """
37
+ MultiBand sample processor. The input audio is splitted across
38
+ frequency bands evenly distributed in mel-scale.
39
+
40
+ Each band will be rescaled to match the power distribution
41
+ of Gaussian noise in that band, using online metrics
42
+ computed on the first few samples.
43
+
44
+ Args:
45
+ n_bands (int): Number of mel-bands to split the signal over.
46
+ sample_rate (int): Sample rate of the audio.
47
+ num_samples (int): Number of samples to use to fit the rescaling
48
+ for each band. The processor won't be stable
49
+ until it has seen that many samples.
50
+ power_std (float or list/tensor): The rescaling factor computed to match the
51
+ power of Gaussian noise in each band is taken to
52
+ that power, i.e. `1.` means full correction of the energy
53
+ in each band, and values less than `1` means only partial
54
+ correction. Can be used to balance the relative importance
55
+ of low vs. high freq in typical audio signals.
56
+ """
57
+ def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
58
+ num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
59
+ super().__init__()
60
+ self.n_bands = n_bands
61
+ self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
62
+ self.num_samples = num_samples
63
+ self.power_std = power_std
64
+ if isinstance(power_std, list):
65
+ assert len(power_std) == n_bands
66
+ power_std = torch.tensor(power_std)
67
+ self.register_buffer('counts', torch.zeros(1))
68
+ self.register_buffer('sum_x', torch.zeros(n_bands))
69
+ self.register_buffer('sum_x2', torch.zeros(n_bands))
70
+ self.register_buffer('sum_target_x2', torch.zeros(n_bands))
71
+ self.counts: torch.Tensor
72
+ self.sum_x: torch.Tensor
73
+ self.sum_x2: torch.Tensor
74
+ self.sum_target_x2: torch.Tensor
75
+
76
+ @property
77
+ def mean(self):
78
+ mean = self.sum_x / self.counts
79
+ return mean
80
+
81
+ @property
82
+ def std(self):
83
+ std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
84
+ return std
85
+
86
+ @property
87
+ def target_std(self):
88
+ target_std = self.sum_target_x2 / self.counts
89
+ return target_std
90
+
91
+ def project_sample(self, x: torch.Tensor):
92
+ assert x.dim() == 3
93
+ bands = self.split_bands(x)
94
+ if self.counts.item() < self.num_samples:
95
+ ref_bands = self.split_bands(torch.randn_like(x))
96
+ self.counts += len(x)
97
+ self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
98
+ self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
99
+ self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
100
+ rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
101
+ bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
102
+ return bands.sum(dim=0)
103
+
104
+ def return_sample(self, x: torch.Tensor):
105
+ assert x.dim() == 3
106
+ bands = self.split_bands(x)
107
+ rescale = (self.std / self.target_std) ** self.power_std
108
+ bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
109
+ return bands.sum(dim=0)
110
+
111
+
112
+ class NoiseSchedule:
113
+ """Noise schedule for diffusion.
114
+
115
+ Args:
116
+ beta_t0 (float): Variance of the first diffusion step.
117
+ beta_t1 (float): Variance of the last diffusion step.
118
+ beta_exp (float): Power schedule exponent
119
+ num_steps (int): Number of diffusion step.
120
+ variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
121
+ clip (float): clipping value for the denoising steps
122
+ rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
123
+ repartition (str): shape of the schedule only power schedule is supported
124
+ sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
125
+ noise_scale (float): Scaling factor for the noise
126
+ """
127
+ def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
128
+ clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
129
+ repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
130
+ sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
131
+
132
+ self.beta_t0 = beta_t0
133
+ self.beta_t1 = beta_t1
134
+ self.variance = variance
135
+ self.num_steps = num_steps
136
+ self.clip = clip
137
+ self.sample_processor = sample_processor
138
+ self.rescale = rescale
139
+ self.n_bands = n_bands
140
+ self.noise_scale = noise_scale
141
+ assert n_bands is None
142
+ if repartition == "power":
143
+ self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
144
+ device=device, dtype=torch.float) ** beta_exp
145
+ else:
146
+ raise RuntimeError('Not implemented')
147
+ self.rng = random.Random(1234)
148
+
149
+ def get_beta(self, step: tp.Union[int, torch.Tensor]):
150
+ if self.n_bands is None:
151
+ return self.betas[step]
152
+ else:
153
+ return self.betas[:, step] # [n_bands, len(step)]
154
+
155
+ def get_initial_noise(self, x: torch.Tensor):
156
+ if self.n_bands is None:
157
+ return torch.randn_like(x)
158
+ return torch.randn((x.size(0), self.n_bands, x.size(2)))
159
+
160
+ def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
161
+ """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
162
+ if step is None:
163
+ return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
164
+ if type(step) is int:
165
+ return (1 - self.betas[:step + 1]).prod()
166
+ else:
167
+ return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
168
+
169
+ def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
170
+ """Create a noisy data item for diffusion model training:
171
+
172
+ Args:
173
+ x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
174
+ tensor_step (bool): If tensor_step = false, only one step t is sample,
175
+ the whole batch is diffused to the same step and t is int.
176
+ If tensor_step = true, t is a tensor of size (x.size(0),)
177
+ every element of the batch is diffused to a independently sampled.
178
+ """
179
+ step: tp.Union[int, torch.Tensor]
180
+ if tensor_step:
181
+ bs = x.size(0)
182
+ step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
183
+ else:
184
+ step = self.rng.randrange(self.num_steps)
185
+ alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
186
+
187
+ x = self.sample_processor.project_sample(x)
188
+ noise = torch.randn_like(x)
189
+ noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
190
+ return TrainingItem(noisy, noise, step)
191
+
192
+ def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
193
+ condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
194
+ """Full ddpm reverse process.
195
+
196
+ Args:
197
+ model (nn.Module): Diffusion model.
198
+ initial (tensor): Initial Noise.
199
+ condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
200
+ return_list (bool): Whether to return the whole process or only the sampled point.
201
+ """
202
+ alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
203
+ current = initial
204
+ iterates = [initial]
205
+ for step in range(self.num_steps)[::-1]:
206
+ with torch.no_grad():
207
+ estimate = model(current, step, condition=condition).sample
208
+ alpha = 1 - self.betas[step]
209
+ previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
210
+ previous_alpha_bar = self.get_alpha_bar(step=step - 1)
211
+ if step == 0:
212
+ sigma2 = 0
213
+ elif self.variance == 'beta':
214
+ sigma2 = 1 - alpha
215
+ elif self.variance == 'beta_tilde':
216
+ sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
217
+ elif self.variance == 'none':
218
+ sigma2 = 0
219
+ else:
220
+ raise ValueError(f'Invalid variance type {self.variance}')
221
+
222
+ if sigma2 > 0:
223
+ previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
224
+ if self.clip:
225
+ previous = previous.clamp(-self.clip, self.clip)
226
+ current = previous
227
+ alpha_bar = previous_alpha_bar
228
+ if step == 0:
229
+ previous *= self.rescale
230
+ if return_list:
231
+ iterates.append(previous.cpu())
232
+
233
+ if return_list:
234
+ return iterates
235
+ else:
236
+ return self.sample_processor.return_sample(previous)
237
+
238
+ def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
239
+ condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
240
+ """Reverse process that only goes through Markov chain states in step_list."""
241
+ if step_list is None:
242
+ step_list = list(range(1000))[::-50] + [0]
243
+ alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
244
+ alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
245
+ betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
246
+ current = initial * self.noise_scale
247
+ iterates = [current]
248
+ for idx, step in enumerate(step_list[:-1]):
249
+ with torch.no_grad():
250
+ estimate = model(current, step, condition=condition).sample * self.noise_scale
251
+ alpha = 1 - betas_subsampled[-1 - idx]
252
+ previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
253
+ previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
254
+ if step == step_list[-2]:
255
+ sigma2 = 0
256
+ previous_alpha_bar = torch.tensor(1.0)
257
+ else:
258
+ sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
259
+ if sigma2 > 0:
260
+ previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
261
+ if self.clip:
262
+ previous = previous.clamp(-self.clip, self.clip)
263
+ current = previous
264
+ alpha_bar = previous_alpha_bar
265
+ if step == 0:
266
+ previous *= self.rescale
267
+ if return_list:
268
+ iterates.append(previous.cpu())
269
+ if return_list:
270
+ return iterates
271
+ else:
272
+ return self.sample_processor.return_sample(previous)
audiocraft/modules/lstm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+
9
+
10
+ class StreamableLSTM(nn.Module):
11
+ """LSTM without worrying about the hidden state, nor the layout of the data.
12
+ Expects input as convolutional layout.
13
+ """
14
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
15
+ super().__init__()
16
+ self.skip = skip
17
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
18
+
19
+ def forward(self, x):
20
+ x = x.permute(2, 0, 1)
21
+ y, _ = self.lstm(x)
22
+ if self.skip:
23
+ y = y + x
24
+ y = y.permute(1, 2, 0)
25
+ return y