import gradio as gr import spaces import torch from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel from hi_diffusers.schedulers.flash_flow_match import ( FlashFlowMatchEulerDiscreteScheduler, ) from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler from transformers import LlamaForCausalLM, PreTrainedTokenizerFast # Constants MODEL_PREFIX: str = "HiDream-ai" LLAMA_MODEL_NAME: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" # Model configurations MODEL_CONFIGS: dict[str, dict] = { "dev": { "path": f"{MODEL_PREFIX}/HiDream-I1-Dev", "guidance_scale": 0.0, "num_inference_steps": 28, "shift": 6.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler, }, "full": { "path": f"{MODEL_PREFIX}/HiDream-I1-Full", "guidance_scale": 5.0, "num_inference_steps": 50, "shift": 3.0, "scheduler": FlowUniPCMultistepScheduler, }, "fast": { "path": f"{MODEL_PREFIX}/HiDream-I1-Fast", "guidance_scale": 0.0, "num_inference_steps": 16, "shift": 3.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler, }, } # Supported image sizes RESOLUTION_OPTIONS: list[str] = [ "1024 × 1024 (Square)", "768 × 1360 (Portrait)", "1360 × 768 (Landscape)", "880 × 1168 (Portrait)", "1168 × 880 (Landscape)", "1248 × 832 (Landscape)", "832 × 1248 (Portrait)", ] # Model cache loaded_models: dict[str, HiDreamImagePipeline] = {} def parse_resolution(res_str: str) -> tuple[int, int]: """Parse resolution string like '1024 × 1024' into (1024, 1024)""" return tuple(map(int, res_str.replace("×", "x").replace(" ", "").split("x"))) def load_models(model_type: str) -> HiDreamImagePipeline: """Load and initialize the HiDream model pipeline for a given model type.""" config = MODEL_CONFIGS[model_type] pretrained_model = config["path"] tokenizer = PreTrainedTokenizerFast.from_pretrained( LLAMA_MODEL_NAME, use_fast=False ) text_encoder = LlamaForCausalLM.from_pretrained( LLAMA_MODEL_NAME, output_hidden_states=True, output_attentions=True, torch_dtype=torch.bfloat16, ).to("cuda") transformer = HiDreamImageTransformer2DModel.from_pretrained( pretrained_model, subfolder="transformer", torch_dtype=torch.bfloat16, ).to("cuda") scheduler = config["scheduler"]( num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False, ) pipe = HiDreamImagePipeline.from_pretrained( pretrained_model, scheduler=scheduler, tokenizer_4=tokenizer, text_encoder_4=text_encoder, torch_dtype=torch.bfloat16, ).to("cuda", torch.bfloat16) pipe.transformer = transformer return pipe # Preload default model print("🔧 Preloading default model (full)...") loaded_models["full"] = load_models("full") print("✅ Model loaded.") @spaces.GPU(duration=90) def generate_image( model_type: str, prompt: str, resolution: str, seed: int, ) -> tuple[object, int]: """Generate image using HiDream pipeline.""" if model_type not in loaded_models: print(f"📦 Lazy-loading model {model_type}...") loaded_models[model_type] = load_models(model_type) pipe: HiDreamImagePipeline = loaded_models[model_type] config = MODEL_CONFIGS[model_type] if seed == -1: seed = torch.randint(0, 1_000_000, (1,)).item() height, width = parse_resolution(resolution) generator = torch.Generator("cuda").manual_seed(seed) image = pipe( prompt=prompt, height=height, width=width, guidance_scale=config["guidance_scale"], num_inference_steps=config["num_inference_steps"], generator=generator, ).images[0] torch.cuda.empty_cache() return image, seed # Gradio UI with gr.Blocks(title="HiDream Image Generator") as demo: gr.Markdown("## 🌈 HiDream Image Generator") with gr.Row(): with gr.Column(): model_type = gr.Radio( choices=list(MODEL_CONFIGS.keys()), value="full", label="Model Type", info="Choose between full, fast or dev variants", ) prompt = gr.Textbox( label="Prompt", placeholder="e.g. A futuristic city with floating cars at sunset", lines=3, ) resolution = gr.Radio( choices=RESOLUTION_OPTIONS, value=RESOLUTION_OPTIONS[0], label="Resolution", ) seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) generate_btn = gr.Button("Generate Image", variant="primary") seed_used = gr.Number(label="Seed Used", interactive=False) with gr.Column(): output_image = gr.Image(label="Generated Image", type="pil") generate_btn.click( fn=generate_image, inputs=[model_type, prompt, resolution, seed], outputs=[output_image, seed_used], ) if __name__ == "__main__": demo.launch()