Hjgugugjhuhjggg commited on
Commit
e3e53f6
·
verified ·
1 Parent(s): bd4d670

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +0 -669
model_loader.py CHANGED
@@ -723,672 +723,3 @@ def initialize_musicgen_model(folder, files):
723
  load_state_dict_safe(model, sd)
724
  model.eval()
725
  return model
726
-
727
- if not fn:
728
- fn = hashlib.md5(url.encode()).hexdigest()
729
- fn = sanitize_filename(fn, url)
730
- elif isinstance(item, (list, tuple)) and len(item) == 2:
731
- url, fn = item
732
- fn = sanitize_filename(fn, url)
733
- elif isinstance(item, dict) and "filename" in item and "url" in item:
734
- fn = sanitize_filename(item["filename"], item["url"])
735
- url = item["url"]
736
- else:
737
- raise ValueError("Invalid file specification")
738
- fp = os.path.join(folder, fn)
739
- download_file(url, fp)
740
- else:
741
- raise ValueError("files_spec must be dict or list")
742
-
743
- def read_json(fp):
744
- with open(fp, 'r', encoding='utf-8') as f:
745
- return json.load(f)
746
-
747
- def get_codegen_tokenizer(vocab_path, merges_path):
748
- with open(vocab_path, 'r', encoding='utf-8') as f:
749
- vocab = json.load(f)
750
- with open(merges_path, 'r', encoding='utf-8') as f:
751
- merges = f.read().splitlines()
752
- merge_ranks = {}
753
- for i, merge in enumerate(merges):
754
- parts = merge.strip().split()
755
- if len(parts) == 2:
756
- merge_ranks[tuple(parts)] = i
757
- def bpe(token):
758
- word = list(token)
759
- pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
760
- while True:
761
- candidate = None
762
- candidate_rank = None
763
- candidate_index = None
764
- for i, pair in enumerate(pairs):
765
- if pair in merge_ranks:
766
- rank = merge_ranks[pair]
767
- if candidate is None or rank < candidate_rank:
768
- candidate = pair
769
- candidate_rank = rank
770
- candidate_index = i
771
- if candidate is None:
772
- break
773
- first, second = candidate
774
- new_word = []
775
- i = 0
776
- while i < len(word):
777
- if i < len(word) - 1 and word[i] == first and word[i+1] == second:
778
- new_word.append(first + second)
779
- i += 2
780
- else:
781
- new_word.append(word[i])
782
- i += 1
783
- word = new_word
784
- if len(word) == 1:
785
- break
786
- pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
787
- return word
788
- def tokenizer(text):
789
- tokens = []
790
- for token in text.split():
791
- bpe_tokens = bpe(token)
792
- for subtoken in bpe_tokens:
793
- tokens.append(vocab.get(subtoken, 0))
794
- return tokens
795
- return tokenizer
796
-
797
- def simple_tokenizer(text, vocab, max_length=77):
798
- toks = text.split()
799
- ids = [vocab.get(t, 1) for t in toks]
800
- if len(ids) < max_length:
801
- ids = ids + [0] * (max_length - len(ids))
802
- else:
803
- ids = ids[:max_length]
804
- return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
805
-
806
- def load_state_dict_safe(model, loaded_state_dict):
807
- model_state = model.state_dict()
808
- new_state = {}
809
- for key, value in model_state.items():
810
- if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape:
811
- new_state[key] = loaded_state_dict[key]
812
- else:
813
- new_state[key] = value
814
- model.load_state_dict(new_state, strict=False)
815
-
816
- class GPT2Config:
817
- def __init__(self, vocab_size=50257, **kwargs):
818
- self.vocab_size = vocab_size
819
- self.__dict__.update(kwargs)
820
- @classmethod
821
- def from_dict(cls, d):
822
- return cls(**d)
823
-
824
- class MBartConfig:
825
- def __init__(self, vocab_size=50265, **kwargs):
826
- self.vocab_size = vocab_size
827
- self.__dict__.update(kwargs)
828
- @classmethod
829
- def from_dict(cls, d):
830
- return cls(**d)
831
-
832
- class CodeGenConfig:
833
- def __init__(self, vocab_size=50257, **kwargs):
834
- self.vocab_size = vocab_size
835
- self.__dict__.update(kwargs)
836
- @classmethod
837
- def from_dict(cls, d):
838
- return cls(**d)
839
-
840
- class BartConfig:
841
- def __init__(self, vocab_size=50265, **kwargs):
842
- self.vocab_size = vocab_size
843
- self.__dict__.update(kwargs)
844
- @classmethod
845
- def from_dict(cls, d):
846
- return cls(**d)
847
-
848
- class AutoencoderKLConfig:
849
- def __init__(self, **kwargs):
850
- self.__dict__.update(kwargs)
851
- @classmethod
852
- def from_dict(cls, d):
853
- return cls(**d)
854
-
855
- class OpenLRMConfig:
856
- def __init__(self, **kwargs):
857
- self.__dict__.update(kwargs)
858
- @classmethod
859
- def from_dict(cls, d):
860
- return cls(**d)
861
-
862
- class UNet2DConditionModelConfig:
863
- def __init__(self, **kwargs):
864
- self.__dict__.update(kwargs)
865
- @classmethod
866
- def from_dict(cls, d):
867
- return cls(**d)
868
-
869
- class MusicGenConfig:
870
- def __init__(self, **kwargs):
871
- self.__dict__.update(kwargs)
872
- @classmethod
873
- def from_dict(cls, d):
874
- return cls(**d)
875
-
876
- class GPT2LMHeadModel(nn.Module):
877
- def __init__(self, config):
878
- super().__init__()
879
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
880
- self.transformer = nn.TransformerEncoder(layer, num_layers=12)
881
- self.lm_head = nn.Linear(768, config.vocab_size)
882
- def forward(self, x):
883
- return self.lm_head(self.transformer(x))
884
-
885
- class MBartForConditionalGeneration(nn.Module):
886
- def __init__(self, config):
887
- super().__init__()
888
- self.config = config
889
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
890
- self.encoder = nn.TransformerEncoder(layer, num_layers=6)
891
- dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
892
- self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
893
- self.output_layer = nn.Linear(768, config.vocab_size)
894
- def forward(self, src, tgt):
895
- return self.output_layer(self.decoder(tgt, self.encoder(src)))
896
-
897
- class CodeGenForCausalLM(nn.Module):
898
- def __init__(self, config):
899
- super().__init__()
900
- d_model = getattr(config, "d_model", 1024)
901
- n_head = getattr(config, "n_head", 16)
902
- num_layers = getattr(config, "num_layers", 12)
903
- dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
904
- self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
905
- self.lm_head = nn.Linear(d_model, config.vocab_size)
906
- def forward(self, tgt, memory=None):
907
- if memory is None:
908
- memory = torch.zeros_like(tgt)
909
- return self.lm_head(self.transformer_decoder(tgt, memory))
910
-
911
- class BartForConditionalGeneration(nn.Module):
912
- def __init__(self, config):
913
- super().__init__()
914
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
915
- self.encoder = nn.TransformerEncoder(layer, num_layers=6)
916
- dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
917
- self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
918
- self.output_layer = nn.Linear(768, config.vocab_size)
919
- def forward(self, src, tgt):
920
- return self.output_layer(self.decoder(tgt, self.encoder(src)))
921
-
922
- class ResnetBlock(nn.Module):
923
- def __init__(self, in_ch, out_ch):
924
- super().__init__()
925
- self.norm1 = nn.GroupNorm(32, in_ch)
926
- self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
927
- self.norm2 = nn.GroupNorm(32, out_ch)
928
- self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
929
- self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
930
- def forward(self, x):
931
- sc = self.conv_shortcut(x)
932
- h = F.silu(self.norm1(x))
933
- h = self.conv1(h)
934
- h = F.silu(self.norm2(h))
935
- h = self.conv2(h)
936
- return h + sc
937
-
938
- class Downsample(nn.Module):
939
- def __init__(self, in_ch, out_ch):
940
- super().__init__()
941
- self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
942
- def forward(self, x):
943
- return self.conv(x)
944
-
945
- class DownBlock(nn.Module):
946
- def __init__(self, in_ch, out_ch, num_res):
947
- super().__init__()
948
- self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
949
- self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
950
- def forward(self, x):
951
- for r in self.resnets:
952
- x = r(x)
953
- for ds in self.downsamplers:
954
- x = ds(x)
955
- return x
956
-
957
- class Upsample(nn.Module):
958
- def __init__(self, in_ch, out_ch):
959
- super().__init__()
960
- self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
961
- def forward(self, x):
962
- return self.conv(x)
963
-
964
- class UpBlock(nn.Module):
965
- def __init__(self, in_ch, out_ch, num_res):
966
- super().__init__()
967
- self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
968
- self.upsampler = Upsample(out_ch, out_ch)
969
- def forward(self, x):
970
- for r in self.resnets:
971
- x = r(x)
972
- return self.upsampler(x)
973
-
974
- class AttentionBlock(nn.Module):
975
- def __init__(self, ch):
976
- super().__init__()
977
- self.norm = nn.GroupNorm(32, ch)
978
- self.query = nn.Conv2d(ch, ch, 1)
979
- self.key = nn.Conv2d(ch, ch, 1)
980
- self.value = nn.Conv2d(ch, ch, 1)
981
- self.proj_attn = nn.Conv2d(ch, ch, 1)
982
- def forward(self, x):
983
- b, c, h, w = x.shape
984
- xn = self.norm(x)
985
- q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
986
- k = self.key(xn).view(b, c, -1)
987
- v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
988
- attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
989
- out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
990
- return x + self.proj_attn(out)
991
-
992
- class Encoder(nn.Module):
993
- def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
994
- super().__init__()
995
- self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
996
- self.down_blocks = nn.ModuleList([
997
- DownBlock(base_ch, base_ch, 2),
998
- DownBlock(base_ch, base_ch * 2, 2),
999
- DownBlock(base_ch * 2, base_ch * 4, 2),
1000
- DownBlock(base_ch * 4, base_ch * 4, 2)
1001
- ])
1002
- self.mid_block = nn.ModuleList([
1003
- ResnetBlock(base_ch * 4, base_ch * 4),
1004
- AttentionBlock(base_ch * 4),
1005
- ResnetBlock(base_ch * 4, base_ch * 4)
1006
- ])
1007
- self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
1008
- self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
1009
- self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
1010
- def forward(self, x):
1011
- x = self.conv_in(x)
1012
- for blk in self.down_blocks:
1013
- x = blk(x)
1014
- for m in self.mid_block:
1015
- x = m(x)
1016
- x = self.conv_norm_out(x)
1017
- x = self.conv_out(x)
1018
- return self.quant_conv(x)
1019
-
1020
- class Decoder(nn.Module):
1021
- def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
1022
- super().__init__()
1023
- self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
1024
- self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
1025
- self.mid_block = nn.ModuleList([
1026
- ResnetBlock(base_ch * 4, base_ch * 4),
1027
- AttentionBlock(base_ch * 4),
1028
- ResnetBlock(base_ch * 4, base_ch * 4)
1029
- ])
1030
- self.up_blocks = nn.ModuleList([
1031
- UpBlock(base_ch * 4, base_ch * 4, 3),
1032
- UpBlock(base_ch * 4, base_ch * 2, 3),
1033
- UpBlock(base_ch * 2, base_ch, 3),
1034
- UpBlock(base_ch, base_ch, 3)
1035
- ])
1036
- self.conv_norm_out = nn.GroupNorm(32, base_ch)
1037
- self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
1038
- def forward(self, x):
1039
- x = self.post_quant_conv(x)
1040
- x = self.conv_in(x)
1041
- for m in self.mid_block:
1042
- x = m(x)
1043
- for up in self.up_blocks:
1044
- x = up(x)
1045
- x = self.conv_norm_out(x)
1046
- return self.conv_out(x)
1047
-
1048
- class AutoencoderKL(nn.Module):
1049
- def __init__(self, config):
1050
- super().__init__()
1051
- in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
1052
- out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
1053
- base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
1054
- latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
1055
- self.encoder = Encoder(in_ch, base_ch, latent_ch)
1056
- self.decoder = Decoder(out_ch, base_ch, latent_ch)
1057
- def forward(self, x):
1058
- return self.decoder(self.encoder(x))
1059
- def decode(self, x):
1060
- return self.decoder(x)
1061
-
1062
- class TransformerBlock(nn.Module):
1063
- def __init__(self, embed_dim, num_heads):
1064
- super().__init__()
1065
- self.norm1 = nn.LayerNorm(embed_dim)
1066
- self.attn = nn.MultiheadAttention(embed_dim, num_heads)
1067
- self.norm2 = nn.LayerNorm(embed_dim)
1068
- hidden_dim = embed_dim * 4
1069
- self.mlp = nn.Sequential(
1070
- nn.Linear(embed_dim, hidden_dim),
1071
- nn.GELU(),
1072
- nn.Linear(hidden_dim, embed_dim)
1073
- )
1074
- def forward(self, x):
1075
- res = x
1076
- x = self.norm1(x)
1077
- x = x.transpose(0, 1)
1078
- attn, _ = self.attn(x, x, x)
1079
- x = attn.transpose(0, 1)
1080
- x = res + x
1081
- return x + self.mlp(self.norm2(x))
1082
-
1083
- class VisionTransformer(nn.Module):
1084
- def __init__(self, config):
1085
- super().__init__()
1086
- if isinstance(config, dict):
1087
- self.img_size = config.get("img_size", 592)
1088
- self.patch_size = config.get("patch_size", 16)
1089
- self.embed_dim = config.get("hidden_size", 768)
1090
- depth = config.get("depth", 12)
1091
- num_heads = config.get("num_heads", 12)
1092
- else:
1093
- self.img_size = config.__dict__.get("img_size", 592)
1094
- self.patch_size = config.__dict__.get("patch_size", 16)
1095
- self.embed_dim = config.__dict__.get("hidden_size", 768)
1096
- depth = config.__dict__.get("depth", 12)
1097
- num_heads = config.__dict__.get("num_heads", 12)
1098
- num_patches = (self.img_size // self.patch_size) ** 2
1099
- self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
1100
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
1101
- self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
1102
- self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
1103
- self.norm = nn.LayerNorm(self.embed_dim)
1104
- self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
1105
- self._init_weights()
1106
- def _init_weights(self):
1107
- nn.init.normal_(self.cls_token, std=0.02)
1108
- nn.init.normal_(self.pos_embed, std=0.02)
1109
- def forward(self, x):
1110
- x = self.patch_embed(x)
1111
- x = x.flatten(2).transpose(1, 2)
1112
- cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
1113
- x = torch.cat((cls_tokens, x), dim=1)
1114
- x = x + self.pos_embed
1115
- for blk in self.blocks:
1116
- x = blk(x)
1117
- return self.norm(x)[:, 0]
1118
-
1119
- class OpenLRM(nn.Module):
1120
- def __init__(self, config):
1121
- super().__init__()
1122
- self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
1123
- hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
1124
- self.linear = nn.Linear(hidden, hidden)
1125
- def forward(self, x):
1126
- return self.linear(self.encoder["model"](x))
1127
-
1128
- class VideoUNet(nn.Module):
1129
- def __init__(self, in_ch=4, out_ch=4, features=None):
1130
- super().__init__()
1131
- if features is None:
1132
- features = [64, 128, 256]
1133
- self.encoder = nn.ModuleList()
1134
- self.pool = nn.MaxPool3d(2, 2)
1135
- self.decoder = nn.ModuleList()
1136
- for f in features:
1137
- self.encoder.append(nn.Sequential(
1138
- nn.Conv3d(in_ch, f, 3, padding=1),
1139
- nn.ReLU(inplace=True),
1140
- nn.Conv3d(f, f, 3, padding=1),
1141
- nn.ReLU(inplace=True)
1142
- ))
1143
- in_ch = f
1144
- for f in reversed(features):
1145
- self.decoder.append(nn.Sequential(
1146
- nn.Conv3d(f * 2, f, 3, padding=1),
1147
- nn.ReLU(inplace=True),
1148
- nn.Conv3d(f, f, 3, padding=1),
1149
- nn.ReLU(inplace=True)
1150
- ))
1151
- self.final_conv = nn.Conv3d(features[0], out_ch, 1)
1152
- def forward(self, x, t, encoder_hidden_states):
1153
- skips = []
1154
- for enc in self.encoder:
1155
- x = enc(x)
1156
- skips.append(x)
1157
- x = self.pool(x)
1158
- for dec in self.decoder:
1159
- skip = skips.pop()
1160
- x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
1161
- x = torch.cat([x, skip], dim=1)
1162
- x = dec(x)
1163
- return self.final_conv(x)
1164
-
1165
- class SentimentClassifierModel(nn.Module):
1166
- def __init__(self, config):
1167
- super().__init__()
1168
- self.classifier = nn.Sequential(
1169
- nn.Linear(768, 256),
1170
- nn.ReLU(),
1171
- nn.Linear(256, 2)
1172
- )
1173
- def forward(self, x):
1174
- return self.classifier(x)
1175
-
1176
- class STTModel(nn.Module):
1177
- def __init__(self, config):
1178
- super().__init__()
1179
- self.net = nn.Sequential(
1180
- nn.Linear(768, 512),
1181
- nn.ReLU(),
1182
- nn.Linear(512, 768)
1183
- )
1184
- def forward(self, x):
1185
- return self.net(x)
1186
-
1187
- class TTSModel(nn.Module):
1188
- def __init__(self, config):
1189
- super().__init__()
1190
- self.net = nn.Sequential(
1191
- nn.Linear(768, 512),
1192
- nn.ReLU(),
1193
- nn.Linear(512, 768)
1194
- )
1195
- def forward(self, x):
1196
- return self.net(x)
1197
-
1198
- class MusicGenModel(nn.Module):
1199
- def __init__(self, config):
1200
- super().__init__()
1201
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
1202
- self.transformer = nn.TransformerEncoder(layer, num_layers=12)
1203
- self.linear = nn.Linear(768, 768)
1204
- def forward(self, x):
1205
- return self.linear(self.transformer(x))
1206
-
1207
- class SimpleTextEncoder(nn.Module):
1208
- def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
1209
- super().__init__()
1210
- self.embedding = nn.Embedding(vocab_size, embed_dim)
1211
- self.max_length = max_length
1212
- def forward(self, text_tokens):
1213
- return self.embedding(text_tokens)
1214
-
1215
- class DiffusionScheduler:
1216
- def __init__(self, steps):
1217
- self.steps = steps
1218
- self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
1219
- self.alphas = 1 - self.betas
1220
- self.alpha_bars = torch.cumprod(self.alphas, dim=0)
1221
- def step(self, noise, t, sample):
1222
- alpha_bar = self.alpha_bars[t]
1223
- if t > 0:
1224
- alpha_bar_prev = self.alpha_bars[t-1]
1225
- else:
1226
- alpha_bar_prev = torch.tensor(1.0, device=sample.device)
1227
- x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar)
1228
- new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise
1229
- return new_sample
1230
-
1231
- class VideoOutput:
1232
- def __init__(self, frames):
1233
- self.frames = [img_as_ubyte(frame) for frame in frames[0]]
1234
-
1235
- class VideoPipeline(nn.Module):
1236
- def __init__(self, unet, vae, text_encoder, vocab):
1237
- super().__init__()
1238
- self.unet = unet
1239
- self.vae = vae
1240
- self.text_encoder = text_encoder
1241
- self.vocab = vocab
1242
- def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
1243
- token_ids = simple_tokenizer(prompt, self.vocab)
1244
- text_emb = self.text_encoder(token_ids)
1245
- latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
1246
- sched = DiffusionScheduler(steps)
1247
- for t in range(steps):
1248
- noise = self.unet(latent, t, text_emb)
1249
- latent = sched.step(noise, t, latent)
1250
- frames = self.vae.decode(latent / 0.18215)
1251
- frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
1252
- return VideoOutput(frames)
1253
-
1254
- def initialize_gpt2_model(folder, files):
1255
- download_files(folder, files)
1256
- config = GPT2Config()
1257
- model = GPT2LMHeadModel(config).to(device)
1258
- sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
1259
- load_state_dict_safe(model, sd)
1260
- model.eval()
1261
- enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
1262
- return model, enc
1263
-
1264
- def initialize_translation_model(folder, files):
1265
- download_files(folder, files)
1266
- config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1267
- model = MBartForConditionalGeneration(config).to(device)
1268
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1269
- load_state_dict_safe(model, sd)
1270
- model.eval()
1271
- vp = os.path.join(folder, "vocab.json")
1272
- if os.path.exists(vp):
1273
- vocab = read_json(vp)
1274
- model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
1275
- else:
1276
- model.tokenizer = lambda txt: txt
1277
- model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
1278
- return model
1279
-
1280
- def initialize_codegen_model(folder, files):
1281
- download_files(folder, files)
1282
- config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1283
- model = CodeGenForCausalLM(config).to(device)
1284
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1285
- load_state_dict_safe(model, sd)
1286
- model.eval()
1287
- tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
1288
- vocab = read_json(os.path.join(folder, "vocab.json"))
1289
- idx2w = {v: k for k, v in vocab.items()}
1290
- model.tokenizer = tok
1291
- return model, tok, vocab, idx2w, vocab
1292
-
1293
- def initialize_summarization_model(folder, files):
1294
- download_files(folder, files)
1295
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1296
- model = BartForConditionalGeneration(config).to(device)
1297
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1298
- load_state_dict_safe(model, sd)
1299
- model.eval()
1300
- vp = os.path.join(folder, "vocab.json")
1301
- if os.path.exists(vp):
1302
- vocab_json = read_json(vp)
1303
- vocab = set(vocab_json.keys())
1304
- return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
1305
- return model, None, None, None
1306
-
1307
- def initialize_imagegen_model(folder, files):
1308
- download_files(folder, files)
1309
- config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1310
- vae = AutoencoderKL(config).to(device)
1311
- sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
1312
- load_state_dict_safe(vae, sd)
1313
- vae.eval()
1314
- return vae
1315
-
1316
- def initialize_image_to_3d_model(folder, files):
1317
- download_files(folder, files)
1318
- config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1319
- model3d = OpenLRM(config).to(device)
1320
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1321
- load_state_dict_safe(model3d, sd)
1322
- model3d.eval()
1323
- return model3d
1324
-
1325
- def initialize_text_to_video_model(folder, files):
1326
- download_files(folder, files)
1327
- unet_cfg = read_json(os.path.join(folder, "config.json"))
1328
- unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
1329
- unet = VideoUNet(**unet_cfg).half().to(device)
1330
- sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
1331
- load_state_dict_safe(unet, sd_unet)
1332
- unet.eval()
1333
- vae_cfg = read_json(os.path.join(folder, "config.json"))
1334
- vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
1335
- vae = AutoencoderKL(vae_cfg).half().to(device)
1336
- sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
1337
- load_state_dict_safe(vae, sd_vae)
1338
- vae.eval()
1339
- vp = os.path.join(folder, "vocab.json")
1340
- text_vocab = read_json(vp) if os.path.exists(vp) else {}
1341
- te_path = os.path.join(folder, "text_encoder.bin")
1342
- if os.path.exists(te_path):
1343
- text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
1344
- sd_te = torch.load(te_path, map_location=device)
1345
- load_state_dict_safe(text_encoder, sd_te)
1346
- else:
1347
- text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
1348
- text_encoder.eval()
1349
- return VideoPipeline(unet, vae, text_encoder, text_vocab)
1350
-
1351
- def initialize_sentiment_model(folder, files):
1352
- download_files(folder, files)
1353
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1354
- model = SentimentClassifierModel(config).to(device)
1355
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1356
- load_state_dict_safe(model, sd)
1357
- model.eval()
1358
- vp = os.path.join(folder, "vocab.json")
1359
- if os.path.exists(vp):
1360
- read_json(vp)
1361
- return model
1362
-
1363
- def initialize_stt_model(folder, files):
1364
- download_files(folder, files)
1365
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1366
- model = STTModel(config).to(device)
1367
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1368
- load_state_dict_safe(model, sd)
1369
- model.eval()
1370
- vp = os.path.join(folder, "vocab.json")
1371
- if os.path.exists(vp):
1372
- read_json(vp)
1373
- return model
1374
-
1375
- def initialize_tts_model(folder, files):
1376
- download_files(folder, files)
1377
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1378
- model = TTSModel(config).to(device)
1379
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1380
- load_state_dict_safe(model, sd)
1381
- model.eval()
1382
- vp = os.path.join(folder, "vocab.json")
1383
- if os.path.exists(vp):
1384
- read_json(vp)
1385
- return model
1386
-
1387
- def initialize_musicgen_model(folder, files):
1388
- download_files(folder, files)
1389
- config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
1390
- model = MusicGenModel(config).to(device)
1391
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
1392
- load_state_dict_safe(model, sd)
1393
- model.eval()
1394
- return model
 
723
  load_state_dict_safe(model, sd)
724
  model.eval()
725
  return model