Duskfallcrew's picture
Create app.py
e9917a9 verified
raw
history blame
13.1 kB
import os
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import HfApi, login
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError
import re
import json
import glob
import gdown
import requests
import subprocess
from urllib.parse import urlparse, unquote
from pathlib import Path
# ---------------------- DEPENDENCIES ----------------------
def install_dependencies_gradio():
"""Installs the necessary dependencies for the Gradio app. Run this ONCE."""
try:
!pip install -U torch diffusers transformers accelerate safetensors huggingface_hub xformers
print("Dependencies installed successfully.")
except Exception as e:
print(f"Error installing dependencies: {e}")
# ---------------------- UTILITY FUNCTIONS ----------------------
def get_save_dtype(save_precision_as):
"""Determines the save dtype based on the user's choice."""
if save_precision_as == "fp16":
return torch.float16
elif save_precision_as == "bf16":
return torch.bfloat16
elif save_precision_as == "float":
return torch.float32 # Using float32 for "float" option
else:
return None
def determine_load_checkpoint(model_to_load):
"""Determines if the model to load is a checkpoint or a Diffusers model."""
if model_to_load.endswith('.ckpt') or model_to_load.endswith('.safetensors'):
return True
elif os.path.isdir(model_to_load):
required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")):
return False
return None # handle this case as required
def increment_filename(filename):
"""Increments the filename to avoid overwriting existing files."""
base, ext = os.path.splitext(filename)
counter = 1
while os.path.exists(filename):
filename = f"{base}({counter}){ext}"
counter += 1
return filename
def create_model_repo(api, user, orgs_name, model_name, make_private=False):
"""Creates a Hugging Face model repository if it doesn't exist."""
if orgs_name == "":
repo_id = user["name"] + "/" + model_name.strip()
else:
repo_id = orgs_name + "/" + model_name.strip()
try:
validate_repo_id(repo_id)
api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
print(f"Model repo '{repo_id}' didn't exist, creating repo")
except HfHubHTTPError as e:
print(f"Model repo '{repo_id}' exists, skipping create repo")
print(f"Model repo '{repo_id}' link: https://huggingface.co/{repo_id}\n")
return repo_id
def is_diffusers_model(model_path):
"""Checks if a given path is a valid Diffusers model directory."""
required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
return required_folders.issubset(set(os.listdir(model_path))) and os.path.isfile(os.path.join(model_path, "model_index.json"))
# ---------------------- CONVERSION AND UPLOAD FUNCTIONS ----------------------
def load_sdxl_model(args, is_load_checkpoint, load_dtype, output_widget):
"""Loads the SDXL model from a checkpoint or Diffusers model."""
model_load_message = "checkpoint" if is_load_checkpoint else "Diffusers" + (" as fp16" if args.fp16 else "")
with output_widget:
print(f"Loading {model_load_message}: {args.model_to_load}")
if is_load_checkpoint:
loaded_model_data = load_from_sdxl_checkpoint(args, output_widget)
else:
loaded_model_data = load_sdxl_from_diffusers(args, load_dtype)
return loaded_model_data
def load_from_sdxl_checkpoint(args, output_widget):
"""Loads the SDXL model components from a checkpoint file (placeholder)."""
# text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
# "sdxl_base_v1-0", args.model_to_load, "cpu"
# )
# Implement Load model from ckpt or safetensors
text_encoder1, text_encoder2, vae, unet = None, None, None, None
with output_widget:
print("Loading from Checkpoint not implemented, please implement based on your model needs.")
return text_encoder1, text_encoder2, vae, unet
def load_sdxl_from_diffusers(args, load_dtype):
"""Loads an SDXL model from a Diffusers model directory."""
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, tokenizer_2=None, scheduler=None
)
text_encoder1 = pipeline.text_encoder
text_encoder2 = pipeline.text_encoder_2
vae = pipeline.vae
unet = pipeline.unet
return text_encoder1, text_encoder2, vae, unet
def convert_and_save_sdxl_model(args, is_save_checkpoint, loaded_model_data, save_dtype, output_widget):
"""Converts and saves the SDXL model as either a checkpoint or a Diffusers model."""
text_encoder1, text_encoder2, vae, unet = loaded_model_data
model_save_message = "checkpoint" + ("" if save_dtype is None else f" in {save_dtype}") if is_save_checkpoint else "Diffusers"
with output_widget:
print(f"Converting and saving as {model_save_message}: {args.model_to_save}")
if is_save_checkpoint:
save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget)
else:
save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget)
def save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
"""Saves the SDXL model components as a checkpoint file (placeholder)."""
# logit_scale = None
# ckpt_info = None
# key_count = sdxl_model_util.save_stable_diffusion_checkpoint(
# args.model_to_save, text_encoder1, text_encoder2, unet, args.epoch, args.global_step, ckpt_info, vae, logit_scale, save_dtype
# )
with output_widget:
print("Saving as Checkpoint not implemented, please implement based on your model needs.")
# print(f"Model saved. Total converted state_dict keys: {key_count}")
def save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
"""Saves the SDXL model as a Diffusers model."""
with output_widget:
reference_model_message = args.reference_model if args.reference_model is not None else 'default model'
print(f"Copying scheduler/tokenizer config from: {reference_model_message}")
# Save diffusers pipeline
pipeline = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder1,
text_encoder_2=text_encoder2,
unet=unet,
scheduler=None, # Replace None if there is a scheduler
tokenizer=None, # Replace None if there is a tokenizer
tokenizer_2=None # Replace None if there is a tokenizer_2
)
pipeline.save_pretrained(args.model_to_save)
with output_widget:
print(f"Model saved as {save_dtype}.")
def convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, output_widget):
"""Main conversion function."""
class Args: # Defining Args locally within convert_model
def __init__(self, model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16):
self.model_to_load = model_to_load
self.save_precision_as = save_precision_as
self.epoch = epoch
self.global_step = global_step
self.reference_model = reference_model
self.output_path = output_path
self.fp16 = fp16
args = Args(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16)
args.model_to_save = increment_filename(os.path.splitext(args.model_to_load)[0] + ".safetensors")
try:
load_dtype = torch.float16 if fp16 else None
save_dtype = get_save_dtype(save_precision_as)
is_load_checkpoint = determine_load_checkpoint(model_to_load)
is_save_checkpoint = not is_load_checkpoint # reverse of load model
loaded_model_data = load_sdxl_model(args, is_load_checkpoint, load_dtype, output_widget)
convert_and_save_sdxl_model(args, is_save_checkpoint, loaded_model_data, save_dtype, output_widget)
with output_widget:
return f"Conversion complete. Model saved to {args.model_to_save}"
except Exception as e:
with output_widget:
return f"Conversion failed: {e}"
def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private, output_widget):
"""Uploads a model to the Hugging Face Hub."""
try:
login(hf_token, add_to_git_credential=True)
api = HfApi()
user = api.whoami(hf_token)
model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
# Determine upload parameters (adjust as needed)
path_in_repo = ""
trained_model = os.path.basename(model_path)
path_in_repo_local = path_in_repo if path_in_repo and not is_diffusers_model(model_path) else ""
notification = f"Uploading {trained_model} from {model_path} to https://huggingface.co/{model_repo}"
with output_widget:
print(notification)
if os.path.isdir(model_path):
if is_diffusers_model(model_path):
commit_message = f"Upload diffusers format: {trained_model}"
print("Detected diffusers model. Adjusting upload parameters.")
else:
commit_message = f"Upload checkpoint: {trained_model}"
print("Detected regular model. Adjusting upload parameters.")
api.upload_folder(
folder_path=model_path,
path_in_repo=path_in_repo_local,
repo_id=model_repo,
commit_message=commit_message,
ignore_patterns=".ipynb_checkpoints",
)
else:
commit_message = f"Upload file: {trained_model}"
api.upload_file(
path_or_fileobj=model_path,
path_in_repo=path_in_repo_local,
repo_id=model_repo,
commit_message=commit_message,
)
with output_widget:
return f"Model upload complete! Check it out at https://huggingface.co/{model_repo}/tree/main"
except Exception as e:
with output_widget:
return f"Upload failed: {e}"
# ---------------------- GRADIO INTERFACE ----------------------
def main(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private):
"""Main function orchestrating the entire process."""
output = gr.Markdown()
conversion_output = convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, output)
upload_output = upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private, output)
# Return a combined output
return f"{conversion_output}\n\n{upload_output}"
with gr.Blocks() as demo:
# Add initial warnings (only once)
gr.Markdown("""
## **⚠️ IMPORTANT WARNINGS ⚠️**
This app may violate Google Colab AUP. Use at your own risk. `xformers` may cause issues.
""")
model_to_load = gr.Textbox(label="Model to Load (Checkpoint or Diffusers)", placeholder="Path to model")
with gr.Row():
save_precision_as = gr.Dropdown(
choices=["fp16", "bf16", "float"], value="fp16", label="Save Precision As"
)
fp16 = gr.Checkbox(label="Load as fp16 (Diffusers only)")
with gr.Row():
epoch = gr.Number(value=0, label="Epoch to Write (Checkpoint)")
global_step = gr.Number(value=0, label="Global Step to Write (Checkpoint)")
reference_model = gr.Textbox(label="Reference Diffusers Model",
placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0")
output_path = gr.Textbox(label="Output Path", value="/content/output")
gr.Markdown("## Hugging Face Hub Configuration")
hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
with gr.Row():
orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
make_private = gr.Checkbox(label="Make Repository Private", value=False)
convert_button = gr.Button("Convert and Upload")
output = gr.Markdown()
convert_button.click(fn=main,
inputs=[model_to_load, save_precision_as, epoch, global_step, reference_model,
output_path, fp16, hf_token, orgs_name, model_name, make_private],
outputs=output)
demo.launch()