from genconvit.genconvit_ed import GenConViTED import torch import torch.nn as nn from transformers import AutoModel from torchvision import transforms import os from huggingface_hub import hf_hub_download device = "cuda" if torch.cuda.is_available() else "cpu" os.environ['PYTHONOPTIMIZE'] = '0' torch.hub.set_dir('./cache') os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" class GenConViT(nn.Module): def __init__(self, ed, vae, net, fp16): super(GenConViT, self).__init__() self.net = net self.fp16 = fp16 if self.net == 'ed': self.model_ed = self._load_model(ed, GenConViTED, 'vivek-metaphy/genconvit') # elif self.net == 'vae': # self.model_vae = self._load_model(vae, 'GenConViTVAE', 'vivek-metaphy/genconvit-vae') else: self.model_ed = self._load_model(ed, GenConViTED, 'vivek-metaphy/genconvit') # self.model_vae = self._load_model(vae, 'GenConViTVAE', 'vivek-metaphy/genconvit-vae') def _load_model(self, model_name, model_class, hf_model_name): try: model = model_class().to(device) checkpoint_path = f'pretrained_models/{model_name}.pth' if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) if 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) else: print(f"Local model not found. Fetching from Hugging Face...") # Download model from Hugging Face and save it locally model_path = hf_hub_download(repo_id="vivek-metaphy/genconvit", filename=f'{model_name}.pth' ) checkpoint = torch.load(model_path, map_location=device) if 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) model.eval() if self.fp16: model.half() return model except Exception as e: raise Exception(f"Error loading model: {e}") def forward(self, x): if self.net == 'ed': x = self.model_ed(x) # elif self.net == 'vae': # x,_ = self.model_vae(x) else: x1 = self.model_ed(x) # x2,_ = self.model_vae(x) x = torch.cat((x1, x1), dim=0) # (x1 + x2) / 2 # return x