Spaces:
Build error
Build error
File size: 3,378 Bytes
7f4a683 b6ffc30 2715645 93b5fa7 78d50d2 f7aaef1 7f4a683 7dd6058 eb3b673 5f2ca72 eb3b673 a330917 6ef0313 104f00c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import spaces
import gradio as gr
import torch
import logging
from diffusers import DiffusionPipeline
from transformer_hidream_image import HiDreamImageTransformer2DModel
from pipeline_hidream_image import HiDreamImagePipeline
import subprocess
try:
print(subprocess.check_output(["nvcc", "--version"]).decode("utf-8"))
except:
print("nvcc version check error")
# subprocess.run('python -m pip install flash-attn --no-build-isolation', shell=True)
from nf4 import *
# Resolution options
RESOLUTION_OPTIONS = [
"1024 × 1024 (Square)",
"768 × 1360 (Portrait)",
"1360 × 768 (Landscape)",
"880 × 1168 (Portrait)",
"1168 × 880 (Landscape)",
"1248 × 832 (Landscape)",
"832 × 1248 (Portrait)"
]
# Parse resolution string to get height and width
def parse_resolution(resolution_str):
return tuple(map(int, resolution_str.split("(")[0].strip().split(" × ")))
@spaces.GPU()
def gen_img_helper(model, prompt, res, seed):
global pipe, current_model
# 1. Check if the model matches loaded model, load the model if not
if model != current_model:
print(f"Unloading model {current_model}...")
del pipe
torch.cuda.empty_cache()
print(f"Loading model {model}...")
pipe, _ = load_models(model)
current_model = model
print("Model loaded successfully!")
# 2. Generate image
res = parse_resolution(res)
return generate_image(pipe, model, prompt, res, seed)
if __name__ == "__main__":
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
# Initialize with default model
print("Loading default model (fast)...")
current_model = "fast"
pipe, _ = load_models(current_model)
print("Model loaded successfully!")
# Create Gradio interface
with gr.Blocks(title="HiDream-I1-nf4 Dashboard") as demo:
gr.Markdown("# HiDream-I1-nf4 Dashboard")
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=list(MODEL_CONFIGS.keys()),
value="fast",
label="Model Type",
info="Select model variant"
)
prompt = gr.Textbox(
label="Prompt",
placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
lines=3
)
resolution = gr.Radio(
choices=RESOLUTION_OPTIONS,
value=RESOLUTION_OPTIONS[0],
label="Resolution",
info="Select image resolution"
)
seed = gr.Number(
label="Seed (use -1 for random)",
value=-1,
precision=0
)
generate_btn = gr.Button("Generate Image")
seed_used = gr.Number(label="Seed Used", interactive=False)
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
generate_btn.click(
fn=gen_img_helper,
inputs=[model_type, prompt, resolution, seed],
outputs=[output_image, seed_used]
)
demo.launch()
|