Hhhh / model_loader.py
Hjgugugjhuhjggg's picture
Upload 27 files
1c817fd verified
raw
history blame
26.6 kB
import os
import json
import urllib.request
import urllib.parse
import torch
import hashlib
from tqdm import tqdm
from skimage import img_as_ubyte
from torch import nn
import torch.nn.functional as F
import inspect
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def filter_kwargs(cls, kwargs):
sig = inspect.signature(cls.__init__)
accepted = set(sig.parameters.keys()) - {"self"}
return {k: v for k, v in kwargs.items() if k in accepted}
def sanitize_filename(name, url=None):
for c in '<>:"/\\|?*':
name = name.replace(c, '')
if not name and url is not None:
name = hashlib.md5(url.encode()).hexdigest()
return name
def download_file(url, filepath):
d = os.path.dirname(filepath)
if d and not os.path.exists(d):
os.makedirs(d, exist_ok=True)
if not os.path.exists(filepath):
def prog(t):
last = [0]
def inner(n, bs, ts):
if ts > 0:
t.total = ts
t.update(n * bs - last[0])
last[0] = n * bs
return inner
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t:
urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
def download_files(folder, files_spec):
if isinstance(files_spec, dict):
for fn, url in files_spec.items():
fn = sanitize_filename(fn, url)
fp = os.path.join(folder, fn)
download_file(url, fp)
elif isinstance(files_spec, list):
for item in files_spec:
if isinstance(item, str):
url = item
parsed = urllib.parse.urlparse(url)
fn = os.path.basename(parsed.path)
if not fn:
fn = hashlib.md5(url.encode()).hexdigest()
fn = sanitize_filename(fn, url)
elif isinstance(item, (list, tuple)) and len(item) == 2:
url, fn = item
fn = sanitize_filename(fn, url)
elif isinstance(item, dict) and "filename" in item and "url" in item:
fn = sanitize_filename(item["filename"], item["url"])
url = item["url"]
else:
raise ValueError("Invalid file specification")
fp = os.path.join(folder, fn)
download_file(url, fp)
else:
raise ValueError("files_spec must be dict or list")
def read_json(fp):
with open(fp, 'r', encoding='utf-8') as f:
return json.load(f)
def get_codegen_tokenizer(vocab_path, merges_path):
with open(vocab_path, 'r', encoding='utf-8') as f:
vocab = json.load(f)
with open(merges_path, 'r', encoding='utf-8') as f:
merges = f.read().splitlines()
def tokenizer(text):
toks = text.split()
return [vocab.get(t, 0) for t in toks]
return tokenizer
def simple_tokenizer(text, vocab, max_length=77):
toks = text.split()
ids = [vocab.get(t, 1) for t in toks]
if len(ids) < max_length:
ids = ids + [0]*(max_length - len(ids))
else:
ids = ids[:max_length]
return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
def load_state_dict_safe(model, loaded_state_dict):
model_state = model.state_dict()
new_state = {}
for key, value in model_state.items():
if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape:
new_state[key] = loaded_state_dict[key]
else:
new_state[key] = value
model.load_state_dict(new_state, strict=False)
class GPT2Config:
def __init__(self, vocab_size=50257, **kwargs):
self.vocab_size = vocab_size
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class MBartConfig:
def __init__(self, vocab_size=50265, **kwargs):
self.vocab_size = vocab_size
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class CodeGenConfig:
def __init__(self, vocab_size=50257, **kwargs):
self.vocab_size = vocab_size
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class BartConfig:
def __init__(self, vocab_size=50265, **kwargs):
self.vocab_size = vocab_size
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class AutoencoderKLConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class OpenLRMConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class UNet2DConditionModelConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class MusicGenConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class GPT2LMHeadModel(nn.Module):
def __init__(self, config):
super().__init__()
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
self.transformer = nn.TransformerEncoder(layer, num_layers=12)
self.lm_head = nn.Linear(768, config.vocab_size)
def forward(self, x):
return self.lm_head(self.transformer(x))
class MBartForConditionalGeneration(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
self.encoder = nn.TransformerEncoder(layer, num_layers=6)
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
self.output_layer = nn.Linear(768, config.vocab_size)
def forward(self, src, tgt):
return self.output_layer(self.decoder(tgt, self.encoder(src)))
class CodeGenForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
d_model = getattr(config, "d_model", 1024)
n_head = getattr(config, "n_head", 16)
num_layers = getattr(config, "num_layers", 12)
dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
self.lm_head = nn.Linear(d_model, config.vocab_size)
def forward(self, tgt, memory=None):
if memory is None:
memory = torch.zeros_like(tgt)
return self.lm_head(self.transformer_decoder(tgt, memory))
class BartForConditionalGeneration(nn.Module):
def __init__(self, config):
super().__init__()
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
self.encoder = nn.TransformerEncoder(layer, num_layers=6)
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
self.output_layer = nn.Linear(768, config.vocab_size)
def forward(self, src, tgt):
return self.output_layer(self.decoder(tgt, self.encoder(src)))
class ResnetBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.norm1 = nn.GroupNorm(32, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.norm2 = nn.GroupNorm(32, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
sc = self.conv_shortcut(x)
h = F.silu(self.norm1(x))
h = self.conv1(h)
h = F.silu(self.norm2(h))
h = self.conv2(h)
return h + sc
class Downsample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class DownBlock(nn.Module):
def __init__(self, in_ch, out_ch, num_res):
super().__init__()
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
def forward(self, x):
for r in self.resnets:
x = r(x)
for ds in self.downsamplers:
x = ds(x)
return x
class Upsample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class UpBlock(nn.Module):
def __init__(self, in_ch, out_ch, num_res):
super().__init__()
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
self.upsampler = Upsample(out_ch, out_ch)
def forward(self, x):
for r in self.resnets:
x = r(x)
return self.upsampler(x)
class AttentionBlock(nn.Module):
def __init__(self, ch):
super().__init__()
self.norm = nn.GroupNorm(32, ch)
self.query = nn.Conv2d(ch, ch, 1)
self.key = nn.Conv2d(ch, ch, 1)
self.value = nn.Conv2d(ch, ch, 1)
self.proj_attn = nn.Conv2d(ch, ch, 1)
def forward(self, x):
b, c, h, w = x.shape
xn = self.norm(x)
q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
k = self.key(xn).view(b, c, -1)
v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
return x + self.proj_attn(out)
class Encoder(nn.Module):
def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
super().__init__()
self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
self.down_blocks = nn.ModuleList([
DownBlock(base_ch, base_ch, 2),
DownBlock(base_ch, base_ch * 2, 2),
DownBlock(base_ch * 2, base_ch * 4, 2),
DownBlock(base_ch * 4, base_ch * 4, 2)
])
self.mid_block = nn.ModuleList([
ResnetBlock(base_ch * 4, base_ch * 4),
AttentionBlock(base_ch * 4),
ResnetBlock(base_ch * 4, base_ch * 4)
])
self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
def forward(self, x):
x = self.conv_in(x)
for blk in self.down_blocks:
x = blk(x)
for m in self.mid_block:
x = m(x)
x = self.conv_norm_out(x)
x = self.conv_out(x)
return self.quant_conv(x)
class Decoder(nn.Module):
def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
super().__init__()
self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
self.mid_block = nn.ModuleList([
ResnetBlock(base_ch * 4, base_ch * 4),
AttentionBlock(base_ch * 4),
ResnetBlock(base_ch * 4, base_ch * 4)
])
self.up_blocks = nn.ModuleList([
UpBlock(base_ch * 4, base_ch * 4, 3),
UpBlock(base_ch * 4, base_ch * 2, 3),
UpBlock(base_ch * 2, base_ch, 3),
UpBlock(base_ch, base_ch, 3)
])
self.conv_norm_out = nn.GroupNorm(32, base_ch)
self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
def forward(self, x):
x = self.post_quant_conv(x)
x = self.conv_in(x)
for m in self.mid_block:
x = m(x)
for up in self.up_blocks:
x = up(x)
x = self.conv_norm_out(x)
return self.conv_out(x)
class AutoencoderKL(nn.Module):
def __init__(self, config):
super().__init__()
in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
self.encoder = Encoder(in_ch, base_ch, latent_ch)
self.decoder = Decoder(out_ch, base_ch, latent_ch)
def forward(self, x):
return self.decoder(self.encoder(x))
def decode(self, x):
return self.decoder(x)
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.norm2 = nn.LayerNorm(embed_dim)
hidden_dim = embed_dim * 4
self.mlp = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, embed_dim)
)
def forward(self, x):
res = x
x = self.norm1(x)
x = x.transpose(0, 1)
attn, _ = self.attn(x, x, x)
x = attn.transpose(0, 1)
x = res + x
return x + self.mlp(self.norm2(x))
class VisionTransformer(nn.Module):
def __init__(self, config):
super().__init__()
if isinstance(config, dict):
self.img_size = config.get("img_size", 592)
self.patch_size = config.get("patch_size", 16)
self.embed_dim = config.get("hidden_size", 768)
depth = config.get("depth", 12)
num_heads = config.get("num_heads", 12)
else:
self.img_size = config.__dict__.get("img_size", 592)
self.patch_size = config.__dict__.get("patch_size", 16)
self.embed_dim = config.__dict__.get("hidden_size", 768)
depth = config.__dict__.get("depth", 12)
num_heads = config.__dict__.get("num_heads", 12)
num_patches = (self.img_size // self.patch_size) ** 2
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
self.norm = nn.LayerNorm(self.embed_dim)
self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
return self.norm(x)[:, 0]
class OpenLRM(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
self.linear = nn.Linear(hidden, hidden)
def forward(self, x):
return self.linear(self.encoder["model"](x))
class VideoUNet(nn.Module):
def __init__(self, in_ch=4, out_ch=4, features=None):
super().__init__()
if features is None:
features = [64, 128, 256]
self.encoder = nn.ModuleList()
self.pool = nn.MaxPool3d(2, 2)
self.decoder = nn.ModuleList()
for f in features:
self.encoder.append(nn.Sequential(
nn.Conv3d(in_ch, f, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(f, f, 3, padding=1),
nn.ReLU(inplace=True)
))
in_ch = f
for f in reversed(features):
self.decoder.append(nn.Sequential(
nn.Conv3d(f * 2, f, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(f, f, 3, padding=1),
nn.ReLU(inplace=True)
))
self.final_conv = nn.Conv3d(features[0], out_ch, 1)
def forward(self, x, t, encoder_hidden_states):
skips = []
for enc in self.encoder:
x = enc(x)
skips.append(x)
x = self.pool(x)
for dec in self.decoder:
skip = skips.pop()
x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
x = torch.cat([x, skip], dim=1)
x = dec(x)
return self.final_conv(x)
class SentimentClassifierModel(nn.Module):
def __init__(self, config):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.Linear(256, 2)
)
def forward(self, x):
return self.classifier(x)
class STTModel(nn.Module):
def __init__(self, config):
super().__init__()
self.net = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Linear(512, 768)
)
def forward(self, x):
return self.net(x)
class TTSModel(nn.Module):
def __init__(self, config):
super().__init__()
self.net = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Linear(512, 768)
)
def forward(self, x):
return self.net(x)
class MusicGenModel(nn.Module):
def __init__(self, config):
super().__init__()
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
self.transformer = nn.TransformerEncoder(layer, num_layers=12)
self.linear = nn.Linear(768, 768)
def forward(self, x):
return self.linear(self.transformer(x))
class SimpleTextEncoder(nn.Module):
def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.max_length = max_length
def forward(self, text_tokens):
return self.embedding(text_tokens)
class DiffusionScheduler:
def __init__(self, steps):
self.steps = steps
self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
def step(self, noise, t, sample):
beta = self.betas[t]
return sample - beta * noise
class VideoOutput:
def __init__(self, frames):
self.frames = [img_as_ubyte(frame) for frame in frames[0]]
class VideoPipeline(nn.Module):
def __init__(self, unet, vae, text_encoder, vocab):
super().__init__()
self.unet = unet
self.vae = vae
self.text_encoder = text_encoder
self.vocab = vocab
def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
token_ids = simple_tokenizer(prompt, self.vocab)
text_emb = self.text_encoder(token_ids)
latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
sched = DiffusionScheduler(steps)
for t in range(steps):
noise = self.unet(latent, t, text_emb)
latent = sched.step(noise, t, latent)
frames = self.vae.decode(latent / 0.18215)
frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
return VideoOutput(frames)
def initialize_gpt2_model(folder, files):
download_files(folder, files)
config = GPT2Config()
model = GPT2LMHeadModel(config).to(device)
sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
return model, enc
def initialize_translation_model(folder, files):
download_files(folder, files)
config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = MBartForConditionalGeneration(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
vocab = read_json(vp)
model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
else:
model.tokenizer = lambda txt: txt
model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
return model
def initialize_codegen_model(folder, files):
download_files(folder, files)
config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = CodeGenForCausalLM(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
vocab = read_json(os.path.join(folder, "vocab.json"))
idx2w = {v: k for k, v in vocab.items()}
model.tokenizer = tok
return model, tok, vocab, idx2w, vocab
def initialize_summarization_model(folder, files):
download_files(folder, files)
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = BartForConditionalGeneration(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
vocab_json = read_json(vp)
vocab = set(vocab_json.keys())
return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
return model, None, None, None
def initialize_imagegen_model(folder, files):
download_files(folder, files)
config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
vae = AutoencoderKL(config).to(device)
sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
load_state_dict_safe(vae, sd)
vae.eval()
return vae
def initialize_image_to_3d_model(folder, files):
download_files(folder, files)
config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model3d = OpenLRM(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model3d, sd)
model3d.eval()
return model3d
def initialize_text_to_video_model(folder, files):
download_files(folder, files)
unet_cfg = read_json(os.path.join(folder, "config.json"))
unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
unet = VideoUNet(**unet_cfg).half().to(device)
sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
load_state_dict_safe(unet, sd_unet)
unet.eval()
vae_cfg = read_json(os.path.join(folder, "config.json"))
vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
vae = AutoencoderKL(vae_cfg).half().to(device)
sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
load_state_dict_safe(vae, sd_vae)
vae.eval()
vp = os.path.join(folder, "vocab.json")
text_vocab = read_json(vp) if os.path.exists(vp) else {}
te_path = os.path.join(folder, "text_encoder.bin")
if os.path.exists(te_path):
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
sd_te = torch.load(te_path, map_location=device)
load_state_dict_safe(text_encoder, sd_te)
else:
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
text_encoder.eval()
return VideoPipeline(unet, vae, text_encoder, text_vocab)
def initialize_sentiment_model(folder, files):
download_files(folder, files)
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = SentimentClassifierModel(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
read_json(vp)
return model
def initialize_stt_model(folder, files):
download_files(folder, files)
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = STTModel(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
read_json(vp)
return model
def initialize_tts_model(folder, files):
download_files(folder, files)
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = TTSModel(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
read_json(vp)
return model
def initialize_musicgen_model(folder, files):
download_files(folder, files)
config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = MusicGenModel(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
return model