Hjgugugjhuhjggg commited on
Commit
bd4d670
·
verified ·
1 Parent(s): 7659fc6

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +1394 -728
model_loader.py CHANGED
@@ -1,728 +1,1394 @@
1
- from tokenxxx import *
2
- from constants import *
3
- from utils import *
4
- import os
5
- import json
6
- import urllib.request
7
- import urllib.parse
8
- import torch
9
- import hashlib
10
- from tqdm import tqdm
11
- from skimage import img_as_ubyte
12
- from torch import nn
13
- import torch.nn.functional as F
14
- import inspect
15
-
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
- def filter_kwargs(cls, kwargs):
19
- sig = inspect.signature(cls.__init__)
20
- accepted = set(sig.parameters.keys()) - {"self"}
21
- return {k: v for k, v in kwargs.items() if k in accepted}
22
-
23
- def sanitize_filename(name, url=None):
24
- for c in '<>:"/\\|?*':
25
- name = name.replace(c, '')
26
- if not name and url is not None:
27
- name = hashlib.md5(url.encode()).hexdigest()
28
- return name
29
-
30
- def download_file(url, filepath):
31
- d = os.path.dirname(filepath)
32
- if d and not os.path.exists(d):
33
- os.makedirs(d, exist_ok=True)
34
- while not os.path.exists(filepath):
35
- try:
36
- def prog(t):
37
- last = [0]
38
- def inner(n, bs, ts):
39
- if ts > 0:
40
- t.total = ts
41
- t.update(n * bs - last[0])
42
- last[0] = n * bs
43
- return inner
44
- with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t:
45
- urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
46
- except Exception:
47
- continue
48
-
49
- def download_files(folder, files_spec):
50
- if isinstance(files_spec, dict):
51
- for fn, url in files_spec.items():
52
- fn = sanitize_filename(fn, url)
53
- fp = os.path.join(folder, fn)
54
- download_file(url, fp)
55
- elif isinstance(files_spec, list):
56
- for item in files_spec:
57
- if isinstance(item, str):
58
- url = item
59
- parsed = urllib.parse.urlparse(url)
60
- fn = os.path.basename(parsed.path)
61
- if not fn:
62
- fn = hashlib.md5(url.encode()).hexdigest()
63
- fn = sanitize_filename(fn, url)
64
- elif isinstance(item, (list, tuple)) and len(item) == 2:
65
- url, fn = item
66
- fn = sanitize_filename(fn, url)
67
- elif isinstance(item, dict) and "filename" in item and "url" in item:
68
- fn = sanitize_filename(item["filename"], item["url"])
69
- url = item["url"]
70
- else:
71
- raise ValueError("Invalid file specification")
72
- fp = os.path.join(folder, fn)
73
- download_file(url, fp)
74
- else:
75
- raise ValueError("files_spec must be dict or list")
76
-
77
- def read_json(fp):
78
- with open(fp, 'r', encoding='utf-8') as f:
79
- return json.load(f)
80
-
81
- def get_codegen_tokenizer(vocab_path, merges_path):
82
- with open(vocab_path, 'r', encoding='utf-8') as f:
83
- vocab = json.load(f)
84
- with open(merges_path, 'r', encoding='utf-8') as f:
85
- merges = f.read().splitlines()
86
- merge_ranks = {}
87
- for i, merge in enumerate(merges):
88
- parts = merge.strip().split()
89
- if len(parts) == 2:
90
- merge_ranks[tuple(parts)] = i
91
- def bpe(token):
92
- word = list(token)
93
- pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
94
- while True:
95
- candidate = None
96
- candidate_rank = None
97
- candidate_index = None
98
- for i, pair in enumerate(pairs):
99
- if pair in merge_ranks:
100
- rank = merge_ranks[pair]
101
- if candidate is None or rank < candidate_rank:
102
- candidate = pair
103
- candidate_rank = rank
104
- candidate_index = i
105
- if candidate is None:
106
- break
107
- first, second = candidate
108
- new_word = []
109
- i = 0
110
- while i < len(word):
111
- if i < len(word) - 1 and word[i] == first and word[i+1] == second:
112
- new_word.append(first + second)
113
- i += 2
114
- else:
115
- new_word.append(word[i])
116
- i += 1
117
- word = new_word
118
- if len(word) == 1:
119
- break
120
- pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
121
- return word
122
- def tokenizer(text):
123
- tokens = []
124
- for token in text.split():
125
- bpe_tokens = bpe(token)
126
- for subtoken in bpe_tokens:
127
- tokens.append(vocab.get(subtoken, 0))
128
- return tokens
129
- return tokenizer
130
-
131
- def simple_tokenizer(text, vocab, max_length=77):
132
- toks = text.split()
133
- ids = [vocab.get(t, 1) for t in toks]
134
- if len(ids) < max_length:
135
- ids = ids + [0] * (max_length - len(ids))
136
- else:
137
- ids = ids[:max_length]
138
- return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
139
-
140
- def load_state_dict_safe(model, loaded_state_dict):
141
- model_state = model.state_dict()
142
- new_state = {}
143
- for key, value in model_state.items():
144
- if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape:
145
- new_state[key] = loaded_state_dict[key]
146
- else:
147
- new_state[key] = value
148
- model.load_state_dict(new_state, strict=False)
149
-
150
- class GPT2Config:
151
- def __init__(self, vocab_size=50257, **kwargs):
152
- self.vocab_size = vocab_size
153
- self.__dict__.update(kwargs)
154
- @classmethod
155
- def from_dict(cls, d):
156
- return cls(**d)
157
-
158
- class MBartConfig:
159
- def __init__(self, vocab_size=50265, **kwargs):
160
- self.vocab_size = vocab_size
161
- self.__dict__.update(kwargs)
162
- @classmethod
163
- def from_dict(cls, d):
164
- return cls(**d)
165
-
166
- class CodeGenConfig:
167
- def __init__(self, vocab_size=50257, **kwargs):
168
- self.vocab_size = vocab_size
169
- self.__dict__.update(kwargs)
170
- @classmethod
171
- def from_dict(cls, d):
172
- return cls(**d)
173
-
174
- class BartConfig:
175
- def __init__(self, vocab_size=50265, **kwargs):
176
- self.vocab_size = vocab_size
177
- self.__dict__.update(kwargs)
178
- @classmethod
179
- def from_dict(cls, d):
180
- return cls(**d)
181
-
182
- class AutoencoderKLConfig:
183
- def __init__(self, **kwargs):
184
- self.__dict__.update(kwargs)
185
- @classmethod
186
- def from_dict(cls, d):
187
- return cls(**d)
188
-
189
- class OpenLRMConfig:
190
- def __init__(self, **kwargs):
191
- self.__dict__.update(kwargs)
192
- @classmethod
193
- def from_dict(cls, d):
194
- return cls(**d)
195
-
196
- class UNet2DConditionModelConfig:
197
- def __init__(self, **kwargs):
198
- self.__dict__.update(kwargs)
199
- @classmethod
200
- def from_dict(cls, d):
201
- return cls(**d)
202
-
203
- class MusicGenConfig:
204
- def __init__(self, **kwargs):
205
- self.__dict__.update(kwargs)
206
- @classmethod
207
- def from_dict(cls, d):
208
- return cls(**d)
209
-
210
- class GPT2LMHeadModel(nn.Module):
211
- def __init__(self, config):
212
- super().__init__()
213
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
214
- self.transformer = nn.TransformerEncoder(layer, num_layers=12)
215
- self.lm_head = nn.Linear(768, config.vocab_size)
216
- def forward(self, x):
217
- return self.lm_head(self.transformer(x))
218
-
219
- class MBartForConditionalGeneration(nn.Module):
220
- def __init__(self, config):
221
- super().__init__()
222
- self.config = config
223
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
224
- self.encoder = nn.TransformerEncoder(layer, num_layers=6)
225
- dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
226
- self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
227
- self.output_layer = nn.Linear(768, config.vocab_size)
228
- def forward(self, src, tgt):
229
- return self.output_layer(self.decoder(tgt, self.encoder(src)))
230
-
231
- class CodeGenForCausalLM(nn.Module):
232
- def __init__(self, config):
233
- super().__init__()
234
- d_model = getattr(config, "d_model", 1024)
235
- n_head = getattr(config, "n_head", 16)
236
- num_layers = getattr(config, "num_layers", 12)
237
- dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
238
- self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
239
- self.lm_head = nn.Linear(d_model, config.vocab_size)
240
- def forward(self, tgt, memory=None):
241
- if memory is None:
242
- memory = torch.zeros_like(tgt)
243
- return self.lm_head(self.transformer_decoder(tgt, memory))
244
-
245
- class BartForConditionalGeneration(nn.Module):
246
- def __init__(self, config):
247
- super().__init__()
248
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
249
- self.encoder = nn.TransformerEncoder(layer, num_layers=6)
250
- dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
251
- self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
252
- self.output_layer = nn.Linear(768, config.vocab_size)
253
- def forward(self, src, tgt):
254
- return self.output_layer(self.decoder(tgt, self.encoder(src)))
255
-
256
- class ResnetBlock(nn.Module):
257
- def __init__(self, in_ch, out_ch):
258
- super().__init__()
259
- self.norm1 = nn.GroupNorm(32, in_ch)
260
- self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
261
- self.norm2 = nn.GroupNorm(32, out_ch)
262
- self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
263
- self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
264
- def forward(self, x):
265
- sc = self.conv_shortcut(x)
266
- h = F.silu(self.norm1(x))
267
- h = self.conv1(h)
268
- h = F.silu(self.norm2(h))
269
- h = self.conv2(h)
270
- return h + sc
271
-
272
- class Downsample(nn.Module):
273
- def __init__(self, in_ch, out_ch):
274
- super().__init__()
275
- self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
276
- def forward(self, x):
277
- return self.conv(x)
278
-
279
- class DownBlock(nn.Module):
280
- def __init__(self, in_ch, out_ch, num_res):
281
- super().__init__()
282
- self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
283
- self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
284
- def forward(self, x):
285
- for r in self.resnets:
286
- x = r(x)
287
- for ds in self.downsamplers:
288
- x = ds(x)
289
- return x
290
-
291
- class Upsample(nn.Module):
292
- def __init__(self, in_ch, out_ch):
293
- super().__init__()
294
- self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
295
- def forward(self, x):
296
- return self.conv(x)
297
-
298
- class UpBlock(nn.Module):
299
- def __init__(self, in_ch, out_ch, num_res):
300
- super().__init__()
301
- self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
302
- self.upsampler = Upsample(out_ch, out_ch)
303
- def forward(self, x):
304
- for r in self.resnets:
305
- x = r(x)
306
- return self.upsampler(x)
307
-
308
- class AttentionBlock(nn.Module):
309
- def __init__(self, ch):
310
- super().__init__()
311
- self.norm = nn.GroupNorm(32, ch)
312
- self.query = nn.Conv2d(ch, ch, 1)
313
- self.key = nn.Conv2d(ch, ch, 1)
314
- self.value = nn.Conv2d(ch, ch, 1)
315
- self.proj_attn = nn.Conv2d(ch, ch, 1)
316
- def forward(self, x):
317
- b, c, h, w = x.shape
318
- xn = self.norm(x)
319
- q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
320
- k = self.key(xn).view(b, c, -1)
321
- v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
322
- attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
323
- out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
324
- return x + self.proj_attn(out)
325
-
326
- class Encoder(nn.Module):
327
- def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
328
- super().__init__()
329
- self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
330
- self.down_blocks = nn.ModuleList([
331
- DownBlock(base_ch, base_ch, 2),
332
- DownBlock(base_ch, base_ch * 2, 2),
333
- DownBlock(base_ch * 2, base_ch * 4, 2),
334
- DownBlock(base_ch * 4, base_ch * 4, 2)
335
- ])
336
- self.mid_block = nn.ModuleList([
337
- ResnetBlock(base_ch * 4, base_ch * 4),
338
- AttentionBlock(base_ch * 4),
339
- ResnetBlock(base_ch * 4, base_ch * 4)
340
- ])
341
- self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
342
- self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
343
- self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
344
- def forward(self, x):
345
- x = self.conv_in(x)
346
- for blk in self.down_blocks:
347
- x = blk(x)
348
- for m in self.mid_block:
349
- x = m(x)
350
- x = self.conv_norm_out(x)
351
- x = self.conv_out(x)
352
- return self.quant_conv(x)
353
-
354
- class Decoder(nn.Module):
355
- def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
356
- super().__init__()
357
- self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
358
- self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
359
- self.mid_block = nn.ModuleList([
360
- ResnetBlock(base_ch * 4, base_ch * 4),
361
- AttentionBlock(base_ch * 4),
362
- ResnetBlock(base_ch * 4, base_ch * 4)
363
- ])
364
- self.up_blocks = nn.ModuleList([
365
- UpBlock(base_ch * 4, base_ch * 4, 3),
366
- UpBlock(base_ch * 4, base_ch * 2, 3),
367
- UpBlock(base_ch * 2, base_ch, 3),
368
- UpBlock(base_ch, base_ch, 3)
369
- ])
370
- self.conv_norm_out = nn.GroupNorm(32, base_ch)
371
- self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
372
- def forward(self, x):
373
- x = self.post_quant_conv(x)
374
- x = self.conv_in(x)
375
- for m in self.mid_block:
376
- x = m(x)
377
- for up in self.up_blocks:
378
- x = up(x)
379
- x = self.conv_norm_out(x)
380
- return self.conv_out(x)
381
-
382
- class AutoencoderKL(nn.Module):
383
- def __init__(self, config):
384
- super().__init__()
385
- in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
386
- out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
387
- base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
388
- latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
389
- self.encoder = Encoder(in_ch, base_ch, latent_ch)
390
- self.decoder = Decoder(out_ch, base_ch, latent_ch)
391
- def forward(self, x):
392
- return self.decoder(self.encoder(x))
393
- def decode(self, x):
394
- return self.decoder(x)
395
-
396
- class TransformerBlock(nn.Module):
397
- def __init__(self, embed_dim, num_heads):
398
- super().__init__()
399
- self.norm1 = nn.LayerNorm(embed_dim)
400
- self.attn = nn.MultiheadAttention(embed_dim, num_heads)
401
- self.norm2 = nn.LayerNorm(embed_dim)
402
- hidden_dim = embed_dim * 4
403
- self.mlp = nn.Sequential(
404
- nn.Linear(embed_dim, hidden_dim),
405
- nn.GELU(),
406
- nn.Linear(hidden_dim, embed_dim)
407
- )
408
- def forward(self, x):
409
- res = x
410
- x = self.norm1(x)
411
- x = x.transpose(0, 1)
412
- attn, _ = self.attn(x, x, x)
413
- x = attn.transpose(0, 1)
414
- x = res + x
415
- return x + self.mlp(self.norm2(x))
416
-
417
- class VisionTransformer(nn.Module):
418
- def __init__(self, config):
419
- super().__init__()
420
- if isinstance(config, dict):
421
- self.img_size = config.get("img_size", 592)
422
- self.patch_size = config.get("patch_size", 16)
423
- self.embed_dim = config.get("hidden_size", 768)
424
- depth = config.get("depth", 12)
425
- num_heads = config.get("num_heads", 12)
426
- else:
427
- self.img_size = config.__dict__.get("img_size", 592)
428
- self.patch_size = config.__dict__.get("patch_size", 16)
429
- self.embed_dim = config.__dict__.get("hidden_size", 768)
430
- depth = config.__dict__.get("depth", 12)
431
- num_heads = config.__dict__.get("num_heads", 12)
432
- num_patches = (self.img_size // self.patch_size) ** 2
433
- self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
434
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
435
- self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
436
- self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
437
- self.norm = nn.LayerNorm(self.embed_dim)
438
- self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
439
- self._init_weights()
440
- def _init_weights(self):
441
- nn.init.normal_(self.cls_token, std=0.02)
442
- nn.init.normal_(self.pos_embed, std=0.02)
443
- def forward(self, x):
444
- x = self.patch_embed(x)
445
- x = x.flatten(2).transpose(1, 2)
446
- cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
447
- x = torch.cat((cls_tokens, x), dim=1)
448
- x = x + self.pos_embed
449
- for blk in self.blocks:
450
- x = blk(x)
451
- return self.norm(x)[:, 0]
452
-
453
- class OpenLRM(nn.Module):
454
- def __init__(self, config):
455
- super().__init__()
456
- self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
457
- hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
458
- self.linear = nn.Linear(hidden, hidden)
459
- def forward(self, x):
460
- return self.linear(self.encoder["model"](x))
461
-
462
- class VideoUNet(nn.Module):
463
- def __init__(self, in_ch=4, out_ch=4, features=None):
464
- super().__init__()
465
- if features is None:
466
- features = [64, 128, 256]
467
- self.encoder = nn.ModuleList()
468
- self.pool = nn.MaxPool3d(2, 2)
469
- self.decoder = nn.ModuleList()
470
- for f in features:
471
- self.encoder.append(nn.Sequential(
472
- nn.Conv3d(in_ch, f, 3, padding=1),
473
- nn.ReLU(inplace=True),
474
- nn.Conv3d(f, f, 3, padding=1),
475
- nn.ReLU(inplace=True)
476
- ))
477
- in_ch = f
478
- for f in reversed(features):
479
- self.decoder.append(nn.Sequential(
480
- nn.Conv3d(f * 2, f, 3, padding=1),
481
- nn.ReLU(inplace=True),
482
- nn.Conv3d(f, f, 3, padding=1),
483
- nn.ReLU(inplace=True)
484
- ))
485
- self.final_conv = nn.Conv3d(features[0], out_ch, 1)
486
- def forward(self, x, t, encoder_hidden_states):
487
- skips = []
488
- for enc in self.encoder:
489
- x = enc(x)
490
- skips.append(x)
491
- x = self.pool(x)
492
- for dec in self.decoder:
493
- skip = skips.pop()
494
- x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
495
- x = torch.cat([x, skip], dim=1)
496
- x = dec(x)
497
- return self.final_conv(x)
498
-
499
- class SentimentClassifierModel(nn.Module):
500
- def __init__(self, config):
501
- super().__init__()
502
- self.classifier = nn.Sequential(
503
- nn.Linear(768, 256),
504
- nn.ReLU(),
505
- nn.Linear(256, 2)
506
- )
507
- def forward(self, x):
508
- return self.classifier(x)
509
-
510
- class STTModel(nn.Module):
511
- def __init__(self, config):
512
- super().__init__()
513
- self.net = nn.Sequential(
514
- nn.Linear(768, 512),
515
- nn.ReLU(),
516
- nn.Linear(512, 768)
517
- )
518
- def forward(self, x):
519
- return self.net(x)
520
-
521
- class TTSModel(nn.Module):
522
- def __init__(self, config):
523
- super().__init__()
524
- self.net = nn.Sequential(
525
- nn.Linear(768, 512),
526
- nn.ReLU(),
527
- nn.Linear(512, 768)
528
- )
529
- def forward(self, x):
530
- return self.net(x)
531
-
532
- class MusicGenModel(nn.Module):
533
- def __init__(self, config):
534
- super().__init__()
535
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
536
- self.transformer = nn.TransformerEncoder(layer, num_layers=12)
537
- self.linear = nn.Linear(768, 768)
538
- def forward(self, x):
539
- return self.linear(self.transformer(x))
540
-
541
- class SimpleTextEncoder(nn.Module):
542
- def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
543
- super().__init__()
544
- self.embedding = nn.Embedding(vocab_size, embed_dim)
545
- self.max_length = max_length
546
- def forward(self, text_tokens):
547
- return self.embedding(text_tokens)
548
-
549
- class DiffusionScheduler:
550
- def __init__(self, steps):
551
- self.steps = steps
552
- self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
553
- self.alphas = 1 - self.betas
554
- self.alpha_bars = torch.cumprod(self.alphas, dim=0)
555
- def step(self, noise, t, sample):
556
- alpha_bar = self.alpha_bars[t]
557
- if t > 0:
558
- alpha_bar_prev = self.alpha_bars[t-1]
559
- else:
560
- alpha_bar_prev = torch.tensor(1.0, device=sample.device)
561
- x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar)
562
- new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise
563
- return new_sample
564
-
565
- class VideoOutput:
566
- def __init__(self, frames):
567
- self.frames = [img_as_ubyte(frame) for frame in frames[0]]
568
-
569
- class VideoPipeline(nn.Module):
570
- def __init__(self, unet, vae, text_encoder, vocab):
571
- super().__init__()
572
- self.unet = unet
573
- self.vae = vae
574
- self.text_encoder = text_encoder
575
- self.vocab = vocab
576
- def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
577
- token_ids = simple_tokenizer(prompt, self.vocab)
578
- text_emb = self.text_encoder(token_ids)
579
- latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
580
- sched = DiffusionScheduler(steps)
581
- for t in range(steps):
582
- noise = self.unet(latent, t, text_emb)
583
- latent = sched.step(noise, t, latent)
584
- frames = self.vae.decode(latent / 0.18215)
585
- frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
586
- return VideoOutput(frames)
587
-
588
- def initialize_gpt2_model(folder, files):
589
- download_files(folder, files)
590
- config = GPT2Config()
591
- model = GPT2LMHeadModel(config).to(device)
592
- sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
593
- load_state_dict_safe(model, sd)
594
- model.eval()
595
- enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
596
- return model, enc
597
-
598
- def initialize_translation_model(folder, files):
599
- download_files(folder, files)
600
- config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
601
- model = MBartForConditionalGeneration(config).to(device)
602
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
603
- load_state_dict_safe(model, sd)
604
- model.eval()
605
- vp = os.path.join(folder, "vocab.json")
606
- if os.path.exists(vp):
607
- vocab = read_json(vp)
608
- model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
609
- else:
610
- model.tokenizer = lambda txt: txt
611
- model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
612
- return model
613
-
614
- def initialize_codegen_model(folder, files):
615
- download_files(folder, files)
616
- config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
617
- model = CodeGenForCausalLM(config).to(device)
618
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
619
- load_state_dict_safe(model, sd)
620
- model.eval()
621
- tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
622
- vocab = read_json(os.path.join(folder, "vocab.json"))
623
- idx2w = {v: k for k, v in vocab.items()}
624
- model.tokenizer = tok
625
- return model, tok, vocab, idx2w, vocab
626
-
627
- def initialize_summarization_model(folder, files):
628
- download_files(folder, files)
629
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
630
- model = BartForConditionalGeneration(config).to(device)
631
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
632
- load_state_dict_safe(model, sd)
633
- model.eval()
634
- vp = os.path.join(folder, "vocab.json")
635
- if os.path.exists(vp):
636
- vocab_json = read_json(vp)
637
- vocab = set(vocab_json.keys())
638
- return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
639
- return model, None, None, None
640
-
641
- def initialize_imagegen_model(folder, files):
642
- download_files(folder, files)
643
- config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
644
- vae = AutoencoderKL(config).to(device)
645
- sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
646
- load_state_dict_safe(vae, sd)
647
- vae.eval()
648
- return vae
649
-
650
- def initialize_image_to_3d_model(folder, files):
651
- download_files(folder, files)
652
- config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
653
- model3d = OpenLRM(config).to(device)
654
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
655
- load_state_dict_safe(model3d, sd)
656
- model3d.eval()
657
- return model3d
658
-
659
- def initialize_text_to_video_model(folder, files):
660
- download_files(folder, files)
661
- unet_cfg = read_json(os.path.join(folder, "config.json"))
662
- unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
663
- unet = VideoUNet(**unet_cfg).half().to(device)
664
- sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
665
- load_state_dict_safe(unet, sd_unet)
666
- unet.eval()
667
- vae_cfg = read_json(os.path.join(folder, "config.json"))
668
- vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
669
- vae = AutoencoderKL(vae_cfg).half().to(device)
670
- sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
671
- load_state_dict_safe(vae, sd_vae)
672
- vae.eval()
673
- vp = os.path.join(folder, "vocab.json")
674
- text_vocab = read_json(vp) if os.path.exists(vp) else {}
675
- te_path = os.path.join(folder, "text_encoder.bin")
676
- if os.path.exists(te_path):
677
- text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
678
- sd_te = torch.load(te_path, map_location=device)
679
- load_state_dict_safe(text_encoder, sd_te)
680
- else:
681
- text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
682
- text_encoder.eval()
683
- return VideoPipeline(unet, vae, text_encoder, text_vocab)
684
-
685
- def initialize_sentiment_model(folder, files):
686
- download_files(folder, files)
687
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
688
- model = SentimentClassifierModel(config).to(device)
689
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
690
- load_state_dict_safe(model, sd)
691
- model.eval()
692
- vp = os.path.join(folder, "vocab.json")
693
- if os.path.exists(vp):
694
- read_json(vp)
695
- return model
696
-
697
- def initialize_stt_model(folder, files):
698
- download_files(folder, files)
699
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
700
- model = STTModel(config).to(device)
701
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
702
- load_state_dict_safe(model, sd)
703
- model.eval()
704
- vp = os.path.join(folder, "vocab.json")
705
- if os.path.exists(vp):
706
- read_json(vp)
707
- return model
708
-
709
- def initialize_tts_model(folder, files):
710
- download_files(folder, files)
711
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
712
- model = TTSModel(config).to(device)
713
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
714
- load_state_dict_safe(model, sd)
715
- model.eval()
716
- vp = os.path.join(folder, "vocab.json")
717
- if os.path.exists(vp):
718
- read_json(vp)
719
- return model
720
-
721
- def initialize_musicgen_model(folder, files):
722
- download_files(folder, files)
723
- config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
724
- model = MusicGenModel(config).to(device)
725
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
726
- load_state_dict_safe(model, sd)
727
- model.eval()
728
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenxxx import *
2
+ from constants import *
3
+ from utils import *
4
+ import os
5
+ import json
6
+ import urllib.request
7
+ import urllib.parse
8
+ import torch
9
+ import hashlib
10
+ from tqdm import tqdm
11
+ from skimage import img_as_ubyte
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ import inspect
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ def filter_kwargs(cls, kwargs):
19
+ sig = inspect.signature(cls.__init__)
20
+ accepted = set(sig.parameters.keys()) - {"self"}
21
+ return {k: v for k, v in kwargs.items() if k in accepted}
22
+
23
+ def sanitize_filename(name, url=None):
24
+ for c in '<>:"/\\|?*':
25
+ name = name.replace(c, '')
26
+ if not name and url is not None:
27
+ name = hashlib.md5(url.encode()).hexdigest()
28
+ return name
29
+
30
+ def download_file(url, filepath):
31
+ d = os.path.dirname(filepath)
32
+ if d and not os.path.exists(d):
33
+ os.makedirs(d, exist_ok=True)
34
+ while not os.path.exists(filepath):
35
+ try:
36
+ def prog(t):
37
+ last = [0]
38
+ def inner(n, bs, ts):
39
+ if ts > 0:
40
+ t.total = ts
41
+ t.update(n * bs - last[0])
42
+ last[0] = n * bs
43
+ return inner
44
+ with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t:
45
+ urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
46
+ except Exception:
47
+ continue
48
+
49
+ def download_files(folder, files_spec):
50
+ if isinstance(files_spec, dict):
51
+ for fn, url in files_spec.items():
52
+ fn = sanitize_filename(fn, url)
53
+ fp = os.path.join(folder, fn)
54
+ download_file(url, fp)
55
+ elif isinstance(files_spec, list):
56
+ for item in files_spec:
57
+ if isinstance(item, str):
58
+ url = item
59
+ parsed = urllib.parse.urlparse(url)
60
+ fn = os.path.basename(parsed.path)
61
+ if not fn:
62
+ fn = hashlib.md5(url.encode()).hexdigest()
63
+ fn = sanitize_filename(fn, url)
64
+ elif isinstance(item, (list, tuple)) and len(item) == 2:
65
+ url, fn = item
66
+ fn = sanitize_filename(fn, url)
67
+ elif isinstance(item, dict) and "filename" in item and "url" in item:
68
+ fn = sanitize_filename(item["filename"], item["url"])
69
+ url = item["url"]
70
+ else:
71
+ raise ValueError("Invalid file specification")
72
+ fp = os.path.join(folder, fn)
73
+ download_file(url, fp)
74
+ else:
75
+ raise ValueError("files_spec must be dict or list")
76
+
77
+ def read_json(fp):
78
+ with open(fp, 'r', encoding='utf-8') as f:
79
+ return json.load(f)
80
+
81
+ def get_codegen_tokenizer(vocab_path, merges_path):
82
+ with open(vocab_path, 'r', encoding='utf-8') as f:
83
+ vocab = json.load(f)
84
+ with open(merges_path, 'r', encoding='utf-8') as f:
85
+ merges = f.read().splitlines()
86
+ merge_ranks = {}
87
+ for i, merge in enumerate(merges):
88
+ parts = merge.strip().split()
89
+ if len(parts) == 2:
90
+ merge_ranks[tuple(parts)] = i
91
+ def bpe(token):
92
+ word = list(token)
93
+ pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
94
+ while True:
95
+ candidate = None
96
+ candidate_rank = None
97
+ candidate_index = None
98
+ for i, pair in enumerate(pairs):
99
+ if pair in merge_ranks:
100
+ rank = merge_ranks[pair]
101
+ if candidate is None or rank < candidate_rank:
102
+ candidate = pair
103
+ candidate_rank = rank
104
+ candidate_index = i
105
+ if candidate is None:
106
+ break
107
+ first, second = candidate
108
+ new_word = []
109
+ i = 0
110
+ while i < len(word):
111
+ if i < len(word) - 1 and word[i] == first and word[i+1] == second:
112
+ new_word.append(first + second)
113
+ i += 2
114
+ else:
115
+ new_word.append(word[i])
116
+ i += 1
117
+ word = new_word
118
+ if len(word) == 1:
119
+ break
120
+ pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
121
+ return word
122
+ def tokenizer(text):
123
+ tokens = []
124
+ for token in text.split():
125
+ bpe_tokens = bpe(token)
126
+ for subtoken in bpe_tokens:
127
+ tokens.append(vocab.get(subtoken, 0))
128
+ return tokens
129
+ return tokenizer
130
+
131
+ def simple_tokenizer(text, vocab, max_length=77):
132
+ toks = text.split()
133
+ ids = [vocab.get(t, 1) for t in toks]
134
+ if len(ids) < max_length:
135
+ ids = ids + [0] * (max_length - len(ids))
136
+ else:
137
+ ids = ids[:max_length]
138
+ return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
139
+
140
+ def load_state_dict_safe(model, loaded_state_dict):
141
+ model_state = model.state_dict()
142
+ new_state = {}
143
+ for key, value in model_state.items():
144
+ if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape:
145
+ new_state[key] = loaded_state_dict[key]
146
+ else:
147
+ new_state[key] = value
148
+ model.load_state_dict(new_state, strict=False)
149
+
150
+ class GPT2Config:
151
+ def __init__(self, vocab_size=50257, **kwargs):
152
+ self.vocab_size = vocab_size
153
+ self.__dict__.update(kwargs)
154
+ @classmethod
155
+ def from_dict(cls, d):
156
+ return cls(**d)
157
+
158
+ class MBartConfig:
159
+ def __init__(self, vocab_size=50265, **kwargs):
160
+ self.vocab_size = vocab_size
161
+ self.__dict__.update(kwargs)
162
+ @classmethod
163
+ def from_dict(cls, d):
164
+ return cls(**d)
165
+
166
+ class CodeGenConfig:
167
+ def __init__(self, vocab_size=50257, **kwargs):
168
+ self.vocab_size = vocab_size
169
+ self.__dict__.update(kwargs)
170
+ @classmethod
171
+ def from_dict(cls, d):
172
+ return cls(**d)
173
+
174
+ class BartConfig:
175
+ def __init__(self, vocab_size=50265, **kwargs):
176
+ self.vocab_size = vocab_size
177
+ self.__dict__.update(kwargs)
178
+ @classmethod
179
+ def from_dict(cls, d):
180
+ return cls(**d)
181
+
182
+ class AutoencoderKLConfig:
183
+ def __init__(self, **kwargs):
184
+ self.__dict__.update(kwargs)
185
+ @classmethod
186
+ def from_dict(cls, d):
187
+ return cls(**d)
188
+
189
+ class OpenLRMConfig:
190
+ def __init__(self, **kwargs):
191
+ self.__dict__.update(kwargs)
192
+ @classmethod
193
+ def from_dict(cls, d):
194
+ return cls(**d)
195
+
196
+ class UNet2DConditionModelConfig:
197
+ def __init__(self, **kwargs):
198
+ self.__dict__.update(kwargs)
199
+ @classmethod
200
+ def from_dict(cls, d):
201
+ return cls(**d)
202
+
203
+ class MusicGenConfig:
204
+ def __init__(self, **kwargs):
205
+ self.__dict__.update(kwargs)
206
+ @classmethod
207
+ def from_dict(cls, d):
208
+ return cls(**d)
209
+
210
+ class GPT2LMHeadModel(nn.Module):
211
+ def __init__(self, config):
212
+ super().__init__()
213
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
214
+ self.transformer = nn.TransformerEncoder(layer, num_layers=12)
215
+ self.lm_head = nn.Linear(768, config.vocab_size)
216
+ def forward(self, x):
217
+ return self.lm_head(self.transformer(x))
218
+
219
+ class MBartForConditionalGeneration(nn.Module):
220
+ def __init__(self, config):
221
+ super().__init__()
222
+ self.config = config
223
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
224
+ self.encoder = nn.TransformerEncoder(layer, num_layers=6)
225
+ dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
226
+ self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
227
+ self.output_layer = nn.Linear(768, config.vocab_size)
228
+ def forward(self, src, tgt):
229
+ return self.output_layer(self.decoder(tgt, self.encoder(src)))
230
+
231
+ class CodeGenForCausalLM(nn.Module):
232
+ def __init__(self, config):
233
+ super().__init__()
234
+ d_model = getattr(config, "d_model", 1024)
235
+ n_head = getattr(config, "n_head", 16)
236
+ num_layers = getattr(config, "num_layers", 12)
237
+ dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
238
+ self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
239
+ self.lm_head = nn.Linear(d_model, config.vocab_size)
240
+ def forward(self, tgt, memory=None):
241
+ if memory is None:
242
+ memory = torch.zeros_like(tgt)
243
+ return self.lm_head(self.transformer_decoder(tgt, memory))
244
+
245
+ class BartForConditionalGeneration(nn.Module):
246
+ def __init__(self, config):
247
+ super().__init__()
248
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
249
+ self.encoder = nn.TransformerEncoder(layer, num_layers=6)
250
+ dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
251
+ self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
252
+ self.output_layer = nn.Linear(768, config.vocab_size)
253
+ def forward(self, src, tgt):
254
+ return self.output_layer(self.decoder(tgt, self.encoder(src)))
255
+
256
+ class ResnetBlock(nn.Module):
257
+ def __init__(self, in_ch, out_ch):
258
+ super().__init__()
259
+ self.norm1 = nn.GroupNorm(32, in_ch)
260
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
261
+ self.norm2 = nn.GroupNorm(32, out_ch)
262
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
263
+ self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
264
+ def forward(self, x):
265
+ sc = self.conv_shortcut(x)
266
+ h = F.silu(self.norm1(x))
267
+ h = self.conv1(h)
268
+ h = F.silu(self.norm2(h))
269
+ h = self.conv2(h)
270
+ return h + sc
271
+
272
+ class Downsample(nn.Module):
273
+ def __init__(self, in_ch, out_ch):
274
+ super().__init__()
275
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
276
+ def forward(self, x):
277
+ return self.conv(x)
278
+
279
+ class DownBlock(nn.Module):
280
+ def __init__(self, in_ch, out_ch, num_res):
281
+ super().__init__()
282
+ self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
283
+ self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
284
+ def forward(self, x):
285
+ for r in self.resnets:
286
+ x = r(x)
287
+ for ds in self.downsamplers:
288
+ x = ds(x)
289
+ return x
290
+
291
+ class Upsample(nn.Module):
292
+ def __init__(self, in_ch, out_ch):
293
+ super().__init__()
294
+ self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
295
+ def forward(self, x):
296
+ return self.conv(x)
297
+
298
+ class UpBlock(nn.Module):
299
+ def __init__(self, in_ch, out_ch, num_res):
300
+ super().__init__()
301
+ self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
302
+ self.upsampler = Upsample(out_ch, out_ch)
303
+ def forward(self, x):
304
+ for r in self.resnets:
305
+ x = r(x)
306
+ return self.upsampler(x)
307
+
308
+ class AttentionBlock(nn.Module):
309
+ def __init__(self, ch):
310
+ super().__init__()
311
+ self.norm = nn.GroupNorm(32, ch)
312
+ self.query = nn.Conv2d(ch, ch, 1)
313
+ self.key = nn.Conv2d(ch, ch, 1)
314
+ self.value = nn.Conv2d(ch, ch, 1)
315
+ self.proj_attn = nn.Conv2d(ch, ch, 1)
316
+ def forward(self, x):
317
+ b, c, h, w = x.shape
318
+ xn = self.norm(x)
319
+ q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
320
+ k = self.key(xn).view(b, c, -1)
321
+ v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
322
+ attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
323
+ out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
324
+ return x + self.proj_attn(out)
325
+
326
+ class Encoder(nn.Module):
327
+ def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
328
+ super().__init__()
329
+ self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
330
+ self.down_blocks = nn.ModuleList([
331
+ DownBlock(base_ch, base_ch, 2),
332
+ DownBlock(base_ch, base_ch * 2, 2),
333
+ DownBlock(base_ch * 2, base_ch * 4, 2),
334
+ DownBlock(base_ch * 4, base_ch * 4, 2)
335
+ ])
336
+ self.mid_block = nn.ModuleList([
337
+ ResnetBlock(base_ch * 4, base_ch * 4),
338
+ AttentionBlock(base_ch * 4),
339
+ ResnetBlock(base_ch * 4, base_ch * 4)
340
+ ])
341
+ self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
342
+ self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
343
+ self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
344
+ def forward(self, x):
345
+ x = self.conv_in(x)
346
+ for blk in self.down_blocks:
347
+ x = blk(x)
348
+ for m in self.mid_block:
349
+ x = m(x)
350
+ x = self.conv_norm_out(x)
351
+ x = self.conv_out(x)
352
+ return self.quant_conv(x)
353
+
354
+ class Decoder(nn.Module):
355
+ def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
356
+ super().__init__()
357
+ self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
358
+ self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
359
+ self.mid_block = nn.ModuleList([
360
+ ResnetBlock(base_ch * 4, base_ch * 4),
361
+ AttentionBlock(base_ch * 4),
362
+ ResnetBlock(base_ch * 4, base_ch * 4)
363
+ ])
364
+ self.up_blocks = nn.ModuleList([
365
+ UpBlock(base_ch * 4, base_ch * 4, 3),
366
+ UpBlock(base_ch * 4, base_ch * 2, 3),
367
+ UpBlock(base_ch * 2, base_ch, 3),
368
+ UpBlock(base_ch, base_ch, 3)
369
+ ])
370
+ self.conv_norm_out = nn.GroupNorm(32, base_ch)
371
+ self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
372
+ def forward(self, x):
373
+ x = self.post_quant_conv(x)
374
+ x = self.conv_in(x)
375
+ for m in self.mid_block:
376
+ x = m(x)
377
+ for up in self.up_blocks:
378
+ x = up(x)
379
+ x = self.conv_norm_out(x)
380
+ return self.conv_out(x)
381
+
382
+ class AutoencoderKL(nn.Module):
383
+ def __init__(self, config):
384
+ super().__init__()
385
+ in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
386
+ out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
387
+ base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
388
+ latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
389
+ self.encoder = Encoder(in_ch, base_ch, latent_ch)
390
+ self.decoder = Decoder(out_ch, base_ch, latent_ch)
391
+ def forward(self, x):
392
+ return self.decoder(self.encoder(x))
393
+ def decode(self, x):
394
+ return self.decoder(x)
395
+
396
+ class TransformerBlock(nn.Module):
397
+ def __init__(self, embed_dim, num_heads):
398
+ super().__init__()
399
+ self.norm1 = nn.LayerNorm(embed_dim)
400
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
401
+ self.norm2 = nn.LayerNorm(embed_dim)
402
+ hidden_dim = embed_dim * 4
403
+ self.mlp = nn.Sequential(
404
+ nn.Linear(embed_dim, hidden_dim),
405
+ nn.GELU(),
406
+ nn.Linear(hidden_dim, embed_dim)
407
+ )
408
+ def forward(self, x):
409
+ res = x
410
+ x = self.norm1(x)
411
+ x = x.transpose(0, 1)
412
+ attn, _ = self.attn(x, x, x)
413
+ x = attn.transpose(0, 1)
414
+ x = res + x
415
+ return x + self.mlp(self.norm2(x))
416
+
417
+ class VisionTransformer(nn.Module):
418
+ def __init__(self, config):
419
+ super().__init__()
420
+ if isinstance(config, dict):
421
+ self.img_size = config.get("img_size", 592)
422
+ self.patch_size = config.get("patch_size", 16)
423
+ self.embed_dim = config.get("hidden_size", 768)
424
+ depth = config.get("depth", 12)
425
+ num_heads = config.get("num_heads", 12)
426
+ else:
427
+ self.img_size = config.__dict__.get("img_size", 592)
428
+ self.patch_size = config.__dict__.get("patch_size", 16)
429
+ self.embed_dim = config.__dict__.get("hidden_size", 768)
430
+ depth = config.__dict__.get("depth", 12)
431
+ num_heads = config.__dict__.get("num_heads", 12)
432
+ num_patches = (self.img_size // self.patch_size) ** 2
433
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
434
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
435
+ self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
436
+ self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
437
+ self.norm = nn.LayerNorm(self.embed_dim)
438
+ self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
439
+ self._init_weights()
440
+ def _init_weights(self):
441
+ nn.init.normal_(self.cls_token, std=0.02)
442
+ nn.init.normal_(self.pos_embed, std=0.02)
443
+ def forward(self, x):
444
+ x = self.patch_embed(x)
445
+ x = x.flatten(2).transpose(1, 2)
446
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
447
+ x = torch.cat((cls_tokens, x), dim=1)
448
+ x = x + self.pos_embed
449
+ for blk in self.blocks:
450
+ x = blk(x)
451
+ return self.norm(x)[:, 0]
452
+
453
+ class OpenLRM(nn.Module):
454
+ def __init__(self, config):
455
+ super().__init__()
456
+ self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
457
+ hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
458
+ self.linear = nn.Linear(hidden, hidden)
459
+ def forward(self, x):
460
+ return self.linear(self.encoder["model"](x))
461
+
462
+ class VideoUNet(nn.Module):
463
+ def __init__(self, in_ch=4, out_ch=4, features=None):
464
+ super().__init__()
465
+ if features is None:
466
+ features = [64, 128, 256]
467
+ self.encoder = nn.ModuleList()
468
+ self.pool = nn.MaxPool3d(2, 2)
469
+ self.decoder = nn.ModuleList()
470
+ for f in features:
471
+ self.encoder.append(nn.Sequential(
472
+ nn.Conv3d(in_ch, f, 3, padding=1),
473
+ nn.ReLU(inplace=True),
474
+ nn.Conv3d(f, f, 3, padding=1),
475
+ nn.ReLU(inplace=True)
476
+ ))
477
+ in_ch = f
478
+ for f in reversed(features):
479
+ self.decoder.append(nn.Sequential(
480
+ nn.Conv3d(f * 2, f, 3, padding=1),
481
+ nn.ReLU(inplace=True),
482
+ nn.Conv3d(f, f, 3, padding=1),
483
+ nn.ReLU(inplace=True)
484
+ ))
485
+ self.final_conv = nn.Conv3d(features[0], out_ch, 1)
486
+ def forward(self, x, t, encoder_hidden_states):
487
+ skips = []
488
+ for enc in self.encoder:
489
+ x = enc(x)
490
+ skips.append(x)
491
+ x = self.pool(x)
492
+ for dec in self.decoder:
493
+ skip = skips.pop()
494
+ x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
495
+ x = torch.cat([x, skip], dim=1)
496
+ x = dec(x)
497
+ return self.final_conv(x)
498
+
499
+ class SentimentClassifierModel(nn.Module):
500
+ def __init__(self, config):
501
+ super().__init__()
502
+ self.classifier = nn.Sequential(
503
+ nn.Linear(768, 256),
504
+ nn.ReLU(),
505
+ nn.Linear(256, 2)
506
+ )
507
+ def forward(self, x):
508
+ return self.classifier(x)
509
+
510
+ class STTModel(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.net = nn.Sequential(
514
+ nn.Linear(768, 512),
515
+ nn.ReLU(),
516
+ nn.Linear(512, 768)
517
+ )
518
+ def forward(self, x):
519
+ return self.net(x)
520
+
521
+ class TTSModel(nn.Module):
522
+ def __init__(self, config):
523
+ super().__init__()
524
+ self.net = nn.Sequential(
525
+ nn.Linear(768, 512),
526
+ nn.ReLU(),
527
+ nn.Linear(512, 768)
528
+ )
529
+ def forward(self, x):
530
+ return self.net(x)
531
+
532
+ class MusicGenModel(nn.Module):
533
+ def __init__(self, config):
534
+ super().__init__()
535
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
536
+ self.transformer = nn.TransformerEncoder(layer, num_layers=12)
537
+ self.linear = nn.Linear(768, 768)
538
+ def forward(self, x):
539
+ return self.linear(self.transformer(x))
540
+
541
+ class SimpleTextEncoder(nn.Module):
542
+ def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
543
+ super().__init__()
544
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
545
+ self.max_length = max_length
546
+ def forward(self, text_tokens):
547
+ return self.embedding(text_tokens)
548
+
549
+ class DiffusionScheduler:
550
+ def __init__(self, steps):
551
+ self.steps = steps
552
+ self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
553
+ self.alphas = 1 - self.betas
554
+ self.alpha_bars = torch.cumprod(self.alphas, dim=0)
555
+ def step(self, noise, t, sample):
556
+ alpha_bar = self.alpha_bars[t]
557
+ alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else torch.tensor(1.0, device=sample.device)
558
+ x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar)
559
+ new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise
560
+ return new_sample
561
+
562
+ class VideoOutput:
563
+ def __init__(self, frames):
564
+ self.frames = [img_as_ubyte(frame) for frame in frames[0]]
565
+
566
+ class VideoPipeline(nn.Module):
567
+ def __init__(self, unet, vae, text_encoder, vocab):
568
+ super().__init__()
569
+ self.unet = unet
570
+ self.vae = vae
571
+ self.text_encoder = text_encoder
572
+ self.vocab = vocab
573
+ def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
574
+ token_ids = simple_tokenizer(prompt, self.vocab)
575
+ text_emb = self.text_encoder(token_ids)
576
+ latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
577
+ sched = DiffusionScheduler(steps)
578
+ for t in range(steps):
579
+ noise = self.unet(latent, t, text_emb)
580
+ latent = sched.step(noise, t, latent)
581
+ frames = self.vae.decode(latent / 0.18215)
582
+ frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
583
+ return VideoOutput(frames)
584
+
585
+ def initialize_gpt2_model(folder, files):
586
+ download_files(folder, files)
587
+ config = GPT2Config()
588
+ model = GPT2LMHeadModel(config).to(device)
589
+ sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
590
+ load_state_dict_safe(model, sd)
591
+ model.eval()
592
+ enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
593
+ return model, enc
594
+
595
+ def initialize_translation_model(folder, files):
596
+ download_files(folder, files)
597
+ config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
598
+ model = MBartForConditionalGeneration(config).to(device)
599
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
600
+ load_state_dict_safe(model, sd)
601
+ model.eval()
602
+ vp = os.path.join(folder, "vocab.json")
603
+ if os.path.exists(vp):
604
+ vocab = read_json(vp)
605
+ model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
606
+ else:
607
+ model.tokenizer = lambda txt: txt
608
+ model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
609
+ return model
610
+
611
+ def initialize_codegen_model(folder, files):
612
+ download_files(folder, files)
613
+ config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
614
+ model = CodeGenForCausalLM(config).to(device)
615
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
616
+ load_state_dict_safe(model, sd)
617
+ model.eval()
618
+ tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
619
+ vocab = read_json(os.path.join(folder, "vocab.json"))
620
+ idx2w = {v: k for k, v in vocab.items()}
621
+ model.tokenizer = tok
622
+ return model, tok, vocab, idx2w, vocab
623
+
624
+ def initialize_summarization_model(folder, files):
625
+ download_files(folder, files)
626
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
627
+ model = BartForConditionalGeneration(config).to(device)
628
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
629
+ load_state_dict_safe(model, sd)
630
+ model.eval()
631
+ vp = os.path.join(folder, "vocab.json")
632
+ if os.path.exists(vp):
633
+ vocab_json = read_json(vp)
634
+ vocab = set(vocab_json.keys())
635
+ return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
636
+ return model, None, None, None
637
+
638
+ def initialize_imagegen_model(folder, files):
639
+ download_files(folder, files)
640
+ config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
641
+ vae = AutoencoderKL(config).to(device)
642
+ sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
643
+ load_state_dict_safe(vae, sd)
644
+ vae.eval()
645
+ return vae
646
+
647
+ def initialize_image_to_3d_model(folder, files):
648
+ download_files(folder, files)
649
+ config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
650
+ model3d = OpenLRM(config).to(device)
651
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
652
+ load_state_dict_safe(model3d, sd)
653
+ model3d.eval()
654
+ return model3d
655
+
656
+ def initialize_text_to_video_model(folder, files):
657
+ download_files(folder, files)
658
+ unet_cfg = read_json(os.path.join(folder, "config.json"))
659
+ unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
660
+ unet = VideoUNet(**unet_cfg).half().to(device)
661
+ sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
662
+ load_state_dict_safe(unet, sd_unet)
663
+ unet.eval()
664
+ vae_cfg = read_json(os.path.join(folder, "config.json"))
665
+ vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
666
+ vae = AutoencoderKL(vae_cfg).half().to(device)
667
+ sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
668
+ load_state_dict_safe(vae, sd_vae)
669
+ vae.eval()
670
+ vp = os.path.join(folder, "vocab.json")
671
+ text_vocab = read_json(vp) if os.path.exists(vp) else {}
672
+ te_path = os.path.join(folder, "text_encoder.bin")
673
+ if os.path.exists(te_path):
674
+ text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
675
+ sd_te = torch.load(te_path, map_location=device)
676
+ load_state_dict_safe(text_encoder, sd_te)
677
+ else:
678
+ text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
679
+ text_encoder.eval()
680
+ return VideoPipeline(unet, vae, text_encoder, text_vocab)
681
+
682
+ def initialize_sentiment_model(folder, files):
683
+ download_files(folder, files)
684
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
685
+ model = SentimentClassifierModel(config).to(device)
686
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
687
+ load_state_dict_safe(model, sd)
688
+ model.eval()
689
+ vp = os.path.join(folder, "vocab.json")
690
+ if os.path.exists(vp):
691
+ read_json(vp)
692
+ return model
693
+
694
+ def initialize_stt_model(folder, files):
695
+ download_files(folder, files)
696
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
697
+ model = STTModel(config).to(device)
698
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
699
+ load_state_dict_safe(model, sd)
700
+ model.eval()
701
+ vp = os.path.join(folder, "vocab.json")
702
+ if os.path.exists(vp):
703
+ read_json(vp)
704
+ return model
705
+
706
+ def initialize_tts_model(folder, files):
707
+ download_files(folder, files)
708
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
709
+ model = TTSModel(config).to(device)
710
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
711
+ load_state_dict_safe(model, sd)
712
+ model.eval()
713
+ vp = os.path.join(folder, "vocab.json")
714
+ if os.path.exists(vp):
715
+ read_json(vp)
716
+ return model
717
+
718
+ def initialize_musicgen_model(folder, files):
719
+ download_files(folder, files)
720
+ config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
721
+ model = MusicGenModel(config).to(device)
722
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
723
+ load_state_dict_safe(model, sd)
724
+ model.eval()
725
+ return model
726
+
727
+ if not fn:
728
+ fn = hashlib.md5(url.encode()).hexdigest()
729
+ fn = sanitize_filename(fn, url)
730
+ elif isinstance(item, (list, tuple)) and len(item) == 2:
731
+ url, fn = item
732
+ fn = sanitize_filename(fn, url)
733
+ elif isinstance(item, dict) and "filename" in item and "url" in item:
734
+ fn = sanitize_filename(item["filename"], item["url"])
735
+ url = item["url"]
736
+ else:
737
+ raise ValueError("Invalid file specification")
738
+ fp = os.path.join(folder, fn)
739
+ download_file(url, fp)
740
+ else:
741
+ raise ValueError("files_spec must be dict or list")
742
+
743
+ def read_json(fp):
744
+ with open(fp, 'r', encoding='utf-8') as f:
745
+ return json.load(f)
746
+
747
+ def get_codegen_tokenizer(vocab_path, merges_path):
748
+ with open(vocab_path, 'r', encoding='utf-8') as f:
749
+ vocab = json.load(f)
750
+ with open(merges_path, 'r', encoding='utf-8') as f:
751
+ merges = f.read().splitlines()
752
+ merge_ranks = {}
753
+ for i, merge in enumerate(merges):
754
+ parts = merge.strip().split()
755
+ if len(parts) == 2:
756
+ merge_ranks[tuple(parts)] = i
757
+ def bpe(token):
758
+ word = list(token)
759
+ pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
760
+ while True:
761
+ candidate = None
762
+ candidate_rank = None
763
+ candidate_index = None
764
+ for i, pair in enumerate(pairs):
765
+ if pair in merge_ranks:
766
+ rank = merge_ranks[pair]
767
+ if candidate is None or rank < candidate_rank:
768
+ candidate = pair
769
+ candidate_rank = rank
770
+ candidate_index = i
771
+ if candidate is None:
772
+ break
773
+ first, second = candidate
774
+ new_word = []
775
+ i = 0
776
+ while i < len(word):
777
+ if i < len(word) - 1 and word[i] == first and word[i+1] == second:
778
+ new_word.append(first + second)
779
+ i += 2
780
+ else:
781
+ new_word.append(word[i])
782
+ i += 1
783
+ word = new_word
784
+ if len(word) == 1:
785
+ break
786
+ pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
787
+ return word
788
+ def tokenizer(text):
789
+ tokens = []
790
+ for token in text.split():
791
+ bpe_tokens = bpe(token)
792
+ for subtoken in bpe_tokens:
793
+ tokens.append(vocab.get(subtoken, 0))
794
+ return tokens
795
+ return tokenizer
796
+
797
+ def simple_tokenizer(text, vocab, max_length=77):
798
+ toks = text.split()
799
+ ids = [vocab.get(t, 1) for t in toks]
800
+ if len(ids) < max_length:
801
+ ids = ids + [0] * (max_length - len(ids))
802
+ else:
803
+ ids = ids[:max_length]
804
+ return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
805
+
806
+ def load_state_dict_safe(model, loaded_state_dict):
807
+ model_state = model.state_dict()
808
+ new_state = {}
809
+ for key, value in model_state.items():
810
+ if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape:
811
+ new_state[key] = loaded_state_dict[key]
812
+ else:
813
+ new_state[key] = value
814
+ model.load_state_dict(new_state, strict=False)
815
+
816
+ class GPT2Config:
817
+ def __init__(self, vocab_size=50257, **kwargs):
818
+ self.vocab_size = vocab_size
819
+ self.__dict__.update(kwargs)
820
+ @classmethod
821
+ def from_dict(cls, d):
822
+ return cls(**d)
823
+
824
+ class MBartConfig:
825
+ def __init__(self, vocab_size=50265, **kwargs):
826
+ self.vocab_size = vocab_size
827
+ self.__dict__.update(kwargs)
828
+ @classmethod
829
+ def from_dict(cls, d):
830
+ return cls(**d)
831
+
832
+ class CodeGenConfig:
833
+ def __init__(self, vocab_size=50257, **kwargs):
834
+ self.vocab_size = vocab_size
835
+ self.__dict__.update(kwargs)
836
+ @classmethod
837
+ def from_dict(cls, d):
838
+ return cls(**d)
839
+
840
+ class BartConfig:
841
+ def __init__(self, vocab_size=50265, **kwargs):
842
+ self.vocab_size = vocab_size
843
+ self.__dict__.update(kwargs)
844
+ @classmethod
845
+ def from_dict(cls, d):
846
+ return cls(**d)
847
+
848
+ class AutoencoderKLConfig:
849
+ def __init__(self, **kwargs):
850
+ self.__dict__.update(kwargs)
851
+ @classmethod
852
+ def from_dict(cls, d):
853
+ return cls(**d)
854
+
855
+ class OpenLRMConfig:
856
+ def __init__(self, **kwargs):
857
+ self.__dict__.update(kwargs)
858
+ @classmethod
859
+ def from_dict(cls, d):
860
+ return cls(**d)
861
+
862
+ class UNet2DConditionModelConfig:
863
+ def __init__(self, **kwargs):
864
+ self.__dict__.update(kwargs)
865
+ @classmethod
866
+ def from_dict(cls, d):
867
+ return cls(**d)
868
+
869
+ class MusicGenConfig:
870
+ def __init__(self, **kwargs):
871
+ self.__dict__.update(kwargs)
872
+ @classmethod
873
+ def from_dict(cls, d):
874
+ return cls(**d)
875
+
876
+ class GPT2LMHeadModel(nn.Module):
877
+ def __init__(self, config):
878
+ super().__init__()
879
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
880
+ self.transformer = nn.TransformerEncoder(layer, num_layers=12)
881
+ self.lm_head = nn.Linear(768, config.vocab_size)
882
+ def forward(self, x):
883
+ return self.lm_head(self.transformer(x))
884
+
885
+ class MBartForConditionalGeneration(nn.Module):
886
+ def __init__(self, config):
887
+ super().__init__()
888
+ self.config = config
889
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
890
+ self.encoder = nn.TransformerEncoder(layer, num_layers=6)
891
+ dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
892
+ self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
893
+ self.output_layer = nn.Linear(768, config.vocab_size)
894
+ def forward(self, src, tgt):
895
+ return self.output_layer(self.decoder(tgt, self.encoder(src)))
896
+
897
+ class CodeGenForCausalLM(nn.Module):
898
+ def __init__(self, config):
899
+ super().__init__()
900
+ d_model = getattr(config, "d_model", 1024)
901
+ n_head = getattr(config, "n_head", 16)
902
+ num_layers = getattr(config, "num_layers", 12)
903
+ dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
904
+ self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
905
+ self.lm_head = nn.Linear(d_model, config.vocab_size)
906
+ def forward(self, tgt, memory=None):
907
+ if memory is None:
908
+ memory = torch.zeros_like(tgt)
909
+ return self.lm_head(self.transformer_decoder(tgt, memory))
910
+
911
+ class BartForConditionalGeneration(nn.Module):
912
+ def __init__(self, config):
913
+ super().__init__()
914
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
915
+ self.encoder = nn.TransformerEncoder(layer, num_layers=6)
916
+ dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
917
+ self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
918
+ self.output_layer = nn.Linear(768, config.vocab_size)
919
+ def forward(self, src, tgt):
920
+ return self.output_layer(self.decoder(tgt, self.encoder(src)))
921
+
922
+ class ResnetBlock(nn.Module):
923
+ def __init__(self, in_ch, out_ch):
924
+ super().__init__()
925
+ self.norm1 = nn.GroupNorm(32, in_ch)
926
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
927
+ self.norm2 = nn.GroupNorm(32, out_ch)
928
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
929
+ self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
930
+ def forward(self, x):
931
+ sc = self.conv_shortcut(x)
932
+ h = F.silu(self.norm1(x))
933
+ h = self.conv1(h)
934
+ h = F.silu(self.norm2(h))
935
+ h = self.conv2(h)
936
+ return h + sc
937
+
938
+ class Downsample(nn.Module):
939
+ def __init__(self, in_ch, out_ch):
940
+ super().__init__()
941
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
942
+ def forward(self, x):
943
+ return self.conv(x)
944
+
945
+ class DownBlock(nn.Module):
946
+ def __init__(self, in_ch, out_ch, num_res):
947
+ super().__init__()
948
+ self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
949
+ self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
950
+ def forward(self, x):
951
+ for r in self.resnets:
952
+ x = r(x)
953
+ for ds in self.downsamplers:
954
+ x = ds(x)
955
+ return x
956
+
957
+ class Upsample(nn.Module):
958
+ def __init__(self, in_ch, out_ch):
959
+ super().__init__()
960
+ self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
961
+ def forward(self, x):
962
+ return self.conv(x)
963
+
964
+ class UpBlock(nn.Module):
965
+ def __init__(self, in_ch, out_ch, num_res):
966
+ super().__init__()
967
+ self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
968
+ self.upsampler = Upsample(out_ch, out_ch)
969
+ def forward(self, x):
970
+ for r in self.resnets:
971
+ x = r(x)
972
+ return self.upsampler(x)
973
+
974
+ class AttentionBlock(nn.Module):
975
+ def __init__(self, ch):
976
+ super().__init__()
977
+ self.norm = nn.GroupNorm(32, ch)
978
+ self.query = nn.Conv2d(ch, ch, 1)
979
+ self.key = nn.Conv2d(ch, ch, 1)
980
+ self.value = nn.Conv2d(ch, ch, 1)
981
+ self.proj_attn = nn.Conv2d(ch, ch, 1)
982
+ def forward(self, x):
983
+ b, c, h, w = x.shape
984
+ xn = self.norm(x)
985
+ q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
986
+ k = self.key(xn).view(b, c, -1)
987
+ v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
988
+ attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
989
+ out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
990
+ return x + self.proj_attn(out)
991
+
992
+ class Encoder(nn.Module):
993
+ def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
994
+ super().__init__()
995
+ self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
996
+ self.down_blocks = nn.ModuleList([
997
+ DownBlock(base_ch, base_ch, 2),
998
+ DownBlock(base_ch, base_ch * 2, 2),
999
+ DownBlock(base_ch * 2, base_ch * 4, 2),
1000
+ DownBlock(base_ch * 4, base_ch * 4, 2)
1001
+ ])
1002
+ self.mid_block = nn.ModuleList([
1003
+ ResnetBlock(base_ch * 4, base_ch * 4),
1004
+ AttentionBlock(base_ch * 4),
1005
+ ResnetBlock(base_ch * 4, base_ch * 4)
1006
+ ])
1007
+ self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
1008
+ self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
1009
+ self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
1010
+ def forward(self, x):
1011
+ x = self.conv_in(x)
1012
+ for blk in self.down_blocks:
1013
+ x = blk(x)
1014
+ for m in self.mid_block:
1015
+ x = m(x)
1016
+ x = self.conv_norm_out(x)
1017
+ x = self.conv_out(x)
1018
+ return self.quant_conv(x)
1019
+
1020
+ class Decoder(nn.Module):
1021
+ def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
1022
+ super().__init__()
1023
+ self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
1024
+ self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
1025
+ self.mid_block = nn.ModuleList([
1026
+ ResnetBlock(base_ch * 4, base_ch * 4),
1027
+ AttentionBlock(base_ch * 4),
1028
+ ResnetBlock(base_ch * 4, base_ch * 4)
1029
+ ])
1030
+ self.up_blocks = nn.ModuleList([
1031
+ UpBlock(base_ch * 4, base_ch * 4, 3),
1032
+ UpBlock(base_ch * 4, base_ch * 2, 3),
1033
+ UpBlock(base_ch * 2, base_ch, 3),
1034
+ UpBlock(base_ch, base_ch, 3)
1035
+ ])
1036
+ self.conv_norm_out = nn.GroupNorm(32, base_ch)
1037
+ self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
1038
+ def forward(self, x):
1039
+ x = self.post_quant_conv(x)
1040
+ x = self.conv_in(x)
1041
+ for m in self.mid_block:
1042
+ x = m(x)
1043
+ for up in self.up_blocks:
1044
+ x = up(x)
1045
+ x = self.conv_norm_out(x)
1046
+ return self.conv_out(x)
1047
+
1048
+ class AutoencoderKL(nn.Module):
1049
+ def __init__(self, config):
1050
+ super().__init__()
1051
+ in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
1052
+ out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
1053
+ base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
1054
+ latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
1055
+ self.encoder = Encoder(in_ch, base_ch, latent_ch)
1056
+ self.decoder = Decoder(out_ch, base_ch, latent_ch)
1057
+ def forward(self, x):
1058
+ return self.decoder(self.encoder(x))
1059
+ def decode(self, x):
1060
+ return self.decoder(x)
1061
+
1062
+ class TransformerBlock(nn.Module):
1063
+ def __init__(self, embed_dim, num_heads):
1064
+ super().__init__()
1065
+ self.norm1 = nn.LayerNorm(embed_dim)
1066
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
1067
+ self.norm2 = nn.LayerNorm(embed_dim)
1068
+ hidden_dim = embed_dim * 4
1069
+ self.mlp = nn.Sequential(
1070
+ nn.Linear(embed_dim, hidden_dim),
1071
+ nn.GELU(),
1072
+ nn.Linear(hidden_dim, embed_dim)
1073
+ )
1074
+ def forward(self, x):
1075
+ res = x
1076
+ x = self.norm1(x)
1077
+ x = x.transpose(0, 1)
1078
+ attn, _ = self.attn(x, x, x)
1079
+ x = attn.transpose(0, 1)
1080
+ x = res + x
1081
+ return x + self.mlp(self.norm2(x))
1082
+
1083
+ class VisionTransformer(nn.Module):
1084
+ def __init__(self, config):
1085
+ super().__init__()
1086
+ if isinstance(config, dict):
1087
+ self.img_size = config.get("img_size", 592)
1088
+ self.patch_size = config.get("patch_size", 16)
1089
+ self.embed_dim = config.get("hidden_size", 768)
1090
+ depth = config.get("depth", 12)
1091
+ num_heads = config.get("num_heads", 12)
1092
+ else:
1093
+ self.img_size = config.__dict__.get("img_size", 592)
1094
+ self.patch_size = config.__dict__.get("patch_size", 16)
1095
+ self.embed_dim = config.__dict__.get("hidden_size", 768)
1096
+ depth = config.__dict__.get("depth", 12)
1097
+ num_heads = config.__dict__.get("num_heads", 12)
1098
+ num_patches = (self.img_size // self.patch_size) ** 2
1099
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
1100
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
1101
+ self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
1102
+ self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
1103
+ self.norm = nn.LayerNorm(self.embed_dim)
1104
+ self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
1105
+ self._init_weights()
1106
+ def _init_weights(self):
1107
+ nn.init.normal_(self.cls_token, std=0.02)
1108
+ nn.init.normal_(self.pos_embed, std=0.02)
1109
+ def forward(self, x):
1110
+ x = self.patch_embed(x)
1111
+ x = x.flatten(2).transpose(1, 2)
1112
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
1113
+ x = torch.cat((cls_tokens, x), dim=1)
1114
+ x = x + self.pos_embed
1115
+ for blk in self.blocks:
1116
+ x = blk(x)
1117
+ return self.norm(x)[:, 0]
1118
+
1119
+ class OpenLRM(nn.Module):
1120
+ def __init__(self, config):
1121
+ super().__init__()
1122
+ self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
1123
+ hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
1124
+ self.linear = nn.Linear(hidden, hidden)
1125
+ def forward(self, x):
1126
+ return self.linear(self.encoder["model"](x))
1127
+
1128
+ class VideoUNet(nn.Module):
1129
+ def __init__(self, in_ch=4, out_ch=4, features=None):
1130
+ super().__init__()
1131
+ if features is None:
1132
+ features = [64, 128, 256]
1133
+ self.encoder = nn.ModuleList()
1134
+ self.pool = nn.MaxPool3d(2, 2)
1135
+ self.decoder = nn.ModuleList()
1136
+ for f in features:
1137
+ self.encoder.append(nn.Sequential(
1138
+ nn.Conv3d(in_ch, f, 3, padding=1),
1139
+ nn.ReLU(inplace=True),
1140
+ nn.Conv3d(f, f, 3, padding=1),
1141
+ nn.ReLU(inplace=True)
1142
+ ))
1143
+ in_ch = f
1144
+ for f in reversed(features):
1145
+ self.decoder.append(nn.Sequential(
1146
+ nn.Conv3d(f * 2, f, 3, padding=1),
1147
+ nn.ReLU(inplace=True),
1148
+ nn.Conv3d(f, f, 3, padding=1),
1149
+ nn.ReLU(inplace=True)
1150
+ ))
1151
+ self.final_conv = nn.Conv3d(features[0], out_ch, 1)
1152
+ def forward(self, x, t, encoder_hidden_states):
1153
+ skips = []
1154
+ for enc in self.encoder:
1155
+ x = enc(x)
1156
+ skips.append(x)
1157
+ x = self.pool(x)
1158
+ for dec in self.decoder:
1159
+ skip = skips.pop()
1160
+ x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
1161
+ x = torch.cat([x, skip], dim=1)
1162
+ x = dec(x)
1163
+ return self.final_conv(x)
1164
+
1165
+ class SentimentClassifierModel(nn.Module):
1166
+ def __init__(self, config):
1167
+ super().__init__()
1168
+ self.classifier = nn.Sequential(
1169
+ nn.Linear(768, 256),
1170
+ nn.ReLU(),
1171
+ nn.Linear(256, 2)
1172
+ )
1173
+ def forward(self, x):
1174
+ return self.classifier(x)
1175
+
1176
+ class STTModel(nn.Module):
1177
+ def __init__(self, config):
1178
+ super().__init__()
1179
+ self.net = nn.Sequential(
1180
+ nn.Linear(768, 512),
1181
+ nn.ReLU(),
1182
+ nn.Linear(512, 768)
1183
+ )
1184
+ def forward(self, x):
1185
+ return self.net(x)
1186
+
1187
+ class TTSModel(nn.Module):
1188
+ def __init__(self, config):
1189
+ super().__init__()
1190
+ self.net = nn.Sequential(
1191
+ nn.Linear(768, 512),
1192
+ nn.ReLU(),
1193
+ nn.Linear(512, 768)
1194
+ )
1195
+ def forward(self, x):
1196
+ return self.net(x)
1197
+
1198
+ class MusicGenModel(nn.Module):
1199
+ def __init__(self, config):
1200
+ super().__init__()
1201
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
1202
+ self.transformer = nn.TransformerEncoder(layer, num_layers=12)
1203
+ self.linear = nn.Linear(768, 768)
1204
+ def forward(self, x):
1205
+ return self.linear(self.transformer(x))
1206
+
1207
+ class SimpleTextEncoder(nn.Module):
1208
+ def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
1209
+ super().__init__()
1210
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
1211
+ self.max_length = max_length
1212
+ def forward(self, text_tokens):
1213
+ return self.embedding(text_tokens)
1214
+
1215
+ class DiffusionScheduler:
1216
+ def __init__(self, steps):
1217
+ self.steps = steps
1218
+ self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
1219
+ self.alphas = 1 - self.betas
1220
+ self.alpha_bars = torch.cumprod(self.alphas, dim=0)
1221
+ def step(self, noise, t, sample):
1222
+ alpha_bar = self.alpha_bars[t]
1223
+ if t > 0:
1224
+ alpha_bar_prev = self.alpha_bars[t-1]
1225
+ else:
1226
+ alpha_bar_prev = torch.tensor(1.0, device=sample.device)
1227
+ x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar)
1228
+ new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise
1229
+ return new_sample
1230
+
1231
+ class VideoOutput:
1232
+ def __init__(self, frames):
1233
+ self.frames = [img_as_ubyte(frame) for frame in frames[0]]
1234
+
1235
+ class VideoPipeline(nn.Module):
1236
+ def __init__(self, unet, vae, text_encoder, vocab):
1237
+ super().__init__()
1238
+ self.unet = unet
1239
+ self.vae = vae
1240
+ self.text_encoder = text_encoder
1241
+ self.vocab = vocab
1242
+ def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
1243
+ token_ids = simple_tokenizer(prompt, self.vocab)
1244
+ text_emb = self.text_encoder(token_ids)
1245
+ latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
1246
+ sched = DiffusionScheduler(steps)
1247
+ for t in range(steps):
1248
+ noise = self.unet(latent, t, text_emb)
1249
+ latent = sched.step(noise, t, latent)
1250
+ frames = self.vae.decode(latent / 0.18215)
1251
+ frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
1252
+ return VideoOutput(frames)
1253
+
1254
+ def initialize_gpt2_model(folder, files):
1255
+ download_files(folder, files)
1256
+ config = GPT2Config()
1257
+ model = GPT2LMHeadModel(config).to(device)
1258
+ sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
1259
+ load_state_dict_safe(model, sd)
1260
+ model.eval()
1261
+ enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
1262
+ return model, enc
1263
+
1264
+ def initialize_translation_model(folder, files):
1265
+ download_files(folder, files)
1266
+ config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1267
+ model = MBartForConditionalGeneration(config).to(device)
1268
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1269
+ load_state_dict_safe(model, sd)
1270
+ model.eval()
1271
+ vp = os.path.join(folder, "vocab.json")
1272
+ if os.path.exists(vp):
1273
+ vocab = read_json(vp)
1274
+ model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
1275
+ else:
1276
+ model.tokenizer = lambda txt: txt
1277
+ model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
1278
+ return model
1279
+
1280
+ def initialize_codegen_model(folder, files):
1281
+ download_files(folder, files)
1282
+ config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1283
+ model = CodeGenForCausalLM(config).to(device)
1284
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1285
+ load_state_dict_safe(model, sd)
1286
+ model.eval()
1287
+ tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
1288
+ vocab = read_json(os.path.join(folder, "vocab.json"))
1289
+ idx2w = {v: k for k, v in vocab.items()}
1290
+ model.tokenizer = tok
1291
+ return model, tok, vocab, idx2w, vocab
1292
+
1293
+ def initialize_summarization_model(folder, files):
1294
+ download_files(folder, files)
1295
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1296
+ model = BartForConditionalGeneration(config).to(device)
1297
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1298
+ load_state_dict_safe(model, sd)
1299
+ model.eval()
1300
+ vp = os.path.join(folder, "vocab.json")
1301
+ if os.path.exists(vp):
1302
+ vocab_json = read_json(vp)
1303
+ vocab = set(vocab_json.keys())
1304
+ return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
1305
+ return model, None, None, None
1306
+
1307
+ def initialize_imagegen_model(folder, files):
1308
+ download_files(folder, files)
1309
+ config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1310
+ vae = AutoencoderKL(config).to(device)
1311
+ sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
1312
+ load_state_dict_safe(vae, sd)
1313
+ vae.eval()
1314
+ return vae
1315
+
1316
+ def initialize_image_to_3d_model(folder, files):
1317
+ download_files(folder, files)
1318
+ config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1319
+ model3d = OpenLRM(config).to(device)
1320
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1321
+ load_state_dict_safe(model3d, sd)
1322
+ model3d.eval()
1323
+ return model3d
1324
+
1325
+ def initialize_text_to_video_model(folder, files):
1326
+ download_files(folder, files)
1327
+ unet_cfg = read_json(os.path.join(folder, "config.json"))
1328
+ unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
1329
+ unet = VideoUNet(**unet_cfg).half().to(device)
1330
+ sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
1331
+ load_state_dict_safe(unet, sd_unet)
1332
+ unet.eval()
1333
+ vae_cfg = read_json(os.path.join(folder, "config.json"))
1334
+ vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
1335
+ vae = AutoencoderKL(vae_cfg).half().to(device)
1336
+ sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
1337
+ load_state_dict_safe(vae, sd_vae)
1338
+ vae.eval()
1339
+ vp = os.path.join(folder, "vocab.json")
1340
+ text_vocab = read_json(vp) if os.path.exists(vp) else {}
1341
+ te_path = os.path.join(folder, "text_encoder.bin")
1342
+ if os.path.exists(te_path):
1343
+ text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
1344
+ sd_te = torch.load(te_path, map_location=device)
1345
+ load_state_dict_safe(text_encoder, sd_te)
1346
+ else:
1347
+ text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
1348
+ text_encoder.eval()
1349
+ return VideoPipeline(unet, vae, text_encoder, text_vocab)
1350
+
1351
+ def initialize_sentiment_model(folder, files):
1352
+ download_files(folder, files)
1353
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1354
+ model = SentimentClassifierModel(config).to(device)
1355
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1356
+ load_state_dict_safe(model, sd)
1357
+ model.eval()
1358
+ vp = os.path.join(folder, "vocab.json")
1359
+ if os.path.exists(vp):
1360
+ read_json(vp)
1361
+ return model
1362
+
1363
+ def initialize_stt_model(folder, files):
1364
+ download_files(folder, files)
1365
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1366
+ model = STTModel(config).to(device)
1367
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1368
+ load_state_dict_safe(model, sd)
1369
+ model.eval()
1370
+ vp = os.path.join(folder, "vocab.json")
1371
+ if os.path.exists(vp):
1372
+ read_json(vp)
1373
+ return model
1374
+
1375
+ def initialize_tts_model(folder, files):
1376
+ download_files(folder, files)
1377
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1378
+ model = TTSModel(config).to(device)
1379
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1380
+ load_state_dict_safe(model, sd)
1381
+ model.eval()
1382
+ vp = os.path.join(folder, "vocab.json")
1383
+ if os.path.exists(vp):
1384
+ read_json(vp)
1385
+ return model
1386
+
1387
+ def initialize_musicgen_model(folder, files):
1388
+ download_files(folder, files)
1389
+ config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1390
+ model = MusicGenModel(config).to(device)
1391
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1392
+ load_state_dict_safe(model, sd)
1393
+ model.eval()
1394
+ return model