HiDream-ai-full / gradio_demo.py
blanchon's picture
first commit
755bbb7
raw
history blame
5.25 kB
import gradio as gr
import spaces
import torch
from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
from hi_diffusers.schedulers.flash_flow_match import (
FlashFlowMatchEulerDiscreteScheduler,
)
from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
# Constants
MODEL_PREFIX: str = "HiDream-ai"
LLAMA_MODEL_NAME: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# Model configurations
MODEL_CONFIGS: dict[str, dict] = {
"dev": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
"guidance_scale": 0.0,
"num_inference_steps": 28,
"shift": 6.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler,
},
"full": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Full",
"guidance_scale": 5.0,
"num_inference_steps": 50,
"shift": 3.0,
"scheduler": FlowUniPCMultistepScheduler,
},
"fast": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
"guidance_scale": 0.0,
"num_inference_steps": 16,
"shift": 3.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler,
},
}
# Supported image sizes
RESOLUTION_OPTIONS: list[str] = [
"1024 × 1024 (Square)",
"768 × 1360 (Portrait)",
"1360 × 768 (Landscape)",
"880 × 1168 (Portrait)",
"1168 × 880 (Landscape)",
"1248 × 832 (Landscape)",
"832 × 1248 (Portrait)",
]
# Model cache
loaded_models: dict[str, HiDreamImagePipeline] = {}
def parse_resolution(res_str: str) -> tuple[int, int]:
"""Parse resolution string like '1024 × 1024' into (1024, 1024)"""
return tuple(map(int, res_str.replace("×", "x").replace(" ", "").split("x")))
def load_models(model_type: str) -> HiDreamImagePipeline:
"""Load and initialize the HiDream model pipeline for a given model type."""
config = MODEL_CONFIGS[model_type]
pretrained_model = config["path"]
tokenizer = PreTrainedTokenizerFast.from_pretrained(
LLAMA_MODEL_NAME, use_fast=False
)
text_encoder = LlamaForCausalLM.from_pretrained(
LLAMA_MODEL_NAME,
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16,
).to("cuda")
transformer = HiDreamImageTransformer2DModel.from_pretrained(
pretrained_model,
subfolder="transformer",
torch_dtype=torch.bfloat16,
).to("cuda")
scheduler = config["scheduler"](
num_train_timesteps=1000,
shift=config["shift"],
use_dynamic_shifting=False,
)
pipe = HiDreamImagePipeline.from_pretrained(
pretrained_model,
scheduler=scheduler,
tokenizer_4=tokenizer,
text_encoder_4=text_encoder,
torch_dtype=torch.bfloat16,
).to("cuda", torch.bfloat16)
pipe.transformer = transformer
return pipe
# Preload default model
print("🔧 Preloading default model (full)...")
loaded_models["full"] = load_models("full")
print("✅ Model loaded.")
@spaces.GPU(duration=90)
def generate_image(
model_type: str,
prompt: str,
resolution: str,
seed: int,
) -> tuple[object, int]:
"""Generate image using HiDream pipeline."""
if model_type not in loaded_models:
print(f"📦 Lazy-loading model {model_type}...")
loaded_models[model_type] = load_models(model_type)
pipe: HiDreamImagePipeline = loaded_models[model_type]
config = MODEL_CONFIGS[model_type]
if seed == -1:
seed = torch.randint(0, 1_000_000, (1,)).item()
height, width = parse_resolution(resolution)
generator = torch.Generator("cuda").manual_seed(seed)
image = pipe(
prompt=prompt,
height=height,
width=width,
guidance_scale=config["guidance_scale"],
num_inference_steps=config["num_inference_steps"],
generator=generator,
).images[0]
torch.cuda.empty_cache()
return image, seed
# Gradio UI
with gr.Blocks(title="HiDream Image Generator") as demo:
gr.Markdown("## 🌈 HiDream Image Generator")
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=list(MODEL_CONFIGS.keys()),
value="full",
label="Model Type",
info="Choose between full, fast or dev variants",
)
prompt = gr.Textbox(
label="Prompt",
placeholder="e.g. A futuristic city with floating cars at sunset",
lines=3,
)
resolution = gr.Radio(
choices=RESOLUTION_OPTIONS,
value=RESOLUTION_OPTIONS[0],
label="Resolution",
)
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
generate_btn = gr.Button("Generate Image", variant="primary")
seed_used = gr.Number(label="Seed Used", interactive=False)
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
generate_btn.click(
fn=generate_image,
inputs=[model_type, prompt, resolution, seed],
outputs=[output_image, seed_used],
)
if __name__ == "__main__":
demo.launch()