import os import urllib.request def download_models(): ED_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth" VAE_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth" ED_MODEL_PATH = "./pretrained_models/genconvit_ed_inference.pth" VAE_MODEL_PATH = "./pretrained_models/genconvit_vae_inference.pth" os.makedirs("pretrained_models", exist_ok=True) def progress(block_num, block_size, total_size): progress_amount = block_num * block_size if total_size > 0: percent = (progress_amount / total_size) * 100 print(f"Downloading... {percent:.2f}%") if not os.path.isfile(ED_MODEL_PATH): print("Downloading ED model") urllib.request.urlretrieve(ED_MODEL_URL, ED_MODEL_PATH, reporthook=progress) if not os.path.isfile(VAE_MODEL_PATH): print("Downloading VAE model") urllib.request.urlretrieve(VAE_MODEL_URL, VAE_MODEL_PATH, reporthook=progress) download_models()