PhoenixStormJr commited on
Commit
2e5d76b
·
verified ·
1 Parent(s): 8a061f8

Upload data_utils.py

Browse files
Files changed (1) hide show
  1. data_utils.py +517 -0
data_utils.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import traceback
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+
11
+ from infer.lib.train.mel_processing import spectrogram_torch
12
+ from infer.lib.train.utils import load_filepaths_and_text, load_wav_to_torch
13
+
14
+
15
+ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
16
+ """
17
+ 1) loads audio, text pairs
18
+ 2) normalizes text and converts them to sequences of integers
19
+ 3) computes spectrograms from audio files.
20
+ """
21
+
22
+ def __init__(self, audiopaths_and_text, hparams):
23
+ self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
24
+ self.max_wav_value = hparams.max_wav_value
25
+ self.sampling_rate = hparams.sampling_rate
26
+ self.filter_length = hparams.filter_length
27
+ self.hop_length = hparams.hop_length
28
+ self.win_length = hparams.win_length
29
+ self.sampling_rate = hparams.sampling_rate
30
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
31
+ self.max_text_len = getattr(hparams, "max_text_len", 5000)
32
+ self._filter()
33
+
34
+ def _filter(self):
35
+ """
36
+ Filter text & store spec lengths
37
+ """
38
+ # Store spectrogram lengths for Bucketing
39
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
40
+ # spec_length = wav_length // hop_length
41
+ audiopaths_and_text_new = []
42
+ lengths = []
43
+ for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
44
+ if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
45
+ audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
46
+ lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
47
+ self.audiopaths_and_text = audiopaths_and_text_new
48
+ self.lengths = lengths
49
+
50
+ def get_sid(self, sid):
51
+ sid = torch.LongTensor([int(sid)])
52
+ return sid
53
+
54
+ def get_audio_text_pair(self, audiopath_and_text):
55
+ # separate filename and text
56
+ file = audiopath_and_text[0]
57
+ phone = audiopath_and_text[1]
58
+ pitch = audiopath_and_text[2]
59
+ pitchf = audiopath_and_text[3]
60
+ dv = audiopath_and_text[4]
61
+
62
+ phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
63
+ spec, wav = self.get_audio(file)
64
+ dv = self.get_sid(dv)
65
+
66
+ len_phone = phone.size()[0]
67
+ len_spec = spec.size()[-1]
68
+ # print(123,phone.shape,pitch.shape,spec.shape)
69
+ if len_phone != len_spec:
70
+ len_min = min(len_phone, len_spec)
71
+ # amor
72
+ len_wav = len_min * self.hop_length
73
+
74
+ spec = spec[:, :len_min]
75
+ wav = wav[:, :len_wav]
76
+
77
+ phone = phone[:len_min, :]
78
+ pitch = pitch[:len_min]
79
+ pitchf = pitchf[:len_min]
80
+
81
+ return (spec, wav, phone, pitch, pitchf, dv)
82
+
83
+ def get_labels(self, phone, pitch, pitchf):
84
+ phone = np.load(phone, allow_pickle=True)
85
+ phone = np.repeat(phone, 2, axis=0)
86
+ pitch = np.load(pitch, allow_pickle=True)
87
+ pitchf = np.load(pitchf, allow_pickle=True)
88
+ n_num = min(phone.shape[0], 900) # DistributedBucketSampler
89
+ # print(234,phone.shape,pitch.shape)
90
+ phone = phone[:n_num, :]
91
+ pitch = pitch[:n_num]
92
+ pitchf = pitchf[:n_num]
93
+ phone = torch.FloatTensor(phone)
94
+ pitch = torch.LongTensor(pitch)
95
+ pitchf = torch.FloatTensor(pitchf)
96
+ return phone, pitch, pitchf
97
+
98
+ def get_audio(self, filename):
99
+ audio, sampling_rate = load_wav_to_torch(filename)
100
+ if sampling_rate != self.sampling_rate:
101
+ raise ValueError(
102
+ "{} SR doesn't match target {} SR".format(
103
+ sampling_rate, self.sampling_rate
104
+ )
105
+ )
106
+ audio_norm = audio
107
+ # audio_norm = audio / self.max_wav_value
108
+ # audio_norm = audio / np.abs(audio).max()
109
+
110
+ audio_norm = audio_norm.unsqueeze(0)
111
+ spec_filename = filename.replace(".wav", ".spec.pt")
112
+ if os.path.exists(spec_filename):
113
+ try:
114
+ spec = torch.load(spec_filename)
115
+ except:
116
+ logger.warning("%s %s", spec_filename, traceback.format_exc())
117
+ spec = spectrogram_torch(
118
+ audio_norm,
119
+ self.filter_length,
120
+ self.sampling_rate,
121
+ self.hop_length,
122
+ self.win_length,
123
+ center=False,
124
+ )
125
+ spec = torch.squeeze(spec, 0)
126
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
127
+ else:
128
+ spec = spectrogram_torch(
129
+ audio_norm,
130
+ self.filter_length,
131
+ self.sampling_rate,
132
+ self.hop_length,
133
+ self.win_length,
134
+ center=False,
135
+ )
136
+ spec = torch.squeeze(spec, 0)
137
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
138
+ return spec, audio_norm
139
+
140
+ def __getitem__(self, index):
141
+ return self.get_audio_text_pair(self.audiopaths_and_text[index])
142
+
143
+ def __len__(self):
144
+ return len(self.audiopaths_and_text)
145
+
146
+
147
+ class TextAudioCollateMultiNSFsid:
148
+ """Zero-pads model inputs and targets"""
149
+
150
+ def __init__(self, return_ids=False):
151
+ self.return_ids = return_ids
152
+
153
+ def __call__(self, batch):
154
+ """Collate's training batch from normalized text and aduio
155
+ PARAMS
156
+ ------
157
+ batch: [text_normalized, spec_normalized, wav_normalized]
158
+ """
159
+ # Right zero-pad all one-hot text sequences to max input length
160
+ _, ids_sorted_decreasing = torch.sort(
161
+ torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
162
+ )
163
+
164
+ max_spec_len = max([x[0].size(1) for x in batch])
165
+ max_wave_len = max([x[1].size(1) for x in batch])
166
+ spec_lengths = torch.LongTensor(len(batch))
167
+ wave_lengths = torch.LongTensor(len(batch))
168
+ spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
169
+ wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
170
+ spec_padded.zero_()
171
+ wave_padded.zero_()
172
+
173
+ max_phone_len = max([x[2].size(0) for x in batch])
174
+ phone_lengths = torch.LongTensor(len(batch))
175
+ phone_padded = torch.FloatTensor(
176
+ len(batch), max_phone_len, batch[0][2].shape[1]
177
+ ) # (spec, wav, phone, pitch)
178
+ pitch_padded = torch.LongTensor(len(batch), max_phone_len)
179
+ pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
180
+ phone_padded.zero_()
181
+ pitch_padded.zero_()
182
+ pitchf_padded.zero_()
183
+ # dv = torch.FloatTensor(len(batch), 256)#gin=256
184
+ sid = torch.LongTensor(len(batch))
185
+
186
+ for i in range(len(ids_sorted_decreasing)):
187
+ row = batch[ids_sorted_decreasing[i]]
188
+
189
+ spec = row[0]
190
+ spec_padded[i, :, : spec.size(1)] = spec
191
+ spec_lengths[i] = spec.size(1)
192
+
193
+ wave = row[1]
194
+ wave_padded[i, :, : wave.size(1)] = wave
195
+ wave_lengths[i] = wave.size(1)
196
+
197
+ phone = row[2]
198
+ phone_padded[i, : phone.size(0), :] = phone
199
+ phone_lengths[i] = phone.size(0)
200
+
201
+ pitch = row[3]
202
+ pitch_padded[i, : pitch.size(0)] = pitch
203
+ pitchf = row[4]
204
+ pitchf_padded[i, : pitchf.size(0)] = pitchf
205
+
206
+ # dv[i] = row[5]
207
+ sid[i] = row[5]
208
+
209
+ return (
210
+ phone_padded,
211
+ phone_lengths,
212
+ pitch_padded,
213
+ pitchf_padded,
214
+ spec_padded,
215
+ spec_lengths,
216
+ wave_padded,
217
+ wave_lengths,
218
+ # dv
219
+ sid,
220
+ )
221
+
222
+
223
+ class TextAudioLoader(torch.utils.data.Dataset):
224
+ """
225
+ 1) loads audio, text pairs
226
+ 2) normalizes text and converts them to sequences of integers
227
+ 3) computes spectrograms from audio files.
228
+ """
229
+
230
+ def __init__(self, audiopaths_and_text, hparams):
231
+ self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
232
+ self.max_wav_value = hparams.max_wav_value
233
+ self.sampling_rate = hparams.sampling_rate
234
+ self.filter_length = hparams.filter_length
235
+ self.hop_length = hparams.hop_length
236
+ self.win_length = hparams.win_length
237
+ self.sampling_rate = hparams.sampling_rate
238
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
239
+ self.max_text_len = getattr(hparams, "max_text_len", 5000)
240
+ self._filter()
241
+
242
+ def _filter(self):
243
+ """
244
+ Filter text & store spec lengths
245
+ """
246
+ # Store spectrogram lengths for Bucketing
247
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
248
+ # spec_length = wav_length // hop_length
249
+ audiopaths_and_text_new = []
250
+ lengths = []
251
+ for audiopath, text, dv in self.audiopaths_and_text:
252
+ if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
253
+ audiopaths_and_text_new.append([audiopath, text, dv])
254
+ lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
255
+ self.audiopaths_and_text = audiopaths_and_text_new
256
+ self.lengths = lengths
257
+
258
+ def get_sid(self, sid):
259
+ sid = torch.LongTensor([int(sid)])
260
+ return sid
261
+
262
+ def get_audio_text_pair(self, audiopath_and_text):
263
+ # separate filename and text
264
+ file = audiopath_and_text[0]
265
+ phone = audiopath_and_text[1]
266
+ dv = audiopath_and_text[2]
267
+
268
+ phone = self.get_labels(phone)
269
+ spec, wav = self.get_audio(file)
270
+ dv = self.get_sid(dv)
271
+
272
+ len_phone = phone.size()[0]
273
+ len_spec = spec.size()[-1]
274
+ if len_phone != len_spec:
275
+ len_min = min(len_phone, len_spec)
276
+ len_wav = len_min * self.hop_length
277
+ spec = spec[:, :len_min]
278
+ wav = wav[:, :len_wav]
279
+ phone = phone[:len_min, :]
280
+ return (spec, wav, phone, dv)
281
+
282
+ def get_labels(self, phone):
283
+ phone = np.load(phone, allow_pickle=True)
284
+ phone = np.repeat(phone, 2, axis=0)
285
+ n_num = min(phone.shape[0], 900) # DistributedBucketSampler
286
+ phone = phone[:n_num, :]
287
+ phone = torch.FloatTensor(phone)
288
+ return phone
289
+
290
+ def get_audio(self, filename):
291
+ audio, sampling_rate = load_wav_to_torch(filename)
292
+ if sampling_rate != self.sampling_rate:
293
+ raise ValueError(
294
+ "{} SR doesn't match target {} SR".format(
295
+ sampling_rate, self.sampling_rate
296
+ )
297
+ )
298
+ audio_norm = audio
299
+ # audio_norm = audio / self.max_wav_value
300
+ # audio_norm = audio / np.abs(audio).max()
301
+
302
+ audio_norm = audio_norm.unsqueeze(0)
303
+ spec_filename = filename.replace(".wav", ".spec.pt")
304
+ if os.path.exists(spec_filename):
305
+ try:
306
+ spec = torch.load(spec_filename)
307
+ except:
308
+ logger.warning("%s %s", spec_filename, traceback.format_exc())
309
+ spec = spectrogram_torch(
310
+ audio_norm,
311
+ self.filter_length,
312
+ self.sampling_rate,
313
+ self.hop_length,
314
+ self.win_length,
315
+ center=False,
316
+ )
317
+ spec = torch.squeeze(spec, 0)
318
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
319
+ else:
320
+ spec = spectrogram_torch(
321
+ audio_norm,
322
+ self.filter_length,
323
+ self.sampling_rate,
324
+ self.hop_length,
325
+ self.win_length,
326
+ center=False,
327
+ )
328
+ spec = torch.squeeze(spec, 0)
329
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
330
+ return spec, audio_norm
331
+
332
+ def __getitem__(self, index):
333
+ return self.get_audio_text_pair(self.audiopaths_and_text[index])
334
+
335
+ def __len__(self):
336
+ return len(self.audiopaths_and_text)
337
+
338
+
339
+ class TextAudioCollate:
340
+ """Zero-pads model inputs and targets"""
341
+
342
+ def __init__(self, return_ids=False):
343
+ self.return_ids = return_ids
344
+
345
+ def __call__(self, batch):
346
+ """Collate's training batch from normalized text and aduio
347
+ PARAMS
348
+ ------
349
+ batch: [text_normalized, spec_normalized, wav_normalized]
350
+ """
351
+ # Right zero-pad all one-hot text sequences to max input length
352
+ _, ids_sorted_decreasing = torch.sort(
353
+ torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
354
+ )
355
+
356
+ max_spec_len = max([x[0].size(1) for x in batch])
357
+ max_wave_len = max([x[1].size(1) for x in batch])
358
+ spec_lengths = torch.LongTensor(len(batch))
359
+ wave_lengths = torch.LongTensor(len(batch))
360
+ spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
361
+ wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
362
+ spec_padded.zero_()
363
+ wave_padded.zero_()
364
+
365
+ max_phone_len = max([x[2].size(0) for x in batch])
366
+ phone_lengths = torch.LongTensor(len(batch))
367
+ phone_padded = torch.FloatTensor(
368
+ len(batch), max_phone_len, batch[0][2].shape[1]
369
+ )
370
+ phone_padded.zero_()
371
+ sid = torch.LongTensor(len(batch))
372
+
373
+ for i in range(len(ids_sorted_decreasing)):
374
+ row = batch[ids_sorted_decreasing[i]]
375
+
376
+ spec = row[0]
377
+ spec_padded[i, :, : spec.size(1)] = spec
378
+ spec_lengths[i] = spec.size(1)
379
+
380
+ wave = row[1]
381
+ wave_padded[i, :, : wave.size(1)] = wave
382
+ wave_lengths[i] = wave.size(1)
383
+
384
+ phone = row[2]
385
+ phone_padded[i, : phone.size(0), :] = phone
386
+ phone_lengths[i] = phone.size(0)
387
+
388
+ sid[i] = row[3]
389
+
390
+ return (
391
+ phone_padded,
392
+ phone_lengths,
393
+ spec_padded,
394
+ spec_lengths,
395
+ wave_padded,
396
+ wave_lengths,
397
+ sid,
398
+ )
399
+
400
+
401
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
402
+ """
403
+ Maintain similar input lengths in a batch.
404
+ Length groups are specified by boundaries.
405
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
406
+
407
+ It removes samples which are not included in the boundaries.
408
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
409
+ """
410
+
411
+ def __init__(
412
+ self,
413
+ dataset,
414
+ batch_size,
415
+ boundaries,
416
+ num_replicas=None,
417
+ rank=None,
418
+ shuffle=True,
419
+ ):
420
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
421
+ self.lengths = dataset.lengths
422
+ self.batch_size = batch_size
423
+ self.boundaries = boundaries
424
+
425
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
426
+ self.total_size = sum(self.num_samples_per_bucket)
427
+ self.num_samples = self.total_size // self.num_replicas
428
+
429
+ def _create_buckets(self):
430
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
431
+ for i in range(len(self.lengths)):
432
+ length = self.lengths[i]
433
+ idx_bucket = self._bisect(length)
434
+ if idx_bucket != -1:
435
+ buckets[idx_bucket].append(i)
436
+
437
+ for i in range(len(buckets) - 1, -1, -1): #
438
+ if len(buckets[i]) == 0:
439
+ buckets.pop(i)
440
+ self.boundaries.pop(i + 1)
441
+
442
+ num_samples_per_bucket = []
443
+ for i in range(len(buckets)):
444
+ len_bucket = len(buckets[i])
445
+ total_batch_size = self.num_replicas * self.batch_size
446
+ rem = (
447
+ total_batch_size - (len_bucket % total_batch_size)
448
+ ) % total_batch_size
449
+ num_samples_per_bucket.append(len_bucket + rem)
450
+ return buckets, num_samples_per_bucket
451
+
452
+ def __iter__(self):
453
+ # deterministically shuffle based on epoch
454
+ g = torch.Generator()
455
+ g.manual_seed(self.epoch)
456
+
457
+ indices = []
458
+ if self.shuffle:
459
+ for bucket in self.buckets:
460
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
461
+ else:
462
+ for bucket in self.buckets:
463
+ indices.append(list(range(len(bucket))))
464
+
465
+ batches = []
466
+ for i in range(len(self.buckets)):
467
+ bucket = self.buckets[i]
468
+ len_bucket = len(bucket)
469
+ ids_bucket = indices[i]
470
+ num_samples_bucket = self.num_samples_per_bucket[i]
471
+
472
+ # add extra samples to make it evenly divisible
473
+ rem = num_samples_bucket - len_bucket
474
+ ids_bucket = (
475
+ ids_bucket
476
+ + ids_bucket * (rem // len_bucket)
477
+ + ids_bucket[: (rem % len_bucket)]
478
+ )
479
+
480
+ # subsample
481
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
482
+
483
+ # batching
484
+ for j in range(len(ids_bucket) // self.batch_size):
485
+ batch = [
486
+ bucket[idx]
487
+ for idx in ids_bucket[
488
+ j * self.batch_size : (j + 1) * self.batch_size
489
+ ]
490
+ ]
491
+ batches.append(batch)
492
+
493
+ if self.shuffle:
494
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
495
+ batches = [batches[i] for i in batch_ids]
496
+ self.batches = batches
497
+
498
+ assert len(self.batches) * self.batch_size == self.num_samples
499
+ return iter(self.batches)
500
+
501
+ def _bisect(self, x, lo=0, hi=None):
502
+ if hi is None:
503
+ hi = len(self.boundaries) - 1
504
+
505
+ if hi > lo:
506
+ mid = (hi + lo) // 2
507
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
508
+ return mid
509
+ elif x <= self.boundaries[mid]:
510
+ return self._bisect(x, lo, mid)
511
+ else:
512
+ return self._bisect(x, mid + 1, hi)
513
+ else:
514
+ return -1
515
+
516
+ def __len__(self):
517
+ return self.num_samples // self.batch_size