PhoenixStormJr commited on
Commit
8a061f8
·
verified ·
1 Parent(s): fa7f626

Upload modules-cpu.py

Browse files
Files changed (1) hide show
  1. modules-cpu.py +305 -0
modules-cpu.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ import numpy as np
7
+ import soundfile as sf
8
+ import torch
9
+ from io import BytesIO
10
+
11
+ from infer.lib.audio import load_audio, wav2
12
+ from infer.lib.infer_pack.models import (
13
+ SynthesizerTrnMs256NSFsid,
14
+ SynthesizerTrnMs256NSFsid_nono,
15
+ SynthesizerTrnMs768NSFsid,
16
+ SynthesizerTrnMs768NSFsid_nono,
17
+ )
18
+ from infer.modules.vc.pipeline import Pipeline
19
+ from infer.modules.vc.utils import *
20
+
21
+
22
+ class VC:
23
+ def __init__(self, config):
24
+ self.n_spk = None
25
+ self.tgt_sr = None
26
+ self.net_g = None
27
+ self.pipeline = None
28
+ self.cpt = None
29
+ self.version = None
30
+ self.if_f0 = None
31
+ self.version = None
32
+ self.hubert_model = None
33
+ self.config = config
34
+ self.config.device = "cpu"
35
+
36
+
37
+ def get_vc(self, sid, *to_return_protect):
38
+ logger.info("Get sid: " + sid)
39
+
40
+ to_return_protect0 = {
41
+ "visible": self.if_f0 != 0,
42
+ "value": (
43
+ to_return_protect[0] if self.if_f0 != 0 and to_return_protect else 0.5
44
+ ),
45
+ "__type__": "update",
46
+ }
47
+ to_return_protect1 = {
48
+ "visible": self.if_f0 != 0,
49
+ "value": (
50
+ to_return_protect[1] if self.if_f0 != 0 and to_return_protect else 0.33
51
+ ),
52
+ "__type__": "update",
53
+ }
54
+
55
+ if sid == "" or sid == []:
56
+ if (
57
+ self.hubert_model is not None
58
+ ): # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
59
+ logger.info("Clean model cache")
60
+ del (self.net_g, self.n_spk, self.hubert_model, self.tgt_sr) # ,cpt
61
+ self.hubert_model = self.net_g = self.n_spk = self.hubert_model = (
62
+ self.tgt_sr
63
+ ) = None
64
+ if torch.cuda.is_available():
65
+ torch.cuda.empty_cache()
66
+ ###楼下不这么折腾清理不干净
67
+ self.if_f0 = self.cpt.get("f0", 1)
68
+ self.version = self.cpt.get("version", "v1")
69
+ if self.version == "v1":
70
+ if self.if_f0 == 1:
71
+ self.net_g = SynthesizerTrnMs256NSFsid(
72
+ *self.cpt["config"], is_half=self.config.is_half
73
+ )
74
+ else:
75
+ self.net_g = SynthesizerTrnMs256NSFsid_nono(*self.cpt["config"])
76
+ elif self.version == "v2":
77
+ if self.if_f0 == 1:
78
+ self.net_g = SynthesizerTrnMs768NSFsid(
79
+ *self.cpt["config"], is_half=self.config.is_half
80
+ )
81
+ else:
82
+ self.net_g = SynthesizerTrnMs768NSFsid_nono(*self.cpt["config"])
83
+ del self.net_g, self.cpt
84
+ if torch.cuda.is_available():
85
+ torch.cuda.empty_cache()
86
+ return (
87
+ {"visible": False, "__type__": "update"},
88
+ {
89
+ "visible": True,
90
+ "value": to_return_protect0,
91
+ "__type__": "update",
92
+ },
93
+ {
94
+ "visible": True,
95
+ "value": to_return_protect1,
96
+ "__type__": "update",
97
+ },
98
+ "",
99
+ "",
100
+ )
101
+ person = f'{os.getenv("weight_root")}/{sid}'
102
+ logger.info(f"Loading: {person}")
103
+
104
+ self.cpt = torch.load(person, map_location="cpu")
105
+ self.tgt_sr = self.cpt["config"][-1]
106
+ self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
107
+ self.if_f0 = self.cpt.get("f0", 1)
108
+ self.version = self.cpt.get("version", "v1")
109
+
110
+ synthesizer_class = {
111
+ ("v1", 1): SynthesizerTrnMs256NSFsid,
112
+ ("v1", 0): SynthesizerTrnMs256NSFsid_nono,
113
+ ("v2", 1): SynthesizerTrnMs768NSFsid,
114
+ ("v2", 0): SynthesizerTrnMs768NSFsid_nono,
115
+ }
116
+
117
+ self.net_g = synthesizer_class.get(
118
+ (self.version, self.if_f0), SynthesizerTrnMs256NSFsid
119
+ )(*self.cpt["config"], is_half=self.config.is_half)
120
+
121
+ del self.net_g.enc_q
122
+
123
+ self.net_g.load_state_dict(self.cpt["weight"], strict=False)
124
+ self.net_g.eval().to("cpu")
125
+ if self.config.is_half:
126
+ self.net_g = self.net_g.half()
127
+ else:
128
+ self.net_g = self.net_g.float()
129
+
130
+ self.pipeline = Pipeline(self.tgt_sr, self.config)
131
+ n_spk = self.cpt["config"][-3]
132
+ index = {"value": get_index_path_from_model(sid), "__type__": "update"}
133
+ logger.info("Select index: " + index["value"])
134
+
135
+ return (
136
+ (
137
+ {"visible": True, "maximum": n_spk, "__type__": "update"},
138
+ to_return_protect0,
139
+ to_return_protect1,
140
+ index,
141
+ index,
142
+ )
143
+ if to_return_protect
144
+ else {"visible": True, "maximum": n_spk, "__type__": "update"}
145
+ )
146
+
147
+ def vc_single(
148
+ self,
149
+ sid,
150
+ input_audio_path,
151
+ f0_up_key,
152
+ f0_file,
153
+ f0_method,
154
+ file_index,
155
+ file_index2,
156
+ index_rate,
157
+ filter_radius,
158
+ resample_sr,
159
+ rms_mix_rate,
160
+ protect,
161
+ ):
162
+ if input_audio_path is None:
163
+ return "You need to upload an audio", None
164
+ f0_up_key = int(f0_up_key)
165
+ try:
166
+ audio = load_audio(input_audio_path, 16000)
167
+ audio_max = np.abs(audio).max() / 0.95
168
+ if audio_max > 1:
169
+ audio /= audio_max
170
+ times = [0, 0, 0]
171
+
172
+ if self.hubert_model is None:
173
+ self.hubert_model = load_hubert(self.config)
174
+
175
+ if file_index:
176
+ file_index = (
177
+ file_index.strip(" ")
178
+ .strip('"')
179
+ .strip("\n")
180
+ .strip('"')
181
+ .strip(" ")
182
+ .replace("trained", "added")
183
+ )
184
+ elif file_index2:
185
+ file_index = file_index2
186
+ else:
187
+ file_index = "" # 防止小白写错,自动帮他替换掉
188
+
189
+ audio_opt = self.pipeline.pipeline(
190
+ self.hubert_model,
191
+ self.net_g,
192
+ sid,
193
+ audio,
194
+ input_audio_path,
195
+ times,
196
+ f0_up_key,
197
+ f0_method,
198
+ file_index,
199
+ index_rate,
200
+ self.if_f0,
201
+ filter_radius,
202
+ self.tgt_sr,
203
+ resample_sr,
204
+ rms_mix_rate,
205
+ self.version,
206
+ protect,
207
+ f0_file,
208
+ )
209
+ if self.tgt_sr != resample_sr >= 16000:
210
+ tgt_sr = resample_sr
211
+ else:
212
+ tgt_sr = self.tgt_sr
213
+ index_info = (
214
+ "Index:\n%s." % file_index
215
+ if os.path.exists(file_index)
216
+ else "Index not used."
217
+ )
218
+ return (
219
+ "Success.\n%s\nTime:\nnpy: %.2fs, f0: %.2fs, infer: %.2fs."
220
+ % (index_info, *times),
221
+ (tgt_sr, audio_opt),
222
+ )
223
+ except:
224
+ info = traceback.format_exc()
225
+ logger.warning(info)
226
+ return info, (None, None)
227
+
228
+ def vc_multi(
229
+ self,
230
+ sid,
231
+ dir_path,
232
+ opt_root,
233
+ paths,
234
+ f0_up_key,
235
+ f0_method,
236
+ file_index,
237
+ file_index2,
238
+ index_rate,
239
+ filter_radius,
240
+ resample_sr,
241
+ rms_mix_rate,
242
+ protect,
243
+ format1,
244
+ ):
245
+ try:
246
+ dir_path = (
247
+ dir_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
248
+ ) # 防止小白拷路径头尾带了空格和"和回车
249
+ opt_root = opt_root.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
250
+ os.makedirs(opt_root, exist_ok=True)
251
+ try:
252
+ if dir_path != "":
253
+ paths = [
254
+ os.path.join(dir_path, name) for name in os.listdir(dir_path)
255
+ ]
256
+ else:
257
+ paths = [path.name for path in paths]
258
+ except:
259
+ traceback.print_exc()
260
+ paths = [path.name for path in paths]
261
+ infos = []
262
+ for path in paths:
263
+ info, opt = self.vc_single(
264
+ sid,
265
+ path,
266
+ f0_up_key,
267
+ None,
268
+ f0_method,
269
+ file_index,
270
+ file_index2,
271
+ # file_big_npy,
272
+ index_rate,
273
+ filter_radius,
274
+ resample_sr,
275
+ rms_mix_rate,
276
+ protect,
277
+ )
278
+ if "Success" in info:
279
+ try:
280
+ tgt_sr, audio_opt = opt
281
+ if format1 in ["wav", "flac"]:
282
+ sf.write(
283
+ "%s/%s.%s"
284
+ % (opt_root, os.path.basename(path), format1),
285
+ audio_opt,
286
+ tgt_sr,
287
+ )
288
+ else:
289
+ path = "%s/%s.%s" % (
290
+ opt_root,
291
+ os.path.basename(path),
292
+ format1,
293
+ )
294
+ with BytesIO() as wavf:
295
+ sf.write(wavf, audio_opt, tgt_sr, format="wav")
296
+ wavf.seek(0, 0)
297
+ with open(path, "wb") as outf:
298
+ wav2(wavf, outf, format1)
299
+ except:
300
+ info += traceback.format_exc()
301
+ infos.append("%s->%s" % (os.path.basename(path), info))
302
+ yield "\n".join(infos)
303
+ yield "\n".join(infos)
304
+ except:
305
+ yield traceback.format_exc()