LPX55's picture
fix: structruring
f7aaef1
raw
history blame
3.38 kB
import os
import spaces
import gradio as gr
import torch
import logging
from diffusers import DiffusionPipeline
from transformer_hidream_image import HiDreamImageTransformer2DModel
from pipeline_hidream_image import HiDreamImagePipeline
import subprocess
try:
print(subprocess.check_output(["nvcc", "--version"]).decode("utf-8"))
except:
print("nvcc version check error")
# subprocess.run('python -m pip install flash-attn --no-build-isolation', shell=True)
from nf4 import *
# Resolution options
RESOLUTION_OPTIONS = [
"1024 × 1024 (Square)",
"768 × 1360 (Portrait)",
"1360 × 768 (Landscape)",
"880 × 1168 (Portrait)",
"1168 × 880 (Landscape)",
"1248 × 832 (Landscape)",
"832 × 1248 (Portrait)"
]
# Parse resolution string to get height and width
def parse_resolution(resolution_str):
return tuple(map(int, resolution_str.split("(")[0].strip().split(" × ")))
@spaces.GPU()
def gen_img_helper(model, prompt, res, seed):
global pipe, current_model
# 1. Check if the model matches loaded model, load the model if not
if model != current_model:
print(f"Unloading model {current_model}...")
del pipe
torch.cuda.empty_cache()
print(f"Loading model {model}...")
pipe, _ = load_models(model)
current_model = model
print("Model loaded successfully!")
# 2. Generate image
res = parse_resolution(res)
return generate_image(pipe, model, prompt, res, seed)
if __name__ == "__main__":
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
# Initialize with default model
print("Loading default model (fast)...")
current_model = "fast"
pipe, _ = load_models(current_model)
print("Model loaded successfully!")
# Create Gradio interface
with gr.Blocks(title="HiDream-I1-nf4 Dashboard") as demo:
gr.Markdown("# HiDream-I1-nf4 Dashboard")
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=list(MODEL_CONFIGS.keys()),
value="fast",
label="Model Type",
info="Select model variant"
)
prompt = gr.Textbox(
label="Prompt",
placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
lines=3
)
resolution = gr.Radio(
choices=RESOLUTION_OPTIONS,
value=RESOLUTION_OPTIONS[0],
label="Resolution",
info="Select image resolution"
)
seed = gr.Number(
label="Seed (use -1 for random)",
value=-1,
precision=0
)
generate_btn = gr.Button("Generate Image")
seed_used = gr.Number(label="Seed Used", interactive=False)
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
generate_btn.click(
fn=gen_img_helper,
inputs=[model_type, prompt, resolution, seed],
outputs=[output_image, seed_used]
)
demo.launch()