Spaces:
Paused
Paused
import gradio as gr | |
import spaces | |
import torch | |
from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel | |
from hi_diffusers.schedulers.flash_flow_match import ( | |
FlashFlowMatchEulerDiscreteScheduler, | |
) | |
from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast | |
# Constants | |
MODEL_PREFIX: str = "HiDream-ai" | |
LLAMA_MODEL_NAME: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
# Model configurations | |
MODEL_CONFIGS: dict[str, dict] = { | |
"dev": { | |
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev", | |
"guidance_scale": 0.0, | |
"num_inference_steps": 28, | |
"shift": 6.0, | |
"scheduler": FlashFlowMatchEulerDiscreteScheduler, | |
}, | |
"full": { | |
"path": f"{MODEL_PREFIX}/HiDream-I1-Full", | |
"guidance_scale": 5.0, | |
"num_inference_steps": 50, | |
"shift": 3.0, | |
"scheduler": FlowUniPCMultistepScheduler, | |
}, | |
"fast": { | |
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast", | |
"guidance_scale": 0.0, | |
"num_inference_steps": 16, | |
"shift": 3.0, | |
"scheduler": FlashFlowMatchEulerDiscreteScheduler, | |
}, | |
} | |
# Supported image sizes | |
RESOLUTION_OPTIONS: list[str] = [ | |
"1024 × 1024 (Square)", | |
"768 × 1360 (Portrait)", | |
"1360 × 768 (Landscape)", | |
"880 × 1168 (Portrait)", | |
"1168 × 880 (Landscape)", | |
"1248 × 832 (Landscape)", | |
"832 × 1248 (Portrait)", | |
] | |
# Model cache | |
loaded_models: dict[str, HiDreamImagePipeline] = {} | |
def parse_resolution(res_str: str) -> tuple[int, int]: | |
"""Parse resolution string like '1024 × 1024' into (1024, 1024)""" | |
return tuple(map(int, res_str.replace("×", "x").replace(" ", "").split("x"))) | |
def load_models(model_type: str) -> HiDreamImagePipeline: | |
"""Load and initialize the HiDream model pipeline for a given model type.""" | |
config = MODEL_CONFIGS[model_type] | |
pretrained_model = config["path"] | |
tokenizer = PreTrainedTokenizerFast.from_pretrained( | |
LLAMA_MODEL_NAME, use_fast=False | |
) | |
text_encoder = LlamaForCausalLM.from_pretrained( | |
LLAMA_MODEL_NAME, | |
output_hidden_states=True, | |
output_attentions=True, | |
torch_dtype=torch.bfloat16, | |
).to("cuda") | |
transformer = HiDreamImageTransformer2DModel.from_pretrained( | |
pretrained_model, | |
subfolder="transformer", | |
torch_dtype=torch.bfloat16, | |
).to("cuda") | |
scheduler = config["scheduler"]( | |
num_train_timesteps=1000, | |
shift=config["shift"], | |
use_dynamic_shifting=False, | |
) | |
pipe = HiDreamImagePipeline.from_pretrained( | |
pretrained_model, | |
scheduler=scheduler, | |
tokenizer_4=tokenizer, | |
text_encoder_4=text_encoder, | |
torch_dtype=torch.bfloat16, | |
).to("cuda", torch.bfloat16) | |
pipe.transformer = transformer | |
return pipe | |
# Preload default model | |
print("🔧 Preloading default model (full)...") | |
loaded_models["full"] = load_models("full") | |
print("✅ Model loaded.") | |
def generate_image( | |
model_type: str, | |
prompt: str, | |
resolution: str, | |
seed: int, | |
) -> tuple[object, int]: | |
"""Generate image using HiDream pipeline.""" | |
if model_type not in loaded_models: | |
print(f"📦 Lazy-loading model {model_type}...") | |
loaded_models[model_type] = load_models(model_type) | |
pipe: HiDreamImagePipeline = loaded_models[model_type] | |
config = MODEL_CONFIGS[model_type] | |
if seed == -1: | |
seed = torch.randint(0, 1_000_000, (1,)).item() | |
height, width = parse_resolution(resolution) | |
generator = torch.Generator("cuda").manual_seed(seed) | |
image = pipe( | |
prompt=prompt, | |
height=height, | |
width=width, | |
guidance_scale=config["guidance_scale"], | |
num_inference_steps=config["num_inference_steps"], | |
generator=generator, | |
).images[0] | |
torch.cuda.empty_cache() | |
return image, seed | |
# Gradio UI | |
with gr.Blocks(title="HiDream Image Generator") as demo: | |
gr.Markdown("## 🌈 HiDream Image Generator") | |
with gr.Row(): | |
with gr.Column(): | |
model_type = gr.Radio( | |
choices=list(MODEL_CONFIGS.keys()), | |
value="full", | |
label="Model Type", | |
info="Choose between full, fast or dev variants", | |
) | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="e.g. A futuristic city with floating cars at sunset", | |
lines=3, | |
) | |
resolution = gr.Radio( | |
choices=RESOLUTION_OPTIONS, | |
value=RESOLUTION_OPTIONS[0], | |
label="Resolution", | |
) | |
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) | |
generate_btn = gr.Button("Generate Image", variant="primary") | |
seed_used = gr.Number(label="Seed Used", interactive=False) | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image", type="pil") | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[model_type, prompt, resolution, seed], | |
outputs=[output_image, seed_used], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |