Nesbitt commited on
Commit
4068b97
·
1 Parent(s): 4068279

Initial Commit

Browse files
README.md CHANGED
@@ -1,13 +1,74 @@
1
- ---
2
- title: MVSEP MDX23 Music Separation Model
3
- emoji: 😻
4
- colorFrom: yellow
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.35.2
8
- app_file: app.py
9
- pinned: false
10
- license: agpl-3.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MVSEP-MDX23-music-separation-model
2
+ Model for [Sound demixing challenge 2023: Music Demixing Track - MDX'23](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023). Model perform separation of music into 4 stems "bass", "drums", "vocals", "other". Model won 3rd place in challenge (Leaderboard C).
3
+
4
+ Model based on [Demucs4](https://github.com/facebookresearch/demucs), [MDX](https://github.com/kuielab/mdx-net) neural net architectures and some MDX weights from [Ultimate Vocal Remover](https://github.com/Anjok07/ultimatevocalremovergui) project (thanks [Kimberley Jensen](https://github.com/KimberleyJensen) for great high quality vocal models). Brought to you by [MVSep.com](https://mvsep.com).
5
+ ## Usage
6
+
7
+ ```
8
+ python inference.py --input_audio mixture1.wav mixture2.wav --output_folder ./results/
9
+ ```
10
+
11
+ With this command audios with names "mixture1.wav" and "mixture2.wav" will be processed and results will be stored in `./results/` folder in WAV format.
12
+
13
+ ### All available keys
14
+ * `--input_audio` - input audio location. You can provide multiple files at once. **Required**
15
+ * `--output_folder` - output audio folder. **Required**
16
+ * `--cpu` - choose CPU instead of GPU for processing. Can be very slow.
17
+ * `--overlap_large` - overlap of splitted audio for light models. Closer to 1.0 - slower, but better quality. Default: 0.6.
18
+ * `--overlap_small` - overlap of splitted audio for heavy models. Closer to 1.0 - slower, but better quality. Default: 0.5.
19
+ * `--single_onnx` - only use single ONNX model for vocals. Can be useful if you have not enough GPU memory.
20
+ * `--chunk_size` - chunk size for ONNX models. Set lower to reduce GPU memory consumption. Default: 1000000.
21
+ * `--large_gpu` - it will store all models on GPU for faster processing of multiple audio files. Requires at least 11 GB of free GPU memory.
22
+ * `--use_kim_model_1` - use first version of Kim model (as it was on contest).
23
+ * `--only_vocals` - only create vocals and instrumental. Skip bass, drums, other. Processing will be faster.
24
+
25
+ ### Notes
26
+ * If you have not enough GPU memory you can use CPU (`--cpu`), but it will be slow. Additionally you can use single ONNX (`--single_onnx`), but it will decrease quality a little bit. Also reduce of chunk size can help (`--chunk_size 200000`).
27
+ * In current revision code requires less GPU memory, but it process multiple files slower. If you want old fast method use argument `--large_gpu`. It will require > 11 GB of GPU memory, but will work faster.
28
+ * There is [Google.Collab version](https://colab.research.google.com/github/jarredou/MVSEP-MDX23-Colab_v2/blob/main/MVSep-MDX23-Colab.ipynb) of this code.
29
+
30
+ ## Quality comparison
31
+
32
+ Quality comparison with best separation models performed on [MultiSong Dataset](https://mvsep.com/quality_checker/leaderboard2.php?sort=bass).
33
+
34
+ | Algorithm | SDR bass | SDR drums | SDR other | SDR vocals | SDR instrumental |
35
+ | ------------- |:---------:|:----------:|:----------:|:----------:|:------------------:|
36
+ | MVSEP MDX23 | 12.5034 | 11.6870 | 6.5378 | 9.5138 | 15.8213 |
37
+ | Demucs HT 4 | 12.1006 | 11.3037 | 5.7728 | 8.3555 | 13.9902 |
38
+ | Demucs 3 | 10.6947 | 10.2744 | 5.3580 | 8.1335 | 14.4409 |
39
+ | MDX B | --- | ---- | --- | 8.5118 | 14.8192 |
40
+
41
+ * Note: SDR - signal to distortion ratio. Larger is better.
42
+
43
+ ## GUI
44
+
45
+ ![GUI Window](https://github.com/ZFTurbo/MVSEP-MDX23-music-separation-model/blob/main/images/MVSep-Window.png)
46
+
47
+ * Script for GUI (based on PyQt5): [gui.py](gui.py).
48
+ * You can download [standalone program for Windows here](https://github.com/ZFTurbo/MVSEP-MDX23-music-separation-model/releases/download/v1.0.1/MVSep-MDX23_v1.0.1.zip) (~730 MB). Unzip archive and to start program double click `run.bat`. On first run it will download pytorch with CUDA support (~2.8 GB) and some Neural Net models.
49
+ * Program will download all needed neural net models from internet at the first run.
50
+ * GUI supports Drag & Drop of multiple files.
51
+ * Progress bar available.
52
+
53
+ ## Changes
54
+
55
+ ### v1.0.1
56
+ * Settings in GUI updated, now you can control all possible options
57
+ * Kim vocal model updated from version 1 to version 2, you still can use version 1 using parameter `--use_kim_model_1`
58
+ * Added possibility to generate only vocals/instrumental pair if you don't need bass, drums and other stems. Use parameter `--only_vocals`
59
+ * Standalone program was updated. It has less size now. GUI will download torch/cuda on the first run.
60
+
61
+ ## Citation
62
+
63
+ * [arxiv paper](https://arxiv.org/abs/2305.07489)
64
+
65
+ ```
66
+ @misc{solovyev2023benchmarks,
67
+ title={Benchmarks and leaderboards for sound demixing tasks},
68
+ author={Roman Solovyev and Alexander Stempkovskiy and Tatiana Habruseva},
69
+ year={2023},
70
+ eprint={2305.07489},
71
+ archivePrefix={arXiv},
72
+ primaryClass={cs.SD}
73
+ }
74
+ ```
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import soundfile as sf
4
+ import numpy as np
5
+ import tempfile
6
+ from scipy.io import wavfile
7
+ from pytube import YouTube
8
+ from gradio import Interface, components as gr
9
+ from moviepy.editor import AudioFileClip
10
+ from inference import EnsembleDemucsMDXMusicSeparationModel, predict_with_model
11
+ import torch
12
+
13
+ def download_youtube_video_as_wav(youtube_url):
14
+ output_dir = "downloads"
15
+ os.makedirs(output_dir, exist_ok=True)
16
+ output_file = os.path.join(output_dir, "temp.mp4")
17
+
18
+ try:
19
+ yt = YouTube(youtube_url)
20
+ yt.streams.filter(only_audio=True).first().download(filename=output_file)
21
+ print("Download completed successfully.")
22
+ except Exception as e:
23
+ print(f"An error occurred while downloading the video: {e}")
24
+ return None
25
+
26
+ # Convert mp4 audio to wav
27
+ wav_file = os.path.join(output_dir, "mixture.wav")
28
+ clip = AudioFileClip(output_file)
29
+ clip.write_audiofile(wav_file)
30
+
31
+ return wav_file
32
+
33
+
34
+ def check_file_readiness(filepath):
35
+ num_same_size_checks = 0
36
+ last_size = -1
37
+
38
+ while num_same_size_checks < 5:
39
+ current_size = os.path.getsize(filepath)
40
+
41
+ if current_size == last_size:
42
+ num_same_size_checks += 1
43
+ else:
44
+ num_same_size_checks = 0
45
+ last_size = current_size
46
+
47
+ time.sleep(1)
48
+
49
+ # If the loop finished, it means the file size has not changed for 5 seconds
50
+ # which indicates that the file is ready
51
+ return True
52
+
53
+
54
+
55
+ def separate_music_file_wrapper(input_string, use_cpu, use_single_onnx, large_overlap, small_overlap, chunk_size, use_large_gpu):
56
+ input_files = []
57
+
58
+ if input_string.startswith("https://www.youtube.com") or input_string.startswith("https://youtu.be"):
59
+ output_file = download_youtube_video_as_wav(input_string)
60
+ if output_file is not None:
61
+ input_files.append(output_file)
62
+ elif os.path.isdir(input_string):
63
+ input_directory = input_string
64
+ input_files = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.endswith('.wav')]
65
+ else:
66
+ raise ValueError("Invalid input! Please provide a valid YouTube link or a directory path.")
67
+
68
+ options = {
69
+ 'input_audio': input_files,
70
+ 'output_folder': 'results',
71
+ 'cpu': use_cpu,
72
+ 'single_onnx': use_single_onnx,
73
+ 'overlap_large': large_overlap,
74
+ 'overlap_small': small_overlap,
75
+ 'chunk_size': chunk_size,
76
+ 'large_gpu': use_large_gpu,
77
+ }
78
+
79
+ predict_with_model(options)
80
+
81
+ # Clear GPU cache
82
+ if torch.cuda.is_available():
83
+ torch.cuda.empty_cache()
84
+
85
+ output_files = {}
86
+ for f in input_files:
87
+ audio_file_name = os.path.splitext(os.path.basename(f))[0]
88
+ output_files["vocals"] = os.path.join(options['output_folder'], audio_file_name + "_vocals.wav")
89
+ output_files["instrumental"] = os.path.join(options['output_folder'], audio_file_name + "_instrum.wav")
90
+ output_files["instrumental2"] = os.path.join(options['output_folder'], audio_file_name + "_instrum2.wav") # For the second instrumental output
91
+ output_files["bass"] = os.path.join(options['output_folder'], audio_file_name + "_bass.wav")
92
+ output_files["drums"] = os.path.join(options['output_folder'], audio_file_name + "_drums.wav")
93
+ output_files["other"] = os.path.join(options['output_folder'], audio_file_name + "_other.wav")
94
+
95
+
96
+ # Check the readiness of the files
97
+ output_files_ready = []
98
+ for k, v in output_files.items():
99
+ if os.path.exists(v) and check_file_readiness(v):
100
+ output_files_ready.append(v)
101
+ else:
102
+ empty_data = np.zeros((44100, 2)) # 2 channels, 1 second of silence at 44100Hz
103
+ empty_file = tempfile.mktemp('.wav')
104
+ wavfile.write(empty_file, 44100, empty_data.astype(np.int16)) # Cast to int16 as wavfile does not support float32
105
+ output_files_ready.append(empty_file)
106
+
107
+ return tuple(output_files_ready)
108
+
109
+ description = """
110
+ # ZFTurbo Web-UI
111
+ Web-UI by [Ma5onic](https://github.com/Ma5onic)
112
+ ## Options:
113
+ - **Use CPU Only:** Select this if you have not enough GPU memory. It will be slower.
114
+ - **Use Single ONNX:** Select this to use a single ONNX model. It will decrease quality a little bit but can help with GPU memory usage.
115
+ - **Large Overlap:** The overlap for large chunks. Adjust as needed.
116
+ - **Small Overlap:** The overlap for small chunks. Adjust as needed.
117
+ - **Chunk Size:** The size of chunks to be processed at a time. Reduce this if facing memory issues.
118
+ - **Use Fast Large GPU Version:** Select this to use the old fast method that requires > 11 GB of GPU memory. It will work faster.
119
+ """
120
+
121
+ iface = Interface(
122
+ fn=separate_music_file_wrapper,
123
+ inputs=[
124
+ gr.Text(label="Input Directory or YouTube Link"),
125
+ gr.Checkbox(label="Use CPU Only", value=False),
126
+ gr.Checkbox(label="Use Single ONNX", value=False),
127
+ gr.Number(label="Large Overlap", value=0.6),
128
+ gr.Number(label="Small Overlap", value=0.5),
129
+ gr.Number(label="Chunk Size", value=1000000),
130
+ gr.Checkbox(label="Use Fast Large GPU Version", value=False)
131
+ ],
132
+ outputs=[
133
+ gr.Audio(label="Vocals"),
134
+ gr.Audio(label="Instrumental"),
135
+ gr.Audio(label="Instrumental 2"),
136
+ gr.Audio(label="Bass"),
137
+ gr.Audio(label="Drums"),
138
+ gr.Audio(label="Other"),
139
+ ],
140
+ description=description,
141
+ )
142
+
143
+ iface.queue().launch(debug=True, share=False)
demucs3/demucs.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import julius
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .states import capture_init
16
+ from .utils import center_trim, unfold
17
+ from .transformer import LayerScale
18
+
19
+
20
+ class BLSTM(nn.Module):
21
+ """
22
+ BiLSTM with same hidden units as input dim.
23
+ If `max_steps` is not None, input will be splitting in overlapping
24
+ chunks and the LSTM applied separately on each chunk.
25
+ """
26
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
27
+ super().__init__()
28
+ assert max_steps is None or max_steps % 4 == 0
29
+ self.max_steps = max_steps
30
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
31
+ self.linear = nn.Linear(2 * dim, dim)
32
+ self.skip = skip
33
+
34
+ def forward(self, x):
35
+ B, C, T = x.shape
36
+ y = x
37
+ framed = False
38
+ if self.max_steps is not None and T > self.max_steps:
39
+ width = self.max_steps
40
+ stride = width // 2
41
+ frames = unfold(x, width, stride)
42
+ nframes = frames.shape[2]
43
+ framed = True
44
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
45
+
46
+ x = x.permute(2, 0, 1)
47
+
48
+ x = self.lstm(x)[0]
49
+ x = self.linear(x)
50
+ x = x.permute(1, 2, 0)
51
+ if framed:
52
+ out = []
53
+ frames = x.reshape(B, -1, C, width)
54
+ limit = stride // 2
55
+ for k in range(nframes):
56
+ if k == 0:
57
+ out.append(frames[:, k, :, :-limit])
58
+ elif k == nframes - 1:
59
+ out.append(frames[:, k, :, limit:])
60
+ else:
61
+ out.append(frames[:, k, :, limit:-limit])
62
+ out = torch.cat(out, -1)
63
+ out = out[..., :T]
64
+ x = out
65
+ if self.skip:
66
+ x = x + y
67
+ return x
68
+
69
+
70
+ def rescale_conv(conv, reference):
71
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
72
+ """
73
+ std = conv.weight.std().detach()
74
+ scale = (std / reference)**0.5
75
+ conv.weight.data /= scale
76
+ if conv.bias is not None:
77
+ conv.bias.data /= scale
78
+
79
+
80
+ def rescale_module(module, reference):
81
+ for sub in module.modules():
82
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
83
+ rescale_conv(sub, reference)
84
+
85
+
86
+ class DConv(nn.Module):
87
+ """
88
+ New residual branches in each encoder layer.
89
+ This alternates dilated convolutions, potentially with LSTMs and attention.
90
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
91
+ e.g. of dim `channels // compress`.
92
+ """
93
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
94
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
95
+ kernel=3, dilate=True):
96
+ """
97
+ Args:
98
+ channels: input/output channels for residual branch.
99
+ compress: amount of channel compression inside the branch.
100
+ depth: number of layers in the residual branch. Each layer has its own
101
+ projection, and potentially LSTM and attention.
102
+ init: initial scale for LayerNorm.
103
+ norm: use GroupNorm.
104
+ attn: use LocalAttention.
105
+ heads: number of heads for the LocalAttention.
106
+ ndecay: number of decay controls in the LocalAttention.
107
+ lstm: use LSTM.
108
+ gelu: Use GELU activation.
109
+ kernel: kernel size for the (dilated) convolutions.
110
+ dilate: if true, use dilation, increasing with the depth.
111
+ """
112
+
113
+ super().__init__()
114
+ assert kernel % 2 == 1
115
+ self.channels = channels
116
+ self.compress = compress
117
+ self.depth = abs(depth)
118
+ dilate = depth > 0
119
+
120
+ norm_fn: tp.Callable[[int], nn.Module]
121
+ norm_fn = lambda d: nn.Identity() # noqa
122
+ if norm:
123
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
124
+
125
+ hidden = int(channels / compress)
126
+
127
+ act: tp.Type[nn.Module]
128
+ if gelu:
129
+ act = nn.GELU
130
+ else:
131
+ act = nn.ReLU
132
+
133
+ self.layers = nn.ModuleList([])
134
+ for d in range(self.depth):
135
+ dilation = 2 ** d if dilate else 1
136
+ padding = dilation * (kernel // 2)
137
+ mods = [
138
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
139
+ norm_fn(hidden), act(),
140
+ nn.Conv1d(hidden, 2 * channels, 1),
141
+ norm_fn(2 * channels), nn.GLU(1),
142
+ LayerScale(channels, init),
143
+ ]
144
+ if attn:
145
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
146
+ if lstm:
147
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
148
+ layer = nn.Sequential(*mods)
149
+ self.layers.append(layer)
150
+
151
+ def forward(self, x):
152
+ for layer in self.layers:
153
+ x = x + layer(x)
154
+ return x
155
+
156
+
157
+ class LocalState(nn.Module):
158
+ """Local state allows to have attention based only on data (no positional embedding),
159
+ but while setting a constraint on the time window (e.g. decaying penalty term).
160
+
161
+ Also a failed experiments with trying to provide some frequency based attention.
162
+ """
163
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
164
+ super().__init__()
165
+ assert channels % heads == 0, (channels, heads)
166
+ self.heads = heads
167
+ self.nfreqs = nfreqs
168
+ self.ndecay = ndecay
169
+ self.content = nn.Conv1d(channels, channels, 1)
170
+ self.query = nn.Conv1d(channels, channels, 1)
171
+ self.key = nn.Conv1d(channels, channels, 1)
172
+ if nfreqs:
173
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
174
+ if ndecay:
175
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
176
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
177
+ self.query_decay.weight.data *= 0.01
178
+ assert self.query_decay.bias is not None # stupid type checker
179
+ self.query_decay.bias.data[:] = -2
180
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
181
+
182
+ def forward(self, x):
183
+ B, C, T = x.shape
184
+ heads = self.heads
185
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
186
+ # left index are keys, right index are queries
187
+ delta = indexes[:, None] - indexes[None, :]
188
+
189
+ queries = self.query(x).view(B, heads, -1, T)
190
+ keys = self.key(x).view(B, heads, -1, T)
191
+ # t are keys, s are queries
192
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
193
+ dots /= keys.shape[2]**0.5
194
+ if self.nfreqs:
195
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
196
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
197
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
198
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
199
+ if self.ndecay:
200
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
201
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
202
+ decay_q = torch.sigmoid(decay_q) / 2
203
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
204
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
205
+
206
+ # Kill self reference.
207
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
208
+ weights = torch.softmax(dots, dim=2)
209
+
210
+ content = self.content(x).view(B, heads, -1, T)
211
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
212
+ if self.nfreqs:
213
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
214
+ result = torch.cat([result, time_sig], 2)
215
+ result = result.reshape(B, -1, T)
216
+ return x + self.proj(result)
217
+
218
+
219
+ class Demucs(nn.Module):
220
+ @capture_init
221
+ def __init__(self,
222
+ sources,
223
+ # Channels
224
+ audio_channels=2,
225
+ channels=64,
226
+ growth=2.,
227
+ # Main structure
228
+ depth=6,
229
+ rewrite=True,
230
+ lstm_layers=0,
231
+ # Convolutions
232
+ kernel_size=8,
233
+ stride=4,
234
+ context=1,
235
+ # Activations
236
+ gelu=True,
237
+ glu=True,
238
+ # Normalization
239
+ norm_starts=4,
240
+ norm_groups=4,
241
+ # DConv residual branch
242
+ dconv_mode=1,
243
+ dconv_depth=2,
244
+ dconv_comp=4,
245
+ dconv_attn=4,
246
+ dconv_lstm=4,
247
+ dconv_init=1e-4,
248
+ # Pre/post processing
249
+ normalize=True,
250
+ resample=True,
251
+ # Weight init
252
+ rescale=0.1,
253
+ # Metadata
254
+ samplerate=44100,
255
+ segment=4 * 10):
256
+ """
257
+ Args:
258
+ sources (list[str]): list of source names
259
+ audio_channels (int): stereo or mono
260
+ channels (int): first convolution channels
261
+ depth (int): number of encoder/decoder layers
262
+ growth (float): multiply (resp divide) number of channels by that
263
+ for each layer of the encoder (resp decoder)
264
+ depth (int): number of layers in the encoder and in the decoder.
265
+ rewrite (bool): add 1x1 convolution to each layer.
266
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
267
+ by default, as this is now replaced by the smaller and faster small LSTMs
268
+ in the DConv branches.
269
+ kernel_size (int): kernel size for convolutions
270
+ stride (int): stride for convolutions
271
+ context (int): kernel size of the convolution in the
272
+ decoder before the transposed convolution. If > 1,
273
+ will provide some context from neighboring time steps.
274
+ gelu: use GELU activation function.
275
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
276
+ norm_starts: layer at which group norm starts being used.
277
+ decoder layers are numbered in reverse order.
278
+ norm_groups: number of groups for group norm.
279
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
280
+ dconv_depth: depth of residual DConv branch.
281
+ dconv_comp: compression of DConv branch.
282
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
283
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
284
+ dconv_init: initial scale for the DConv branch LayerScale.
285
+ normalize (bool): normalizes the input audio on the fly, and scales back
286
+ the output by the same amount.
287
+ resample (bool): upsample x2 the input and downsample /2 the output.
288
+ rescale (int): rescale initial weights of convolutions
289
+ to get their standard deviation closer to `rescale`.
290
+ samplerate (int): stored as meta information for easing
291
+ future evaluations of the model.
292
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
293
+ This is used by `demucs.apply.apply_model`.
294
+ """
295
+
296
+ super().__init__()
297
+ self.audio_channels = audio_channels
298
+ self.sources = sources
299
+ self.kernel_size = kernel_size
300
+ self.context = context
301
+ self.stride = stride
302
+ self.depth = depth
303
+ self.resample = resample
304
+ self.channels = channels
305
+ self.normalize = normalize
306
+ self.samplerate = samplerate
307
+ self.segment = segment
308
+ self.encoder = nn.ModuleList()
309
+ self.decoder = nn.ModuleList()
310
+ self.skip_scales = nn.ModuleList()
311
+
312
+ if glu:
313
+ activation = nn.GLU(dim=1)
314
+ ch_scale = 2
315
+ else:
316
+ activation = nn.ReLU()
317
+ ch_scale = 1
318
+ if gelu:
319
+ act2 = nn.GELU
320
+ else:
321
+ act2 = nn.ReLU
322
+
323
+ in_channels = audio_channels
324
+ padding = 0
325
+ for index in range(depth):
326
+ norm_fn = lambda d: nn.Identity() # noqa
327
+ if index >= norm_starts:
328
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
329
+
330
+ encode = []
331
+ encode += [
332
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
333
+ norm_fn(channels),
334
+ act2(),
335
+ ]
336
+ attn = index >= dconv_attn
337
+ lstm = index >= dconv_lstm
338
+ if dconv_mode & 1:
339
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
340
+ compress=dconv_comp, attn=attn, lstm=lstm)]
341
+ if rewrite:
342
+ encode += [
343
+ nn.Conv1d(channels, ch_scale * channels, 1),
344
+ norm_fn(ch_scale * channels), activation]
345
+ self.encoder.append(nn.Sequential(*encode))
346
+
347
+ decode = []
348
+ if index > 0:
349
+ out_channels = in_channels
350
+ else:
351
+ out_channels = len(self.sources) * audio_channels
352
+ if rewrite:
353
+ decode += [
354
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
355
+ norm_fn(ch_scale * channels), activation]
356
+ if dconv_mode & 2:
357
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
358
+ compress=dconv_comp, attn=attn, lstm=lstm)]
359
+ decode += [nn.ConvTranspose1d(channels, out_channels,
360
+ kernel_size, stride, padding=padding)]
361
+ if index > 0:
362
+ decode += [norm_fn(out_channels), act2()]
363
+ self.decoder.insert(0, nn.Sequential(*decode))
364
+ in_channels = channels
365
+ channels = int(growth * channels)
366
+
367
+ channels = in_channels
368
+ if lstm_layers:
369
+ self.lstm = BLSTM(channels, lstm_layers)
370
+ else:
371
+ self.lstm = None
372
+
373
+ if rescale:
374
+ rescale_module(self, reference=rescale)
375
+
376
+ def valid_length(self, length):
377
+ """
378
+ Return the nearest valid length to use with the model so that
379
+ there is no time steps left over in a convolution, e.g. for all
380
+ layers, size of the input - kernel_size % stride = 0.
381
+
382
+ Note that input are automatically padded if necessary to ensure that the output
383
+ has the same length as the input.
384
+ """
385
+ if self.resample:
386
+ length *= 2
387
+
388
+ for _ in range(self.depth):
389
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
390
+ length = max(1, length)
391
+
392
+ for idx in range(self.depth):
393
+ length = (length - 1) * self.stride + self.kernel_size
394
+
395
+ if self.resample:
396
+ length = math.ceil(length / 2)
397
+ return int(length)
398
+
399
+ def forward(self, mix):
400
+ x = mix
401
+ length = x.shape[-1]
402
+
403
+ if self.normalize:
404
+ mono = mix.mean(dim=1, keepdim=True)
405
+ mean = mono.mean(dim=-1, keepdim=True)
406
+ std = mono.std(dim=-1, keepdim=True)
407
+ x = (x - mean) / (1e-5 + std)
408
+ else:
409
+ mean = 0
410
+ std = 1
411
+
412
+ delta = self.valid_length(length) - length
413
+ x = F.pad(x, (delta // 2, delta - delta // 2))
414
+
415
+ if self.resample:
416
+ x = julius.resample_frac(x, 1, 2)
417
+
418
+ saved = []
419
+ for encode in self.encoder:
420
+ x = encode(x)
421
+ saved.append(x)
422
+
423
+ if self.lstm:
424
+ x = self.lstm(x)
425
+
426
+ for decode in self.decoder:
427
+ skip = saved.pop(-1)
428
+ skip = center_trim(skip, x)
429
+ x = decode(x + skip)
430
+
431
+ if self.resample:
432
+ x = julius.resample_frac(x, 2, 1)
433
+ x = x * std + mean
434
+ x = center_trim(x, length)
435
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
436
+ return x
437
+
438
+ def load_state_dict(self, state, strict=True):
439
+ # fix a mismatch with previous generation Demucs models.
440
+ for idx in range(self.depth):
441
+ for a in ['encoder', 'decoder']:
442
+ for b in ['bias', 'weight']:
443
+ new = f'{a}.{idx}.3.{b}'
444
+ old = f'{a}.{idx}.2.{b}'
445
+ if old in state and new not in state:
446
+ state[new] = state.pop(old)
447
+ super().load_state_dict(state, strict=strict)
demucs3/hdemucs.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ This code contains the spectrogram and Hybrid version of Demucs.
8
+ """
9
+ from copy import deepcopy
10
+ import math
11
+ import typing as tp
12
+
13
+ from openunmix.filtering import wiener
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+
18
+ from .demucs import DConv, rescale_module
19
+ from .states import capture_init
20
+ from .spec import spectro, ispectro
21
+
22
+
23
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
24
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
25
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
26
+ x0 = x
27
+ length = x.shape[-1]
28
+ padding_left, padding_right = paddings
29
+ if mode == 'reflect':
30
+ max_pad = max(padding_left, padding_right)
31
+ if length <= max_pad:
32
+ extra_pad = max_pad - length + 1
33
+ extra_pad_right = min(padding_right, extra_pad)
34
+ extra_pad_left = extra_pad - extra_pad_right
35
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
36
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
37
+ out = F.pad(x, paddings, mode, value)
38
+ assert out.shape[-1] == length + padding_left + padding_right
39
+ assert (out[..., padding_left: padding_left + length] == x0).all()
40
+ return out
41
+
42
+
43
+ class ScaledEmbedding(nn.Module):
44
+ """
45
+ Boost learning rate for embeddings (with `scale`).
46
+ Also, can make embeddings continuous with `smooth`.
47
+ """
48
+ def __init__(self, num_embeddings: int, embedding_dim: int,
49
+ scale: float = 10., smooth=False):
50
+ super().__init__()
51
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
52
+ if smooth:
53
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
54
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
55
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
56
+ self.embedding.weight.data[:] = weight
57
+ self.embedding.weight.data /= scale
58
+ self.scale = scale
59
+
60
+ @property
61
+ def weight(self):
62
+ return self.embedding.weight * self.scale
63
+
64
+ def forward(self, x):
65
+ out = self.embedding(x) * self.scale
66
+ return out
67
+
68
+
69
+ class HEncLayer(nn.Module):
70
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
71
+ freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
72
+ rewrite=True):
73
+ """Encoder layer. This used both by the time and the frequency branch.
74
+
75
+ Args:
76
+ chin: number of input channels.
77
+ chout: number of output channels.
78
+ norm_groups: number of groups for group norm.
79
+ empty: used to make a layer with just the first conv. this is used
80
+ before merging the time and freq. branches.
81
+ freq: this is acting on frequencies.
82
+ dconv: insert DConv residual branches.
83
+ norm: use GroupNorm.
84
+ context: context size for the 1x1 conv.
85
+ dconv_kw: list of kwargs for the DConv class.
86
+ pad: pad the input. Padding is done so that the output size is
87
+ always the input size / stride.
88
+ rewrite: add 1x1 conv at the end of the layer.
89
+ """
90
+ super().__init__()
91
+ norm_fn = lambda d: nn.Identity() # noqa
92
+ if norm:
93
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
94
+ if pad:
95
+ pad = kernel_size // 4
96
+ else:
97
+ pad = 0
98
+ klass = nn.Conv1d
99
+ self.freq = freq
100
+ self.kernel_size = kernel_size
101
+ self.stride = stride
102
+ self.empty = empty
103
+ self.norm = norm
104
+ self.pad = pad
105
+ if freq:
106
+ kernel_size = [kernel_size, 1]
107
+ stride = [stride, 1]
108
+ pad = [pad, 0]
109
+ klass = nn.Conv2d
110
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
111
+ if self.empty:
112
+ return
113
+ self.norm1 = norm_fn(chout)
114
+ self.rewrite = None
115
+ if rewrite:
116
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
117
+ self.norm2 = norm_fn(2 * chout)
118
+
119
+ self.dconv = None
120
+ if dconv:
121
+ self.dconv = DConv(chout, **dconv_kw)
122
+
123
+ def forward(self, x, inject=None):
124
+ """
125
+ `inject` is used to inject the result from the time branch into the frequency branch,
126
+ when both have the same stride.
127
+ """
128
+ if not self.freq and x.dim() == 4:
129
+ B, C, Fr, T = x.shape
130
+ x = x.view(B, -1, T)
131
+
132
+ if not self.freq:
133
+ le = x.shape[-1]
134
+ if not le % self.stride == 0:
135
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
136
+ y = self.conv(x)
137
+ if self.empty:
138
+ return y
139
+ if inject is not None:
140
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
141
+ if inject.dim() == 3 and y.dim() == 4:
142
+ inject = inject[:, :, None]
143
+ y = y + inject
144
+ y = F.gelu(self.norm1(y))
145
+ if self.dconv:
146
+ if self.freq:
147
+ B, C, Fr, T = y.shape
148
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
149
+ y = self.dconv(y)
150
+ if self.freq:
151
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
152
+ if self.rewrite:
153
+ z = self.norm2(self.rewrite(y))
154
+ z = F.glu(z, dim=1)
155
+ else:
156
+ z = y
157
+ return z
158
+
159
+
160
+ class MultiWrap(nn.Module):
161
+ """
162
+ Takes one layer and replicate it N times. each replica will act
163
+ on a frequency band. All is done so that if the N replica have the same weights,
164
+ then this is exactly equivalent to applying the original module on all frequencies.
165
+
166
+ This is a bit over-engineered to avoid edge artifacts when splitting
167
+ the frequency bands, but it is possible the naive implementation would work as well...
168
+ """
169
+ def __init__(self, layer, split_ratios):
170
+ """
171
+ Args:
172
+ layer: module to clone, must be either HEncLayer or HDecLayer.
173
+ split_ratios: list of float indicating which ratio to keep for each band.
174
+ """
175
+ super().__init__()
176
+ self.split_ratios = split_ratios
177
+ self.layers = nn.ModuleList()
178
+ self.conv = isinstance(layer, HEncLayer)
179
+ assert not layer.norm
180
+ assert layer.freq
181
+ assert layer.pad
182
+ if not self.conv:
183
+ assert not layer.context_freq
184
+ for k in range(len(split_ratios) + 1):
185
+ lay = deepcopy(layer)
186
+ if self.conv:
187
+ lay.conv.padding = (0, 0)
188
+ else:
189
+ lay.pad = False
190
+ for m in lay.modules():
191
+ if hasattr(m, 'reset_parameters'):
192
+ m.reset_parameters()
193
+ self.layers.append(lay)
194
+
195
+ def forward(self, x, skip=None, length=None):
196
+ B, C, Fr, T = x.shape
197
+
198
+ ratios = list(self.split_ratios) + [1]
199
+ start = 0
200
+ outs = []
201
+ for ratio, layer in zip(ratios, self.layers):
202
+ if self.conv:
203
+ pad = layer.kernel_size // 4
204
+ if ratio == 1:
205
+ limit = Fr
206
+ frames = -1
207
+ else:
208
+ limit = int(round(Fr * ratio))
209
+ le = limit - start
210
+ if start == 0:
211
+ le += pad
212
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
213
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
214
+ if start == 0:
215
+ limit -= pad
216
+ assert limit - start > 0, (limit, start)
217
+ assert limit <= Fr, (limit, Fr)
218
+ y = x[:, :, start:limit, :]
219
+ if start == 0:
220
+ y = F.pad(y, (0, 0, pad, 0))
221
+ if ratio == 1:
222
+ y = F.pad(y, (0, 0, 0, pad))
223
+ outs.append(layer(y))
224
+ start = limit - layer.kernel_size + layer.stride
225
+ else:
226
+ if ratio == 1:
227
+ limit = Fr
228
+ else:
229
+ limit = int(round(Fr * ratio))
230
+ last = layer.last
231
+ layer.last = True
232
+
233
+ y = x[:, :, start:limit]
234
+ s = skip[:, :, start:limit]
235
+ out, _ = layer(y, s, None)
236
+ if outs:
237
+ outs[-1][:, :, -layer.stride:] += (
238
+ out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
239
+ out = out[:, :, layer.stride:]
240
+ if ratio == 1:
241
+ out = out[:, :, :-layer.stride // 2, :]
242
+ if start == 0:
243
+ out = out[:, :, layer.stride // 2:, :]
244
+ outs.append(out)
245
+ layer.last = last
246
+ start = limit
247
+ out = torch.cat(outs, dim=2)
248
+ if not self.conv and not last:
249
+ out = F.gelu(out)
250
+ if self.conv:
251
+ return out
252
+ else:
253
+ return out, None
254
+
255
+
256
+ class HDecLayer(nn.Module):
257
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
258
+ freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
259
+ context_freq=True, rewrite=True):
260
+ """
261
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
262
+ """
263
+ super().__init__()
264
+ norm_fn = lambda d: nn.Identity() # noqa
265
+ if norm:
266
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
267
+ if pad:
268
+ pad = kernel_size // 4
269
+ else:
270
+ pad = 0
271
+ self.pad = pad
272
+ self.last = last
273
+ self.freq = freq
274
+ self.chin = chin
275
+ self.empty = empty
276
+ self.stride = stride
277
+ self.kernel_size = kernel_size
278
+ self.norm = norm
279
+ self.context_freq = context_freq
280
+ klass = nn.Conv1d
281
+ klass_tr = nn.ConvTranspose1d
282
+ if freq:
283
+ kernel_size = [kernel_size, 1]
284
+ stride = [stride, 1]
285
+ klass = nn.Conv2d
286
+ klass_tr = nn.ConvTranspose2d
287
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
288
+ self.norm2 = norm_fn(chout)
289
+ if self.empty:
290
+ return
291
+ self.rewrite = None
292
+ if rewrite:
293
+ if context_freq:
294
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
295
+ else:
296
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
297
+ [0, context])
298
+ self.norm1 = norm_fn(2 * chin)
299
+
300
+ self.dconv = None
301
+ if dconv:
302
+ self.dconv = DConv(chin, **dconv_kw)
303
+
304
+ def forward(self, x, skip, length):
305
+ if self.freq and x.dim() == 3:
306
+ B, C, T = x.shape
307
+ x = x.view(B, self.chin, -1, T)
308
+
309
+ if not self.empty:
310
+ x = x + skip
311
+
312
+ if self.rewrite:
313
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
314
+ else:
315
+ y = x
316
+ if self.dconv:
317
+ if self.freq:
318
+ B, C, Fr, T = y.shape
319
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
320
+ y = self.dconv(y)
321
+ if self.freq:
322
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
323
+ else:
324
+ y = x
325
+ assert skip is None
326
+ z = self.norm2(self.conv_tr(y))
327
+ if self.freq:
328
+ if self.pad:
329
+ z = z[..., self.pad:-self.pad, :]
330
+ else:
331
+ z = z[..., self.pad:self.pad + length]
332
+ assert z.shape[-1] == length, (z.shape[-1], length)
333
+ if not self.last:
334
+ z = F.gelu(z)
335
+ return z, y
336
+
337
+
338
+ class HDemucs(nn.Module):
339
+ """
340
+ Spectrogram and hybrid Demucs model.
341
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
342
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
343
+ Frequency layers can still access information across time steps thanks to the DConv residual.
344
+
345
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
346
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
347
+
348
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
349
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
350
+ Open Unmix implementation [Stoter et al. 2019].
351
+
352
+ The loss is always on the temporal domain, by backpropagating through the above
353
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
354
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
355
+ contribution, without changing the one from the waveform, which will lead to worse performance.
356
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
357
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
358
+ hybrid models.
359
+
360
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
361
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
362
+
363
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
364
+ """
365
+ @capture_init
366
+ def __init__(self,
367
+ sources,
368
+ # Channels
369
+ audio_channels=2,
370
+ channels=48,
371
+ channels_time=None,
372
+ growth=2,
373
+ # STFT
374
+ nfft=4096,
375
+ wiener_iters=0,
376
+ end_iters=0,
377
+ wiener_residual=False,
378
+ cac=True,
379
+ # Main structure
380
+ depth=6,
381
+ rewrite=True,
382
+ hybrid=True,
383
+ hybrid_old=False,
384
+ # Frequency branch
385
+ multi_freqs=None,
386
+ multi_freqs_depth=2,
387
+ freq_emb=0.2,
388
+ emb_scale=10,
389
+ emb_smooth=True,
390
+ # Convolutions
391
+ kernel_size=8,
392
+ time_stride=2,
393
+ stride=4,
394
+ context=1,
395
+ context_enc=0,
396
+ # Normalization
397
+ norm_starts=4,
398
+ norm_groups=4,
399
+ # DConv residual branch
400
+ dconv_mode=1,
401
+ dconv_depth=2,
402
+ dconv_comp=4,
403
+ dconv_attn=4,
404
+ dconv_lstm=4,
405
+ dconv_init=1e-4,
406
+ # Weight init
407
+ rescale=0.1,
408
+ # Metadata
409
+ samplerate=44100,
410
+ segment=4 * 10):
411
+ """
412
+ Args:
413
+ sources (list[str]): list of source names.
414
+ audio_channels (int): input/output audio channels.
415
+ channels (int): initial number of hidden channels.
416
+ channels_time: if not None, use a different `channels` value for the time branch.
417
+ growth: increase the number of hidden channels by this factor at each layer.
418
+ nfft: number of fft bins. Note that changing this require careful computation of
419
+ various shape parameters and will not work out of the box for hybrid models.
420
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
421
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
422
+ wiener_residual: add residual source before wiener filtering.
423
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
424
+ in input and output. no further processing is done before ISTFT.
425
+ depth (int): number of layers in the encoder and in the decoder.
426
+ rewrite (bool): add 1x1 convolution to each layer.
427
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
428
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
429
+ this bug to avoid retraining them.
430
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
431
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
432
+ layers will be wrapped.
433
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
434
+ the actual value controls the weight of the embedding.
435
+ emb_scale: equivalent to scaling the embedding learning rate
436
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
437
+ kernel_size: kernel_size for encoder and decoder layers.
438
+ stride: stride for encoder and decoder layers.
439
+ time_stride: stride for the final time layer, after the merge.
440
+ context: context for 1x1 conv in the decoder.
441
+ context_enc: context for 1x1 conv in the encoder.
442
+ norm_starts: layer at which group norm starts being used.
443
+ decoder layers are numbered in reverse order.
444
+ norm_groups: number of groups for group norm.
445
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
446
+ dconv_depth: depth of residual DConv branch.
447
+ dconv_comp: compression of DConv branch.
448
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
449
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
450
+ dconv_init: initial scale for the DConv branch LayerScale.
451
+ rescale: weight recaling trick
452
+
453
+ """
454
+ super().__init__()
455
+ self.cac = cac
456
+ self.wiener_residual = wiener_residual
457
+ self.audio_channels = audio_channels
458
+ self.sources = sources
459
+ self.kernel_size = kernel_size
460
+ self.context = context
461
+ self.stride = stride
462
+ self.depth = depth
463
+ self.channels = channels
464
+ self.samplerate = samplerate
465
+ self.segment = segment
466
+
467
+ self.nfft = nfft
468
+ self.hop_length = nfft // 4
469
+ self.wiener_iters = wiener_iters
470
+ self.end_iters = end_iters
471
+ self.freq_emb = None
472
+ self.hybrid = hybrid
473
+ self.hybrid_old = hybrid_old
474
+ if hybrid_old:
475
+ assert hybrid, "hybrid_old must come with hybrid=True"
476
+ if hybrid:
477
+ assert wiener_iters == end_iters
478
+
479
+ self.encoder = nn.ModuleList()
480
+ self.decoder = nn.ModuleList()
481
+
482
+ if hybrid:
483
+ self.tencoder = nn.ModuleList()
484
+ self.tdecoder = nn.ModuleList()
485
+
486
+ chin = audio_channels
487
+ chin_z = chin # number of channels for the freq branch
488
+ if self.cac:
489
+ chin_z *= 2
490
+ chout = channels_time or channels
491
+ chout_z = channels
492
+ freqs = nfft // 2
493
+
494
+ for index in range(depth):
495
+ lstm = index >= dconv_lstm
496
+ attn = index >= dconv_attn
497
+ norm = index >= norm_starts
498
+ freq = freqs > 1
499
+ stri = stride
500
+ ker = kernel_size
501
+ if not freq:
502
+ assert freqs == 1
503
+ ker = time_stride * 2
504
+ stri = time_stride
505
+
506
+ pad = True
507
+ last_freq = False
508
+ if freq and freqs <= kernel_size:
509
+ ker = freqs
510
+ pad = False
511
+ last_freq = True
512
+
513
+ kw = {
514
+ 'kernel_size': ker,
515
+ 'stride': stri,
516
+ 'freq': freq,
517
+ 'pad': pad,
518
+ 'norm': norm,
519
+ 'rewrite': rewrite,
520
+ 'norm_groups': norm_groups,
521
+ 'dconv_kw': {
522
+ 'lstm': lstm,
523
+ 'attn': attn,
524
+ 'depth': dconv_depth,
525
+ 'compress': dconv_comp,
526
+ 'init': dconv_init,
527
+ 'gelu': True,
528
+ }
529
+ }
530
+ kwt = dict(kw)
531
+ kwt['freq'] = 0
532
+ kwt['kernel_size'] = kernel_size
533
+ kwt['stride'] = stride
534
+ kwt['pad'] = True
535
+ kw_dec = dict(kw)
536
+ multi = False
537
+ if multi_freqs and index < multi_freqs_depth:
538
+ multi = True
539
+ kw_dec['context_freq'] = False
540
+
541
+ if last_freq:
542
+ chout_z = max(chout, chout_z)
543
+ chout = chout_z
544
+
545
+ enc = HEncLayer(chin_z, chout_z,
546
+ dconv=dconv_mode & 1, context=context_enc, **kw)
547
+ if hybrid and freq:
548
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
549
+ empty=last_freq, **kwt)
550
+ self.tencoder.append(tenc)
551
+
552
+ if multi:
553
+ enc = MultiWrap(enc, multi_freqs)
554
+ self.encoder.append(enc)
555
+ if index == 0:
556
+ chin = self.audio_channels * len(self.sources)
557
+ chin_z = chin
558
+ if self.cac:
559
+ chin_z *= 2
560
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
561
+ last=index == 0, context=context, **kw_dec)
562
+ if multi:
563
+ dec = MultiWrap(dec, multi_freqs)
564
+ if hybrid and freq:
565
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
566
+ last=index == 0, context=context, **kwt)
567
+ self.tdecoder.insert(0, tdec)
568
+ self.decoder.insert(0, dec)
569
+
570
+ chin = chout
571
+ chin_z = chout_z
572
+ chout = int(growth * chout)
573
+ chout_z = int(growth * chout_z)
574
+ if freq:
575
+ if freqs <= kernel_size:
576
+ freqs = 1
577
+ else:
578
+ freqs //= stride
579
+ if index == 0 and freq_emb:
580
+ self.freq_emb = ScaledEmbedding(
581
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
582
+ self.freq_emb_scale = freq_emb
583
+
584
+ if rescale:
585
+ rescale_module(self, reference=rescale)
586
+
587
+ def _spec(self, x):
588
+ hl = self.hop_length
589
+ nfft = self.nfft
590
+ x0 = x # noqa
591
+
592
+ if self.hybrid:
593
+ # We re-pad the signal in order to keep the property
594
+ # that the size of the output is exactly the size of the input
595
+ # divided by the stride (here hop_length), when divisible.
596
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
597
+ # which is not supported by torch.stft.
598
+ # Having all convolution operations follow this convention allow to easily
599
+ # align the time and frequency branches later on.
600
+ assert hl == nfft // 4
601
+ le = int(math.ceil(x.shape[-1] / hl))
602
+ pad = hl // 2 * 3
603
+ if not self.hybrid_old:
604
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
605
+ else:
606
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
607
+
608
+ z = spectro(x, nfft, hl)[..., :-1, :]
609
+ if self.hybrid:
610
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
611
+ z = z[..., 2:2+le]
612
+ return z
613
+
614
+ def _ispec(self, z, length=None, scale=0):
615
+ hl = self.hop_length // (4 ** scale)
616
+ z = F.pad(z, (0, 0, 0, 1))
617
+ if self.hybrid:
618
+ z = F.pad(z, (2, 2))
619
+ pad = hl // 2 * 3
620
+ if not self.hybrid_old:
621
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
622
+ else:
623
+ le = hl * int(math.ceil(length / hl))
624
+ x = ispectro(z, hl, length=le)
625
+ if not self.hybrid_old:
626
+ x = x[..., pad:pad + length]
627
+ else:
628
+ x = x[..., :length]
629
+ else:
630
+ x = ispectro(z, hl, length)
631
+ return x
632
+
633
+ def _magnitude(self, z):
634
+ # return the magnitude of the spectrogram, except when cac is True,
635
+ # in which case we just move the complex dimension to the channel one.
636
+ if self.cac:
637
+ B, C, Fr, T = z.shape
638
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
639
+ m = m.reshape(B, C * 2, Fr, T)
640
+ else:
641
+ m = z.abs()
642
+ return m
643
+
644
+ def _mask(self, z, m):
645
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
646
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
647
+ niters = self.wiener_iters
648
+ if self.cac:
649
+ B, S, C, Fr, T = m.shape
650
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
651
+ out = torch.view_as_complex(out.contiguous())
652
+ return out
653
+ if self.training:
654
+ niters = self.end_iters
655
+ if niters < 0:
656
+ z = z[:, None]
657
+ return z / (1e-8 + z.abs()) * m
658
+ else:
659
+ return self._wiener(m, z, niters)
660
+
661
+ def _wiener(self, mag_out, mix_stft, niters):
662
+ # apply wiener filtering from OpenUnmix.
663
+ init = mix_stft.dtype
664
+ wiener_win_len = 300
665
+ residual = self.wiener_residual
666
+
667
+ B, S, C, Fq, T = mag_out.shape
668
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
669
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
670
+
671
+ outs = []
672
+ for sample in range(B):
673
+ pos = 0
674
+ out = []
675
+ for pos in range(0, T, wiener_win_len):
676
+ frame = slice(pos, pos + wiener_win_len)
677
+ z_out = wiener(
678
+ mag_out[sample, frame], mix_stft[sample, frame], niters,
679
+ residual=residual)
680
+ out.append(z_out.transpose(-1, -2))
681
+ outs.append(torch.cat(out, dim=0))
682
+ out = torch.view_as_complex(torch.stack(outs, 0))
683
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
684
+ if residual:
685
+ out = out[:, :-1]
686
+ assert list(out.shape) == [B, S, C, Fq, T]
687
+ return out.to(init)
688
+
689
+ def forward(self, mix):
690
+ x = mix
691
+ length = x.shape[-1]
692
+
693
+ z = self._spec(mix)
694
+ mag = self._magnitude(z)
695
+ x = mag
696
+
697
+ B, C, Fq, T = x.shape
698
+
699
+ # unlike previous Demucs, we always normalize because it is easier.
700
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
701
+ std = x.std(dim=(1, 2, 3), keepdim=True)
702
+ x = (x - mean) / (1e-5 + std)
703
+ # x will be the freq. branch input.
704
+
705
+ if self.hybrid:
706
+ # Prepare the time branch input.
707
+ xt = mix
708
+ meant = xt.mean(dim=(1, 2), keepdim=True)
709
+ stdt = xt.std(dim=(1, 2), keepdim=True)
710
+ xt = (xt - meant) / (1e-5 + stdt)
711
+
712
+ # okay, this is a giant mess I know...
713
+ saved = [] # skip connections, freq.
714
+ saved_t = [] # skip connections, time.
715
+ lengths = [] # saved lengths to properly remove padding, freq branch.
716
+ lengths_t = [] # saved lengths for time branch.
717
+ for idx, encode in enumerate(self.encoder):
718
+ lengths.append(x.shape[-1])
719
+ inject = None
720
+ if self.hybrid and idx < len(self.tencoder):
721
+ # we have not yet merged branches.
722
+ lengths_t.append(xt.shape[-1])
723
+ tenc = self.tencoder[idx]
724
+ xt = tenc(xt)
725
+ if not tenc.empty:
726
+ # save for skip connection
727
+ saved_t.append(xt)
728
+ else:
729
+ # tenc contains just the first conv., so that now time and freq.
730
+ # branches have the same shape and can be merged.
731
+ inject = xt
732
+ x = encode(x, inject)
733
+ if idx == 0 and self.freq_emb is not None:
734
+ # add frequency embedding to allow for non equivariant convolutions
735
+ # over the frequency axis.
736
+ frs = torch.arange(x.shape[-2], device=x.device)
737
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
738
+ x = x + self.freq_emb_scale * emb
739
+
740
+ saved.append(x)
741
+
742
+ x = torch.zeros_like(x)
743
+ if self.hybrid:
744
+ xt = torch.zeros_like(x)
745
+ # initialize everything to zero (signal will go through u-net skips).
746
+
747
+ for idx, decode in enumerate(self.decoder):
748
+ skip = saved.pop(-1)
749
+ x, pre = decode(x, skip, lengths.pop(-1))
750
+ # `pre` contains the output just before final transposed convolution,
751
+ # which is used when the freq. and time branch separate.
752
+
753
+ if self.hybrid:
754
+ offset = self.depth - len(self.tdecoder)
755
+ if self.hybrid and idx >= offset:
756
+ tdec = self.tdecoder[idx - offset]
757
+ length_t = lengths_t.pop(-1)
758
+ if tdec.empty:
759
+ assert pre.shape[2] == 1, pre.shape
760
+ pre = pre[:, :, 0]
761
+ xt, _ = tdec(pre, None, length_t)
762
+ else:
763
+ skip = saved_t.pop(-1)
764
+ xt, _ = tdec(xt, skip, length_t)
765
+
766
+ # Let's make sure we used all stored skip connections.
767
+ assert len(saved) == 0
768
+ assert len(lengths_t) == 0
769
+ assert len(saved_t) == 0
770
+
771
+ S = len(self.sources)
772
+ x = x.view(B, S, -1, Fq, T)
773
+ x = x * std[:, None] + mean[:, None]
774
+
775
+ zout = self._mask(z, x)
776
+ x = self._ispec(zout, length)
777
+
778
+ if self.hybrid:
779
+ xt = xt.view(B, S, -1, length)
780
+ xt = xt * stdt[:, None] + meant[:, None]
781
+ x = xt + x
782
+ return x
demucs3/htdemucs.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+ """
8
+ This code contains the spectrogram and Hybrid version of Demucs.
9
+ """
10
+ import math
11
+
12
+ from openunmix.filtering import wiener
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from fractions import Fraction
17
+ from einops import rearrange
18
+
19
+ from .transformer import CrossTransformerEncoder
20
+
21
+ from .demucs import rescale_module
22
+ from .states import capture_init
23
+ from .spec import spectro, ispectro
24
+ from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
25
+
26
+
27
+ class HTDemucs(nn.Module):
28
+ """
29
+ Spectrogram and hybrid Demucs model.
30
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
31
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
32
+ Frequency layers can still access information across time steps thanks to the DConv residual.
33
+
34
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
35
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
36
+
37
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
38
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
39
+ Open Unmix implementation [Stoter et al. 2019].
40
+
41
+ The loss is always on the temporal domain, by backpropagating through the above
42
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
43
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
44
+ contribution, without changing the one from the waveform, which will lead to worse performance.
45
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
46
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
47
+ hybrid models.
48
+
49
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
50
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
51
+
52
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
53
+ """
54
+
55
+ @capture_init
56
+ def __init__(
57
+ self,
58
+ sources,
59
+ # Channels
60
+ audio_channels=2,
61
+ channels=48,
62
+ channels_time=None,
63
+ growth=2,
64
+ # STFT
65
+ nfft=4096,
66
+ wiener_iters=0,
67
+ end_iters=0,
68
+ wiener_residual=False,
69
+ cac=True,
70
+ # Main structure
71
+ depth=4,
72
+ rewrite=True,
73
+ # Frequency branch
74
+ multi_freqs=None,
75
+ multi_freqs_depth=3,
76
+ freq_emb=0.2,
77
+ emb_scale=10,
78
+ emb_smooth=True,
79
+ # Convolutions
80
+ kernel_size=8,
81
+ time_stride=2,
82
+ stride=4,
83
+ context=1,
84
+ context_enc=0,
85
+ # Normalization
86
+ norm_starts=4,
87
+ norm_groups=4,
88
+ # DConv residual branch
89
+ dconv_mode=1,
90
+ dconv_depth=2,
91
+ dconv_comp=8,
92
+ dconv_init=1e-3,
93
+ # Before the Transformer
94
+ bottom_channels=0,
95
+ # Transformer
96
+ t_layers=5,
97
+ t_emb="sin",
98
+ t_hidden_scale=4.0,
99
+ t_heads=8,
100
+ t_dropout=0.0,
101
+ t_max_positions=10000,
102
+ t_norm_in=True,
103
+ t_norm_in_group=False,
104
+ t_group_norm=False,
105
+ t_norm_first=True,
106
+ t_norm_out=True,
107
+ t_max_period=10000.0,
108
+ t_weight_decay=0.0,
109
+ t_lr=None,
110
+ t_layer_scale=True,
111
+ t_gelu=True,
112
+ t_weight_pos_embed=1.0,
113
+ t_sin_random_shift=0,
114
+ t_cape_mean_normalize=True,
115
+ t_cape_augment=True,
116
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
117
+ t_sparse_self_attn=False,
118
+ t_sparse_cross_attn=False,
119
+ t_mask_type="diag",
120
+ t_mask_random_seed=42,
121
+ t_sparse_attn_window=500,
122
+ t_global_window=100,
123
+ t_sparsity=0.95,
124
+ t_auto_sparsity=False,
125
+ # ------ Particuliar parameters
126
+ t_cross_first=False,
127
+ # Weight init
128
+ rescale=0.1,
129
+ # Metadata
130
+ samplerate=44100,
131
+ segment=10,
132
+ use_train_segment=True,
133
+ ):
134
+ """
135
+ Args:
136
+ sources (list[str]): list of source names.
137
+ audio_channels (int): input/output audio channels.
138
+ channels (int): initial number of hidden channels.
139
+ channels_time: if not None, use a different `channels` value for the time branch.
140
+ growth: increase the number of hidden channels by this factor at each layer.
141
+ nfft: number of fft bins. Note that changing this require careful computation of
142
+ various shape parameters and will not work out of the box for hybrid models.
143
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
144
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
145
+ wiener_residual: add residual source before wiener filtering.
146
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
147
+ in input and output. no further processing is done before ISTFT.
148
+ depth (int): number of layers in the encoder and in the decoder.
149
+ rewrite (bool): add 1x1 convolution to each layer.
150
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
151
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
152
+ layers will be wrapped.
153
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
154
+ the actual value controls the weight of the embedding.
155
+ emb_scale: equivalent to scaling the embedding learning rate
156
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
157
+ kernel_size: kernel_size for encoder and decoder layers.
158
+ stride: stride for encoder and decoder layers.
159
+ time_stride: stride for the final time layer, after the merge.
160
+ context: context for 1x1 conv in the decoder.
161
+ context_enc: context for 1x1 conv in the encoder.
162
+ norm_starts: layer at which group norm starts being used.
163
+ decoder layers are numbered in reverse order.
164
+ norm_groups: number of groups for group norm.
165
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
166
+ dconv_depth: depth of residual DConv branch.
167
+ dconv_comp: compression of DConv branch.
168
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
169
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
170
+ dconv_init: initial scale for the DConv branch LayerScale.
171
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
172
+ transformer in order to change the number of channels
173
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
174
+ t_emb: "sin", "cape" or "scaled"
175
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
176
+ for instance if C = 384 (the number of channels in the transformer) and
177
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
178
+ 384 * 4 = 1536
179
+ t_heads: number of heads for the transformer
180
+ t_dropout: dropout in the transformer
181
+ t_max_positions: max_positions for the "scaled" positional embedding, only
182
+ useful if t_emb="scaled"
183
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
184
+ transformer layers
185
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
186
+ timesteps (GroupNorm with group=1)
187
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
188
+ timesteps (GroupNorm with group=1)
189
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
190
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
191
+ t_max_period: (float) denominator in the sinusoidal embedding expression
192
+ t_weight_decay: (float) weight decay for the transformer
193
+ t_lr: (float) specific learning rate for the transformer
194
+ t_layer_scale: (bool) Layer Scale for the transformer
195
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
196
+ t_weight_pos_embed: (float) weighting of the positional embedding
197
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
198
+ see: https://arxiv.org/abs/2106.03143
199
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
200
+ during the inference, see: https://arxiv.org/abs/2106.03143
201
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
202
+ see: https://arxiv.org/abs/2106.03143
203
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
204
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
205
+ unless you designed really specific masks)
206
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
207
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
208
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
209
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
210
+ that generated the random part of the mask
211
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
212
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
213
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
214
+ and mask[:, :t_global_window] will be True
215
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
216
+ level of the random part of the mask.
217
+ t_cross_first: (bool) if True cross attention is the first layer of the
218
+ transformer (False seems to be better)
219
+ rescale: weight rescaling trick
220
+ use_train_segment: (bool) if True, the actual size that is used during the
221
+ training is used during inference.
222
+ """
223
+ super().__init__()
224
+ self.cac = cac
225
+ self.wiener_residual = wiener_residual
226
+ self.audio_channels = audio_channels
227
+ self.sources = sources
228
+ self.kernel_size = kernel_size
229
+ self.context = context
230
+ self.stride = stride
231
+ self.depth = depth
232
+ self.bottom_channels = bottom_channels
233
+ self.channels = channels
234
+ self.samplerate = samplerate
235
+ self.segment = segment
236
+ self.use_train_segment = use_train_segment
237
+ self.nfft = nfft
238
+ self.hop_length = nfft // 4
239
+ self.wiener_iters = wiener_iters
240
+ self.end_iters = end_iters
241
+ self.freq_emb = None
242
+ assert wiener_iters == end_iters
243
+
244
+ self.encoder = nn.ModuleList()
245
+ self.decoder = nn.ModuleList()
246
+
247
+ self.tencoder = nn.ModuleList()
248
+ self.tdecoder = nn.ModuleList()
249
+
250
+ chin = audio_channels
251
+ chin_z = chin # number of channels for the freq branch
252
+ if self.cac:
253
+ chin_z *= 2
254
+ chout = channels_time or channels
255
+ chout_z = channels
256
+ freqs = nfft // 2
257
+
258
+ for index in range(depth):
259
+ norm = index >= norm_starts
260
+ freq = freqs > 1
261
+ stri = stride
262
+ ker = kernel_size
263
+ if not freq:
264
+ assert freqs == 1
265
+ ker = time_stride * 2
266
+ stri = time_stride
267
+
268
+ pad = True
269
+ last_freq = False
270
+ if freq and freqs <= kernel_size:
271
+ ker = freqs
272
+ pad = False
273
+ last_freq = True
274
+
275
+ kw = {
276
+ "kernel_size": ker,
277
+ "stride": stri,
278
+ "freq": freq,
279
+ "pad": pad,
280
+ "norm": norm,
281
+ "rewrite": rewrite,
282
+ "norm_groups": norm_groups,
283
+ "dconv_kw": {
284
+ "depth": dconv_depth,
285
+ "compress": dconv_comp,
286
+ "init": dconv_init,
287
+ "gelu": True,
288
+ },
289
+ }
290
+ kwt = dict(kw)
291
+ kwt["freq"] = 0
292
+ kwt["kernel_size"] = kernel_size
293
+ kwt["stride"] = stride
294
+ kwt["pad"] = True
295
+ kw_dec = dict(kw)
296
+ multi = False
297
+ if multi_freqs and index < multi_freqs_depth:
298
+ multi = True
299
+ kw_dec["context_freq"] = False
300
+
301
+ if last_freq:
302
+ chout_z = max(chout, chout_z)
303
+ chout = chout_z
304
+
305
+ enc = HEncLayer(
306
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
307
+ )
308
+ if freq:
309
+ tenc = HEncLayer(
310
+ chin,
311
+ chout,
312
+ dconv=dconv_mode & 1,
313
+ context=context_enc,
314
+ empty=last_freq,
315
+ **kwt
316
+ )
317
+ self.tencoder.append(tenc)
318
+
319
+ if multi:
320
+ enc = MultiWrap(enc, multi_freqs)
321
+ self.encoder.append(enc)
322
+ if index == 0:
323
+ chin = self.audio_channels * len(self.sources)
324
+ chin_z = chin
325
+ if self.cac:
326
+ chin_z *= 2
327
+ dec = HDecLayer(
328
+ chout_z,
329
+ chin_z,
330
+ dconv=dconv_mode & 2,
331
+ last=index == 0,
332
+ context=context,
333
+ **kw_dec
334
+ )
335
+ if multi:
336
+ dec = MultiWrap(dec, multi_freqs)
337
+ if freq:
338
+ tdec = HDecLayer(
339
+ chout,
340
+ chin,
341
+ dconv=dconv_mode & 2,
342
+ empty=last_freq,
343
+ last=index == 0,
344
+ context=context,
345
+ **kwt
346
+ )
347
+ self.tdecoder.insert(0, tdec)
348
+ self.decoder.insert(0, dec)
349
+
350
+ chin = chout
351
+ chin_z = chout_z
352
+ chout = int(growth * chout)
353
+ chout_z = int(growth * chout_z)
354
+ if freq:
355
+ if freqs <= kernel_size:
356
+ freqs = 1
357
+ else:
358
+ freqs //= stride
359
+ if index == 0 and freq_emb:
360
+ self.freq_emb = ScaledEmbedding(
361
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
362
+ )
363
+ self.freq_emb_scale = freq_emb
364
+
365
+ if rescale:
366
+ rescale_module(self, reference=rescale)
367
+
368
+ transformer_channels = channels * growth ** (depth - 1)
369
+ if bottom_channels:
370
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
371
+ self.channel_downsampler = nn.Conv1d(
372
+ bottom_channels, transformer_channels, 1
373
+ )
374
+ self.channel_upsampler_t = nn.Conv1d(
375
+ transformer_channels, bottom_channels, 1
376
+ )
377
+ self.channel_downsampler_t = nn.Conv1d(
378
+ bottom_channels, transformer_channels, 1
379
+ )
380
+
381
+ transformer_channels = bottom_channels
382
+
383
+ if t_layers > 0:
384
+ self.crosstransformer = CrossTransformerEncoder(
385
+ dim=transformer_channels,
386
+ emb=t_emb,
387
+ hidden_scale=t_hidden_scale,
388
+ num_heads=t_heads,
389
+ num_layers=t_layers,
390
+ cross_first=t_cross_first,
391
+ dropout=t_dropout,
392
+ max_positions=t_max_positions,
393
+ norm_in=t_norm_in,
394
+ norm_in_group=t_norm_in_group,
395
+ group_norm=t_group_norm,
396
+ norm_first=t_norm_first,
397
+ norm_out=t_norm_out,
398
+ max_period=t_max_period,
399
+ weight_decay=t_weight_decay,
400
+ lr=t_lr,
401
+ layer_scale=t_layer_scale,
402
+ gelu=t_gelu,
403
+ sin_random_shift=t_sin_random_shift,
404
+ weight_pos_embed=t_weight_pos_embed,
405
+ cape_mean_normalize=t_cape_mean_normalize,
406
+ cape_augment=t_cape_augment,
407
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
408
+ sparse_self_attn=t_sparse_self_attn,
409
+ sparse_cross_attn=t_sparse_cross_attn,
410
+ mask_type=t_mask_type,
411
+ mask_random_seed=t_mask_random_seed,
412
+ sparse_attn_window=t_sparse_attn_window,
413
+ global_window=t_global_window,
414
+ sparsity=t_sparsity,
415
+ auto_sparsity=t_auto_sparsity,
416
+ )
417
+ else:
418
+ self.crosstransformer = None
419
+
420
+ def _spec(self, x):
421
+ hl = self.hop_length
422
+ nfft = self.nfft
423
+ x0 = x # noqa
424
+
425
+ # We re-pad the signal in order to keep the property
426
+ # that the size of the output is exactly the size of the input
427
+ # divided by the stride (here hop_length), when divisible.
428
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
429
+ # which is not supported by torch.stft.
430
+ # Having all convolution operations follow this convention allow to easily
431
+ # align the time and frequency branches later on.
432
+ assert hl == nfft // 4
433
+ le = int(math.ceil(x.shape[-1] / hl))
434
+ pad = hl // 2 * 3
435
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
436
+
437
+ z = spectro(x, nfft, hl)[..., :-1, :]
438
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
439
+ z = z[..., 2: 2 + le]
440
+ return z
441
+
442
+ def _ispec(self, z, length=None, scale=0):
443
+ hl = self.hop_length // (4**scale)
444
+ z = F.pad(z, (0, 0, 0, 1))
445
+ z = F.pad(z, (2, 2))
446
+ pad = hl // 2 * 3
447
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
448
+ x = ispectro(z, hl, length=le)
449
+ x = x[..., pad: pad + length]
450
+ return x
451
+
452
+ def _magnitude(self, z):
453
+ # return the magnitude of the spectrogram, except when cac is True,
454
+ # in which case we just move the complex dimension to the channel one.
455
+ if self.cac:
456
+ B, C, Fr, T = z.shape
457
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
458
+ m = m.reshape(B, C * 2, Fr, T)
459
+ else:
460
+ m = z.abs()
461
+ return m
462
+
463
+ def _mask(self, z, m):
464
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
465
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
466
+ niters = self.wiener_iters
467
+ if self.cac:
468
+ B, S, C, Fr, T = m.shape
469
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
470
+ out = torch.view_as_complex(out.contiguous())
471
+ return out
472
+ if self.training:
473
+ niters = self.end_iters
474
+ if niters < 0:
475
+ z = z[:, None]
476
+ return z / (1e-8 + z.abs()) * m
477
+ else:
478
+ return self._wiener(m, z, niters)
479
+
480
+ def _wiener(self, mag_out, mix_stft, niters):
481
+ # apply wiener filtering from OpenUnmix.
482
+ init = mix_stft.dtype
483
+ wiener_win_len = 300
484
+ residual = self.wiener_residual
485
+
486
+ B, S, C, Fq, T = mag_out.shape
487
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
488
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
489
+
490
+ outs = []
491
+ for sample in range(B):
492
+ pos = 0
493
+ out = []
494
+ for pos in range(0, T, wiener_win_len):
495
+ frame = slice(pos, pos + wiener_win_len)
496
+ z_out = wiener(
497
+ mag_out[sample, frame],
498
+ mix_stft[sample, frame],
499
+ niters,
500
+ residual=residual,
501
+ )
502
+ out.append(z_out.transpose(-1, -2))
503
+ outs.append(torch.cat(out, dim=0))
504
+ out = torch.view_as_complex(torch.stack(outs, 0))
505
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
506
+ if residual:
507
+ out = out[:, :-1]
508
+ assert list(out.shape) == [B, S, C, Fq, T]
509
+ return out.to(init)
510
+
511
+ def valid_length(self, length: int):
512
+ """
513
+ Return a length that is appropriate for evaluation.
514
+ In our case, always return the training length, unless
515
+ it is smaller than the given length, in which case this
516
+ raises an error.
517
+ """
518
+ if not self.use_train_segment:
519
+ return length
520
+ training_length = int(self.segment * self.samplerate)
521
+ if training_length < length:
522
+ raise ValueError(
523
+ f"Given length {length} is longer than "
524
+ f"training length {training_length}")
525
+ return training_length
526
+
527
+ def forward(self, mix):
528
+ length = mix.shape[-1]
529
+ length_pre_pad = None
530
+ if self.use_train_segment:
531
+ if self.training:
532
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
533
+ else:
534
+ training_length = int(self.segment * self.samplerate)
535
+ if mix.shape[-1] < training_length:
536
+ length_pre_pad = mix.shape[-1]
537
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
538
+ z = self._spec(mix)
539
+ mag = self._magnitude(z)
540
+ x = mag
541
+
542
+ B, C, Fq, T = x.shape
543
+
544
+ # unlike previous Demucs, we always normalize because it is easier.
545
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
546
+ std = x.std(dim=(1, 2, 3), keepdim=True)
547
+ x = (x - mean) / (1e-5 + std)
548
+ # x will be the freq. branch input.
549
+
550
+ # Prepare the time branch input.
551
+ xt = mix
552
+ meant = xt.mean(dim=(1, 2), keepdim=True)
553
+ stdt = xt.std(dim=(1, 2), keepdim=True)
554
+ xt = (xt - meant) / (1e-5 + stdt)
555
+
556
+ # okay, this is a giant mess I know...
557
+ saved = [] # skip connections, freq.
558
+ saved_t = [] # skip connections, time.
559
+ lengths = [] # saved lengths to properly remove padding, freq branch.
560
+ lengths_t = [] # saved lengths for time branch.
561
+ for idx, encode in enumerate(self.encoder):
562
+ lengths.append(x.shape[-1])
563
+ inject = None
564
+ if idx < len(self.tencoder):
565
+ # we have not yet merged branches.
566
+ lengths_t.append(xt.shape[-1])
567
+ tenc = self.tencoder[idx]
568
+ xt = tenc(xt)
569
+ if not tenc.empty:
570
+ # save for skip connection
571
+ saved_t.append(xt)
572
+ else:
573
+ # tenc contains just the first conv., so that now time and freq.
574
+ # branches have the same shape and can be merged.
575
+ inject = xt
576
+ x = encode(x, inject)
577
+ if idx == 0 and self.freq_emb is not None:
578
+ # add frequency embedding to allow for non equivariant convolutions
579
+ # over the frequency axis.
580
+ frs = torch.arange(x.shape[-2], device=x.device)
581
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
582
+ x = x + self.freq_emb_scale * emb
583
+
584
+ saved.append(x)
585
+ if self.crosstransformer:
586
+ if self.bottom_channels:
587
+ b, c, f, t = x.shape
588
+ x = rearrange(x, "b c f t-> b c (f t)")
589
+ x = self.channel_upsampler(x)
590
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
591
+ xt = self.channel_upsampler_t(xt)
592
+
593
+ x, xt = self.crosstransformer(x, xt)
594
+
595
+ if self.bottom_channels:
596
+ x = rearrange(x, "b c f t-> b c (f t)")
597
+ x = self.channel_downsampler(x)
598
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
599
+ xt = self.channel_downsampler_t(xt)
600
+
601
+ for idx, decode in enumerate(self.decoder):
602
+ skip = saved.pop(-1)
603
+ x, pre = decode(x, skip, lengths.pop(-1))
604
+ # `pre` contains the output just before final transposed convolution,
605
+ # which is used when the freq. and time branch separate.
606
+
607
+ offset = self.depth - len(self.tdecoder)
608
+ if idx >= offset:
609
+ tdec = self.tdecoder[idx - offset]
610
+ length_t = lengths_t.pop(-1)
611
+ if tdec.empty:
612
+ assert pre.shape[2] == 1, pre.shape
613
+ pre = pre[:, :, 0]
614
+ xt, _ = tdec(pre, None, length_t)
615
+ else:
616
+ skip = saved_t.pop(-1)
617
+ xt, _ = tdec(xt, skip, length_t)
618
+
619
+ # Let's make sure we used all stored skip connections.
620
+ assert len(saved) == 0
621
+ assert len(lengths_t) == 0
622
+ assert len(saved_t) == 0
623
+
624
+ S = len(self.sources)
625
+ x = x.view(B, S, -1, Fq, T)
626
+ x = x * std[:, None] + mean[:, None]
627
+
628
+ zout = self._mask(z, x)
629
+ if self.use_train_segment:
630
+ if self.training:
631
+ x = self._ispec(zout, length)
632
+ else:
633
+ x = self._ispec(zout, training_length)
634
+ else:
635
+ x = self._ispec(zout, length)
636
+
637
+ if self.use_train_segment:
638
+ if self.training:
639
+ xt = xt.view(B, S, -1, length)
640
+ else:
641
+ xt = xt.view(B, S, -1, training_length)
642
+ else:
643
+ xt = xt.view(B, S, -1, length)
644
+ xt = xt * stdt[:, None] + meant[:, None]
645
+ x = xt + x
646
+ if length_pre_pad:
647
+ x = x[..., :length_pre_pad]
648
+ return x
demucs3/spec.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Conveniance wrapper to perform STFT and iSTFT"""
7
+
8
+ import torch as th
9
+
10
+
11
+ def spectro(x, n_fft=512, hop_length=None, pad=0):
12
+ *other, length = x.shape
13
+ x = x.reshape(-1, length)
14
+ z = th.stft(x,
15
+ n_fft * (1 + pad),
16
+ hop_length or n_fft // 4,
17
+ window=th.hann_window(n_fft).to(x),
18
+ win_length=n_fft,
19
+ normalized=True,
20
+ center=True,
21
+ return_complex=True,
22
+ pad_mode='reflect')
23
+ _, freqs, frame = z.shape
24
+ return z.view(*other, freqs, frame)
25
+
26
+
27
+ def ispectro(z, hop_length=None, length=None, pad=0):
28
+ *other, freqs, frames = z.shape
29
+ n_fft = 2 * freqs - 2
30
+ z = z.view(-1, freqs, frames)
31
+ win_length = n_fft // (1 + pad)
32
+ x = th.istft(z,
33
+ n_fft,
34
+ hop_length,
35
+ window=th.hann_window(win_length).to(z.real),
36
+ win_length=win_length,
37
+ normalized=True,
38
+ length=length,
39
+ center=True)
40
+ _, length = x.shape
41
+ return x.view(*other, length)
demucs3/states.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Utilities to save and load models.
8
+ """
9
+ from contextlib import contextmanager
10
+
11
+ import functools
12
+ import hashlib
13
+ import inspect
14
+ import io
15
+ from pathlib import Path
16
+ import warnings
17
+
18
+ from omegaconf import OmegaConf
19
+ from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
20
+ import torch
21
+
22
+
23
+ def get_quantizer(model, args, optimizer=None):
24
+ """Return the quantizer given the XP quantization args."""
25
+ quantizer = None
26
+ if args.diffq:
27
+ quantizer = DiffQuantizer(
28
+ model, min_size=args.min_size, group_size=args.group_size)
29
+ if optimizer is not None:
30
+ quantizer.setup_optimizer(optimizer)
31
+ elif args.qat:
32
+ quantizer = UniformQuantizer(
33
+ model, bits=args.qat, min_size=args.min_size)
34
+ return quantizer
35
+
36
+
37
+ def load_model(path_or_package, strict=False):
38
+ """Load a model from the given serialized model, either given as a dict (already loaded)
39
+ or a path to a file on disk."""
40
+ if isinstance(path_or_package, dict):
41
+ package = path_or_package
42
+ elif isinstance(path_or_package, (str, Path)):
43
+ with warnings.catch_warnings():
44
+ warnings.simplefilter("ignore")
45
+ path = path_or_package
46
+ package = torch.load(path, 'cpu')
47
+ else:
48
+ raise ValueError(f"Invalid type for {path_or_package}.")
49
+
50
+ klass = package["klass"]
51
+ args = package["args"]
52
+ kwargs = package["kwargs"]
53
+
54
+ if strict:
55
+ model = klass(*args, **kwargs)
56
+ else:
57
+ sig = inspect.signature(klass)
58
+ for key in list(kwargs):
59
+ if key not in sig.parameters:
60
+ warnings.warn("Dropping inexistant parameter " + key)
61
+ del kwargs[key]
62
+ model = klass(*args, **kwargs)
63
+
64
+ state = package["state"]
65
+
66
+ set_state(model, state)
67
+ return model
68
+
69
+
70
+ def get_state(model, quantizer, half=False):
71
+ """Get the state from a model, potentially with quantization applied.
72
+ If `half` is True, model are stored as half precision, which shouldn't impact performance
73
+ but half the state size."""
74
+ if quantizer is None:
75
+ dtype = torch.half if half else None
76
+ state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
77
+ else:
78
+ state = quantizer.get_quantized_state()
79
+ state['__quantized'] = True
80
+ return state
81
+
82
+
83
+ def set_state(model, state, quantizer=None):
84
+ """Set the state on a given model."""
85
+ if state.get('__quantized'):
86
+ if quantizer is not None:
87
+ quantizer.restore_quantized_state(model, state['quantized'])
88
+ else:
89
+ restore_quantized_state(model, state)
90
+ else:
91
+ model.load_state_dict(state)
92
+ return state
93
+
94
+
95
+ def save_with_checksum(content, path):
96
+ """Save the given value on disk, along with a sha256 hash.
97
+ Should be used with the output of either `serialize_model` or `get_state`."""
98
+ buf = io.BytesIO()
99
+ torch.save(content, buf)
100
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
101
+
102
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
103
+ path.write_bytes(buf.getvalue())
104
+
105
+
106
+ def serialize_model(model, training_args, quantizer=None, half=True):
107
+ args, kwargs = model._init_args_kwargs
108
+ klass = model.__class__
109
+
110
+ state = get_state(model, quantizer, half)
111
+ return {
112
+ 'klass': klass,
113
+ 'args': args,
114
+ 'kwargs': kwargs,
115
+ 'state': state,
116
+ 'training_args': OmegaConf.to_container(training_args, resolve=True),
117
+ }
118
+
119
+
120
+ def copy_state(state):
121
+ return {k: v.cpu().clone() for k, v in state.items()}
122
+
123
+
124
+ @contextmanager
125
+ def swap_state(model, state):
126
+ """
127
+ Context manager that swaps the state of a model, e.g:
128
+
129
+ # model is in old state
130
+ with swap_state(model, new_state):
131
+ # model in new state
132
+ # model back to old state
133
+ """
134
+ old_state = copy_state(model.state_dict())
135
+ model.load_state_dict(state, strict=False)
136
+ try:
137
+ yield
138
+ finally:
139
+ model.load_state_dict(old_state)
140
+
141
+
142
+ def capture_init(init):
143
+ @functools.wraps(init)
144
+ def __init__(self, *args, **kwargs):
145
+ self._init_args_kwargs = (args, kwargs)
146
+ init(self, *args, **kwargs)
147
+
148
+ return __init__
demucs3/transformer.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Meta, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+
8
+ import random
9
+ import typing as tp
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import math
16
+ from einops import rearrange
17
+
18
+
19
+ def create_sin_embedding(
20
+ length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
21
+ ):
22
+ # We aim for TBC format
23
+ assert dim % 2 == 0
24
+ pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
25
+ half_dim = dim // 2
26
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
27
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
28
+ return torch.cat(
29
+ [
30
+ torch.cos(phase),
31
+ torch.sin(phase),
32
+ ],
33
+ dim=-1,
34
+ )
35
+
36
+
37
+ def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
38
+ """
39
+ :param d_model: dimension of the model
40
+ :param height: height of the positions
41
+ :param width: width of the positions
42
+ :return: d_model*height*width position matrix
43
+ """
44
+ if d_model % 4 != 0:
45
+ raise ValueError(
46
+ "Cannot use sin/cos positional encoding with "
47
+ "odd dimension (got dim={:d})".format(d_model)
48
+ )
49
+ pe = torch.zeros(d_model, height, width)
50
+ # Each dimension use half of d_model
51
+ d_model = int(d_model / 2)
52
+ div_term = torch.exp(
53
+ torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
54
+ )
55
+ pos_w = torch.arange(0.0, width).unsqueeze(1)
56
+ pos_h = torch.arange(0.0, height).unsqueeze(1)
57
+ pe[0:d_model:2, :, :] = (
58
+ torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
59
+ )
60
+ pe[1:d_model:2, :, :] = (
61
+ torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
62
+ )
63
+ pe[d_model::2, :, :] = (
64
+ torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
65
+ )
66
+ pe[d_model + 1:: 2, :, :] = (
67
+ torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
68
+ )
69
+
70
+ return pe[None, :].to(device)
71
+
72
+
73
+ def create_sin_embedding_cape(
74
+ length: int,
75
+ dim: int,
76
+ batch_size: int,
77
+ mean_normalize: bool,
78
+ augment: bool, # True during training
79
+ max_global_shift: float = 0.0, # delta max
80
+ max_local_shift: float = 0.0, # epsilon max
81
+ max_scale: float = 1.0,
82
+ device: str = "cpu",
83
+ max_period: float = 10000.0,
84
+ ):
85
+ # We aim for TBC format
86
+ assert dim % 2 == 0
87
+ pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
88
+ pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
89
+ if mean_normalize:
90
+ pos -= torch.nanmean(pos, dim=0, keepdim=True)
91
+
92
+ if augment:
93
+ delta = np.random.uniform(
94
+ -max_global_shift, +max_global_shift, size=[1, batch_size, 1]
95
+ )
96
+ delta_local = np.random.uniform(
97
+ -max_local_shift, +max_local_shift, size=[length, batch_size, 1]
98
+ )
99
+ log_lambdas = np.random.uniform(
100
+ -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
101
+ )
102
+ pos = (pos + delta + delta_local) * np.exp(log_lambdas)
103
+
104
+ pos = pos.to(device)
105
+
106
+ half_dim = dim // 2
107
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
108
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
109
+ return torch.cat(
110
+ [
111
+ torch.cos(phase),
112
+ torch.sin(phase),
113
+ ],
114
+ dim=-1,
115
+ ).float()
116
+
117
+
118
+ def get_causal_mask(length):
119
+ pos = torch.arange(length)
120
+ return pos > pos[:, None]
121
+
122
+
123
+ def get_elementary_mask(
124
+ T1,
125
+ T2,
126
+ mask_type,
127
+ sparse_attn_window,
128
+ global_window,
129
+ mask_random_seed,
130
+ sparsity,
131
+ device,
132
+ ):
133
+ """
134
+ When the input of the Decoder has length T1 and the output T2
135
+ The mask matrix has shape (T2, T1)
136
+ """
137
+ assert mask_type in ["diag", "jmask", "random", "global"]
138
+
139
+ if mask_type == "global":
140
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
141
+ mask[:, :global_window] = True
142
+ line_window = int(global_window * T2 / T1)
143
+ mask[:line_window, :] = True
144
+
145
+ if mask_type == "diag":
146
+
147
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
148
+ rows = torch.arange(T2)[:, None]
149
+ cols = (
150
+ (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
151
+ .long()
152
+ .clamp(0, T1 - 1)
153
+ )
154
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
155
+
156
+ elif mask_type == "jmask":
157
+ mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
158
+ rows = torch.arange(T2 + 2)[:, None]
159
+ t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
160
+ t = (t * (t + 1) / 2).int()
161
+ t = torch.cat([-t.flip(0)[:-1], t])
162
+ cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
163
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
164
+ mask = mask[1:-1, 1:-1]
165
+
166
+ elif mask_type == "random":
167
+ gene = torch.Generator(device=device)
168
+ gene.manual_seed(mask_random_seed)
169
+ mask = (
170
+ torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
171
+ > sparsity
172
+ )
173
+
174
+ mask = mask.to(device)
175
+ return mask
176
+
177
+
178
+ def get_mask(
179
+ T1,
180
+ T2,
181
+ mask_type,
182
+ sparse_attn_window,
183
+ global_window,
184
+ mask_random_seed,
185
+ sparsity,
186
+ device,
187
+ ):
188
+ """
189
+ Return a SparseCSRTensor mask that is a combination of elementary masks
190
+ mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
191
+ """
192
+ from xformers.sparse import SparseCSRTensor
193
+ # create a list
194
+ mask_types = mask_type.split("_")
195
+
196
+ all_masks = [
197
+ get_elementary_mask(
198
+ T1,
199
+ T2,
200
+ mask,
201
+ sparse_attn_window,
202
+ global_window,
203
+ mask_random_seed,
204
+ sparsity,
205
+ device,
206
+ )
207
+ for mask in mask_types
208
+ ]
209
+
210
+ final_mask = torch.stack(all_masks).sum(axis=0) > 0
211
+
212
+ return SparseCSRTensor.from_dense(final_mask[None])
213
+
214
+
215
+ class ScaledEmbedding(nn.Module):
216
+ def __init__(
217
+ self,
218
+ num_embeddings: int,
219
+ embedding_dim: int,
220
+ scale: float = 1.0,
221
+ boost: float = 3.0,
222
+ ):
223
+ super().__init__()
224
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
225
+ self.embedding.weight.data *= scale / boost
226
+ self.boost = boost
227
+
228
+ @property
229
+ def weight(self):
230
+ return self.embedding.weight * self.boost
231
+
232
+ def forward(self, x):
233
+ return self.embedding(x) * self.boost
234
+
235
+
236
+ class LayerScale(nn.Module):
237
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
238
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
239
+ """
240
+
241
+ def __init__(self, channels: int, init: float = 0, channel_last=False):
242
+ """
243
+ channel_last = False corresponds to (B, C, T) tensors
244
+ channel_last = True corresponds to (T, B, C) tensors
245
+ """
246
+ super().__init__()
247
+ self.channel_last = channel_last
248
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
249
+ self.scale.data[:] = init
250
+
251
+ def forward(self, x):
252
+ if self.channel_last:
253
+ return self.scale * x
254
+ else:
255
+ return self.scale[:, None] * x
256
+
257
+
258
+ class MyGroupNorm(nn.GroupNorm):
259
+ def __init__(self, *args, **kwargs):
260
+ super().__init__(*args, **kwargs)
261
+
262
+ def forward(self, x):
263
+ """
264
+ x: (B, T, C)
265
+ if num_groups=1: Normalisation on all T and C together for each B
266
+ """
267
+ x = x.transpose(1, 2)
268
+ return super().forward(x).transpose(1, 2)
269
+
270
+
271
+ class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
272
+ def __init__(
273
+ self,
274
+ d_model,
275
+ nhead,
276
+ dim_feedforward=2048,
277
+ dropout=0.1,
278
+ activation=F.relu,
279
+ group_norm=0,
280
+ norm_first=False,
281
+ norm_out=False,
282
+ layer_norm_eps=1e-5,
283
+ layer_scale=False,
284
+ init_values=1e-4,
285
+ device=None,
286
+ dtype=None,
287
+ sparse=False,
288
+ mask_type="diag",
289
+ mask_random_seed=42,
290
+ sparse_attn_window=500,
291
+ global_window=50,
292
+ auto_sparsity=False,
293
+ sparsity=0.95,
294
+ batch_first=False,
295
+ ):
296
+ factory_kwargs = {"device": device, "dtype": dtype}
297
+ super().__init__(
298
+ d_model=d_model,
299
+ nhead=nhead,
300
+ dim_feedforward=dim_feedforward,
301
+ dropout=dropout,
302
+ activation=activation,
303
+ layer_norm_eps=layer_norm_eps,
304
+ batch_first=batch_first,
305
+ norm_first=norm_first,
306
+ device=device,
307
+ dtype=dtype,
308
+ )
309
+ self.sparse = sparse
310
+ self.auto_sparsity = auto_sparsity
311
+ if sparse:
312
+ if not auto_sparsity:
313
+ self.mask_type = mask_type
314
+ self.sparse_attn_window = sparse_attn_window
315
+ self.global_window = global_window
316
+ self.sparsity = sparsity
317
+ if group_norm:
318
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
319
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
320
+
321
+ self.norm_out = None
322
+ if self.norm_first & norm_out:
323
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
324
+ self.gamma_1 = (
325
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
326
+ )
327
+ self.gamma_2 = (
328
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
329
+ )
330
+
331
+ if sparse:
332
+ self.self_attn = MultiheadAttention(
333
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
334
+ auto_sparsity=sparsity if auto_sparsity else 0,
335
+ )
336
+ self.__setattr__("src_mask", torch.zeros(1, 1))
337
+ self.mask_random_seed = mask_random_seed
338
+
339
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
340
+ """
341
+ if batch_first = False, src shape is (T, B, C)
342
+ the case where batch_first=True is not covered
343
+ """
344
+ device = src.device
345
+ x = src
346
+ T, B, C = x.shape
347
+ if self.sparse and not self.auto_sparsity:
348
+ assert src_mask is None
349
+ src_mask = self.src_mask
350
+ if src_mask.shape[-1] != T:
351
+ src_mask = get_mask(
352
+ T,
353
+ T,
354
+ self.mask_type,
355
+ self.sparse_attn_window,
356
+ self.global_window,
357
+ self.mask_random_seed,
358
+ self.sparsity,
359
+ device,
360
+ )
361
+ self.__setattr__("src_mask", src_mask)
362
+
363
+ if self.norm_first:
364
+ x = x + self.gamma_1(
365
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
366
+ )
367
+ x = x + self.gamma_2(self._ff_block(self.norm2(x)))
368
+
369
+ if self.norm_out:
370
+ x = self.norm_out(x)
371
+ else:
372
+ x = self.norm1(
373
+ x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
374
+ )
375
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
376
+
377
+ return x
378
+
379
+
380
+ class CrossTransformerEncoderLayer(nn.Module):
381
+ def __init__(
382
+ self,
383
+ d_model: int,
384
+ nhead: int,
385
+ dim_feedforward: int = 2048,
386
+ dropout: float = 0.1,
387
+ activation=F.relu,
388
+ layer_norm_eps: float = 1e-5,
389
+ layer_scale: bool = False,
390
+ init_values: float = 1e-4,
391
+ norm_first: bool = False,
392
+ group_norm: bool = False,
393
+ norm_out: bool = False,
394
+ sparse=False,
395
+ mask_type="diag",
396
+ mask_random_seed=42,
397
+ sparse_attn_window=500,
398
+ global_window=50,
399
+ sparsity=0.95,
400
+ auto_sparsity=None,
401
+ device=None,
402
+ dtype=None,
403
+ batch_first=False,
404
+ ):
405
+ factory_kwargs = {"device": device, "dtype": dtype}
406
+ super().__init__()
407
+
408
+ self.sparse = sparse
409
+ self.auto_sparsity = auto_sparsity
410
+ if sparse:
411
+ if not auto_sparsity:
412
+ self.mask_type = mask_type
413
+ self.sparse_attn_window = sparse_attn_window
414
+ self.global_window = global_window
415
+ self.sparsity = sparsity
416
+
417
+ self.cross_attn: nn.Module
418
+ self.cross_attn = nn.MultiheadAttention(
419
+ d_model, nhead, dropout=dropout, batch_first=batch_first)
420
+ # Implementation of Feedforward model
421
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
422
+ self.dropout = nn.Dropout(dropout)
423
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
424
+
425
+ self.norm_first = norm_first
426
+ self.norm1: nn.Module
427
+ self.norm2: nn.Module
428
+ self.norm3: nn.Module
429
+ if group_norm:
430
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
431
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
432
+ self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
435
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
436
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
437
+
438
+ self.norm_out = None
439
+ if self.norm_first & norm_out:
440
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
441
+
442
+ self.gamma_1 = (
443
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
444
+ )
445
+ self.gamma_2 = (
446
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
447
+ )
448
+
449
+ self.dropout1 = nn.Dropout(dropout)
450
+ self.dropout2 = nn.Dropout(dropout)
451
+
452
+ # Legacy string support for activation function.
453
+ if isinstance(activation, str):
454
+ self.activation = self._get_activation_fn(activation)
455
+ else:
456
+ self.activation = activation
457
+
458
+ if sparse:
459
+ self.cross_attn = MultiheadAttention(
460
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
461
+ auto_sparsity=sparsity if auto_sparsity else 0)
462
+ if not auto_sparsity:
463
+ self.__setattr__("mask", torch.zeros(1, 1))
464
+ self.mask_random_seed = mask_random_seed
465
+
466
+ def forward(self, q, k, mask=None):
467
+ """
468
+ Args:
469
+ q: tensor of shape (T, B, C)
470
+ k: tensor of shape (S, B, C)
471
+ mask: tensor of shape (T, S)
472
+
473
+ """
474
+ device = q.device
475
+ T, B, C = q.shape
476
+ S, B, C = k.shape
477
+ if self.sparse and not self.auto_sparsity:
478
+ assert mask is None
479
+ mask = self.mask
480
+ if mask.shape[-1] != S or mask.shape[-2] != T:
481
+ mask = get_mask(
482
+ S,
483
+ T,
484
+ self.mask_type,
485
+ self.sparse_attn_window,
486
+ self.global_window,
487
+ self.mask_random_seed,
488
+ self.sparsity,
489
+ device,
490
+ )
491
+ self.__setattr__("mask", mask)
492
+
493
+ if self.norm_first:
494
+ x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
495
+ x = x + self.gamma_2(self._ff_block(self.norm3(x)))
496
+ if self.norm_out:
497
+ x = self.norm_out(x)
498
+ else:
499
+ x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
500
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
501
+
502
+ return x
503
+
504
+ # self-attention block
505
+ def _ca_block(self, q, k, attn_mask=None):
506
+ x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
507
+ return self.dropout1(x)
508
+
509
+ # feed forward block
510
+ def _ff_block(self, x):
511
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
512
+ return self.dropout2(x)
513
+
514
+ def _get_activation_fn(self, activation):
515
+ if activation == "relu":
516
+ return F.relu
517
+ elif activation == "gelu":
518
+ return F.gelu
519
+
520
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
521
+
522
+
523
+ # ----------------- MULTI-BLOCKS MODELS: -----------------------
524
+
525
+
526
+ class CrossTransformerEncoder(nn.Module):
527
+ def __init__(
528
+ self,
529
+ dim: int,
530
+ emb: str = "sin",
531
+ hidden_scale: float = 4.0,
532
+ num_heads: int = 8,
533
+ num_layers: int = 6,
534
+ cross_first: bool = False,
535
+ dropout: float = 0.0,
536
+ max_positions: int = 1000,
537
+ norm_in: bool = True,
538
+ norm_in_group: bool = False,
539
+ group_norm: int = False,
540
+ norm_first: bool = False,
541
+ norm_out: bool = False,
542
+ max_period: float = 10000.0,
543
+ weight_decay: float = 0.0,
544
+ lr: tp.Optional[float] = None,
545
+ layer_scale: bool = False,
546
+ gelu: bool = True,
547
+ sin_random_shift: int = 0,
548
+ weight_pos_embed: float = 1.0,
549
+ cape_mean_normalize: bool = True,
550
+ cape_augment: bool = True,
551
+ cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
552
+ sparse_self_attn: bool = False,
553
+ sparse_cross_attn: bool = False,
554
+ mask_type: str = "diag",
555
+ mask_random_seed: int = 42,
556
+ sparse_attn_window: int = 500,
557
+ global_window: int = 50,
558
+ auto_sparsity: bool = False,
559
+ sparsity: float = 0.95,
560
+ ):
561
+ super().__init__()
562
+ """
563
+ """
564
+ assert dim % num_heads == 0
565
+
566
+ hidden_dim = int(dim * hidden_scale)
567
+
568
+ self.num_layers = num_layers
569
+ # classic parity = 1 means that if idx%2 == 1 there is a
570
+ # classical encoder else there is a cross encoder
571
+ self.classic_parity = 1 if cross_first else 0
572
+ self.emb = emb
573
+ self.max_period = max_period
574
+ self.weight_decay = weight_decay
575
+ self.weight_pos_embed = weight_pos_embed
576
+ self.sin_random_shift = sin_random_shift
577
+ if emb == "cape":
578
+ self.cape_mean_normalize = cape_mean_normalize
579
+ self.cape_augment = cape_augment
580
+ self.cape_glob_loc_scale = cape_glob_loc_scale
581
+ if emb == "scaled":
582
+ self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
583
+
584
+ self.lr = lr
585
+
586
+ activation: tp.Any = F.gelu if gelu else F.relu
587
+
588
+ self.norm_in: nn.Module
589
+ self.norm_in_t: nn.Module
590
+ if norm_in:
591
+ self.norm_in = nn.LayerNorm(dim)
592
+ self.norm_in_t = nn.LayerNorm(dim)
593
+ elif norm_in_group:
594
+ self.norm_in = MyGroupNorm(int(norm_in_group), dim)
595
+ self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
596
+ else:
597
+ self.norm_in = nn.Identity()
598
+ self.norm_in_t = nn.Identity()
599
+
600
+ # spectrogram layers
601
+ self.layers = nn.ModuleList()
602
+ # temporal layers
603
+ self.layers_t = nn.ModuleList()
604
+
605
+ kwargs_common = {
606
+ "d_model": dim,
607
+ "nhead": num_heads,
608
+ "dim_feedforward": hidden_dim,
609
+ "dropout": dropout,
610
+ "activation": activation,
611
+ "group_norm": group_norm,
612
+ "norm_first": norm_first,
613
+ "norm_out": norm_out,
614
+ "layer_scale": layer_scale,
615
+ "mask_type": mask_type,
616
+ "mask_random_seed": mask_random_seed,
617
+ "sparse_attn_window": sparse_attn_window,
618
+ "global_window": global_window,
619
+ "sparsity": sparsity,
620
+ "auto_sparsity": auto_sparsity,
621
+ "batch_first": True,
622
+ }
623
+
624
+ kwargs_classic_encoder = dict(kwargs_common)
625
+ kwargs_classic_encoder.update({
626
+ "sparse": sparse_self_attn,
627
+ })
628
+ kwargs_cross_encoder = dict(kwargs_common)
629
+ kwargs_cross_encoder.update({
630
+ "sparse": sparse_cross_attn,
631
+ })
632
+
633
+ for idx in range(num_layers):
634
+ if idx % 2 == self.classic_parity:
635
+
636
+ self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
637
+ self.layers_t.append(
638
+ MyTransformerEncoderLayer(**kwargs_classic_encoder)
639
+ )
640
+
641
+ else:
642
+ self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
643
+
644
+ self.layers_t.append(
645
+ CrossTransformerEncoderLayer(**kwargs_cross_encoder)
646
+ )
647
+
648
+ def forward(self, x, xt):
649
+ B, C, Fr, T1 = x.shape
650
+ pos_emb_2d = create_2d_sin_embedding(
651
+ C, Fr, T1, x.device, self.max_period
652
+ ) # (1, C, Fr, T1)
653
+ pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
654
+ x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
655
+ x = self.norm_in(x)
656
+ x = x + self.weight_pos_embed * pos_emb_2d
657
+
658
+ B, C, T2 = xt.shape
659
+ xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
660
+ pos_emb = self._get_pos_embedding(T2, B, C, x.device)
661
+ pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
662
+ xt = self.norm_in_t(xt)
663
+ xt = xt + self.weight_pos_embed * pos_emb
664
+
665
+ for idx in range(self.num_layers):
666
+ if idx % 2 == self.classic_parity:
667
+ x = self.layers[idx](x)
668
+ xt = self.layers_t[idx](xt)
669
+ else:
670
+ old_x = x
671
+ x = self.layers[idx](x, xt)
672
+ xt = self.layers_t[idx](xt, old_x)
673
+
674
+ x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
675
+ xt = rearrange(xt, "b t2 c -> b c t2")
676
+ return x, xt
677
+
678
+ def _get_pos_embedding(self, T, B, C, device):
679
+ if self.emb == "sin":
680
+ shift = random.randrange(self.sin_random_shift + 1)
681
+ pos_emb = create_sin_embedding(
682
+ T, C, shift=shift, device=device, max_period=self.max_period
683
+ )
684
+ elif self.emb == "cape":
685
+ if self.training:
686
+ pos_emb = create_sin_embedding_cape(
687
+ T,
688
+ C,
689
+ B,
690
+ device=device,
691
+ max_period=self.max_period,
692
+ mean_normalize=self.cape_mean_normalize,
693
+ augment=self.cape_augment,
694
+ max_global_shift=self.cape_glob_loc_scale[0],
695
+ max_local_shift=self.cape_glob_loc_scale[1],
696
+ max_scale=self.cape_glob_loc_scale[2],
697
+ )
698
+ else:
699
+ pos_emb = create_sin_embedding_cape(
700
+ T,
701
+ C,
702
+ B,
703
+ device=device,
704
+ max_period=self.max_period,
705
+ mean_normalize=self.cape_mean_normalize,
706
+ augment=False,
707
+ )
708
+
709
+ elif self.emb == "scaled":
710
+ pos = torch.arange(T, device=device)
711
+ pos_emb = self.position_embeddings(pos)[:, None]
712
+
713
+ return pos_emb
714
+
715
+ def make_optim_group(self):
716
+ group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
717
+ if self.lr is not None:
718
+ group["lr"] = self.lr
719
+ return group
720
+
721
+
722
+ # Attention Modules
723
+
724
+
725
+ class MultiheadAttention(nn.Module):
726
+ def __init__(
727
+ self,
728
+ embed_dim,
729
+ num_heads,
730
+ dropout=0.0,
731
+ bias=True,
732
+ add_bias_kv=False,
733
+ add_zero_attn=False,
734
+ kdim=None,
735
+ vdim=None,
736
+ batch_first=False,
737
+ auto_sparsity=None,
738
+ ):
739
+ super().__init__()
740
+ assert auto_sparsity is not None, "sanity check"
741
+ self.num_heads = num_heads
742
+ self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
743
+ self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
744
+ self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
745
+ self.attn_drop = torch.nn.Dropout(dropout)
746
+ self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
747
+ self.proj_drop = torch.nn.Dropout(dropout)
748
+ self.batch_first = batch_first
749
+ self.auto_sparsity = auto_sparsity
750
+
751
+ def forward(
752
+ self,
753
+ query,
754
+ key,
755
+ value,
756
+ key_padding_mask=None,
757
+ need_weights=True,
758
+ attn_mask=None,
759
+ average_attn_weights=True,
760
+ ):
761
+
762
+ if not self.batch_first: # N, B, C
763
+ query = query.permute(1, 0, 2) # B, N_q, C
764
+ key = key.permute(1, 0, 2) # B, N_k, C
765
+ value = value.permute(1, 0, 2) # B, N_k, C
766
+ B, N_q, C = query.shape
767
+ B, N_k, C = key.shape
768
+
769
+ q = (
770
+ self.q(query)
771
+ .reshape(B, N_q, self.num_heads, C // self.num_heads)
772
+ .permute(0, 2, 1, 3)
773
+ )
774
+ q = q.flatten(0, 1)
775
+ k = (
776
+ self.k(key)
777
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
778
+ .permute(0, 2, 1, 3)
779
+ )
780
+ k = k.flatten(0, 1)
781
+ v = (
782
+ self.v(value)
783
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
784
+ .permute(0, 2, 1, 3)
785
+ )
786
+ v = v.flatten(0, 1)
787
+
788
+ if self.auto_sparsity:
789
+ assert attn_mask is None
790
+ x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
791
+ else:
792
+ x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
793
+ x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
794
+
795
+ x = x.transpose(1, 2).reshape(B, N_q, C)
796
+ x = self.proj(x)
797
+ x = self.proj_drop(x)
798
+ if not self.batch_first:
799
+ x = x.permute(1, 0, 2)
800
+ return x, None
801
+
802
+
803
+ def scaled_query_key_softmax(q, k, att_mask):
804
+ from xformers.ops import masked_matmul
805
+ q = q / (k.size(-1)) ** 0.5
806
+ att = masked_matmul(q, k.transpose(-2, -1), att_mask)
807
+ att = torch.nn.functional.softmax(att, -1)
808
+ return att
809
+
810
+
811
+ def scaled_dot_product_attention(q, k, v, att_mask, dropout):
812
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
813
+ att = dropout(att)
814
+ y = att @ v
815
+ return y
816
+
817
+
818
+ def _compute_buckets(x, R):
819
+ qq = torch.einsum('btf,bfhi->bhti', x, R)
820
+ qq = torch.cat([qq, -qq], dim=-1)
821
+ buckets = qq.argmax(dim=-1)
822
+
823
+ return buckets.permute(0, 2, 1).byte().contiguous()
824
+
825
+
826
+ def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
827
+ # assert False, "The code for the custom sparse kernel is not ready for release yet."
828
+ from xformers.ops import find_locations, sparse_memory_efficient_attention
829
+ n_hashes = 32
830
+ proj_size = 4
831
+ query, key, value = [x.contiguous() for x in [query, key, value]]
832
+ with torch.no_grad():
833
+ R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
834
+ bucket_query = _compute_buckets(query, R)
835
+ bucket_key = _compute_buckets(key, R)
836
+ row_offsets, column_indices = find_locations(
837
+ bucket_query, bucket_key, sparsity, infer_sparsity)
838
+ return sparse_memory_efficient_attention(
839
+ query, key, value, row_offsets, column_indices, attn_bias)
demucs3/utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from contextlib import contextmanager
9
+ import math
10
+ import os
11
+ import tempfile
12
+ import typing as tp
13
+
14
+ import torch
15
+ from torch.nn import functional as F
16
+ from torch.utils.data import Subset
17
+
18
+
19
+ def unfold(a, kernel_size, stride):
20
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
21
+ with K the kernel size, by extracting frames with the given stride.
22
+
23
+ This will pad the input so that `F = ceil(T / K)`.
24
+
25
+ see https://github.com/pytorch/pytorch/issues/60466
26
+ """
27
+ *shape, length = a.shape
28
+ n_frames = math.ceil(length / stride)
29
+ tgt_length = (n_frames - 1) * stride + kernel_size
30
+ a = F.pad(a, (0, tgt_length - length))
31
+ strides = list(a.stride())
32
+ assert strides[-1] == 1, 'data should be contiguous'
33
+ strides = strides[:-1] + [stride, 1]
34
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
35
+
36
+
37
+ def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
38
+ """
39
+ Center trim `tensor` with respect to `reference`, along the last dimension.
40
+ `reference` can also be a number, representing the length to trim to.
41
+ If the size difference != 0 mod 2, the extra sample is removed on the right side.
42
+ """
43
+ ref_size: int
44
+ if isinstance(reference, torch.Tensor):
45
+ ref_size = reference.size(-1)
46
+ else:
47
+ ref_size = reference
48
+ delta = tensor.size(-1) - ref_size
49
+ if delta < 0:
50
+ raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
51
+ if delta:
52
+ tensor = tensor[..., delta // 2:-(delta - delta // 2)]
53
+ return tensor
54
+
55
+
56
+ def pull_metric(history: tp.List[dict], name: str):
57
+ out = []
58
+ for metrics in history:
59
+ metric = metrics
60
+ for part in name.split("."):
61
+ metric = metric[part]
62
+ out.append(metric)
63
+ return out
64
+
65
+
66
+ def EMA(beta: float = 1):
67
+ """
68
+ Exponential Moving Average callback.
69
+ Returns a single function that can be called to repeatidly update the EMA
70
+ with a dict of metrics. The callback will return
71
+ the new averaged dict of metrics.
72
+
73
+ Note that for `beta=1`, this is just plain averaging.
74
+ """
75
+ fix: tp.Dict[str, float] = defaultdict(float)
76
+ total: tp.Dict[str, float] = defaultdict(float)
77
+
78
+ def _update(metrics: dict, weight: float = 1) -> dict:
79
+ nonlocal total, fix
80
+ for key, value in metrics.items():
81
+ total[key] = total[key] * beta + weight * float(value)
82
+ fix[key] = fix[key] * beta + weight
83
+ return {key: tot / fix[key] for key, tot in total.items()}
84
+ return _update
85
+
86
+
87
+ def sizeof_fmt(num: float, suffix: str = 'B'):
88
+ """
89
+ Given `num` bytes, return human readable size.
90
+ Taken from https://stackoverflow.com/a/1094933
91
+ """
92
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
93
+ if abs(num) < 1024.0:
94
+ return "%3.1f%s%s" % (num, unit, suffix)
95
+ num /= 1024.0
96
+ return "%.1f%s%s" % (num, 'Yi', suffix)
97
+
98
+
99
+ @contextmanager
100
+ def temp_filenames(count: int, delete=True):
101
+ names = []
102
+ try:
103
+ for _ in range(count):
104
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
105
+ yield names
106
+ finally:
107
+ if delete:
108
+ for name in names:
109
+ os.unlink(name)
110
+
111
+
112
+ def random_subset(dataset, max_samples: int, seed: int = 42):
113
+ if max_samples >= len(dataset):
114
+ return dataset
115
+
116
+ generator = torch.Generator().manual_seed(seed)
117
+ perm = torch.randperm(len(dataset), generator=generator)
118
+ return Subset(dataset, perm[:max_samples].tolist())
119
+
120
+
121
+ class DummyPoolExecutor:
122
+ class DummyResult:
123
+ def __init__(self, func, *args, **kwargs):
124
+ self.func = func
125
+ self.args = args
126
+ self.kwargs = kwargs
127
+
128
+ def result(self):
129
+ return self.func(*self.args, **self.kwargs)
130
+
131
+ def __init__(self, workers=0):
132
+ pass
133
+
134
+ def submit(self, func, *args, **kwargs):
135
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
136
+
137
+ def __enter__(self):
138
+ return self
139
+
140
+ def __exit__(self, exc_type, exc_value, exc_tb):
141
+ return
demucs4/demucs.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import julius
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .states import capture_init
16
+ from .utils import center_trim, unfold
17
+ from .transformer import LayerScale
18
+
19
+
20
+ class BLSTM(nn.Module):
21
+ """
22
+ BiLSTM with same hidden units as input dim.
23
+ If `max_steps` is not None, input will be splitting in overlapping
24
+ chunks and the LSTM applied separately on each chunk.
25
+ """
26
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
27
+ super().__init__()
28
+ assert max_steps is None or max_steps % 4 == 0
29
+ self.max_steps = max_steps
30
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
31
+ self.linear = nn.Linear(2 * dim, dim)
32
+ self.skip = skip
33
+
34
+ def forward(self, x):
35
+ B, C, T = x.shape
36
+ y = x
37
+ framed = False
38
+ if self.max_steps is not None and T > self.max_steps:
39
+ width = self.max_steps
40
+ stride = width // 2
41
+ frames = unfold(x, width, stride)
42
+ nframes = frames.shape[2]
43
+ framed = True
44
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
45
+
46
+ x = x.permute(2, 0, 1)
47
+
48
+ x = self.lstm(x)[0]
49
+ x = self.linear(x)
50
+ x = x.permute(1, 2, 0)
51
+ if framed:
52
+ out = []
53
+ frames = x.reshape(B, -1, C, width)
54
+ limit = stride // 2
55
+ for k in range(nframes):
56
+ if k == 0:
57
+ out.append(frames[:, k, :, :-limit])
58
+ elif k == nframes - 1:
59
+ out.append(frames[:, k, :, limit:])
60
+ else:
61
+ out.append(frames[:, k, :, limit:-limit])
62
+ out = torch.cat(out, -1)
63
+ out = out[..., :T]
64
+ x = out
65
+ if self.skip:
66
+ x = x + y
67
+ return x
68
+
69
+
70
+ def rescale_conv(conv, reference):
71
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
72
+ """
73
+ std = conv.weight.std().detach()
74
+ scale = (std / reference)**0.5
75
+ conv.weight.data /= scale
76
+ if conv.bias is not None:
77
+ conv.bias.data /= scale
78
+
79
+
80
+ def rescale_module(module, reference):
81
+ for sub in module.modules():
82
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
83
+ rescale_conv(sub, reference)
84
+
85
+
86
+ class DConv(nn.Module):
87
+ """
88
+ New residual branches in each encoder layer.
89
+ This alternates dilated convolutions, potentially with LSTMs and attention.
90
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
91
+ e.g. of dim `channels // compress`.
92
+ """
93
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
94
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
95
+ kernel=3, dilate=True):
96
+ """
97
+ Args:
98
+ channels: input/output channels for residual branch.
99
+ compress: amount of channel compression inside the branch.
100
+ depth: number of layers in the residual branch. Each layer has its own
101
+ projection, and potentially LSTM and attention.
102
+ init: initial scale for LayerNorm.
103
+ norm: use GroupNorm.
104
+ attn: use LocalAttention.
105
+ heads: number of heads for the LocalAttention.
106
+ ndecay: number of decay controls in the LocalAttention.
107
+ lstm: use LSTM.
108
+ gelu: Use GELU activation.
109
+ kernel: kernel size for the (dilated) convolutions.
110
+ dilate: if true, use dilation, increasing with the depth.
111
+ """
112
+
113
+ super().__init__()
114
+ assert kernel % 2 == 1
115
+ self.channels = channels
116
+ self.compress = compress
117
+ self.depth = abs(depth)
118
+ dilate = depth > 0
119
+
120
+ norm_fn: tp.Callable[[int], nn.Module]
121
+ norm_fn = lambda d: nn.Identity() # noqa
122
+ if norm:
123
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
124
+
125
+ hidden = int(channels / compress)
126
+
127
+ act: tp.Type[nn.Module]
128
+ if gelu:
129
+ act = nn.GELU
130
+ else:
131
+ act = nn.ReLU
132
+
133
+ self.layers = nn.ModuleList([])
134
+ for d in range(self.depth):
135
+ dilation = 2 ** d if dilate else 1
136
+ padding = dilation * (kernel // 2)
137
+ mods = [
138
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
139
+ norm_fn(hidden), act(),
140
+ nn.Conv1d(hidden, 2 * channels, 1),
141
+ norm_fn(2 * channels), nn.GLU(1),
142
+ LayerScale(channels, init),
143
+ ]
144
+ if attn:
145
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
146
+ if lstm:
147
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
148
+ layer = nn.Sequential(*mods)
149
+ self.layers.append(layer)
150
+
151
+ def forward(self, x):
152
+ for layer in self.layers:
153
+ x = x + layer(x)
154
+ return x
155
+
156
+
157
+ class LocalState(nn.Module):
158
+ """Local state allows to have attention based only on data (no positional embedding),
159
+ but while setting a constraint on the time window (e.g. decaying penalty term).
160
+
161
+ Also a failed experiments with trying to provide some frequency based attention.
162
+ """
163
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
164
+ super().__init__()
165
+ assert channels % heads == 0, (channels, heads)
166
+ self.heads = heads
167
+ self.nfreqs = nfreqs
168
+ self.ndecay = ndecay
169
+ self.content = nn.Conv1d(channels, channels, 1)
170
+ self.query = nn.Conv1d(channels, channels, 1)
171
+ self.key = nn.Conv1d(channels, channels, 1)
172
+ if nfreqs:
173
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
174
+ if ndecay:
175
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
176
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
177
+ self.query_decay.weight.data *= 0.01
178
+ assert self.query_decay.bias is not None # stupid type checker
179
+ self.query_decay.bias.data[:] = -2
180
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
181
+
182
+ def forward(self, x):
183
+ B, C, T = x.shape
184
+ heads = self.heads
185
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
186
+ # left index are keys, right index are queries
187
+ delta = indexes[:, None] - indexes[None, :]
188
+
189
+ queries = self.query(x).view(B, heads, -1, T)
190
+ keys = self.key(x).view(B, heads, -1, T)
191
+ # t are keys, s are queries
192
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
193
+ dots /= keys.shape[2]**0.5
194
+ if self.nfreqs:
195
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
196
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
197
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
198
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
199
+ if self.ndecay:
200
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
201
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
202
+ decay_q = torch.sigmoid(decay_q) / 2
203
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
204
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
205
+
206
+ # Kill self reference.
207
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
208
+ weights = torch.softmax(dots, dim=2)
209
+
210
+ content = self.content(x).view(B, heads, -1, T)
211
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
212
+ if self.nfreqs:
213
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
214
+ result = torch.cat([result, time_sig], 2)
215
+ result = result.reshape(B, -1, T)
216
+ return x + self.proj(result)
217
+
218
+
219
+ class Demucs(nn.Module):
220
+ @capture_init
221
+ def __init__(self,
222
+ sources,
223
+ # Channels
224
+ audio_channels=2,
225
+ channels=64,
226
+ growth=2.,
227
+ # Main structure
228
+ depth=6,
229
+ rewrite=True,
230
+ lstm_layers=0,
231
+ # Convolutions
232
+ kernel_size=8,
233
+ stride=4,
234
+ context=1,
235
+ # Activations
236
+ gelu=True,
237
+ glu=True,
238
+ # Normalization
239
+ norm_starts=4,
240
+ norm_groups=4,
241
+ # DConv residual branch
242
+ dconv_mode=1,
243
+ dconv_depth=2,
244
+ dconv_comp=4,
245
+ dconv_attn=4,
246
+ dconv_lstm=4,
247
+ dconv_init=1e-4,
248
+ # Pre/post processing
249
+ normalize=True,
250
+ resample=True,
251
+ # Weight init
252
+ rescale=0.1,
253
+ # Metadata
254
+ samplerate=44100,
255
+ segment=4 * 10):
256
+ """
257
+ Args:
258
+ sources (list[str]): list of source names
259
+ audio_channels (int): stereo or mono
260
+ channels (int): first convolution channels
261
+ depth (int): number of encoder/decoder layers
262
+ growth (float): multiply (resp divide) number of channels by that
263
+ for each layer of the encoder (resp decoder)
264
+ depth (int): number of layers in the encoder and in the decoder.
265
+ rewrite (bool): add 1x1 convolution to each layer.
266
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
267
+ by default, as this is now replaced by the smaller and faster small LSTMs
268
+ in the DConv branches.
269
+ kernel_size (int): kernel size for convolutions
270
+ stride (int): stride for convolutions
271
+ context (int): kernel size of the convolution in the
272
+ decoder before the transposed convolution. If > 1,
273
+ will provide some context from neighboring time steps.
274
+ gelu: use GELU activation function.
275
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
276
+ norm_starts: layer at which group norm starts being used.
277
+ decoder layers are numbered in reverse order.
278
+ norm_groups: number of groups for group norm.
279
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
280
+ dconv_depth: depth of residual DConv branch.
281
+ dconv_comp: compression of DConv branch.
282
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
283
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
284
+ dconv_init: initial scale for the DConv branch LayerScale.
285
+ normalize (bool): normalizes the input audio on the fly, and scales back
286
+ the output by the same amount.
287
+ resample (bool): upsample x2 the input and downsample /2 the output.
288
+ rescale (int): rescale initial weights of convolutions
289
+ to get their standard deviation closer to `rescale`.
290
+ samplerate (int): stored as meta information for easing
291
+ future evaluations of the model.
292
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
293
+ This is used by `demucs.apply.apply_model`.
294
+ """
295
+
296
+ super().__init__()
297
+ self.audio_channels = audio_channels
298
+ self.sources = sources
299
+ self.kernel_size = kernel_size
300
+ self.context = context
301
+ self.stride = stride
302
+ self.depth = depth
303
+ self.resample = resample
304
+ self.channels = channels
305
+ self.normalize = normalize
306
+ self.samplerate = samplerate
307
+ self.segment = segment
308
+ self.encoder = nn.ModuleList()
309
+ self.decoder = nn.ModuleList()
310
+ self.skip_scales = nn.ModuleList()
311
+
312
+ if glu:
313
+ activation = nn.GLU(dim=1)
314
+ ch_scale = 2
315
+ else:
316
+ activation = nn.ReLU()
317
+ ch_scale = 1
318
+ if gelu:
319
+ act2 = nn.GELU
320
+ else:
321
+ act2 = nn.ReLU
322
+
323
+ in_channels = audio_channels
324
+ padding = 0
325
+ for index in range(depth):
326
+ norm_fn = lambda d: nn.Identity() # noqa
327
+ if index >= norm_starts:
328
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
329
+
330
+ encode = []
331
+ encode += [
332
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
333
+ norm_fn(channels),
334
+ act2(),
335
+ ]
336
+ attn = index >= dconv_attn
337
+ lstm = index >= dconv_lstm
338
+ if dconv_mode & 1:
339
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
340
+ compress=dconv_comp, attn=attn, lstm=lstm)]
341
+ if rewrite:
342
+ encode += [
343
+ nn.Conv1d(channels, ch_scale * channels, 1),
344
+ norm_fn(ch_scale * channels), activation]
345
+ self.encoder.append(nn.Sequential(*encode))
346
+
347
+ decode = []
348
+ if index > 0:
349
+ out_channels = in_channels
350
+ else:
351
+ out_channels = len(self.sources) * audio_channels
352
+ if rewrite:
353
+ decode += [
354
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
355
+ norm_fn(ch_scale * channels), activation]
356
+ if dconv_mode & 2:
357
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
358
+ compress=dconv_comp, attn=attn, lstm=lstm)]
359
+ decode += [nn.ConvTranspose1d(channels, out_channels,
360
+ kernel_size, stride, padding=padding)]
361
+ if index > 0:
362
+ decode += [norm_fn(out_channels), act2()]
363
+ self.decoder.insert(0, nn.Sequential(*decode))
364
+ in_channels = channels
365
+ channels = int(growth * channels)
366
+
367
+ channels = in_channels
368
+ if lstm_layers:
369
+ self.lstm = BLSTM(channels, lstm_layers)
370
+ else:
371
+ self.lstm = None
372
+
373
+ if rescale:
374
+ rescale_module(self, reference=rescale)
375
+
376
+ def valid_length(self, length):
377
+ """
378
+ Return the nearest valid length to use with the model so that
379
+ there is no time steps left over in a convolution, e.g. for all
380
+ layers, size of the input - kernel_size % stride = 0.
381
+
382
+ Note that input are automatically padded if necessary to ensure that the output
383
+ has the same length as the input.
384
+ """
385
+ if self.resample:
386
+ length *= 2
387
+
388
+ for _ in range(self.depth):
389
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
390
+ length = max(1, length)
391
+
392
+ for idx in range(self.depth):
393
+ length = (length - 1) * self.stride + self.kernel_size
394
+
395
+ if self.resample:
396
+ length = math.ceil(length / 2)
397
+ return int(length)
398
+
399
+ def forward(self, mix):
400
+ x = mix
401
+ length = x.shape[-1]
402
+
403
+ if self.normalize:
404
+ mono = mix.mean(dim=1, keepdim=True)
405
+ mean = mono.mean(dim=-1, keepdim=True)
406
+ std = mono.std(dim=-1, keepdim=True)
407
+ x = (x - mean) / (1e-5 + std)
408
+ else:
409
+ mean = 0
410
+ std = 1
411
+
412
+ delta = self.valid_length(length) - length
413
+ x = F.pad(x, (delta // 2, delta - delta // 2))
414
+
415
+ if self.resample:
416
+ x = julius.resample_frac(x, 1, 2)
417
+
418
+ saved = []
419
+ for encode in self.encoder:
420
+ x = encode(x)
421
+ saved.append(x)
422
+
423
+ if self.lstm:
424
+ x = self.lstm(x)
425
+
426
+ for decode in self.decoder:
427
+ skip = saved.pop(-1)
428
+ skip = center_trim(skip, x)
429
+ x = decode(x + skip)
430
+
431
+ if self.resample:
432
+ x = julius.resample_frac(x, 2, 1)
433
+ x = x * std + mean
434
+ x = center_trim(x, length)
435
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
436
+ return x
437
+
438
+ def load_state_dict(self, state, strict=True):
439
+ # fix a mismatch with previous generation Demucs models.
440
+ for idx in range(self.depth):
441
+ for a in ['encoder', 'decoder']:
442
+ for b in ['bias', 'weight']:
443
+ new = f'{a}.{idx}.3.{b}'
444
+ old = f'{a}.{idx}.2.{b}'
445
+ if old in state and new not in state:
446
+ state[new] = state.pop(old)
447
+ super().load_state_dict(state, strict=strict)
demucs4/hdemucs.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ This code contains the spectrogram and Hybrid version of Demucs.
8
+ """
9
+ from copy import deepcopy
10
+ import math
11
+ import typing as tp
12
+
13
+ from openunmix.filtering import wiener
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+
18
+ from .demucs import DConv, rescale_module
19
+ from .states import capture_init
20
+ from .spec import spectro, ispectro
21
+
22
+
23
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
24
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
25
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
26
+ x0 = x
27
+ length = x.shape[-1]
28
+ padding_left, padding_right = paddings
29
+ if mode == 'reflect':
30
+ max_pad = max(padding_left, padding_right)
31
+ if length <= max_pad:
32
+ extra_pad = max_pad - length + 1
33
+ extra_pad_right = min(padding_right, extra_pad)
34
+ extra_pad_left = extra_pad - extra_pad_right
35
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
36
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
37
+ out = F.pad(x, paddings, mode, value)
38
+ assert out.shape[-1] == length + padding_left + padding_right
39
+ assert (out[..., padding_left: padding_left + length] == x0).all()
40
+ return out
41
+
42
+
43
+ class ScaledEmbedding(nn.Module):
44
+ """
45
+ Boost learning rate for embeddings (with `scale`).
46
+ Also, can make embeddings continuous with `smooth`.
47
+ """
48
+ def __init__(self, num_embeddings: int, embedding_dim: int,
49
+ scale: float = 10., smooth=False):
50
+ super().__init__()
51
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
52
+ if smooth:
53
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
54
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
55
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
56
+ self.embedding.weight.data[:] = weight
57
+ self.embedding.weight.data /= scale
58
+ self.scale = scale
59
+
60
+ @property
61
+ def weight(self):
62
+ return self.embedding.weight * self.scale
63
+
64
+ def forward(self, x):
65
+ out = self.embedding(x) * self.scale
66
+ return out
67
+
68
+
69
+ class HEncLayer(nn.Module):
70
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
71
+ freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
72
+ rewrite=True):
73
+ """Encoder layer. This used both by the time and the frequency branch.
74
+
75
+ Args:
76
+ chin: number of input channels.
77
+ chout: number of output channels.
78
+ norm_groups: number of groups for group norm.
79
+ empty: used to make a layer with just the first conv. this is used
80
+ before merging the time and freq. branches.
81
+ freq: this is acting on frequencies.
82
+ dconv: insert DConv residual branches.
83
+ norm: use GroupNorm.
84
+ context: context size for the 1x1 conv.
85
+ dconv_kw: list of kwargs for the DConv class.
86
+ pad: pad the input. Padding is done so that the output size is
87
+ always the input size / stride.
88
+ rewrite: add 1x1 conv at the end of the layer.
89
+ """
90
+ super().__init__()
91
+ norm_fn = lambda d: nn.Identity() # noqa
92
+ if norm:
93
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
94
+ if pad:
95
+ pad = kernel_size // 4
96
+ else:
97
+ pad = 0
98
+ klass = nn.Conv1d
99
+ self.freq = freq
100
+ self.kernel_size = kernel_size
101
+ self.stride = stride
102
+ self.empty = empty
103
+ self.norm = norm
104
+ self.pad = pad
105
+ if freq:
106
+ kernel_size = [kernel_size, 1]
107
+ stride = [stride, 1]
108
+ pad = [pad, 0]
109
+ klass = nn.Conv2d
110
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
111
+ if self.empty:
112
+ return
113
+ self.norm1 = norm_fn(chout)
114
+ self.rewrite = None
115
+ if rewrite:
116
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
117
+ self.norm2 = norm_fn(2 * chout)
118
+
119
+ self.dconv = None
120
+ if dconv:
121
+ self.dconv = DConv(chout, **dconv_kw)
122
+
123
+ def forward(self, x, inject=None):
124
+ """
125
+ `inject` is used to inject the result from the time branch into the frequency branch,
126
+ when both have the same stride.
127
+ """
128
+ if not self.freq and x.dim() == 4:
129
+ B, C, Fr, T = x.shape
130
+ x = x.view(B, -1, T)
131
+
132
+ if not self.freq:
133
+ le = x.shape[-1]
134
+ if not le % self.stride == 0:
135
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
136
+ y = self.conv(x)
137
+ if self.empty:
138
+ return y
139
+ if inject is not None:
140
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
141
+ if inject.dim() == 3 and y.dim() == 4:
142
+ inject = inject[:, :, None]
143
+ y = y + inject
144
+ y = F.gelu(self.norm1(y))
145
+ if self.dconv:
146
+ if self.freq:
147
+ B, C, Fr, T = y.shape
148
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
149
+ y = self.dconv(y)
150
+ if self.freq:
151
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
152
+ if self.rewrite:
153
+ z = self.norm2(self.rewrite(y))
154
+ z = F.glu(z, dim=1)
155
+ else:
156
+ z = y
157
+ return z
158
+
159
+
160
+ class MultiWrap(nn.Module):
161
+ """
162
+ Takes one layer and replicate it N times. each replica will act
163
+ on a frequency band. All is done so that if the N replica have the same weights,
164
+ then this is exactly equivalent to applying the original module on all frequencies.
165
+
166
+ This is a bit over-engineered to avoid edge artifacts when splitting
167
+ the frequency bands, but it is possible the naive implementation would work as well...
168
+ """
169
+ def __init__(self, layer, split_ratios):
170
+ """
171
+ Args:
172
+ layer: module to clone, must be either HEncLayer or HDecLayer.
173
+ split_ratios: list of float indicating which ratio to keep for each band.
174
+ """
175
+ super().__init__()
176
+ self.split_ratios = split_ratios
177
+ self.layers = nn.ModuleList()
178
+ self.conv = isinstance(layer, HEncLayer)
179
+ assert not layer.norm
180
+ assert layer.freq
181
+ assert layer.pad
182
+ if not self.conv:
183
+ assert not layer.context_freq
184
+ for k in range(len(split_ratios) + 1):
185
+ lay = deepcopy(layer)
186
+ if self.conv:
187
+ lay.conv.padding = (0, 0)
188
+ else:
189
+ lay.pad = False
190
+ for m in lay.modules():
191
+ if hasattr(m, 'reset_parameters'):
192
+ m.reset_parameters()
193
+ self.layers.append(lay)
194
+
195
+ def forward(self, x, skip=None, length=None):
196
+ B, C, Fr, T = x.shape
197
+
198
+ ratios = list(self.split_ratios) + [1]
199
+ start = 0
200
+ outs = []
201
+ for ratio, layer in zip(ratios, self.layers):
202
+ if self.conv:
203
+ pad = layer.kernel_size // 4
204
+ if ratio == 1:
205
+ limit = Fr
206
+ frames = -1
207
+ else:
208
+ limit = int(round(Fr * ratio))
209
+ le = limit - start
210
+ if start == 0:
211
+ le += pad
212
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
213
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
214
+ if start == 0:
215
+ limit -= pad
216
+ assert limit - start > 0, (limit, start)
217
+ assert limit <= Fr, (limit, Fr)
218
+ y = x[:, :, start:limit, :]
219
+ if start == 0:
220
+ y = F.pad(y, (0, 0, pad, 0))
221
+ if ratio == 1:
222
+ y = F.pad(y, (0, 0, 0, pad))
223
+ outs.append(layer(y))
224
+ start = limit - layer.kernel_size + layer.stride
225
+ else:
226
+ if ratio == 1:
227
+ limit = Fr
228
+ else:
229
+ limit = int(round(Fr * ratio))
230
+ last = layer.last
231
+ layer.last = True
232
+
233
+ y = x[:, :, start:limit]
234
+ s = skip[:, :, start:limit]
235
+ out, _ = layer(y, s, None)
236
+ if outs:
237
+ outs[-1][:, :, -layer.stride:] += (
238
+ out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
239
+ out = out[:, :, layer.stride:]
240
+ if ratio == 1:
241
+ out = out[:, :, :-layer.stride // 2, :]
242
+ if start == 0:
243
+ out = out[:, :, layer.stride // 2:, :]
244
+ outs.append(out)
245
+ layer.last = last
246
+ start = limit
247
+ out = torch.cat(outs, dim=2)
248
+ if not self.conv and not last:
249
+ out = F.gelu(out)
250
+ if self.conv:
251
+ return out
252
+ else:
253
+ return out, None
254
+
255
+
256
+ class HDecLayer(nn.Module):
257
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
258
+ freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
259
+ context_freq=True, rewrite=True):
260
+ """
261
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
262
+ """
263
+ super().__init__()
264
+ norm_fn = lambda d: nn.Identity() # noqa
265
+ if norm:
266
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
267
+ if pad:
268
+ pad = kernel_size // 4
269
+ else:
270
+ pad = 0
271
+ self.pad = pad
272
+ self.last = last
273
+ self.freq = freq
274
+ self.chin = chin
275
+ self.empty = empty
276
+ self.stride = stride
277
+ self.kernel_size = kernel_size
278
+ self.norm = norm
279
+ self.context_freq = context_freq
280
+ klass = nn.Conv1d
281
+ klass_tr = nn.ConvTranspose1d
282
+ if freq:
283
+ kernel_size = [kernel_size, 1]
284
+ stride = [stride, 1]
285
+ klass = nn.Conv2d
286
+ klass_tr = nn.ConvTranspose2d
287
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
288
+ self.norm2 = norm_fn(chout)
289
+ if self.empty:
290
+ return
291
+ self.rewrite = None
292
+ if rewrite:
293
+ if context_freq:
294
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
295
+ else:
296
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
297
+ [0, context])
298
+ self.norm1 = norm_fn(2 * chin)
299
+
300
+ self.dconv = None
301
+ if dconv:
302
+ self.dconv = DConv(chin, **dconv_kw)
303
+
304
+ def forward(self, x, skip, length):
305
+ if self.freq and x.dim() == 3:
306
+ B, C, T = x.shape
307
+ x = x.view(B, self.chin, -1, T)
308
+
309
+ if not self.empty:
310
+ x = x + skip
311
+
312
+ if self.rewrite:
313
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
314
+ else:
315
+ y = x
316
+ if self.dconv:
317
+ if self.freq:
318
+ B, C, Fr, T = y.shape
319
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
320
+ y = self.dconv(y)
321
+ if self.freq:
322
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
323
+ else:
324
+ y = x
325
+ assert skip is None
326
+ z = self.norm2(self.conv_tr(y))
327
+ if self.freq:
328
+ if self.pad:
329
+ z = z[..., self.pad:-self.pad, :]
330
+ else:
331
+ z = z[..., self.pad:self.pad + length]
332
+ assert z.shape[-1] == length, (z.shape[-1], length)
333
+ if not self.last:
334
+ z = F.gelu(z)
335
+ return z, y
336
+
337
+
338
+ class HDemucs(nn.Module):
339
+ """
340
+ Spectrogram and hybrid Demucs model.
341
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
342
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
343
+ Frequency layers can still access information across time steps thanks to the DConv residual.
344
+
345
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
346
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
347
+
348
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
349
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
350
+ Open Unmix implementation [Stoter et al. 2019].
351
+
352
+ The loss is always on the temporal domain, by backpropagating through the above
353
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
354
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
355
+ contribution, without changing the one from the waveform, which will lead to worse performance.
356
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
357
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
358
+ hybrid models.
359
+
360
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
361
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
362
+
363
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
364
+ """
365
+ @capture_init
366
+ def __init__(self,
367
+ sources,
368
+ # Channels
369
+ audio_channels=2,
370
+ channels=48,
371
+ channels_time=None,
372
+ growth=2,
373
+ # STFT
374
+ nfft=4096,
375
+ wiener_iters=0,
376
+ end_iters=0,
377
+ wiener_residual=False,
378
+ cac=True,
379
+ # Main structure
380
+ depth=6,
381
+ rewrite=True,
382
+ hybrid=True,
383
+ hybrid_old=False,
384
+ # Frequency branch
385
+ multi_freqs=None,
386
+ multi_freqs_depth=2,
387
+ freq_emb=0.2,
388
+ emb_scale=10,
389
+ emb_smooth=True,
390
+ # Convolutions
391
+ kernel_size=8,
392
+ time_stride=2,
393
+ stride=4,
394
+ context=1,
395
+ context_enc=0,
396
+ # Normalization
397
+ norm_starts=4,
398
+ norm_groups=4,
399
+ # DConv residual branch
400
+ dconv_mode=1,
401
+ dconv_depth=2,
402
+ dconv_comp=4,
403
+ dconv_attn=4,
404
+ dconv_lstm=4,
405
+ dconv_init=1e-4,
406
+ # Weight init
407
+ rescale=0.1,
408
+ # Metadata
409
+ samplerate=44100,
410
+ segment=4 * 10):
411
+ """
412
+ Args:
413
+ sources (list[str]): list of source names.
414
+ audio_channels (int): input/output audio channels.
415
+ channels (int): initial number of hidden channels.
416
+ channels_time: if not None, use a different `channels` value for the time branch.
417
+ growth: increase the number of hidden channels by this factor at each layer.
418
+ nfft: number of fft bins. Note that changing this require careful computation of
419
+ various shape parameters and will not work out of the box for hybrid models.
420
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
421
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
422
+ wiener_residual: add residual source before wiener filtering.
423
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
424
+ in input and output. no further processing is done before ISTFT.
425
+ depth (int): number of layers in the encoder and in the decoder.
426
+ rewrite (bool): add 1x1 convolution to each layer.
427
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
428
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
429
+ this bug to avoid retraining them.
430
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
431
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
432
+ layers will be wrapped.
433
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
434
+ the actual value controls the weight of the embedding.
435
+ emb_scale: equivalent to scaling the embedding learning rate
436
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
437
+ kernel_size: kernel_size for encoder and decoder layers.
438
+ stride: stride for encoder and decoder layers.
439
+ time_stride: stride for the final time layer, after the merge.
440
+ context: context for 1x1 conv in the decoder.
441
+ context_enc: context for 1x1 conv in the encoder.
442
+ norm_starts: layer at which group norm starts being used.
443
+ decoder layers are numbered in reverse order.
444
+ norm_groups: number of groups for group norm.
445
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
446
+ dconv_depth: depth of residual DConv branch.
447
+ dconv_comp: compression of DConv branch.
448
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
449
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
450
+ dconv_init: initial scale for the DConv branch LayerScale.
451
+ rescale: weight recaling trick
452
+
453
+ """
454
+ super().__init__()
455
+ self.cac = cac
456
+ self.wiener_residual = wiener_residual
457
+ self.audio_channels = audio_channels
458
+ self.sources = sources
459
+ self.kernel_size = kernel_size
460
+ self.context = context
461
+ self.stride = stride
462
+ self.depth = depth
463
+ self.channels = channels
464
+ self.samplerate = samplerate
465
+ self.segment = segment
466
+
467
+ self.nfft = nfft
468
+ self.hop_length = nfft // 4
469
+ self.wiener_iters = wiener_iters
470
+ self.end_iters = end_iters
471
+ self.freq_emb = None
472
+ self.hybrid = hybrid
473
+ self.hybrid_old = hybrid_old
474
+ if hybrid_old:
475
+ assert hybrid, "hybrid_old must come with hybrid=True"
476
+ if hybrid:
477
+ assert wiener_iters == end_iters
478
+
479
+ self.encoder = nn.ModuleList()
480
+ self.decoder = nn.ModuleList()
481
+
482
+ if hybrid:
483
+ self.tencoder = nn.ModuleList()
484
+ self.tdecoder = nn.ModuleList()
485
+
486
+ chin = audio_channels
487
+ chin_z = chin # number of channels for the freq branch
488
+ if self.cac:
489
+ chin_z *= 2
490
+ chout = channels_time or channels
491
+ chout_z = channels
492
+ freqs = nfft // 2
493
+
494
+ for index in range(depth):
495
+ lstm = index >= dconv_lstm
496
+ attn = index >= dconv_attn
497
+ norm = index >= norm_starts
498
+ freq = freqs > 1
499
+ stri = stride
500
+ ker = kernel_size
501
+ if not freq:
502
+ assert freqs == 1
503
+ ker = time_stride * 2
504
+ stri = time_stride
505
+
506
+ pad = True
507
+ last_freq = False
508
+ if freq and freqs <= kernel_size:
509
+ ker = freqs
510
+ pad = False
511
+ last_freq = True
512
+
513
+ kw = {
514
+ 'kernel_size': ker,
515
+ 'stride': stri,
516
+ 'freq': freq,
517
+ 'pad': pad,
518
+ 'norm': norm,
519
+ 'rewrite': rewrite,
520
+ 'norm_groups': norm_groups,
521
+ 'dconv_kw': {
522
+ 'lstm': lstm,
523
+ 'attn': attn,
524
+ 'depth': dconv_depth,
525
+ 'compress': dconv_comp,
526
+ 'init': dconv_init,
527
+ 'gelu': True,
528
+ }
529
+ }
530
+ kwt = dict(kw)
531
+ kwt['freq'] = 0
532
+ kwt['kernel_size'] = kernel_size
533
+ kwt['stride'] = stride
534
+ kwt['pad'] = True
535
+ kw_dec = dict(kw)
536
+ multi = False
537
+ if multi_freqs and index < multi_freqs_depth:
538
+ multi = True
539
+ kw_dec['context_freq'] = False
540
+
541
+ if last_freq:
542
+ chout_z = max(chout, chout_z)
543
+ chout = chout_z
544
+
545
+ enc = HEncLayer(chin_z, chout_z,
546
+ dconv=dconv_mode & 1, context=context_enc, **kw)
547
+ if hybrid and freq:
548
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
549
+ empty=last_freq, **kwt)
550
+ self.tencoder.append(tenc)
551
+
552
+ if multi:
553
+ enc = MultiWrap(enc, multi_freqs)
554
+ self.encoder.append(enc)
555
+ if index == 0:
556
+ chin = self.audio_channels * len(self.sources)
557
+ chin_z = chin
558
+ if self.cac:
559
+ chin_z *= 2
560
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
561
+ last=index == 0, context=context, **kw_dec)
562
+ if multi:
563
+ dec = MultiWrap(dec, multi_freqs)
564
+ if hybrid and freq:
565
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
566
+ last=index == 0, context=context, **kwt)
567
+ self.tdecoder.insert(0, tdec)
568
+ self.decoder.insert(0, dec)
569
+
570
+ chin = chout
571
+ chin_z = chout_z
572
+ chout = int(growth * chout)
573
+ chout_z = int(growth * chout_z)
574
+ if freq:
575
+ if freqs <= kernel_size:
576
+ freqs = 1
577
+ else:
578
+ freqs //= stride
579
+ if index == 0 and freq_emb:
580
+ self.freq_emb = ScaledEmbedding(
581
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
582
+ self.freq_emb_scale = freq_emb
583
+
584
+ if rescale:
585
+ rescale_module(self, reference=rescale)
586
+
587
+ def _spec(self, x):
588
+ hl = self.hop_length
589
+ nfft = self.nfft
590
+ x0 = x # noqa
591
+
592
+ if self.hybrid:
593
+ # We re-pad the signal in order to keep the property
594
+ # that the size of the output is exactly the size of the input
595
+ # divided by the stride (here hop_length), when divisible.
596
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
597
+ # which is not supported by torch.stft.
598
+ # Having all convolution operations follow this convention allow to easily
599
+ # align the time and frequency branches later on.
600
+ assert hl == nfft // 4
601
+ le = int(math.ceil(x.shape[-1] / hl))
602
+ pad = hl // 2 * 3
603
+ if not self.hybrid_old:
604
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
605
+ else:
606
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
607
+
608
+ z = spectro(x, nfft, hl)[..., :-1, :]
609
+ if self.hybrid:
610
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
611
+ z = z[..., 2:2+le]
612
+ return z
613
+
614
+ def _ispec(self, z, length=None, scale=0):
615
+ hl = self.hop_length // (4 ** scale)
616
+ z = F.pad(z, (0, 0, 0, 1))
617
+ if self.hybrid:
618
+ z = F.pad(z, (2, 2))
619
+ pad = hl // 2 * 3
620
+ if not self.hybrid_old:
621
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
622
+ else:
623
+ le = hl * int(math.ceil(length / hl))
624
+ x = ispectro(z, hl, length=le)
625
+ if not self.hybrid_old:
626
+ x = x[..., pad:pad + length]
627
+ else:
628
+ x = x[..., :length]
629
+ else:
630
+ x = ispectro(z, hl, length)
631
+ return x
632
+
633
+ def _magnitude(self, z):
634
+ # return the magnitude of the spectrogram, except when cac is True,
635
+ # in which case we just move the complex dimension to the channel one.
636
+ if self.cac:
637
+ B, C, Fr, T = z.shape
638
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
639
+ m = m.reshape(B, C * 2, Fr, T)
640
+ else:
641
+ m = z.abs()
642
+ return m
643
+
644
+ def _mask(self, z, m):
645
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
646
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
647
+ niters = self.wiener_iters
648
+ if self.cac:
649
+ B, S, C, Fr, T = m.shape
650
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
651
+ out = torch.view_as_complex(out.contiguous())
652
+ return out
653
+ if self.training:
654
+ niters = self.end_iters
655
+ if niters < 0:
656
+ z = z[:, None]
657
+ return z / (1e-8 + z.abs()) * m
658
+ else:
659
+ return self._wiener(m, z, niters)
660
+
661
+ def _wiener(self, mag_out, mix_stft, niters):
662
+ # apply wiener filtering from OpenUnmix.
663
+ init = mix_stft.dtype
664
+ wiener_win_len = 300
665
+ residual = self.wiener_residual
666
+
667
+ B, S, C, Fq, T = mag_out.shape
668
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
669
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
670
+
671
+ outs = []
672
+ for sample in range(B):
673
+ pos = 0
674
+ out = []
675
+ for pos in range(0, T, wiener_win_len):
676
+ frame = slice(pos, pos + wiener_win_len)
677
+ z_out = wiener(
678
+ mag_out[sample, frame], mix_stft[sample, frame], niters,
679
+ residual=residual)
680
+ out.append(z_out.transpose(-1, -2))
681
+ outs.append(torch.cat(out, dim=0))
682
+ out = torch.view_as_complex(torch.stack(outs, 0))
683
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
684
+ if residual:
685
+ out = out[:, :-1]
686
+ assert list(out.shape) == [B, S, C, Fq, T]
687
+ return out.to(init)
688
+
689
+ def forward(self, mix):
690
+ x = mix
691
+ length = x.shape[-1]
692
+
693
+ z = self._spec(mix)
694
+ mag = self._magnitude(z)
695
+ x = mag
696
+
697
+ B, C, Fq, T = x.shape
698
+
699
+ # unlike previous Demucs, we always normalize because it is easier.
700
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
701
+ std = x.std(dim=(1, 2, 3), keepdim=True)
702
+ x = (x - mean) / (1e-5 + std)
703
+ # x will be the freq. branch input.
704
+
705
+ if self.hybrid:
706
+ # Prepare the time branch input.
707
+ xt = mix
708
+ meant = xt.mean(dim=(1, 2), keepdim=True)
709
+ stdt = xt.std(dim=(1, 2), keepdim=True)
710
+ xt = (xt - meant) / (1e-5 + stdt)
711
+
712
+ # okay, this is a giant mess I know...
713
+ saved = [] # skip connections, freq.
714
+ saved_t = [] # skip connections, time.
715
+ lengths = [] # saved lengths to properly remove padding, freq branch.
716
+ lengths_t = [] # saved lengths for time branch.
717
+ for idx, encode in enumerate(self.encoder):
718
+ lengths.append(x.shape[-1])
719
+ inject = None
720
+ if self.hybrid and idx < len(self.tencoder):
721
+ # we have not yet merged branches.
722
+ lengths_t.append(xt.shape[-1])
723
+ tenc = self.tencoder[idx]
724
+ xt = tenc(xt)
725
+ if not tenc.empty:
726
+ # save for skip connection
727
+ saved_t.append(xt)
728
+ else:
729
+ # tenc contains just the first conv., so that now time and freq.
730
+ # branches have the same shape and can be merged.
731
+ inject = xt
732
+ x = encode(x, inject)
733
+ if idx == 0 and self.freq_emb is not None:
734
+ # add frequency embedding to allow for non equivariant convolutions
735
+ # over the frequency axis.
736
+ frs = torch.arange(x.shape[-2], device=x.device)
737
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
738
+ x = x + self.freq_emb_scale * emb
739
+
740
+ saved.append(x)
741
+
742
+ x = torch.zeros_like(x)
743
+ if self.hybrid:
744
+ xt = torch.zeros_like(x)
745
+ # initialize everything to zero (signal will go through u-net skips).
746
+
747
+ for idx, decode in enumerate(self.decoder):
748
+ skip = saved.pop(-1)
749
+ x, pre = decode(x, skip, lengths.pop(-1))
750
+ # `pre` contains the output just before final transposed convolution,
751
+ # which is used when the freq. and time branch separate.
752
+
753
+ if self.hybrid:
754
+ offset = self.depth - len(self.tdecoder)
755
+ if self.hybrid and idx >= offset:
756
+ tdec = self.tdecoder[idx - offset]
757
+ length_t = lengths_t.pop(-1)
758
+ if tdec.empty:
759
+ assert pre.shape[2] == 1, pre.shape
760
+ pre = pre[:, :, 0]
761
+ xt, _ = tdec(pre, None, length_t)
762
+ else:
763
+ skip = saved_t.pop(-1)
764
+ xt, _ = tdec(xt, skip, length_t)
765
+
766
+ # Let's make sure we used all stored skip connections.
767
+ assert len(saved) == 0
768
+ assert len(lengths_t) == 0
769
+ assert len(saved_t) == 0
770
+
771
+ S = len(self.sources)
772
+ x = x.view(B, S, -1, Fq, T)
773
+ x = x * std[:, None] + mean[:, None]
774
+
775
+ zout = self._mask(z, x)
776
+ x = self._ispec(zout, length)
777
+
778
+ if self.hybrid:
779
+ xt = xt.view(B, S, -1, length)
780
+ xt = xt * stdt[:, None] + meant[:, None]
781
+ x = xt + x
782
+ return x
demucs4/htdemucs.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+ """
8
+ This code contains the spectrogram and Hybrid version of Demucs.
9
+ """
10
+ import math
11
+
12
+ from openunmix.filtering import wiener
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from fractions import Fraction
17
+ from einops import rearrange
18
+
19
+ from .transformer import CrossTransformerEncoder
20
+
21
+ from .demucs import rescale_module
22
+ from .states import capture_init
23
+ from .spec import spectro, ispectro
24
+ from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
25
+
26
+
27
+ class HTDemucs(nn.Module):
28
+ """
29
+ Spectrogram and hybrid Demucs model.
30
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
31
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
32
+ Frequency layers can still access information across time steps thanks to the DConv residual.
33
+
34
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
35
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
36
+
37
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
38
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
39
+ Open Unmix implementation [Stoter et al. 2019].
40
+
41
+ The loss is always on the temporal domain, by backpropagating through the above
42
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
43
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
44
+ contribution, without changing the one from the waveform, which will lead to worse performance.
45
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
46
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
47
+ hybrid models.
48
+
49
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
50
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
51
+
52
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
53
+ """
54
+
55
+ @capture_init
56
+ def __init__(
57
+ self,
58
+ sources,
59
+ # Channels
60
+ audio_channels=2,
61
+ channels=48,
62
+ channels_time=None,
63
+ growth=2,
64
+ # STFT
65
+ nfft=4096,
66
+ wiener_iters=0,
67
+ end_iters=0,
68
+ wiener_residual=False,
69
+ cac=True,
70
+ # Main structure
71
+ depth=4,
72
+ rewrite=True,
73
+ # Frequency branch
74
+ multi_freqs=None,
75
+ multi_freqs_depth=3,
76
+ freq_emb=0.2,
77
+ emb_scale=10,
78
+ emb_smooth=True,
79
+ # Convolutions
80
+ kernel_size=8,
81
+ time_stride=2,
82
+ stride=4,
83
+ context=1,
84
+ context_enc=0,
85
+ # Normalization
86
+ norm_starts=4,
87
+ norm_groups=4,
88
+ # DConv residual branch
89
+ dconv_mode=1,
90
+ dconv_depth=2,
91
+ dconv_comp=8,
92
+ dconv_init=1e-3,
93
+ # Before the Transformer
94
+ bottom_channels=0,
95
+ # Transformer
96
+ t_layers=5,
97
+ t_emb="sin",
98
+ t_hidden_scale=4.0,
99
+ t_heads=8,
100
+ t_dropout=0.0,
101
+ t_max_positions=10000,
102
+ t_norm_in=True,
103
+ t_norm_in_group=False,
104
+ t_group_norm=False,
105
+ t_norm_first=True,
106
+ t_norm_out=True,
107
+ t_max_period=10000.0,
108
+ t_weight_decay=0.0,
109
+ t_lr=None,
110
+ t_layer_scale=True,
111
+ t_gelu=True,
112
+ t_weight_pos_embed=1.0,
113
+ t_sin_random_shift=0,
114
+ t_cape_mean_normalize=True,
115
+ t_cape_augment=True,
116
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
117
+ t_sparse_self_attn=False,
118
+ t_sparse_cross_attn=False,
119
+ t_mask_type="diag",
120
+ t_mask_random_seed=42,
121
+ t_sparse_attn_window=500,
122
+ t_global_window=100,
123
+ t_sparsity=0.95,
124
+ t_auto_sparsity=False,
125
+ # ------ Particuliar parameters
126
+ t_cross_first=False,
127
+ # Weight init
128
+ rescale=0.1,
129
+ # Metadata
130
+ samplerate=44100,
131
+ segment=10,
132
+ use_train_segment=True,
133
+ ):
134
+ """
135
+ Args:
136
+ sources (list[str]): list of source names.
137
+ audio_channels (int): input/output audio channels.
138
+ channels (int): initial number of hidden channels.
139
+ channels_time: if not None, use a different `channels` value for the time branch.
140
+ growth: increase the number of hidden channels by this factor at each layer.
141
+ nfft: number of fft bins. Note that changing this require careful computation of
142
+ various shape parameters and will not work out of the box for hybrid models.
143
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
144
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
145
+ wiener_residual: add residual source before wiener filtering.
146
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
147
+ in input and output. no further processing is done before ISTFT.
148
+ depth (int): number of layers in the encoder and in the decoder.
149
+ rewrite (bool): add 1x1 convolution to each layer.
150
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
151
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
152
+ layers will be wrapped.
153
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
154
+ the actual value controls the weight of the embedding.
155
+ emb_scale: equivalent to scaling the embedding learning rate
156
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
157
+ kernel_size: kernel_size for encoder and decoder layers.
158
+ stride: stride for encoder and decoder layers.
159
+ time_stride: stride for the final time layer, after the merge.
160
+ context: context for 1x1 conv in the decoder.
161
+ context_enc: context for 1x1 conv in the encoder.
162
+ norm_starts: layer at which group norm starts being used.
163
+ decoder layers are numbered in reverse order.
164
+ norm_groups: number of groups for group norm.
165
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
166
+ dconv_depth: depth of residual DConv branch.
167
+ dconv_comp: compression of DConv branch.
168
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
169
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
170
+ dconv_init: initial scale for the DConv branch LayerScale.
171
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
172
+ transformer in order to change the number of channels
173
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
174
+ t_emb: "sin", "cape" or "scaled"
175
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
176
+ for instance if C = 384 (the number of channels in the transformer) and
177
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
178
+ 384 * 4 = 1536
179
+ t_heads: number of heads for the transformer
180
+ t_dropout: dropout in the transformer
181
+ t_max_positions: max_positions for the "scaled" positional embedding, only
182
+ useful if t_emb="scaled"
183
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
184
+ transformer layers
185
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
186
+ timesteps (GroupNorm with group=1)
187
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
188
+ timesteps (GroupNorm with group=1)
189
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
190
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
191
+ t_max_period: (float) denominator in the sinusoidal embedding expression
192
+ t_weight_decay: (float) weight decay for the transformer
193
+ t_lr: (float) specific learning rate for the transformer
194
+ t_layer_scale: (bool) Layer Scale for the transformer
195
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
196
+ t_weight_pos_embed: (float) weighting of the positional embedding
197
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
198
+ see: https://arxiv.org/abs/2106.03143
199
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
200
+ during the inference, see: https://arxiv.org/abs/2106.03143
201
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
202
+ see: https://arxiv.org/abs/2106.03143
203
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
204
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
205
+ unless you designed really specific masks)
206
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
207
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
208
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
209
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
210
+ that generated the random part of the mask
211
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
212
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
213
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
214
+ and mask[:, :t_global_window] will be True
215
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
216
+ level of the random part of the mask.
217
+ t_cross_first: (bool) if True cross attention is the first layer of the
218
+ transformer (False seems to be better)
219
+ rescale: weight rescaling trick
220
+ use_train_segment: (bool) if True, the actual size that is used during the
221
+ training is used during inference.
222
+ """
223
+ super().__init__()
224
+ self.cac = cac
225
+ self.wiener_residual = wiener_residual
226
+ self.audio_channels = audio_channels
227
+ self.sources = sources
228
+ self.kernel_size = kernel_size
229
+ self.context = context
230
+ self.stride = stride
231
+ self.depth = depth
232
+ self.bottom_channels = bottom_channels
233
+ self.channels = channels
234
+ self.samplerate = samplerate
235
+ self.segment = segment
236
+ self.use_train_segment = use_train_segment
237
+ self.nfft = nfft
238
+ self.hop_length = nfft // 4
239
+ self.wiener_iters = wiener_iters
240
+ self.end_iters = end_iters
241
+ self.freq_emb = None
242
+ assert wiener_iters == end_iters
243
+
244
+ self.encoder = nn.ModuleList()
245
+ self.decoder = nn.ModuleList()
246
+
247
+ self.tencoder = nn.ModuleList()
248
+ self.tdecoder = nn.ModuleList()
249
+
250
+ chin = audio_channels
251
+ chin_z = chin # number of channels for the freq branch
252
+ if self.cac:
253
+ chin_z *= 2
254
+ chout = channels_time or channels
255
+ chout_z = channels
256
+ freqs = nfft // 2
257
+
258
+ for index in range(depth):
259
+ norm = index >= norm_starts
260
+ freq = freqs > 1
261
+ stri = stride
262
+ ker = kernel_size
263
+ if not freq:
264
+ assert freqs == 1
265
+ ker = time_stride * 2
266
+ stri = time_stride
267
+
268
+ pad = True
269
+ last_freq = False
270
+ if freq and freqs <= kernel_size:
271
+ ker = freqs
272
+ pad = False
273
+ last_freq = True
274
+
275
+ kw = {
276
+ "kernel_size": ker,
277
+ "stride": stri,
278
+ "freq": freq,
279
+ "pad": pad,
280
+ "norm": norm,
281
+ "rewrite": rewrite,
282
+ "norm_groups": norm_groups,
283
+ "dconv_kw": {
284
+ "depth": dconv_depth,
285
+ "compress": dconv_comp,
286
+ "init": dconv_init,
287
+ "gelu": True,
288
+ },
289
+ }
290
+ kwt = dict(kw)
291
+ kwt["freq"] = 0
292
+ kwt["kernel_size"] = kernel_size
293
+ kwt["stride"] = stride
294
+ kwt["pad"] = True
295
+ kw_dec = dict(kw)
296
+ multi = False
297
+ if multi_freqs and index < multi_freqs_depth:
298
+ multi = True
299
+ kw_dec["context_freq"] = False
300
+
301
+ if last_freq:
302
+ chout_z = max(chout, chout_z)
303
+ chout = chout_z
304
+
305
+ enc = HEncLayer(
306
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
307
+ )
308
+ if freq:
309
+ tenc = HEncLayer(
310
+ chin,
311
+ chout,
312
+ dconv=dconv_mode & 1,
313
+ context=context_enc,
314
+ empty=last_freq,
315
+ **kwt
316
+ )
317
+ self.tencoder.append(tenc)
318
+
319
+ if multi:
320
+ enc = MultiWrap(enc, multi_freqs)
321
+ self.encoder.append(enc)
322
+ if index == 0:
323
+ chin = self.audio_channels * len(self.sources)
324
+ chin_z = chin
325
+ if self.cac:
326
+ chin_z *= 2
327
+ dec = HDecLayer(
328
+ chout_z,
329
+ chin_z,
330
+ dconv=dconv_mode & 2,
331
+ last=index == 0,
332
+ context=context,
333
+ **kw_dec
334
+ )
335
+ if multi:
336
+ dec = MultiWrap(dec, multi_freqs)
337
+ if freq:
338
+ tdec = HDecLayer(
339
+ chout,
340
+ chin,
341
+ dconv=dconv_mode & 2,
342
+ empty=last_freq,
343
+ last=index == 0,
344
+ context=context,
345
+ **kwt
346
+ )
347
+ self.tdecoder.insert(0, tdec)
348
+ self.decoder.insert(0, dec)
349
+
350
+ chin = chout
351
+ chin_z = chout_z
352
+ chout = int(growth * chout)
353
+ chout_z = int(growth * chout_z)
354
+ if freq:
355
+ if freqs <= kernel_size:
356
+ freqs = 1
357
+ else:
358
+ freqs //= stride
359
+ if index == 0 and freq_emb:
360
+ self.freq_emb = ScaledEmbedding(
361
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
362
+ )
363
+ self.freq_emb_scale = freq_emb
364
+
365
+ if rescale:
366
+ rescale_module(self, reference=rescale)
367
+
368
+ transformer_channels = channels * growth ** (depth - 1)
369
+ if bottom_channels:
370
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
371
+ self.channel_downsampler = nn.Conv1d(
372
+ bottom_channels, transformer_channels, 1
373
+ )
374
+ self.channel_upsampler_t = nn.Conv1d(
375
+ transformer_channels, bottom_channels, 1
376
+ )
377
+ self.channel_downsampler_t = nn.Conv1d(
378
+ bottom_channels, transformer_channels, 1
379
+ )
380
+
381
+ transformer_channels = bottom_channels
382
+
383
+ if t_layers > 0:
384
+ self.crosstransformer = CrossTransformerEncoder(
385
+ dim=transformer_channels,
386
+ emb=t_emb,
387
+ hidden_scale=t_hidden_scale,
388
+ num_heads=t_heads,
389
+ num_layers=t_layers,
390
+ cross_first=t_cross_first,
391
+ dropout=t_dropout,
392
+ max_positions=t_max_positions,
393
+ norm_in=t_norm_in,
394
+ norm_in_group=t_norm_in_group,
395
+ group_norm=t_group_norm,
396
+ norm_first=t_norm_first,
397
+ norm_out=t_norm_out,
398
+ max_period=t_max_period,
399
+ weight_decay=t_weight_decay,
400
+ lr=t_lr,
401
+ layer_scale=t_layer_scale,
402
+ gelu=t_gelu,
403
+ sin_random_shift=t_sin_random_shift,
404
+ weight_pos_embed=t_weight_pos_embed,
405
+ cape_mean_normalize=t_cape_mean_normalize,
406
+ cape_augment=t_cape_augment,
407
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
408
+ sparse_self_attn=t_sparse_self_attn,
409
+ sparse_cross_attn=t_sparse_cross_attn,
410
+ mask_type=t_mask_type,
411
+ mask_random_seed=t_mask_random_seed,
412
+ sparse_attn_window=t_sparse_attn_window,
413
+ global_window=t_global_window,
414
+ sparsity=t_sparsity,
415
+ auto_sparsity=t_auto_sparsity,
416
+ )
417
+ else:
418
+ self.crosstransformer = None
419
+
420
+ def _spec(self, x):
421
+ hl = self.hop_length
422
+ nfft = self.nfft
423
+ x0 = x # noqa
424
+
425
+ # We re-pad the signal in order to keep the property
426
+ # that the size of the output is exactly the size of the input
427
+ # divided by the stride (here hop_length), when divisible.
428
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
429
+ # which is not supported by torch.stft.
430
+ # Having all convolution operations follow this convention allow to easily
431
+ # align the time and frequency branches later on.
432
+ assert hl == nfft // 4
433
+ le = int(math.ceil(x.shape[-1] / hl))
434
+ pad = hl // 2 * 3
435
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
436
+
437
+ z = spectro(x, nfft, hl)[..., :-1, :]
438
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
439
+ z = z[..., 2: 2 + le]
440
+ return z
441
+
442
+ def _ispec(self, z, length=None, scale=0):
443
+ hl = self.hop_length // (4**scale)
444
+ z = F.pad(z, (0, 0, 0, 1))
445
+ z = F.pad(z, (2, 2))
446
+ pad = hl // 2 * 3
447
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
448
+ x = ispectro(z, hl, length=le)
449
+ x = x[..., pad: pad + length]
450
+ return x
451
+
452
+ def _magnitude(self, z):
453
+ # return the magnitude of the spectrogram, except when cac is True,
454
+ # in which case we just move the complex dimension to the channel one.
455
+ if self.cac:
456
+ B, C, Fr, T = z.shape
457
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
458
+ m = m.reshape(B, C * 2, Fr, T)
459
+ else:
460
+ m = z.abs()
461
+ return m
462
+
463
+ def _mask(self, z, m):
464
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
465
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
466
+ niters = self.wiener_iters
467
+ if self.cac:
468
+ B, S, C, Fr, T = m.shape
469
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
470
+ out = torch.view_as_complex(out.contiguous())
471
+ return out
472
+ if self.training:
473
+ niters = self.end_iters
474
+ if niters < 0:
475
+ z = z[:, None]
476
+ return z / (1e-8 + z.abs()) * m
477
+ else:
478
+ return self._wiener(m, z, niters)
479
+
480
+ def _wiener(self, mag_out, mix_stft, niters):
481
+ # apply wiener filtering from OpenUnmix.
482
+ init = mix_stft.dtype
483
+ wiener_win_len = 300
484
+ residual = self.wiener_residual
485
+
486
+ B, S, C, Fq, T = mag_out.shape
487
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
488
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
489
+
490
+ outs = []
491
+ for sample in range(B):
492
+ pos = 0
493
+ out = []
494
+ for pos in range(0, T, wiener_win_len):
495
+ frame = slice(pos, pos + wiener_win_len)
496
+ z_out = wiener(
497
+ mag_out[sample, frame],
498
+ mix_stft[sample, frame],
499
+ niters,
500
+ residual=residual,
501
+ )
502
+ out.append(z_out.transpose(-1, -2))
503
+ outs.append(torch.cat(out, dim=0))
504
+ out = torch.view_as_complex(torch.stack(outs, 0))
505
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
506
+ if residual:
507
+ out = out[:, :-1]
508
+ assert list(out.shape) == [B, S, C, Fq, T]
509
+ return out.to(init)
510
+
511
+ def valid_length(self, length: int):
512
+ """
513
+ Return a length that is appropriate for evaluation.
514
+ In our case, always return the training length, unless
515
+ it is smaller than the given length, in which case this
516
+ raises an error.
517
+ """
518
+ if not self.use_train_segment:
519
+ return length
520
+ training_length = int(self.segment * self.samplerate)
521
+ if training_length < length:
522
+ raise ValueError(
523
+ f"Given length {length} is longer than "
524
+ f"training length {training_length}")
525
+ return training_length
526
+
527
+ def forward(self, mix):
528
+ length = mix.shape[-1]
529
+ length_pre_pad = None
530
+ if self.use_train_segment:
531
+ if self.training:
532
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
533
+ else:
534
+ training_length = int(self.segment * self.samplerate)
535
+ if mix.shape[-1] < training_length:
536
+ length_pre_pad = mix.shape[-1]
537
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
538
+ z = self._spec(mix)
539
+ mag = self._magnitude(z)
540
+ x = mag
541
+
542
+ B, C, Fq, T = x.shape
543
+
544
+ # unlike previous Demucs, we always normalize because it is easier.
545
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
546
+ std = x.std(dim=(1, 2, 3), keepdim=True)
547
+ x = (x - mean) / (1e-5 + std)
548
+ # x will be the freq. branch input.
549
+
550
+ # Prepare the time branch input.
551
+ xt = mix
552
+ meant = xt.mean(dim=(1, 2), keepdim=True)
553
+ stdt = xt.std(dim=(1, 2), keepdim=True)
554
+ xt = (xt - meant) / (1e-5 + stdt)
555
+
556
+ # okay, this is a giant mess I know...
557
+ saved = [] # skip connections, freq.
558
+ saved_t = [] # skip connections, time.
559
+ lengths = [] # saved lengths to properly remove padding, freq branch.
560
+ lengths_t = [] # saved lengths for time branch.
561
+ for idx, encode in enumerate(self.encoder):
562
+ lengths.append(x.shape[-1])
563
+ inject = None
564
+ if idx < len(self.tencoder):
565
+ # we have not yet merged branches.
566
+ lengths_t.append(xt.shape[-1])
567
+ tenc = self.tencoder[idx]
568
+ xt = tenc(xt)
569
+ if not tenc.empty:
570
+ # save for skip connection
571
+ saved_t.append(xt)
572
+ else:
573
+ # tenc contains just the first conv., so that now time and freq.
574
+ # branches have the same shape and can be merged.
575
+ inject = xt
576
+ x = encode(x, inject)
577
+ if idx == 0 and self.freq_emb is not None:
578
+ # add frequency embedding to allow for non equivariant convolutions
579
+ # over the frequency axis.
580
+ frs = torch.arange(x.shape[-2], device=x.device)
581
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
582
+ x = x + self.freq_emb_scale * emb
583
+
584
+ saved.append(x)
585
+ if self.crosstransformer:
586
+ if self.bottom_channels:
587
+ b, c, f, t = x.shape
588
+ x = rearrange(x, "b c f t-> b c (f t)")
589
+ x = self.channel_upsampler(x)
590
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
591
+ xt = self.channel_upsampler_t(xt)
592
+
593
+ x, xt = self.crosstransformer(x, xt)
594
+
595
+ if self.bottom_channels:
596
+ x = rearrange(x, "b c f t-> b c (f t)")
597
+ x = self.channel_downsampler(x)
598
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
599
+ xt = self.channel_downsampler_t(xt)
600
+
601
+ for idx, decode in enumerate(self.decoder):
602
+ skip = saved.pop(-1)
603
+ x, pre = decode(x, skip, lengths.pop(-1))
604
+ # `pre` contains the output just before final transposed convolution,
605
+ # which is used when the freq. and time branch separate.
606
+
607
+ offset = self.depth - len(self.tdecoder)
608
+ if idx >= offset:
609
+ tdec = self.tdecoder[idx - offset]
610
+ length_t = lengths_t.pop(-1)
611
+ if tdec.empty:
612
+ assert pre.shape[2] == 1, pre.shape
613
+ pre = pre[:, :, 0]
614
+ xt, _ = tdec(pre, None, length_t)
615
+ else:
616
+ skip = saved_t.pop(-1)
617
+ xt, _ = tdec(xt, skip, length_t)
618
+
619
+ # Let's make sure we used all stored skip connections.
620
+ assert len(saved) == 0
621
+ assert len(lengths_t) == 0
622
+ assert len(saved_t) == 0
623
+
624
+ S = len(self.sources)
625
+ x = x.view(B, S, -1, Fq, T)
626
+ x = x * std[:, None] + mean[:, None]
627
+
628
+ zout = self._mask(z, x)
629
+ if self.use_train_segment:
630
+ if self.training:
631
+ x = self._ispec(zout, length)
632
+ else:
633
+ x = self._ispec(zout, training_length)
634
+ else:
635
+ x = self._ispec(zout, length)
636
+
637
+ if self.use_train_segment:
638
+ if self.training:
639
+ xt = xt.view(B, S, -1, length)
640
+ else:
641
+ xt = xt.view(B, S, -1, training_length)
642
+ else:
643
+ xt = xt.view(B, S, -1, length)
644
+ xt = xt * stdt[:, None] + meant[:, None]
645
+ x = xt + x
646
+ if length_pre_pad:
647
+ x = x[..., :length_pre_pad]
648
+ return x
demucs4/spec.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Conveniance wrapper to perform STFT and iSTFT"""
7
+
8
+ import torch as th
9
+
10
+
11
+ def spectro(x, n_fft=512, hop_length=None, pad=0):
12
+ *other, length = x.shape
13
+ x = x.reshape(-1, length)
14
+ z = th.stft(x,
15
+ n_fft * (1 + pad),
16
+ hop_length or n_fft // 4,
17
+ window=th.hann_window(n_fft).to(x),
18
+ win_length=n_fft,
19
+ normalized=True,
20
+ center=True,
21
+ return_complex=True,
22
+ pad_mode='reflect')
23
+ _, freqs, frame = z.shape
24
+ return z.view(*other, freqs, frame)
25
+
26
+
27
+ def ispectro(z, hop_length=None, length=None, pad=0):
28
+ *other, freqs, frames = z.shape
29
+ n_fft = 2 * freqs - 2
30
+ z = z.view(-1, freqs, frames)
31
+ win_length = n_fft // (1 + pad)
32
+ x = th.istft(z,
33
+ n_fft,
34
+ hop_length,
35
+ window=th.hann_window(win_length).to(z.real),
36
+ win_length=win_length,
37
+ normalized=True,
38
+ length=length,
39
+ center=True)
40
+ _, length = x.shape
41
+ return x.view(*other, length)
demucs4/states.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Utilities to save and load models.
8
+ """
9
+ from contextlib import contextmanager
10
+
11
+ import functools
12
+ import hashlib
13
+ import inspect
14
+ import io
15
+ from pathlib import Path
16
+ import warnings
17
+
18
+ from omegaconf import OmegaConf
19
+ from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
20
+ import torch
21
+
22
+
23
+ def get_quantizer(model, args, optimizer=None):
24
+ """Return the quantizer given the XP quantization args."""
25
+ quantizer = None
26
+ if args.diffq:
27
+ quantizer = DiffQuantizer(
28
+ model, min_size=args.min_size, group_size=args.group_size)
29
+ if optimizer is not None:
30
+ quantizer.setup_optimizer(optimizer)
31
+ elif args.qat:
32
+ quantizer = UniformQuantizer(
33
+ model, bits=args.qat, min_size=args.min_size)
34
+ return quantizer
35
+
36
+
37
+ def load_model(path_or_package, strict=False):
38
+ """Load a model from the given serialized model, either given as a dict (already loaded)
39
+ or a path to a file on disk."""
40
+ if isinstance(path_or_package, dict):
41
+ package = path_or_package
42
+ elif isinstance(path_or_package, (str, Path)):
43
+ with warnings.catch_warnings():
44
+ warnings.simplefilter("ignore")
45
+ path = path_or_package
46
+ package = torch.load(path, 'cpu')
47
+ else:
48
+ raise ValueError(f"Invalid type for {path_or_package}.")
49
+
50
+ klass = package["klass"]
51
+ args = package["args"]
52
+ kwargs = package["kwargs"]
53
+
54
+ if strict:
55
+ model = klass(*args, **kwargs)
56
+ else:
57
+ sig = inspect.signature(klass)
58
+ for key in list(kwargs):
59
+ if key not in sig.parameters:
60
+ warnings.warn("Dropping inexistant parameter " + key)
61
+ del kwargs[key]
62
+ model = klass(*args, **kwargs)
63
+
64
+ state = package["state"]
65
+
66
+ set_state(model, state)
67
+ return model
68
+
69
+
70
+ def get_state(model, quantizer, half=False):
71
+ """Get the state from a model, potentially with quantization applied.
72
+ If `half` is True, model are stored as half precision, which shouldn't impact performance
73
+ but half the state size."""
74
+ if quantizer is None:
75
+ dtype = torch.half if half else None
76
+ state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
77
+ else:
78
+ state = quantizer.get_quantized_state()
79
+ state['__quantized'] = True
80
+ return state
81
+
82
+
83
+ def set_state(model, state, quantizer=None):
84
+ """Set the state on a given model."""
85
+ if state.get('__quantized'):
86
+ if quantizer is not None:
87
+ quantizer.restore_quantized_state(model, state['quantized'])
88
+ else:
89
+ restore_quantized_state(model, state)
90
+ else:
91
+ model.load_state_dict(state)
92
+ return state
93
+
94
+
95
+ def save_with_checksum(content, path):
96
+ """Save the given value on disk, along with a sha256 hash.
97
+ Should be used with the output of either `serialize_model` or `get_state`."""
98
+ buf = io.BytesIO()
99
+ torch.save(content, buf)
100
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
101
+
102
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
103
+ path.write_bytes(buf.getvalue())
104
+
105
+
106
+ def serialize_model(model, training_args, quantizer=None, half=True):
107
+ args, kwargs = model._init_args_kwargs
108
+ klass = model.__class__
109
+
110
+ state = get_state(model, quantizer, half)
111
+ return {
112
+ 'klass': klass,
113
+ 'args': args,
114
+ 'kwargs': kwargs,
115
+ 'state': state,
116
+ 'training_args': OmegaConf.to_container(training_args, resolve=True),
117
+ }
118
+
119
+
120
+ def copy_state(state):
121
+ return {k: v.cpu().clone() for k, v in state.items()}
122
+
123
+
124
+ @contextmanager
125
+ def swap_state(model, state):
126
+ """
127
+ Context manager that swaps the state of a model, e.g:
128
+
129
+ # model is in old state
130
+ with swap_state(model, new_state):
131
+ # model in new state
132
+ # model back to old state
133
+ """
134
+ old_state = copy_state(model.state_dict())
135
+ model.load_state_dict(state, strict=False)
136
+ try:
137
+ yield
138
+ finally:
139
+ model.load_state_dict(old_state)
140
+
141
+
142
+ def capture_init(init):
143
+ @functools.wraps(init)
144
+ def __init__(self, *args, **kwargs):
145
+ self._init_args_kwargs = (args, kwargs)
146
+ init(self, *args, **kwargs)
147
+
148
+ return __init__
demucs4/transformer.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Meta, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+
8
+ import random
9
+ import typing as tp
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import math
16
+ from einops import rearrange
17
+
18
+
19
+ def create_sin_embedding(
20
+ length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
21
+ ):
22
+ # We aim for TBC format
23
+ assert dim % 2 == 0
24
+ pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
25
+ half_dim = dim // 2
26
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
27
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
28
+ return torch.cat(
29
+ [
30
+ torch.cos(phase),
31
+ torch.sin(phase),
32
+ ],
33
+ dim=-1,
34
+ )
35
+
36
+
37
+ def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
38
+ """
39
+ :param d_model: dimension of the model
40
+ :param height: height of the positions
41
+ :param width: width of the positions
42
+ :return: d_model*height*width position matrix
43
+ """
44
+ if d_model % 4 != 0:
45
+ raise ValueError(
46
+ "Cannot use sin/cos positional encoding with "
47
+ "odd dimension (got dim={:d})".format(d_model)
48
+ )
49
+ pe = torch.zeros(d_model, height, width)
50
+ # Each dimension use half of d_model
51
+ d_model = int(d_model / 2)
52
+ div_term = torch.exp(
53
+ torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
54
+ )
55
+ pos_w = torch.arange(0.0, width).unsqueeze(1)
56
+ pos_h = torch.arange(0.0, height).unsqueeze(1)
57
+ pe[0:d_model:2, :, :] = (
58
+ torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
59
+ )
60
+ pe[1:d_model:2, :, :] = (
61
+ torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
62
+ )
63
+ pe[d_model::2, :, :] = (
64
+ torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
65
+ )
66
+ pe[d_model + 1:: 2, :, :] = (
67
+ torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
68
+ )
69
+
70
+ return pe[None, :].to(device)
71
+
72
+
73
+ def create_sin_embedding_cape(
74
+ length: int,
75
+ dim: int,
76
+ batch_size: int,
77
+ mean_normalize: bool,
78
+ augment: bool, # True during training
79
+ max_global_shift: float = 0.0, # delta max
80
+ max_local_shift: float = 0.0, # epsilon max
81
+ max_scale: float = 1.0,
82
+ device: str = "cpu",
83
+ max_period: float = 10000.0,
84
+ ):
85
+ # We aim for TBC format
86
+ assert dim % 2 == 0
87
+ pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
88
+ pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
89
+ if mean_normalize:
90
+ pos -= torch.nanmean(pos, dim=0, keepdim=True)
91
+
92
+ if augment:
93
+ delta = np.random.uniform(
94
+ -max_global_shift, +max_global_shift, size=[1, batch_size, 1]
95
+ )
96
+ delta_local = np.random.uniform(
97
+ -max_local_shift, +max_local_shift, size=[length, batch_size, 1]
98
+ )
99
+ log_lambdas = np.random.uniform(
100
+ -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
101
+ )
102
+ pos = (pos + delta + delta_local) * np.exp(log_lambdas)
103
+
104
+ pos = pos.to(device)
105
+
106
+ half_dim = dim // 2
107
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
108
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
109
+ return torch.cat(
110
+ [
111
+ torch.cos(phase),
112
+ torch.sin(phase),
113
+ ],
114
+ dim=-1,
115
+ ).float()
116
+
117
+
118
+ def get_causal_mask(length):
119
+ pos = torch.arange(length)
120
+ return pos > pos[:, None]
121
+
122
+
123
+ def get_elementary_mask(
124
+ T1,
125
+ T2,
126
+ mask_type,
127
+ sparse_attn_window,
128
+ global_window,
129
+ mask_random_seed,
130
+ sparsity,
131
+ device,
132
+ ):
133
+ """
134
+ When the input of the Decoder has length T1 and the output T2
135
+ The mask matrix has shape (T2, T1)
136
+ """
137
+ assert mask_type in ["diag", "jmask", "random", "global"]
138
+
139
+ if mask_type == "global":
140
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
141
+ mask[:, :global_window] = True
142
+ line_window = int(global_window * T2 / T1)
143
+ mask[:line_window, :] = True
144
+
145
+ if mask_type == "diag":
146
+
147
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
148
+ rows = torch.arange(T2)[:, None]
149
+ cols = (
150
+ (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
151
+ .long()
152
+ .clamp(0, T1 - 1)
153
+ )
154
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
155
+
156
+ elif mask_type == "jmask":
157
+ mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
158
+ rows = torch.arange(T2 + 2)[:, None]
159
+ t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
160
+ t = (t * (t + 1) / 2).int()
161
+ t = torch.cat([-t.flip(0)[:-1], t])
162
+ cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
163
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
164
+ mask = mask[1:-1, 1:-1]
165
+
166
+ elif mask_type == "random":
167
+ gene = torch.Generator(device=device)
168
+ gene.manual_seed(mask_random_seed)
169
+ mask = (
170
+ torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
171
+ > sparsity
172
+ )
173
+
174
+ mask = mask.to(device)
175
+ return mask
176
+
177
+
178
+ def get_mask(
179
+ T1,
180
+ T2,
181
+ mask_type,
182
+ sparse_attn_window,
183
+ global_window,
184
+ mask_random_seed,
185
+ sparsity,
186
+ device,
187
+ ):
188
+ """
189
+ Return a SparseCSRTensor mask that is a combination of elementary masks
190
+ mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
191
+ """
192
+ from xformers.sparse import SparseCSRTensor
193
+ # create a list
194
+ mask_types = mask_type.split("_")
195
+
196
+ all_masks = [
197
+ get_elementary_mask(
198
+ T1,
199
+ T2,
200
+ mask,
201
+ sparse_attn_window,
202
+ global_window,
203
+ mask_random_seed,
204
+ sparsity,
205
+ device,
206
+ )
207
+ for mask in mask_types
208
+ ]
209
+
210
+ final_mask = torch.stack(all_masks).sum(axis=0) > 0
211
+
212
+ return SparseCSRTensor.from_dense(final_mask[None])
213
+
214
+
215
+ class ScaledEmbedding(nn.Module):
216
+ def __init__(
217
+ self,
218
+ num_embeddings: int,
219
+ embedding_dim: int,
220
+ scale: float = 1.0,
221
+ boost: float = 3.0,
222
+ ):
223
+ super().__init__()
224
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
225
+ self.embedding.weight.data *= scale / boost
226
+ self.boost = boost
227
+
228
+ @property
229
+ def weight(self):
230
+ return self.embedding.weight * self.boost
231
+
232
+ def forward(self, x):
233
+ return self.embedding(x) * self.boost
234
+
235
+
236
+ class LayerScale(nn.Module):
237
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
238
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
239
+ """
240
+
241
+ def __init__(self, channels: int, init: float = 0, channel_last=False):
242
+ """
243
+ channel_last = False corresponds to (B, C, T) tensors
244
+ channel_last = True corresponds to (T, B, C) tensors
245
+ """
246
+ super().__init__()
247
+ self.channel_last = channel_last
248
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
249
+ self.scale.data[:] = init
250
+
251
+ def forward(self, x):
252
+ if self.channel_last:
253
+ return self.scale * x
254
+ else:
255
+ return self.scale[:, None] * x
256
+
257
+
258
+ class MyGroupNorm(nn.GroupNorm):
259
+ def __init__(self, *args, **kwargs):
260
+ super().__init__(*args, **kwargs)
261
+
262
+ def forward(self, x):
263
+ """
264
+ x: (B, T, C)
265
+ if num_groups=1: Normalisation on all T and C together for each B
266
+ """
267
+ x = x.transpose(1, 2)
268
+ return super().forward(x).transpose(1, 2)
269
+
270
+
271
+ class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
272
+ def __init__(
273
+ self,
274
+ d_model,
275
+ nhead,
276
+ dim_feedforward=2048,
277
+ dropout=0.1,
278
+ activation=F.relu,
279
+ group_norm=0,
280
+ norm_first=False,
281
+ norm_out=False,
282
+ layer_norm_eps=1e-5,
283
+ layer_scale=False,
284
+ init_values=1e-4,
285
+ device=None,
286
+ dtype=None,
287
+ sparse=False,
288
+ mask_type="diag",
289
+ mask_random_seed=42,
290
+ sparse_attn_window=500,
291
+ global_window=50,
292
+ auto_sparsity=False,
293
+ sparsity=0.95,
294
+ batch_first=False,
295
+ ):
296
+ factory_kwargs = {"device": device, "dtype": dtype}
297
+ super().__init__(
298
+ d_model=d_model,
299
+ nhead=nhead,
300
+ dim_feedforward=dim_feedforward,
301
+ dropout=dropout,
302
+ activation=activation,
303
+ layer_norm_eps=layer_norm_eps,
304
+ batch_first=batch_first,
305
+ norm_first=norm_first,
306
+ device=device,
307
+ dtype=dtype,
308
+ )
309
+ self.sparse = sparse
310
+ self.auto_sparsity = auto_sparsity
311
+ if sparse:
312
+ if not auto_sparsity:
313
+ self.mask_type = mask_type
314
+ self.sparse_attn_window = sparse_attn_window
315
+ self.global_window = global_window
316
+ self.sparsity = sparsity
317
+ if group_norm:
318
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
319
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
320
+
321
+ self.norm_out = None
322
+ if self.norm_first & norm_out:
323
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
324
+ self.gamma_1 = (
325
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
326
+ )
327
+ self.gamma_2 = (
328
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
329
+ )
330
+
331
+ if sparse:
332
+ self.self_attn = MultiheadAttention(
333
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
334
+ auto_sparsity=sparsity if auto_sparsity else 0,
335
+ )
336
+ self.__setattr__("src_mask", torch.zeros(1, 1))
337
+ self.mask_random_seed = mask_random_seed
338
+
339
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
340
+ """
341
+ if batch_first = False, src shape is (T, B, C)
342
+ the case where batch_first=True is not covered
343
+ """
344
+ device = src.device
345
+ x = src
346
+ T, B, C = x.shape
347
+ if self.sparse and not self.auto_sparsity:
348
+ assert src_mask is None
349
+ src_mask = self.src_mask
350
+ if src_mask.shape[-1] != T:
351
+ src_mask = get_mask(
352
+ T,
353
+ T,
354
+ self.mask_type,
355
+ self.sparse_attn_window,
356
+ self.global_window,
357
+ self.mask_random_seed,
358
+ self.sparsity,
359
+ device,
360
+ )
361
+ self.__setattr__("src_mask", src_mask)
362
+
363
+ if self.norm_first:
364
+ x = x + self.gamma_1(
365
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
366
+ )
367
+ x = x + self.gamma_2(self._ff_block(self.norm2(x)))
368
+
369
+ if self.norm_out:
370
+ x = self.norm_out(x)
371
+ else:
372
+ x = self.norm1(
373
+ x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
374
+ )
375
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
376
+
377
+ return x
378
+
379
+
380
+ class CrossTransformerEncoderLayer(nn.Module):
381
+ def __init__(
382
+ self,
383
+ d_model: int,
384
+ nhead: int,
385
+ dim_feedforward: int = 2048,
386
+ dropout: float = 0.1,
387
+ activation=F.relu,
388
+ layer_norm_eps: float = 1e-5,
389
+ layer_scale: bool = False,
390
+ init_values: float = 1e-4,
391
+ norm_first: bool = False,
392
+ group_norm: bool = False,
393
+ norm_out: bool = False,
394
+ sparse=False,
395
+ mask_type="diag",
396
+ mask_random_seed=42,
397
+ sparse_attn_window=500,
398
+ global_window=50,
399
+ sparsity=0.95,
400
+ auto_sparsity=None,
401
+ device=None,
402
+ dtype=None,
403
+ batch_first=False,
404
+ ):
405
+ factory_kwargs = {"device": device, "dtype": dtype}
406
+ super().__init__()
407
+
408
+ self.sparse = sparse
409
+ self.auto_sparsity = auto_sparsity
410
+ if sparse:
411
+ if not auto_sparsity:
412
+ self.mask_type = mask_type
413
+ self.sparse_attn_window = sparse_attn_window
414
+ self.global_window = global_window
415
+ self.sparsity = sparsity
416
+
417
+ self.cross_attn: nn.Module
418
+ self.cross_attn = nn.MultiheadAttention(
419
+ d_model, nhead, dropout=dropout, batch_first=batch_first)
420
+ # Implementation of Feedforward model
421
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
422
+ self.dropout = nn.Dropout(dropout)
423
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
424
+
425
+ self.norm_first = norm_first
426
+ self.norm1: nn.Module
427
+ self.norm2: nn.Module
428
+ self.norm3: nn.Module
429
+ if group_norm:
430
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
431
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
432
+ self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
435
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
436
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
437
+
438
+ self.norm_out = None
439
+ if self.norm_first & norm_out:
440
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
441
+
442
+ self.gamma_1 = (
443
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
444
+ )
445
+ self.gamma_2 = (
446
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
447
+ )
448
+
449
+ self.dropout1 = nn.Dropout(dropout)
450
+ self.dropout2 = nn.Dropout(dropout)
451
+
452
+ # Legacy string support for activation function.
453
+ if isinstance(activation, str):
454
+ self.activation = self._get_activation_fn(activation)
455
+ else:
456
+ self.activation = activation
457
+
458
+ if sparse:
459
+ self.cross_attn = MultiheadAttention(
460
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
461
+ auto_sparsity=sparsity if auto_sparsity else 0)
462
+ if not auto_sparsity:
463
+ self.__setattr__("mask", torch.zeros(1, 1))
464
+ self.mask_random_seed = mask_random_seed
465
+
466
+ def forward(self, q, k, mask=None):
467
+ """
468
+ Args:
469
+ q: tensor of shape (T, B, C)
470
+ k: tensor of shape (S, B, C)
471
+ mask: tensor of shape (T, S)
472
+
473
+ """
474
+ device = q.device
475
+ T, B, C = q.shape
476
+ S, B, C = k.shape
477
+ if self.sparse and not self.auto_sparsity:
478
+ assert mask is None
479
+ mask = self.mask
480
+ if mask.shape[-1] != S or mask.shape[-2] != T:
481
+ mask = get_mask(
482
+ S,
483
+ T,
484
+ self.mask_type,
485
+ self.sparse_attn_window,
486
+ self.global_window,
487
+ self.mask_random_seed,
488
+ self.sparsity,
489
+ device,
490
+ )
491
+ self.__setattr__("mask", mask)
492
+
493
+ if self.norm_first:
494
+ x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
495
+ x = x + self.gamma_2(self._ff_block(self.norm3(x)))
496
+ if self.norm_out:
497
+ x = self.norm_out(x)
498
+ else:
499
+ x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
500
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
501
+
502
+ return x
503
+
504
+ # self-attention block
505
+ def _ca_block(self, q, k, attn_mask=None):
506
+ x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
507
+ return self.dropout1(x)
508
+
509
+ # feed forward block
510
+ def _ff_block(self, x):
511
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
512
+ return self.dropout2(x)
513
+
514
+ def _get_activation_fn(self, activation):
515
+ if activation == "relu":
516
+ return F.relu
517
+ elif activation == "gelu":
518
+ return F.gelu
519
+
520
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
521
+
522
+
523
+ # ----------------- MULTI-BLOCKS MODELS: -----------------------
524
+
525
+
526
+ class CrossTransformerEncoder(nn.Module):
527
+ def __init__(
528
+ self,
529
+ dim: int,
530
+ emb: str = "sin",
531
+ hidden_scale: float = 4.0,
532
+ num_heads: int = 8,
533
+ num_layers: int = 6,
534
+ cross_first: bool = False,
535
+ dropout: float = 0.0,
536
+ max_positions: int = 1000,
537
+ norm_in: bool = True,
538
+ norm_in_group: bool = False,
539
+ group_norm: int = False,
540
+ norm_first: bool = False,
541
+ norm_out: bool = False,
542
+ max_period: float = 10000.0,
543
+ weight_decay: float = 0.0,
544
+ lr: tp.Optional[float] = None,
545
+ layer_scale: bool = False,
546
+ gelu: bool = True,
547
+ sin_random_shift: int = 0,
548
+ weight_pos_embed: float = 1.0,
549
+ cape_mean_normalize: bool = True,
550
+ cape_augment: bool = True,
551
+ cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
552
+ sparse_self_attn: bool = False,
553
+ sparse_cross_attn: bool = False,
554
+ mask_type: str = "diag",
555
+ mask_random_seed: int = 42,
556
+ sparse_attn_window: int = 500,
557
+ global_window: int = 50,
558
+ auto_sparsity: bool = False,
559
+ sparsity: float = 0.95,
560
+ ):
561
+ super().__init__()
562
+ """
563
+ """
564
+ assert dim % num_heads == 0
565
+
566
+ hidden_dim = int(dim * hidden_scale)
567
+
568
+ self.num_layers = num_layers
569
+ # classic parity = 1 means that if idx%2 == 1 there is a
570
+ # classical encoder else there is a cross encoder
571
+ self.classic_parity = 1 if cross_first else 0
572
+ self.emb = emb
573
+ self.max_period = max_period
574
+ self.weight_decay = weight_decay
575
+ self.weight_pos_embed = weight_pos_embed
576
+ self.sin_random_shift = sin_random_shift
577
+ if emb == "cape":
578
+ self.cape_mean_normalize = cape_mean_normalize
579
+ self.cape_augment = cape_augment
580
+ self.cape_glob_loc_scale = cape_glob_loc_scale
581
+ if emb == "scaled":
582
+ self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
583
+
584
+ self.lr = lr
585
+
586
+ activation: tp.Any = F.gelu if gelu else F.relu
587
+
588
+ self.norm_in: nn.Module
589
+ self.norm_in_t: nn.Module
590
+ if norm_in:
591
+ self.norm_in = nn.LayerNorm(dim)
592
+ self.norm_in_t = nn.LayerNorm(dim)
593
+ elif norm_in_group:
594
+ self.norm_in = MyGroupNorm(int(norm_in_group), dim)
595
+ self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
596
+ else:
597
+ self.norm_in = nn.Identity()
598
+ self.norm_in_t = nn.Identity()
599
+
600
+ # spectrogram layers
601
+ self.layers = nn.ModuleList()
602
+ # temporal layers
603
+ self.layers_t = nn.ModuleList()
604
+
605
+ kwargs_common = {
606
+ "d_model": dim,
607
+ "nhead": num_heads,
608
+ "dim_feedforward": hidden_dim,
609
+ "dropout": dropout,
610
+ "activation": activation,
611
+ "group_norm": group_norm,
612
+ "norm_first": norm_first,
613
+ "norm_out": norm_out,
614
+ "layer_scale": layer_scale,
615
+ "mask_type": mask_type,
616
+ "mask_random_seed": mask_random_seed,
617
+ "sparse_attn_window": sparse_attn_window,
618
+ "global_window": global_window,
619
+ "sparsity": sparsity,
620
+ "auto_sparsity": auto_sparsity,
621
+ "batch_first": True,
622
+ }
623
+
624
+ kwargs_classic_encoder = dict(kwargs_common)
625
+ kwargs_classic_encoder.update({
626
+ "sparse": sparse_self_attn,
627
+ })
628
+ kwargs_cross_encoder = dict(kwargs_common)
629
+ kwargs_cross_encoder.update({
630
+ "sparse": sparse_cross_attn,
631
+ })
632
+
633
+ for idx in range(num_layers):
634
+ if idx % 2 == self.classic_parity:
635
+
636
+ self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
637
+ self.layers_t.append(
638
+ MyTransformerEncoderLayer(**kwargs_classic_encoder)
639
+ )
640
+
641
+ else:
642
+ self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
643
+
644
+ self.layers_t.append(
645
+ CrossTransformerEncoderLayer(**kwargs_cross_encoder)
646
+ )
647
+
648
+ def forward(self, x, xt):
649
+ B, C, Fr, T1 = x.shape
650
+ pos_emb_2d = create_2d_sin_embedding(
651
+ C, Fr, T1, x.device, self.max_period
652
+ ) # (1, C, Fr, T1)
653
+ pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
654
+ x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
655
+ x = self.norm_in(x)
656
+ x = x + self.weight_pos_embed * pos_emb_2d
657
+
658
+ B, C, T2 = xt.shape
659
+ xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
660
+ pos_emb = self._get_pos_embedding(T2, B, C, x.device)
661
+ pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
662
+ xt = self.norm_in_t(xt)
663
+ xt = xt + self.weight_pos_embed * pos_emb
664
+
665
+ for idx in range(self.num_layers):
666
+ if idx % 2 == self.classic_parity:
667
+ x = self.layers[idx](x)
668
+ xt = self.layers_t[idx](xt)
669
+ else:
670
+ old_x = x
671
+ x = self.layers[idx](x, xt)
672
+ xt = self.layers_t[idx](xt, old_x)
673
+
674
+ x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
675
+ xt = rearrange(xt, "b t2 c -> b c t2")
676
+ return x, xt
677
+
678
+ def _get_pos_embedding(self, T, B, C, device):
679
+ if self.emb == "sin":
680
+ shift = random.randrange(self.sin_random_shift + 1)
681
+ pos_emb = create_sin_embedding(
682
+ T, C, shift=shift, device=device, max_period=self.max_period
683
+ )
684
+ elif self.emb == "cape":
685
+ if self.training:
686
+ pos_emb = create_sin_embedding_cape(
687
+ T,
688
+ C,
689
+ B,
690
+ device=device,
691
+ max_period=self.max_period,
692
+ mean_normalize=self.cape_mean_normalize,
693
+ augment=self.cape_augment,
694
+ max_global_shift=self.cape_glob_loc_scale[0],
695
+ max_local_shift=self.cape_glob_loc_scale[1],
696
+ max_scale=self.cape_glob_loc_scale[2],
697
+ )
698
+ else:
699
+ pos_emb = create_sin_embedding_cape(
700
+ T,
701
+ C,
702
+ B,
703
+ device=device,
704
+ max_period=self.max_period,
705
+ mean_normalize=self.cape_mean_normalize,
706
+ augment=False,
707
+ )
708
+
709
+ elif self.emb == "scaled":
710
+ pos = torch.arange(T, device=device)
711
+ pos_emb = self.position_embeddings(pos)[:, None]
712
+
713
+ return pos_emb
714
+
715
+ def make_optim_group(self):
716
+ group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
717
+ if self.lr is not None:
718
+ group["lr"] = self.lr
719
+ return group
720
+
721
+
722
+ # Attention Modules
723
+
724
+
725
+ class MultiheadAttention(nn.Module):
726
+ def __init__(
727
+ self,
728
+ embed_dim,
729
+ num_heads,
730
+ dropout=0.0,
731
+ bias=True,
732
+ add_bias_kv=False,
733
+ add_zero_attn=False,
734
+ kdim=None,
735
+ vdim=None,
736
+ batch_first=False,
737
+ auto_sparsity=None,
738
+ ):
739
+ super().__init__()
740
+ assert auto_sparsity is not None, "sanity check"
741
+ self.num_heads = num_heads
742
+ self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
743
+ self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
744
+ self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
745
+ self.attn_drop = torch.nn.Dropout(dropout)
746
+ self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
747
+ self.proj_drop = torch.nn.Dropout(dropout)
748
+ self.batch_first = batch_first
749
+ self.auto_sparsity = auto_sparsity
750
+
751
+ def forward(
752
+ self,
753
+ query,
754
+ key,
755
+ value,
756
+ key_padding_mask=None,
757
+ need_weights=True,
758
+ attn_mask=None,
759
+ average_attn_weights=True,
760
+ ):
761
+
762
+ if not self.batch_first: # N, B, C
763
+ query = query.permute(1, 0, 2) # B, N_q, C
764
+ key = key.permute(1, 0, 2) # B, N_k, C
765
+ value = value.permute(1, 0, 2) # B, N_k, C
766
+ B, N_q, C = query.shape
767
+ B, N_k, C = key.shape
768
+
769
+ q = (
770
+ self.q(query)
771
+ .reshape(B, N_q, self.num_heads, C // self.num_heads)
772
+ .permute(0, 2, 1, 3)
773
+ )
774
+ q = q.flatten(0, 1)
775
+ k = (
776
+ self.k(key)
777
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
778
+ .permute(0, 2, 1, 3)
779
+ )
780
+ k = k.flatten(0, 1)
781
+ v = (
782
+ self.v(value)
783
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
784
+ .permute(0, 2, 1, 3)
785
+ )
786
+ v = v.flatten(0, 1)
787
+
788
+ if self.auto_sparsity:
789
+ assert attn_mask is None
790
+ x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
791
+ else:
792
+ x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
793
+ x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
794
+
795
+ x = x.transpose(1, 2).reshape(B, N_q, C)
796
+ x = self.proj(x)
797
+ x = self.proj_drop(x)
798
+ if not self.batch_first:
799
+ x = x.permute(1, 0, 2)
800
+ return x, None
801
+
802
+
803
+ def scaled_query_key_softmax(q, k, att_mask):
804
+ from xformers.ops import masked_matmul
805
+ q = q / (k.size(-1)) ** 0.5
806
+ att = masked_matmul(q, k.transpose(-2, -1), att_mask)
807
+ att = torch.nn.functional.softmax(att, -1)
808
+ return att
809
+
810
+
811
+ def scaled_dot_product_attention(q, k, v, att_mask, dropout):
812
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
813
+ att = dropout(att)
814
+ y = att @ v
815
+ return y
816
+
817
+
818
+ def _compute_buckets(x, R):
819
+ qq = torch.einsum('btf,bfhi->bhti', x, R)
820
+ qq = torch.cat([qq, -qq], dim=-1)
821
+ buckets = qq.argmax(dim=-1)
822
+
823
+ return buckets.permute(0, 2, 1).byte().contiguous()
824
+
825
+
826
+ def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
827
+ # assert False, "The code for the custom sparse kernel is not ready for release yet."
828
+ from xformers.ops import find_locations, sparse_memory_efficient_attention
829
+ n_hashes = 32
830
+ proj_size = 4
831
+ query, key, value = [x.contiguous() for x in [query, key, value]]
832
+ with torch.no_grad():
833
+ R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
834
+ bucket_query = _compute_buckets(query, R)
835
+ bucket_key = _compute_buckets(key, R)
836
+ row_offsets, column_indices = find_locations(
837
+ bucket_query, bucket_key, sparsity, infer_sparsity)
838
+ return sparse_memory_efficient_attention(
839
+ query, key, value, row_offsets, column_indices, attn_bias)
demucs4/utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from contextlib import contextmanager
9
+ import math
10
+ import os
11
+ import tempfile
12
+ import typing as tp
13
+
14
+ import torch
15
+ from torch.nn import functional as F
16
+ from torch.utils.data import Subset
17
+
18
+
19
+ def unfold(a, kernel_size, stride):
20
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
21
+ with K the kernel size, by extracting frames with the given stride.
22
+
23
+ This will pad the input so that `F = ceil(T / K)`.
24
+
25
+ see https://github.com/pytorch/pytorch/issues/60466
26
+ """
27
+ *shape, length = a.shape
28
+ n_frames = math.ceil(length / stride)
29
+ tgt_length = (n_frames - 1) * stride + kernel_size
30
+ a = F.pad(a, (0, tgt_length - length))
31
+ strides = list(a.stride())
32
+ assert strides[-1] == 1, 'data should be contiguous'
33
+ strides = strides[:-1] + [stride, 1]
34
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
35
+
36
+
37
+ def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
38
+ """
39
+ Center trim `tensor` with respect to `reference`, along the last dimension.
40
+ `reference` can also be a number, representing the length to trim to.
41
+ If the size difference != 0 mod 2, the extra sample is removed on the right side.
42
+ """
43
+ ref_size: int
44
+ if isinstance(reference, torch.Tensor):
45
+ ref_size = reference.size(-1)
46
+ else:
47
+ ref_size = reference
48
+ delta = tensor.size(-1) - ref_size
49
+ if delta < 0:
50
+ raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
51
+ if delta:
52
+ tensor = tensor[..., delta // 2:-(delta - delta // 2)]
53
+ return tensor
54
+
55
+
56
+ def pull_metric(history: tp.List[dict], name: str):
57
+ out = []
58
+ for metrics in history:
59
+ metric = metrics
60
+ for part in name.split("."):
61
+ metric = metric[part]
62
+ out.append(metric)
63
+ return out
64
+
65
+
66
+ def EMA(beta: float = 1):
67
+ """
68
+ Exponential Moving Average callback.
69
+ Returns a single function that can be called to repeatidly update the EMA
70
+ with a dict of metrics. The callback will return
71
+ the new averaged dict of metrics.
72
+
73
+ Note that for `beta=1`, this is just plain averaging.
74
+ """
75
+ fix: tp.Dict[str, float] = defaultdict(float)
76
+ total: tp.Dict[str, float] = defaultdict(float)
77
+
78
+ def _update(metrics: dict, weight: float = 1) -> dict:
79
+ nonlocal total, fix
80
+ for key, value in metrics.items():
81
+ total[key] = total[key] * beta + weight * float(value)
82
+ fix[key] = fix[key] * beta + weight
83
+ return {key: tot / fix[key] for key, tot in total.items()}
84
+ return _update
85
+
86
+
87
+ def sizeof_fmt(num: float, suffix: str = 'B'):
88
+ """
89
+ Given `num` bytes, return human readable size.
90
+ Taken from https://stackoverflow.com/a/1094933
91
+ """
92
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
93
+ if abs(num) < 1024.0:
94
+ return "%3.1f%s%s" % (num, unit, suffix)
95
+ num /= 1024.0
96
+ return "%.1f%s%s" % (num, 'Yi', suffix)
97
+
98
+
99
+ @contextmanager
100
+ def temp_filenames(count: int, delete=True):
101
+ names = []
102
+ try:
103
+ for _ in range(count):
104
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
105
+ yield names
106
+ finally:
107
+ if delete:
108
+ for name in names:
109
+ os.unlink(name)
110
+
111
+
112
+ def random_subset(dataset, max_samples: int, seed: int = 42):
113
+ if max_samples >= len(dataset):
114
+ return dataset
115
+
116
+ generator = torch.Generator().manual_seed(seed)
117
+ perm = torch.randperm(len(dataset), generator=generator)
118
+ return Subset(dataset, perm[:max_samples].tolist())
119
+
120
+
121
+ class DummyPoolExecutor:
122
+ class DummyResult:
123
+ def __init__(self, func, *args, **kwargs):
124
+ self.func = func
125
+ self.args = args
126
+ self.kwargs = kwargs
127
+
128
+ def result(self):
129
+ return self.func(*self.args, **self.kwargs)
130
+
131
+ def __init__(self, workers=0):
132
+ pass
133
+
134
+ def submit(self, func, *args, **kwargs):
135
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
136
+
137
+ def __enter__(self):
138
+ return self
139
+
140
+ def __exit__(self, exc_type, exc_value, exc_tb):
141
+ return
gui.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'Roman Solovyev (ZFTurbo), IPPM RAS'
3
+
4
+ if __name__ == '__main__':
5
+ import os
6
+
7
+ gpu_use = "0"
8
+ print('GPU use: {}'.format(gpu_use))
9
+ os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_use)
10
+
11
+ import time
12
+ import os
13
+ import numpy as np
14
+ from PyQt5.QtCore import *
15
+ from PyQt5 import QtCore
16
+ from PyQt5.QtWidgets import *
17
+ from PyQt5.QtGui import *
18
+ import sys
19
+ from inference import predict_with_model, __VERSION__
20
+ import torch
21
+
22
+
23
+ root = dict()
24
+
25
+
26
+ class Worker(QObject):
27
+ finished = pyqtSignal()
28
+ progress = pyqtSignal(int)
29
+
30
+ def __init__(self, options):
31
+ super().__init__()
32
+ self.options = options
33
+
34
+ def run(self):
35
+ global root
36
+ # Here we pass the update_progress (uncalled!)
37
+ self.options['update_percent_func'] = self.update_progress
38
+ predict_with_model(self.options)
39
+ root['button_start'].setDisabled(False)
40
+ root['button_finish'].setDisabled(True)
41
+ root['start_proc'] = False
42
+ self.finished.emit()
43
+
44
+ def update_progress(self, percent):
45
+ self.progress.emit(percent)
46
+
47
+
48
+ class Ui_Dialog(object):
49
+ def setupUi(self, Dialog):
50
+ global root
51
+
52
+ Dialog.setObjectName("Settings")
53
+ Dialog.resize(370, 320)
54
+
55
+ self.checkbox_cpu = QCheckBox("Use CPU instead of GPU?", Dialog)
56
+ self.checkbox_cpu.move(30, 10)
57
+ self.checkbox_cpu.resize(320, 40)
58
+ if root['cpu']:
59
+ self.checkbox_cpu.setChecked(True)
60
+
61
+ self.checkbox_single_onnx = QCheckBox("Use single ONNX?", Dialog)
62
+ self.checkbox_single_onnx.move(30, 40)
63
+ self.checkbox_single_onnx.resize(320, 40)
64
+ if root['single_onnx']:
65
+ self.checkbox_single_onnx.setChecked(True)
66
+
67
+ self.checkbox_large_gpu = QCheckBox("Use large GPU?", Dialog)
68
+ self.checkbox_large_gpu.move(30, 70)
69
+ self.checkbox_large_gpu.resize(320, 40)
70
+ if root['large_gpu']:
71
+ self.checkbox_large_gpu.setChecked(True)
72
+
73
+ self.checkbox_kim_1 = QCheckBox("Use old Kim Vocal model?", Dialog)
74
+ self.checkbox_kim_1.move(30, 100)
75
+ self.checkbox_kim_1.resize(320, 40)
76
+ if root['use_kim_model_1']:
77
+ self.checkbox_kim_1.setChecked(True)
78
+
79
+ self.checkbox_only_vocals = QCheckBox("Generate only vocals/instrumental?", Dialog)
80
+ self.checkbox_only_vocals.move(30, 130)
81
+ self.checkbox_only_vocals.resize(320, 40)
82
+ if root['only_vocals']:
83
+ self.checkbox_only_vocals.setChecked(True)
84
+
85
+ self.chunk_size_label = QLabel(Dialog)
86
+ self.chunk_size_label.setText('Chunk size')
87
+ self.chunk_size_label.move(30, 160)
88
+ self.chunk_size_label.resize(320, 40)
89
+
90
+ self.chunk_size_valid = QIntValidator(bottom=100000, top=10000000)
91
+ self.chunk_size = QLineEdit(Dialog)
92
+ self.chunk_size.setFixedWidth(140)
93
+ self.chunk_size.move(130, 170)
94
+ self.chunk_size.setValidator(self.chunk_size_valid)
95
+ self.chunk_size.setText(str(root['chunk_size']))
96
+
97
+ self.overlap_large_label = QLabel(Dialog)
98
+ self.overlap_large_label.setText('Overlap large')
99
+ self.overlap_large_label.move(30, 190)
100
+ self.overlap_large_label.resize(320, 40)
101
+
102
+ self.overlap_large_valid = QDoubleValidator(bottom=0.001, top=0.999, decimals=10)
103
+ self.overlap_large_valid.setNotation(QDoubleValidator.Notation.StandardNotation)
104
+ self.overlap_large = QLineEdit(Dialog)
105
+ self.overlap_large.setFixedWidth(140)
106
+ self.overlap_large.move(130, 200)
107
+ self.overlap_large.setValidator(self.overlap_large_valid)
108
+ self.overlap_large.setText(str(root['overlap_large']))
109
+
110
+ self.overlap_small_label = QLabel(Dialog)
111
+ self.overlap_small_label.setText('Overlap small')
112
+ self.overlap_small_label.move(30, 220)
113
+ self.overlap_small_label.resize(320, 40)
114
+
115
+ self.overlap_small_valid = QDoubleValidator(0.001, 0.999, 10)
116
+ self.overlap_small_valid.setNotation(QDoubleValidator.Notation.StandardNotation)
117
+ self.overlap_small = QLineEdit(Dialog)
118
+ self.overlap_small.setFixedWidth(140)
119
+ self.overlap_small.move(130, 230)
120
+ self.overlap_small.setValidator(self.overlap_small_valid)
121
+ self.overlap_small.setText(str(root['overlap_small']))
122
+
123
+ self.pushButton_save = QPushButton(Dialog)
124
+ self.pushButton_save.setObjectName("pushButton_save")
125
+ self.pushButton_save.move(30, 280)
126
+ self.pushButton_save.resize(150, 35)
127
+
128
+ self.pushButton_cancel = QPushButton(Dialog)
129
+ self.pushButton_cancel.setObjectName("pushButton_cancel")
130
+ self.pushButton_cancel.move(190, 280)
131
+ self.pushButton_cancel.resize(150, 35)
132
+
133
+ self.retranslateUi(Dialog)
134
+ QtCore.QMetaObject.connectSlotsByName(Dialog)
135
+ self.Dialog = Dialog
136
+
137
+ # connect the two functions
138
+ self.pushButton_save.clicked.connect(self.return_save)
139
+ self.pushButton_cancel.clicked.connect(self.return_cancel)
140
+
141
+ def retranslateUi(self, Dialog):
142
+ _translate = QtCore.QCoreApplication.translate
143
+ Dialog.setWindowTitle(_translate("Settings", "Settings"))
144
+ self.pushButton_cancel.setText(_translate("Settings", "Cancel"))
145
+ self.pushButton_save.setText(_translate("Settings", "Save settings"))
146
+
147
+ def return_save(self):
148
+ global root
149
+ # print("save")
150
+ root['cpu'] = self.checkbox_cpu.isChecked()
151
+ root['single_onnx'] = self.checkbox_single_onnx.isChecked()
152
+ root['large_gpu'] = self.checkbox_large_gpu.isChecked()
153
+ root['use_kim_model_1'] = self.checkbox_kim_1.isChecked()
154
+ root['only_vocals'] = self.checkbox_only_vocals.isChecked()
155
+
156
+ chunk_size_text = self.chunk_size.text()
157
+ state = self.chunk_size_valid.validate(chunk_size_text, 0)
158
+ if state[0] == QValidator.State.Acceptable:
159
+ root['chunk_size'] = chunk_size_text
160
+
161
+ overlap_large_text = self.overlap_large.text()
162
+ # locale problems... it wants comma instead of dot
163
+ if 0:
164
+ state = self.overlap_large_valid.validate(overlap_large_text, 0)
165
+ if state[0] == QValidator.State.Acceptable:
166
+ root['overlap_large'] = float(overlap_large_text)
167
+ else:
168
+ root['overlap_large'] = float(overlap_large_text)
169
+
170
+ overlap_small_text = self.overlap_small.text()
171
+ if 0:
172
+ state = self.overlap_small_valid.validate(overlap_small_text, 0)
173
+ if state[0] == QValidator.State.Acceptable:
174
+ root['overlap_small'] = float(overlap_small_text)
175
+ else:
176
+ root['overlap_small'] = float(overlap_small_text)
177
+
178
+ self.Dialog.close()
179
+
180
+ def return_cancel(self):
181
+ global root
182
+ # print("cancel")
183
+ self.Dialog.close()
184
+
185
+
186
+ class MyWidget(QWidget):
187
+ def __init__(self):
188
+ super().__init__()
189
+ self.initUI()
190
+
191
+ def initUI(self):
192
+ self.resize(560, 360)
193
+ self.move(300, 300)
194
+ self.setWindowTitle('MVSEP music separation model')
195
+ self.setAcceptDrops(True)
196
+
197
+ def dragEnterEvent(self, event):
198
+ if event.mimeData().hasUrls():
199
+ event.accept()
200
+ else:
201
+ event.ignore()
202
+
203
+ def dropEvent(self, event):
204
+ global root
205
+ files = [u.toLocalFile() for u in event.mimeData().urls()]
206
+ txt = ''
207
+ root['input_files'] = []
208
+ for f in files:
209
+ root['input_files'].append(f)
210
+ txt += f + '\n'
211
+ root['input_files_list_text_area'].insertPlainText(txt)
212
+ root['progress_bar'].setValue(0)
213
+
214
+ def execute_long_task(self):
215
+ global root
216
+
217
+ if len(root['input_files']) == 0 and 1:
218
+ QMessageBox.about(root['w'], "Error", "No input files specified!")
219
+ return
220
+
221
+ root['progress_bar'].show()
222
+ root['button_start'].setDisabled(True)
223
+ root['button_finish'].setDisabled(False)
224
+ root['start_proc'] = True
225
+
226
+ options = {
227
+ 'input_audio': root['input_files'],
228
+ 'output_folder': root['output_folder'],
229
+ 'cpu': root['cpu'],
230
+ 'single_onnx': root['single_onnx'],
231
+ 'large_gpu': root['large_gpu'],
232
+ 'chunk_size': root['chunk_size'],
233
+ 'overlap_large': root['overlap_large'],
234
+ 'overlap_small': root['overlap_small'],
235
+ 'use_kim_model_1': root['use_kim_model_1'],
236
+ 'only_vocals': root['only_vocals'],
237
+ }
238
+
239
+ self.update_progress(0)
240
+ self.thread = QThread()
241
+ self.worker = Worker(options)
242
+ self.worker.moveToThread(self.thread)
243
+
244
+ self.thread.started.connect(self.worker.run)
245
+ self.worker.finished.connect(self.thread.quit)
246
+ self.worker.finished.connect(self.worker.deleteLater)
247
+ self.thread.finished.connect(self.thread.deleteLater)
248
+ self.worker.progress.connect(self.update_progress)
249
+
250
+ self.thread.start()
251
+
252
+ def stop_separation(self):
253
+ global root
254
+ self.thread.terminate()
255
+ root['button_start'].setDisabled(False)
256
+ root['button_finish'].setDisabled(True)
257
+ root['start_proc'] = False
258
+ root['progress_bar'].hide()
259
+
260
+ def update_progress(self, progress):
261
+ global root
262
+ root['progress_bar'].setValue(progress)
263
+
264
+ def open_settings(self):
265
+ global root
266
+ dialog = QDialog()
267
+ dialog.ui = Ui_Dialog()
268
+ dialog.ui.setupUi(dialog)
269
+ dialog.exec_()
270
+
271
+
272
+ def dialog_select_input_files():
273
+ global root
274
+ files, _ = QFileDialog.getOpenFileNames(
275
+ None,
276
+ "QFileDialog.getOpenFileNames()",
277
+ "",
278
+ "All Files (*);;Audio Files (*.wav, *.mp3, *.flac)",
279
+ )
280
+ if files:
281
+ txt = ''
282
+ root['input_files'] = []
283
+ for f in files:
284
+ root['input_files'].append(f)
285
+ txt += f + '\n'
286
+ root['input_files_list_text_area'].insertPlainText(txt)
287
+ root['progress_bar'].setValue(0)
288
+ return files
289
+
290
+
291
+ def dialog_select_output_folder():
292
+ global root
293
+ foldername = QFileDialog.getExistingDirectory(
294
+ None,
295
+ "Select Directory"
296
+ )
297
+ root['output_folder'] = foldername + '/'
298
+ root['output_folder_line_edit'].setText(root['output_folder'])
299
+ return foldername
300
+
301
+
302
+ def create_dialog():
303
+ global root
304
+ app = QApplication(sys.argv)
305
+
306
+ w = MyWidget()
307
+
308
+ root['input_files'] = []
309
+ root['output_folder'] = os.path.dirname(os.path.abspath(__file__)) + '/results/'
310
+ root['cpu'] = False
311
+ root['large_gpu'] = False
312
+ root['single_onnx'] = False
313
+ root['chunk_size'] = 1000000
314
+ root['overlap_large'] = 0.6
315
+ root['overlap_small'] = 0.5
316
+ root['use_kim_model_1'] = False
317
+ root['only_vocals'] = False
318
+
319
+ t = torch.cuda.get_device_properties(0).total_memory / (1024 * 1024 * 1024)
320
+ if t > 11.5:
321
+ print('You have enough GPU memory ({:.2f} GB), so we set fast GPU mode. You can change in settings!'.format(t))
322
+ root['large_gpu'] = True
323
+ root['single_onnx'] = False
324
+ elif t < 8:
325
+ root['large_gpu'] = False
326
+ root['single_onnx'] = True
327
+ root['chunk_size'] = 500000
328
+
329
+ button_select_input_files = QPushButton(w)
330
+ button_select_input_files.setText("Input audio files")
331
+ button_select_input_files.clicked.connect(dialog_select_input_files)
332
+ button_select_input_files.setFixedHeight(35)
333
+ button_select_input_files.setFixedWidth(150)
334
+ button_select_input_files.move(30, 20)
335
+
336
+ input_files_list_text_area = QTextEdit(w)
337
+ input_files_list_text_area.setReadOnly(True)
338
+ input_files_list_text_area.setLineWrapMode(QTextEdit.NoWrap)
339
+ font = input_files_list_text_area.font()
340
+ font.setFamily("Courier")
341
+ font.setPointSize(10)
342
+ input_files_list_text_area.move(30, 60)
343
+ input_files_list_text_area.resize(500, 100)
344
+
345
+ button_select_output_folder = QPushButton(w)
346
+ button_select_output_folder.setText("Output folder")
347
+ button_select_output_folder.setFixedHeight(35)
348
+ button_select_output_folder.setFixedWidth(150)
349
+ button_select_output_folder.clicked.connect(dialog_select_output_folder)
350
+ button_select_output_folder.move(30, 180)
351
+
352
+ output_folder_line_edit = QLineEdit(w)
353
+ output_folder_line_edit.setReadOnly(True)
354
+ font = output_folder_line_edit.font()
355
+ font.setFamily("Courier")
356
+ font.setPointSize(10)
357
+ output_folder_line_edit.move(30, 220)
358
+ output_folder_line_edit.setFixedWidth(500)
359
+ output_folder_line_edit.setText(root['output_folder'])
360
+
361
+ progress_bar = QProgressBar(w)
362
+ # progress_bar.move(30, 310)
363
+ progress_bar.setValue(0)
364
+ progress_bar.setGeometry(30, 310, 500, 35)
365
+ progress_bar.setAlignment(QtCore.Qt.AlignCenter)
366
+ progress_bar.hide()
367
+ root['progress_bar'] = progress_bar
368
+
369
+ button_start = QPushButton('Start separation', w)
370
+ button_start.clicked.connect(w.execute_long_task)
371
+ button_start.setFixedHeight(35)
372
+ button_start.setFixedWidth(150)
373
+ button_start.move(30, 270)
374
+
375
+ button_finish = QPushButton('Stop separation', w)
376
+ button_finish.clicked.connect(w.stop_separation)
377
+ button_finish.setFixedHeight(35)
378
+ button_finish.setFixedWidth(150)
379
+ button_finish.move(200, 270)
380
+ button_finish.setDisabled(True)
381
+
382
+ button_settings = QPushButton('⚙', w)
383
+ button_settings.clicked.connect(w.open_settings)
384
+ button_settings.setFixedHeight(35)
385
+ button_settings.setFixedWidth(35)
386
+ button_settings.move(495, 270)
387
+ button_settings.setDisabled(False)
388
+
389
+ mvsep_link = QLabel(w)
390
+ mvsep_link.setOpenExternalLinks(True)
391
+ font = mvsep_link.font()
392
+ font.setFamily("Courier")
393
+ font.setPointSize(10)
394
+ mvsep_link.move(415, 30)
395
+ mvsep_link.setText('Powered by <a href="https://mvsep.com">MVSep.com</a>')
396
+
397
+ root['w'] = w
398
+ root['input_files_list_text_area'] = input_files_list_text_area
399
+ root['output_folder_line_edit'] = output_folder_line_edit
400
+ root['button_start'] = button_start
401
+ root['button_finish'] = button_finish
402
+ root['button_settings'] = button_settings
403
+
404
+ # w.showMaximized()
405
+ w.show()
406
+ sys.exit(app.exec_())
407
+
408
+
409
+ if __name__ == '__main__':
410
+ print('Version: {}'.format(__VERSION__))
411
+ create_dialog()
images/MVSep-Window.png ADDED
inference.py ADDED
@@ -0,0 +1,920 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'https://github.com/ZFTurbo/'
3
+
4
+ if __name__ == '__main__':
5
+ import os
6
+
7
+ gpu_use = "0"
8
+ print('GPU use: {}'.format(gpu_use))
9
+ os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_use)
10
+
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import os
16
+ import argparse
17
+ import soundfile as sf
18
+
19
+ from demucs.states import load_model
20
+ from demucs import pretrained
21
+ from demucs.apply import apply_model
22
+ import onnxruntime as ort
23
+ from time import time
24
+ import librosa
25
+ import hashlib
26
+
27
+
28
+ __VERSION__ = '1.0.1'
29
+
30
+
31
+ class Conv_TDF_net_trim_model(nn.Module):
32
+ def __init__(self, device, target_name, L, n_fft, hop=1024):
33
+
34
+ super(Conv_TDF_net_trim_model, self).__init__()
35
+
36
+ self.dim_c = 4
37
+ self.dim_f, self.dim_t = 3072, 256
38
+ self.n_fft = n_fft
39
+ self.hop = hop
40
+ self.n_bins = self.n_fft // 2 + 1
41
+ self.chunk_size = hop * (self.dim_t - 1)
42
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
43
+ self.target_name = target_name
44
+
45
+ out_c = self.dim_c * 4 if target_name == '*' else self.dim_c
46
+ self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
47
+
48
+ self.n = L // 2
49
+
50
+ def stft(self, x):
51
+ x = x.reshape([-1, self.chunk_size])
52
+ x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
53
+ x = torch.view_as_real(x)
54
+ x = x.permute([0, 3, 1, 2])
55
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t])
56
+ return x[:, :, :self.dim_f]
57
+
58
+ def istft(self, x, freq_pad=None):
59
+ freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
60
+ x = torch.cat([x, freq_pad], -2)
61
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
62
+ x = x.permute([0, 2, 3, 1])
63
+ x = x.contiguous()
64
+ x = torch.view_as_complex(x)
65
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
66
+ return x.reshape([-1, 2, self.chunk_size])
67
+
68
+ def forward(self, x):
69
+ x = self.first_conv(x)
70
+ x = x.transpose(-1, -2)
71
+
72
+ ds_outputs = []
73
+ for i in range(self.n):
74
+ x = self.ds_dense[i](x)
75
+ ds_outputs.append(x)
76
+ x = self.ds[i](x)
77
+
78
+ x = self.mid_dense(x)
79
+ for i in range(self.n):
80
+ x = self.us[i](x)
81
+ x *= ds_outputs[-i - 1]
82
+ x = self.us_dense[i](x)
83
+
84
+ x = x.transpose(-1, -2)
85
+ x = self.final_conv(x)
86
+ return x
87
+
88
+
89
+ def get_models(name, device, load=True, vocals_model_type=0):
90
+ if vocals_model_type == 2:
91
+ model_vocals = Conv_TDF_net_trim_model(
92
+ device=device,
93
+ target_name='vocals',
94
+ L=11,
95
+ n_fft=7680
96
+ )
97
+ elif vocals_model_type == 3:
98
+ model_vocals = Conv_TDF_net_trim_model(
99
+ device=device,
100
+ target_name='vocals',
101
+ L=11,
102
+ n_fft=6144
103
+ )
104
+
105
+ return [model_vocals]
106
+
107
+
108
+ def demix_base(mix, device, models, infer_session):
109
+ start_time = time()
110
+ sources = []
111
+ n_sample = mix.shape[1]
112
+ for model in models:
113
+ trim = model.n_fft // 2
114
+ gen_size = model.chunk_size - 2 * trim
115
+ pad = gen_size - n_sample % gen_size
116
+ mix_p = np.concatenate(
117
+ (
118
+ np.zeros((2, trim)),
119
+ mix,
120
+ np.zeros((2, pad)),
121
+ np.zeros((2, trim))
122
+ ), 1
123
+ )
124
+
125
+ mix_waves = []
126
+ i = 0
127
+ while i < n_sample + pad:
128
+ waves = np.array(mix_p[:, i:i + model.chunk_size])
129
+ mix_waves.append(waves)
130
+ i += gen_size
131
+ mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(device)
132
+
133
+ with torch.no_grad():
134
+ _ort = infer_session
135
+ stft_res = model.stft(mix_waves)
136
+ res = _ort.run(None, {'input': stft_res.cpu().numpy()})[0]
137
+ ten = torch.tensor(res)
138
+ tar_waves = model.istft(ten.to(device))
139
+ tar_waves = tar_waves.cpu()
140
+ tar_signal = tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy()[:, :-pad]
141
+
142
+ sources.append(tar_signal)
143
+ # print('Time demix base: {:.2f} sec'.format(time() - start_time))
144
+ return np.array(sources)
145
+
146
+
147
+ def demix_full(mix, device, chunk_size, models, infer_session, overlap=0.75):
148
+ start_time = time()
149
+
150
+ step = int(chunk_size * (1 - overlap))
151
+ # print('Initial shape: {} Chunk size: {} Step: {} Device: {}'.format(mix.shape, chunk_size, step, device))
152
+ result = np.zeros((1, 2, mix.shape[-1]), dtype=np.float32)
153
+ divider = np.zeros((1, 2, mix.shape[-1]), dtype=np.float32)
154
+
155
+ total = 0
156
+ for i in range(0, mix.shape[-1], step):
157
+ total += 1
158
+
159
+ start = i
160
+ end = min(i + chunk_size, mix.shape[-1])
161
+ # print('Chunk: {} Start: {} End: {}'.format(total, start, end))
162
+ mix_part = mix[:, start:end]
163
+ sources = demix_base(mix_part, device, models, infer_session)
164
+ # print(sources.shape)
165
+ result[..., start:end] += sources
166
+ divider[..., start:end] += 1
167
+ sources = result / divider
168
+ # print('Final shape: {} Overall time: {:.2f}'.format(sources.shape, time() - start_time))
169
+ return sources
170
+
171
+
172
+ class EnsembleDemucsMDXMusicSeparationModel:
173
+ """
174
+ Doesn't do any separation just passes the input back as output
175
+ """
176
+ def __init__(self, options):
177
+ """
178
+ options - user options
179
+ """
180
+ # print(options)
181
+
182
+ if torch.cuda.is_available():
183
+ device = 'cuda:0'
184
+ else:
185
+ device = 'cpu'
186
+ if 'cpu' in options:
187
+ if options['cpu']:
188
+ device = 'cpu'
189
+ print('Use device: {}'.format(device))
190
+ self.single_onnx = False
191
+ if 'single_onnx' in options:
192
+ if options['single_onnx']:
193
+ self.single_onnx = True
194
+ print('Use single vocal ONNX')
195
+
196
+ self.kim_model_1 = False
197
+ if 'use_kim_model_1' in options:
198
+ if options['use_kim_model_1']:
199
+ self.kim_model_1 = True
200
+ if self.kim_model_1:
201
+ print('Use Kim model 1')
202
+ else:
203
+ print('Use Kim model 2')
204
+
205
+ self.overlap_large = float(options['overlap_large'])
206
+ self.overlap_small = float(options['overlap_small'])
207
+ if self.overlap_large > 0.99:
208
+ self.overlap_large = 0.99
209
+ if self.overlap_large < 0.0:
210
+ self.overlap_large = 0.0
211
+ if self.overlap_small > 0.99:
212
+ self.overlap_small = 0.99
213
+ if self.overlap_small < 0.0:
214
+ self.overlap_small = 0.0
215
+
216
+ model_folder = os.path.dirname(os.path.realpath(__file__)) + '/models/'
217
+ remote_url = 'https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th'
218
+ model_path = model_folder + '04573f0d-f3cf25b2.th'
219
+ if not os.path.isfile(model_path):
220
+ torch.hub.download_url_to_file(remote_url, model_folder + '04573f0d-f3cf25b2.th')
221
+ model_vocals = load_model(model_path)
222
+ model_vocals.to(device)
223
+ self.model_vocals_only = model_vocals
224
+
225
+ self.models = []
226
+ self.weights_vocals = np.array([10, 1, 8, 9])
227
+ self.weights_bass = np.array([19, 4, 5, 8])
228
+ self.weights_drums = np.array([18, 2, 4, 9])
229
+ self.weights_other = np.array([14, 2, 5, 10])
230
+
231
+ model1 = pretrained.get_model('htdemucs_ft')
232
+ model1.to(device)
233
+ self.models.append(model1)
234
+
235
+ model2 = pretrained.get_model('htdemucs')
236
+ model2.to(device)
237
+ self.models.append(model2)
238
+
239
+ model3 = pretrained.get_model('htdemucs_6s')
240
+ model3.to(device)
241
+ self.models.append(model3)
242
+
243
+ model4 = pretrained.get_model('hdemucs_mmi')
244
+ model4.to(device)
245
+ self.models.append(model4)
246
+
247
+ if 0:
248
+ for model in self.models:
249
+ print(model.sources)
250
+ '''
251
+ ['drums', 'bass', 'other', 'vocals']
252
+ ['drums', 'bass', 'other', 'vocals']
253
+ ['drums', 'bass', 'other', 'vocals', 'guitar', 'piano']
254
+ ['drums', 'bass', 'other', 'vocals']
255
+ '''
256
+
257
+ if device == 'cpu':
258
+ chunk_size = 200000000
259
+ providers = ["CPUExecutionProvider"]
260
+ else:
261
+ chunk_size = 1000000
262
+ providers = ["CUDAExecutionProvider"]
263
+ if 'chunk_size' in options:
264
+ chunk_size = int(options['chunk_size'])
265
+
266
+ # MDX-B model 1 initialization
267
+ self.chunk_size = chunk_size
268
+ self.mdx_models1 = get_models('tdf_extra', load=False, device=device, vocals_model_type=2)
269
+ if self.kim_model_1:
270
+ model_path_onnx1 = model_folder + 'Kim_Vocal_1.onnx'
271
+ remote_url_onnx1 = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/Kim_Vocal_1.onnx'
272
+ else:
273
+ model_path_onnx1 = model_folder + 'Kim_Vocal_2.onnx'
274
+ remote_url_onnx1 = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/Kim_Vocal_2.onnx'
275
+ if not os.path.isfile(model_path_onnx1):
276
+ torch.hub.download_url_to_file(remote_url_onnx1, model_path_onnx1)
277
+ print('Model path: {}'.format(model_path_onnx1))
278
+ print('Device: {} Chunk size: {}'.format(device, chunk_size))
279
+ self.infer_session1 = ort.InferenceSession(
280
+ model_path_onnx1,
281
+ providers=providers,
282
+ provider_options=[{"device_id": 0}],
283
+ )
284
+
285
+ if self.single_onnx is False:
286
+ # MDX-B model 2 initialization
287
+ self.chunk_size = chunk_size
288
+ self.mdx_models2 = get_models('tdf_extra', load=False, device=device, vocals_model_type=2)
289
+ root_path = os.path.dirname(os.path.realpath(__file__)) + '/'
290
+ model_path_onnx2 = model_folder + 'Kim_Inst.onnx'
291
+ remote_url_onnx2 = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/Kim_Inst.onnx'
292
+ if not os.path.isfile(model_path_onnx2):
293
+ torch.hub.download_url_to_file(remote_url_onnx2, model_path_onnx2)
294
+ print('Model path: {}'.format(model_path_onnx2))
295
+ print('Device: {} Chunk size: {}'.format(device, chunk_size))
296
+ self.infer_session2 = ort.InferenceSession(
297
+ model_path_onnx2,
298
+ providers=providers,
299
+ provider_options=[{"device_id": 0}],
300
+ )
301
+
302
+ self.device = device
303
+ pass
304
+
305
+ @property
306
+ def instruments(self):
307
+ """ DO NOT CHANGE """
308
+ return ['bass', 'drums', 'other', 'vocals']
309
+
310
+ def raise_aicrowd_error(self, msg):
311
+ """ Will be used by the evaluator to provide logs, DO NOT CHANGE """
312
+ raise NameError(msg)
313
+
314
+ def separate_music_file(
315
+ self,
316
+ mixed_sound_array,
317
+ sample_rate,
318
+ update_percent_func=None,
319
+ current_file_number=0,
320
+ total_files=0,
321
+ only_vocals=False,
322
+ ):
323
+ """
324
+ Implements the sound separation for a single sound file
325
+ Inputs: Outputs from soundfile.read('mixture.wav')
326
+ mixed_sound_array
327
+ sample_rate
328
+
329
+ Outputs:
330
+ separated_music_arrays: Dictionary numpy array of each separated instrument
331
+ output_sample_rates: Dictionary of sample rates separated sequence
332
+ """
333
+
334
+ # print('Update percent func: {}'.format(update_percent_func))
335
+
336
+ separated_music_arrays = {}
337
+ output_sample_rates = {}
338
+
339
+ audio = np.expand_dims(mixed_sound_array.T, axis=0)
340
+ audio = torch.from_numpy(audio).type('torch.FloatTensor').to(self.device)
341
+
342
+ overlap_large = self.overlap_large
343
+ overlap_small = self.overlap_small
344
+
345
+ # Get Demics vocal only
346
+ model = self.model_vocals_only
347
+ shifts = 1
348
+ overlap = overlap_large
349
+ vocals_demucs = 0.5 * apply_model(model, audio, shifts=shifts, overlap=overlap)[0][3].cpu().numpy()
350
+
351
+ if update_percent_func is not None:
352
+ val = 100 * (current_file_number + 0.10) / total_files
353
+ update_percent_func(int(val))
354
+
355
+ vocals_demucs += 0.5 * -apply_model(model, -audio, shifts=shifts, overlap=overlap)[0][3].cpu().numpy()
356
+
357
+ if update_percent_func is not None:
358
+ val = 100 * (current_file_number + 0.20) / total_files
359
+ update_percent_func(int(val))
360
+
361
+ overlap = overlap_large
362
+ sources1 = demix_full(
363
+ mixed_sound_array.T,
364
+ self.device,
365
+ self.chunk_size,
366
+ self.mdx_models1,
367
+ self.infer_session1,
368
+ overlap=overlap
369
+ )[0]
370
+
371
+ vocals_mdxb1 = sources1
372
+
373
+ if update_percent_func is not None:
374
+ val = 100 * (current_file_number + 0.30) / total_files
375
+ update_percent_func(int(val))
376
+
377
+ if self.single_onnx is False:
378
+ sources2 = -demix_full(
379
+ -mixed_sound_array.T,
380
+ self.device,
381
+ self.chunk_size,
382
+ self.mdx_models2,
383
+ self.infer_session2,
384
+ overlap=overlap
385
+ )[0]
386
+
387
+ # it's instrumental so need to invert
388
+ instrum_mdxb2 = sources2
389
+ vocals_mdxb2 = mixed_sound_array.T - instrum_mdxb2
390
+
391
+ if update_percent_func is not None:
392
+ val = 100 * (current_file_number + 0.40) / total_files
393
+ update_percent_func(int(val))
394
+
395
+ # Ensemble vocals for MDX and Demucs
396
+ if self.single_onnx is False:
397
+ weights = np.array([12, 8, 3])
398
+ vocals = (weights[0] * vocals_mdxb1.T + weights[1] * vocals_mdxb2.T + weights[2] * vocals_demucs.T) / weights.sum()
399
+ else:
400
+ weights = np.array([6, 1])
401
+ vocals = (weights[0] * vocals_mdxb1.T + weights[1] * vocals_demucs.T) / weights.sum()
402
+
403
+ # vocals
404
+ separated_music_arrays['vocals'] = vocals
405
+ output_sample_rates['vocals'] = sample_rate
406
+
407
+ if not only_vocals:
408
+ # Generate instrumental
409
+ instrum = mixed_sound_array - vocals
410
+
411
+ audio = np.expand_dims(instrum.T, axis=0)
412
+ audio = torch.from_numpy(audio).type('torch.FloatTensor').to(self.device)
413
+
414
+ all_outs = []
415
+ for i, model in enumerate(self.models):
416
+ if i == 0:
417
+ overlap = overlap_small
418
+ elif i > 0:
419
+ overlap = overlap_large
420
+ out = 0.5 * apply_model(model, audio, shifts=shifts, overlap=overlap)[0].cpu().numpy() \
421
+ + 0.5 * -apply_model(model, -audio, shifts=shifts, overlap=overlap)[0].cpu().numpy()
422
+
423
+ if update_percent_func is not None:
424
+ val = 100 * (current_file_number + 0.50 + i * 0.10) / total_files
425
+ update_percent_func(int(val))
426
+
427
+ if i == 2:
428
+ # ['drums', 'bass', 'other', 'vocals', 'guitar', 'piano']
429
+ out[2] = out[2] + out[4] + out[5]
430
+ out = out[:4]
431
+
432
+ out[0] = self.weights_drums[i] * out[0]
433
+ out[1] = self.weights_bass[i] * out[1]
434
+ out[2] = self.weights_other[i] * out[2]
435
+ out[3] = self.weights_vocals[i] * out[3]
436
+
437
+ all_outs.append(out)
438
+ out = np.array(all_outs).sum(axis=0)
439
+ out[0] = out[0] / self.weights_drums.sum()
440
+ out[1] = out[1] / self.weights_bass.sum()
441
+ out[2] = out[2] / self.weights_other.sum()
442
+ out[3] = out[3] / self.weights_vocals.sum()
443
+
444
+ # other
445
+ res = mixed_sound_array - vocals - out[0].T - out[1].T
446
+ res = np.clip(res, -1, 1)
447
+ separated_music_arrays['other'] = (2 * res + out[2].T) / 3.0
448
+ output_sample_rates['other'] = sample_rate
449
+
450
+ # drums
451
+ res = mixed_sound_array - vocals - out[1].T - out[2].T
452
+ res = np.clip(res, -1, 1)
453
+ separated_music_arrays['drums'] = (res + 2 * out[0].T.copy()) / 3.0
454
+ output_sample_rates['drums'] = sample_rate
455
+
456
+ # bass
457
+ res = mixed_sound_array - vocals - out[0].T - out[2].T
458
+ res = np.clip(res, -1, 1)
459
+ separated_music_arrays['bass'] = (res + 2 * out[1].T) / 3.0
460
+ output_sample_rates['bass'] = sample_rate
461
+
462
+ bass = separated_music_arrays['bass']
463
+ drums = separated_music_arrays['drums']
464
+ other = separated_music_arrays['other']
465
+
466
+ separated_music_arrays['other'] = mixed_sound_array - vocals - bass - drums
467
+ separated_music_arrays['drums'] = mixed_sound_array - vocals - bass - other
468
+ separated_music_arrays['bass'] = mixed_sound_array - vocals - drums - other
469
+
470
+ if update_percent_func is not None:
471
+ val = 100 * (current_file_number + 0.95) / total_files
472
+ update_percent_func(int(val))
473
+
474
+ return separated_music_arrays, output_sample_rates
475
+
476
+
477
+ class EnsembleDemucsMDXMusicSeparationModelLowGPU:
478
+ """
479
+ Doesn't do any separation just passes the input back as output
480
+ """
481
+
482
+ def __init__(self, options):
483
+ """
484
+ options - user options
485
+ """
486
+ # print(options)
487
+
488
+ if torch.cuda.is_available():
489
+ device = 'cuda:0'
490
+ else:
491
+ device = 'cpu'
492
+ if 'cpu' in options:
493
+ if options['cpu']:
494
+ device = 'cpu'
495
+ print('Use device: {}'.format(device))
496
+ self.single_onnx = False
497
+ if 'single_onnx' in options:
498
+ if options['single_onnx']:
499
+ self.single_onnx = True
500
+ print('Use single vocal ONNX')
501
+
502
+ self.kim_model_1 = False
503
+ if 'use_kim_model_1' in options:
504
+ if options['use_kim_model_1']:
505
+ self.kim_model_1 = True
506
+ if self.kim_model_1:
507
+ print('Use Kim model 1')
508
+ else:
509
+ print('Use Kim model 2')
510
+
511
+ self.overlap_large = float(options['overlap_large'])
512
+ self.overlap_small = float(options['overlap_small'])
513
+ if self.overlap_large > 0.99:
514
+ self.overlap_large = 0.99
515
+ if self.overlap_large < 0.0:
516
+ self.overlap_large = 0.0
517
+ if self.overlap_small > 0.99:
518
+ self.overlap_small = 0.99
519
+ if self.overlap_small < 0.0:
520
+ self.overlap_small = 0.0
521
+
522
+ self.weights_vocals = np.array([10, 1, 8, 9])
523
+ self.weights_bass = np.array([19, 4, 5, 8])
524
+ self.weights_drums = np.array([18, 2, 4, 9])
525
+ self.weights_other = np.array([14, 2, 5, 10])
526
+
527
+ if device == 'cpu':
528
+ chunk_size = 200000000
529
+ self.providers = ["CPUExecutionProvider"]
530
+ else:
531
+ chunk_size = 1000000
532
+ self.providers = ["CUDAExecutionProvider"]
533
+ if 'chunk_size' in options:
534
+ chunk_size = int(options['chunk_size'])
535
+ self.chunk_size = chunk_size
536
+ self.device = device
537
+ pass
538
+
539
+ @property
540
+ def instruments(self):
541
+ """ DO NOT CHANGE """
542
+ return ['bass', 'drums', 'other', 'vocals']
543
+
544
+ def raise_aicrowd_error(self, msg):
545
+ """ Will be used by the evaluator to provide logs, DO NOT CHANGE """
546
+ raise NameError(msg)
547
+
548
+ def separate_music_file(
549
+ self,
550
+ mixed_sound_array,
551
+ sample_rate,
552
+ update_percent_func=None,
553
+ current_file_number=0,
554
+ total_files=0,
555
+ only_vocals=False
556
+ ):
557
+ """
558
+ Implements the sound separation for a single sound file
559
+ Inputs: Outputs from soundfile.read('mixture.wav')
560
+ mixed_sound_array
561
+ sample_rate
562
+
563
+ Outputs:
564
+ separated_music_arrays: Dictionary numpy array of each separated instrument
565
+ output_sample_rates: Dictionary of sample rates separated sequence
566
+ """
567
+
568
+ # print('Update percent func: {}'.format(update_percent_func))
569
+
570
+ separated_music_arrays = {}
571
+ output_sample_rates = {}
572
+
573
+ audio = np.expand_dims(mixed_sound_array.T, axis=0)
574
+ audio = torch.from_numpy(audio).type('torch.FloatTensor').to(self.device)
575
+
576
+ overlap_large = self.overlap_large
577
+ overlap_small = self.overlap_small
578
+
579
+ # Get Demucs vocal only
580
+ model_folder = os.path.dirname(os.path.realpath(__file__)) + '/models/'
581
+ remote_url = 'https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th'
582
+ model_path = model_folder + '04573f0d-f3cf25b2.th'
583
+ if not os.path.isfile(model_path):
584
+ torch.hub.download_url_to_file(remote_url, model_folder + '04573f0d-f3cf25b2.th')
585
+ model_vocals = load_model(model_path)
586
+ model_vocals.to(self.device)
587
+ shifts = 1
588
+ overlap = overlap_large
589
+ vocals_demucs = 0.5 * apply_model(model_vocals, audio, shifts=shifts, overlap=overlap)[0][3].cpu().numpy()
590
+
591
+ if update_percent_func is not None:
592
+ val = 100 * (current_file_number + 0.10) / total_files
593
+ update_percent_func(int(val))
594
+
595
+ vocals_demucs += 0.5 * -apply_model(model_vocals, -audio, shifts=shifts, overlap=overlap)[0][3].cpu().numpy()
596
+ model_vocals = model_vocals.cpu()
597
+ del model_vocals
598
+
599
+ if update_percent_func is not None:
600
+ val = 100 * (current_file_number + 0.20) / total_files
601
+ update_percent_func(int(val))
602
+
603
+ # MDX-B model 1 initialization
604
+ mdx_models1 = get_models('tdf_extra', load=False, device=self.device, vocals_model_type=2)
605
+ if self.kim_model_1:
606
+ model_path_onnx1 = model_folder + 'Kim_Vocal_1.onnx'
607
+ remote_url_onnx1 = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/Kim_Vocal_1.onnx'
608
+ else:
609
+ model_path_onnx1 = model_folder + 'Kim_Vocal_2.onnx'
610
+ remote_url_onnx1 = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/Kim_Vocal_2.onnx'
611
+ if not os.path.isfile(model_path_onnx1):
612
+ torch.hub.download_url_to_file(remote_url_onnx1, model_path_onnx1)
613
+ print('Model path: {}'.format(model_path_onnx1))
614
+ print('Device: {} Chunk size: {}'.format(self.device, self.chunk_size))
615
+ infer_session1 = ort.InferenceSession(
616
+ model_path_onnx1,
617
+ providers=self.providers,
618
+ provider_options=[{"device_id": 0}],
619
+ )
620
+ overlap = overlap_large
621
+ sources1 = demix_full(
622
+ mixed_sound_array.T,
623
+ self.device,
624
+ self.chunk_size,
625
+ mdx_models1,
626
+ infer_session1,
627
+ overlap=overlap
628
+ )[0]
629
+ vocals_mdxb1 = sources1
630
+ del infer_session1
631
+ del mdx_models1
632
+
633
+ if update_percent_func is not None:
634
+ val = 100 * (current_file_number + 0.30) / total_files
635
+ update_percent_func(int(val))
636
+
637
+ if self.single_onnx is False:
638
+ # MDX-B model 2 initialization
639
+ mdx_models2 = get_models('tdf_extra', load=False, device=self.device, vocals_model_type=2)
640
+ root_path = os.path.dirname(os.path.realpath(__file__)) + '/'
641
+ model_path_onnx2 = model_folder + 'Kim_Inst.onnx'
642
+ remote_url_onnx2 = 'https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/Kim_Inst.onnx'
643
+ if not os.path.isfile(model_path_onnx2):
644
+ torch.hub.download_url_to_file(remote_url_onnx2, model_path_onnx2)
645
+ print('Model path: {}'.format(model_path_onnx2))
646
+ print('Device: {} Chunk size: {}'.format(self.device, self.chunk_size))
647
+ infer_session2 = ort.InferenceSession(
648
+ model_path_onnx2,
649
+ providers=self.providers,
650
+ provider_options=[{"device_id": 0}],
651
+ )
652
+
653
+ overlap = overlap_large
654
+ sources2 = -demix_full(
655
+ -mixed_sound_array.T,
656
+ self.device,
657
+ self.chunk_size,
658
+ mdx_models2,
659
+ infer_session2,
660
+ overlap=overlap
661
+ )[0]
662
+
663
+ # it's instrumental so need to invert
664
+ instrum_mdxb2 = sources2
665
+ vocals_mdxb2 = mixed_sound_array.T - instrum_mdxb2
666
+ del infer_session2
667
+ del mdx_models2
668
+
669
+ if update_percent_func is not None:
670
+ val = 100 * (current_file_number + 0.40) / total_files
671
+ update_percent_func(int(val))
672
+
673
+ # Ensemble vocals for MDX and Demucs
674
+ if self.single_onnx is False:
675
+ weights = np.array([12, 8, 3])
676
+ vocals = (weights[0] * vocals_mdxb1.T + weights[1] * vocals_mdxb2.T + weights[2] * vocals_demucs.T) / weights.sum()
677
+ else:
678
+ weights = np.array([6, 1])
679
+ vocals = (weights[0] * vocals_mdxb1.T + weights[1] * vocals_demucs.T) / weights.sum()
680
+
681
+ # Generate instrumental
682
+ instrum = mixed_sound_array - vocals
683
+
684
+ audio = np.expand_dims(instrum.T, axis=0)
685
+ audio = torch.from_numpy(audio).type('torch.FloatTensor').to(self.device)
686
+
687
+ all_outs = []
688
+
689
+ i = 0
690
+ overlap = overlap_small
691
+ model = pretrained.get_model('htdemucs_ft')
692
+ model.to(self.device)
693
+ out = 0.5 * apply_model(model, audio, shifts=shifts, overlap=overlap)[0].cpu().numpy() \
694
+ + 0.5 * -apply_model(model, -audio, shifts=shifts, overlap=overlap)[0].cpu().numpy()
695
+
696
+ if update_percent_func is not None:
697
+ val = 100 * (current_file_number + 0.50 + i * 0.10) / total_files
698
+ update_percent_func(int(val))
699
+
700
+ out[0] = self.weights_drums[i] * out[0]
701
+ out[1] = self.weights_bass[i] * out[1]
702
+ out[2] = self.weights_other[i] * out[2]
703
+ out[3] = self.weights_vocals[i] * out[3]
704
+ all_outs.append(out)
705
+ model = model.cpu()
706
+ del model
707
+
708
+ i = 1
709
+ overlap = overlap_large
710
+ model = pretrained.get_model('htdemucs')
711
+ model.to(self.device)
712
+ out = 0.5 * apply_model(model, audio, shifts=shifts, overlap=overlap)[0].cpu().numpy() \
713
+ + 0.5 * -apply_model(model, -audio, shifts=shifts, overlap=overlap)[0].cpu().numpy()
714
+
715
+ if update_percent_func is not None:
716
+ val = 100 * (current_file_number + 0.50 + i * 0.10) / total_files
717
+ update_percent_func(int(val))
718
+
719
+ out[0] = self.weights_drums[i] * out[0]
720
+ out[1] = self.weights_bass[i] * out[1]
721
+ out[2] = self.weights_other[i] * out[2]
722
+ out[3] = self.weights_vocals[i] * out[3]
723
+ all_outs.append(out)
724
+ model = model.cpu()
725
+ del model
726
+
727
+ i = 2
728
+ overlap = overlap_large
729
+ model = pretrained.get_model('htdemucs_6s')
730
+ model.to(self.device)
731
+ out = 0.5 * apply_model(model, audio, shifts=shifts, overlap=overlap)[0].cpu().numpy() \
732
+ + 0.5 * -apply_model(model, -audio, shifts=shifts, overlap=overlap)[0].cpu().numpy()
733
+
734
+ if update_percent_func is not None:
735
+ val = 100 * (current_file_number + 0.50 + i * 0.10) / total_files
736
+ update_percent_func(int(val))
737
+
738
+ # More stems need to add
739
+ out[2] = out[2] + out[4] + out[5]
740
+ out = out[:4]
741
+ out[0] = self.weights_drums[i] * out[0]
742
+ out[1] = self.weights_bass[i] * out[1]
743
+ out[2] = self.weights_other[i] * out[2]
744
+ out[3] = self.weights_vocals[i] * out[3]
745
+ all_outs.append(out)
746
+ model = model.cpu()
747
+ del model
748
+
749
+ i = 3
750
+ model = pretrained.get_model('hdemucs_mmi')
751
+ model.to(self.device)
752
+ out = 0.5 * apply_model(model, audio, shifts=shifts, overlap=overlap)[0].cpu().numpy() \
753
+ + 0.5 * -apply_model(model, -audio, shifts=shifts, overlap=overlap)[0].cpu().numpy()
754
+
755
+ if update_percent_func is not None:
756
+ val = 100 * (current_file_number + 0.50 + i * 0.10) / total_files
757
+ update_percent_func(int(val))
758
+
759
+ out[0] = self.weights_drums[i] * out[0]
760
+ out[1] = self.weights_bass[i] * out[1]
761
+ out[2] = self.weights_other[i] * out[2]
762
+ out[3] = self.weights_vocals[i] * out[3]
763
+ all_outs.append(out)
764
+ model = model.cpu()
765
+ del model
766
+
767
+ out = np.array(all_outs).sum(axis=0)
768
+ out[0] = out[0] / self.weights_drums.sum()
769
+ out[1] = out[1] / self.weights_bass.sum()
770
+ out[2] = out[2] / self.weights_other.sum()
771
+ out[3] = out[3] / self.weights_vocals.sum()
772
+
773
+ # vocals
774
+ separated_music_arrays['vocals'] = vocals
775
+ output_sample_rates['vocals'] = sample_rate
776
+
777
+ # other
778
+ res = mixed_sound_array - vocals - out[0].T - out[1].T
779
+ res = np.clip(res, -1, 1)
780
+ separated_music_arrays['other'] = (2 * res + out[2].T) / 3.0
781
+ output_sample_rates['other'] = sample_rate
782
+
783
+ # drums
784
+ res = mixed_sound_array - vocals - out[1].T - out[2].T
785
+ res = np.clip(res, -1, 1)
786
+ separated_music_arrays['drums'] = (res + 2 * out[0].T.copy()) / 3.0
787
+ output_sample_rates['drums'] = sample_rate
788
+
789
+ # bass
790
+ res = mixed_sound_array - vocals - out[0].T - out[2].T
791
+ res = np.clip(res, -1, 1)
792
+ separated_music_arrays['bass'] = (res + 2 * out[1].T) / 3.0
793
+ output_sample_rates['bass'] = sample_rate
794
+
795
+ bass = separated_music_arrays['bass']
796
+ drums = separated_music_arrays['drums']
797
+ other = separated_music_arrays['other']
798
+
799
+ separated_music_arrays['other'] = mixed_sound_array - vocals - bass - drums
800
+ separated_music_arrays['drums'] = mixed_sound_array - vocals - bass - other
801
+ separated_music_arrays['bass'] = mixed_sound_array - vocals - drums - other
802
+
803
+ if update_percent_func is not None:
804
+ val = 100 * (current_file_number + 0.95) / total_files
805
+ update_percent_func(int(val))
806
+
807
+ return separated_music_arrays, output_sample_rates
808
+
809
+
810
+ def predict_with_model(options):
811
+ for input_audio in options['input_audio']:
812
+ if not os.path.isfile(input_audio):
813
+ print('Error. No such file: {}. Please check path!'.format(input_audio))
814
+ return
815
+ output_folder = options['output_folder']
816
+ if not os.path.isdir(output_folder):
817
+ os.mkdir(output_folder)
818
+
819
+ only_vocals = False
820
+ if 'only_vocals' in options:
821
+ if options['only_vocals'] is True:
822
+ print('Generate only vocals and instrumental')
823
+ only_vocals = True
824
+
825
+ model = None
826
+ if 'large_gpu' in options:
827
+ if options['large_gpu'] is True:
828
+ print('Use fast large GPU memory version of code')
829
+ model = EnsembleDemucsMDXMusicSeparationModel(options)
830
+ if model is None:
831
+ print('Use low GPU memory version of code')
832
+ model = EnsembleDemucsMDXMusicSeparationModelLowGPU(options)
833
+
834
+ update_percent_func = None
835
+ if 'update_percent_func' in options:
836
+ update_percent_func = options['update_percent_func']
837
+
838
+ for i, input_audio in enumerate(options['input_audio']):
839
+ print('Go for: {}'.format(input_audio))
840
+ audio, sr = librosa.load(input_audio, mono=False, sr=44100)
841
+ if len(audio.shape) == 1:
842
+ audio = np.stack([audio, audio], axis=0)
843
+ print("Input audio: {} Sample rate: {}".format(audio.shape, sr))
844
+ result, sample_rates = model.separate_music_file(
845
+ audio.T,
846
+ sr,
847
+ update_percent_func,
848
+ i,
849
+ len(options['input_audio']),
850
+ only_vocals,
851
+ )
852
+ all_instrum = model.instruments
853
+ if only_vocals:
854
+ all_instrum = ['vocals']
855
+ for instrum in all_instrum:
856
+ output_name = os.path.splitext(os.path.basename(input_audio))[0] + '_{}.wav'.format(instrum)
857
+ sf.write(output_folder + '/' + output_name, result[instrum], sample_rates[instrum], subtype='FLOAT')
858
+ print('File created: {}'.format(output_folder + '/' + output_name))
859
+
860
+ # instrumental part 1
861
+ inst = audio.T - result['vocals']
862
+ output_name = os.path.splitext(os.path.basename(input_audio))[0] + '_{}.wav'.format('instrum')
863
+ sf.write(output_folder + '/' + output_name, inst, sr, subtype='FLOAT')
864
+ print('File created: {}'.format(output_folder + '/' + output_name))
865
+
866
+ if not only_vocals:
867
+ # instrumental part 2
868
+ inst2 = result['bass'] + result['drums'] + result['other']
869
+ output_name = os.path.splitext(os.path.basename(input_audio))[0] + '_{}.wav'.format('instrum2')
870
+ sf.write(output_folder + '/' + output_name, inst2, sr, subtype='FLOAT')
871
+ print('File created: {}'.format(output_folder + '/' + output_name))
872
+
873
+ if update_percent_func is not None:
874
+ val = 100
875
+ update_percent_func(int(val))
876
+
877
+
878
+ def md5(fname):
879
+ hash_md5 = hashlib.md5()
880
+ with open(fname, "rb") as f:
881
+ for chunk in iter(lambda: f.read(4096), b""):
882
+ hash_md5.update(chunk)
883
+ return hash_md5.hexdigest()
884
+
885
+
886
+ if __name__ == '__main__':
887
+ start_time = time()
888
+
889
+ print("Version: {}".format(__VERSION__))
890
+ m = argparse.ArgumentParser()
891
+ m.add_argument("--input_audio", "-i", nargs='+', type=str, help="Input audio location. You can provide multiple files at once", required=True)
892
+ m.add_argument("--output_folder", "-r", type=str, help="Output audio folder", required=True)
893
+ m.add_argument("--cpu", action='store_true', help="Choose CPU instead of GPU for processing. Can be very slow.")
894
+ m.add_argument("--overlap_large", "-ol", type=float, help="Overlap of splited audio for light models. Closer to 1.0 - slower", required=False, default=0.6)
895
+ m.add_argument("--overlap_small", "-os", type=float, help="Overlap of splited audio for heavy models. Closer to 1.0 - slower", required=False, default=0.5)
896
+ m.add_argument("--single_onnx", action='store_true', help="Only use single ONNX model for vocals. Can be useful if you have not enough GPU memory.")
897
+ m.add_argument("--chunk_size", "-cz", type=int, help="Chunk size for ONNX models. Set lower to reduce GPU memory consumption. Default: 1000000", required=False, default=1000000)
898
+ m.add_argument("--large_gpu", action='store_true', help="It will store all models on GPU for faster processing of multiple audio files. Requires 11 and more GB of free GPU memory.")
899
+ m.add_argument("--use_kim_model_1", action='store_true', help="Use first version of Kim model (as it was on contest).")
900
+ m.add_argument("--only_vocals", action='store_true', help="Only create vocals and instrumental. Skip bass, drums, other")
901
+
902
+ options = m.parse_args().__dict__
903
+ print("Options: ".format(options))
904
+ for el in options:
905
+ print('{}: {}'.format(el, options[el]))
906
+ predict_with_model(options)
907
+ print('Time: {:.0f} sec'.format(time() - start_time))
908
+ print('Presented by https://mvsep.com')
909
+
910
+
911
+ """
912
+ Example:
913
+ python inference.py
914
+ --input_audio mixture.wav mixture1.wav
915
+ --output_folder ./results/
916
+ --cpu
917
+ --overlap_large 0.25
918
+ --overlap_small 0.25
919
+ --chunk_size 500000
920
+ """
models/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ soundfile
3
+ scipy
4
+ torch>=1.8.1
5
+ tqdm
6
+ librosa
7
+ demucs
8
+ onnxruntime-gpu
9
+ PyQt5
10
+ gradio
11
+ moviepy
12
+ pytube