import gradio as gr import spaces import torch from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel from hi_diffusers.schedulers.flash_flow_match import ( FlashFlowMatchEulerDiscreteScheduler, ) from transformers import LlamaForCausalLM, PreTrainedTokenizerFast # Constants MODEL_PREFIX: str = "HiDream-ai" LLAMA_MODEL_NAME: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" MODEL_PATH = "HiDream-ai/HiDream-I1-Dev" MODEL_CONFIGS: dict[str, dict] = { "guidance_scale": 0.0, "num_inference_steps": 28, "shift": 6.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler, } # Model configurations # MODEL_CONFIGS: dict[str, dict] = { # "full": { # "path": f"{MODEL_PREFIX}/HiDream-I1-Full", # "guidance_scale": 5.0, # "num_inference_steps": 50, # "shift": 3.0, # "scheduler": FlowUniPCMultistepScheduler, # }, # "fast": { # "path": f"{MODEL_PREFIX}/HiDream-I1-Fast", # "guidance_scale": 0.0, # "num_inference_steps": 16, # "shift": 3.0, # "scheduler": FlashFlowMatchEulerDiscreteScheduler, # }, # } # Supported image sizes RESOLUTION_OPTIONS: list[str] = [ "1024 x 1024 (Square)", "768 x 1360 (Portrait)", "1360 x 768 (Landscape)", "880 x 1168 (Portrait)", "1168 x 880 (Landscape)", "1248 x 832 (Landscape)", "832 x 1248 (Portrait)", ] def parse_resolution(res_str: str) -> tuple[int, int]: return tuple(map(int, res_str.replace(" ", "").split("x"))) tokenizer = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME, use_fast=False) text_encoder = LlamaForCausalLM.from_pretrained( LLAMA_MODEL_NAME, output_hidden_states=True, output_attentions=True, torch_dtype=torch.bfloat16, ).to("cuda") transformer = HiDreamImageTransformer2DModel.from_pretrained( MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16, ).to("cuda") scheduler = MODEL_CONFIGS["scheduler"]( num_train_timesteps=1000, shift=MODEL_CONFIGS["shift"], use_dynamic_shifting=False, ) pipe = HiDreamImagePipeline.from_pretrained( MODEL_PATH, scheduler=scheduler, tokenizer_4=tokenizer, text_encoder_4=text_encoder, torch_dtype=torch.bfloat16, ).to("cuda", torch.bfloat16) pipe.transformer = transformer @spaces.GPU(duration=90) def generate_image( model_type: str, prompt: str, resolution: str, seed: int, ) -> tuple[object, int]: config = MODEL_CONFIGS[model_type] if seed == -1: seed = torch.randint(0, 1_000_000, (1,)).item() height, width = parse_resolution(resolution) generator = torch.Generator("cuda").manual_seed(seed) image = pipe( prompt=prompt, height=height, width=width, guidance_scale=config["guidance_scale"], num_inference_steps=config["num_inference_steps"], generator=generator, ).images[0] torch.cuda.empty_cache() return image, seed # Gradio UI with gr.Blocks(title="HiDream Image Generator") as demo: gr.Markdown("## 🌈 HiDream Image Generator") with gr.Row(): with gr.Column(): model_type = gr.Radio( choices=list(MODEL_CONFIGS.keys()), value="full", label="Model Type", info="Choose between full, fast or dev variants", ) prompt = gr.Textbox( label="Prompt", placeholder="e.g. A futuristic city with floating cars at sunset", lines=3, ) resolution = gr.Radio( choices=RESOLUTION_OPTIONS, value=RESOLUTION_OPTIONS[0], label="Resolution", ) seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) generate_btn = gr.Button("Generate Image", variant="primary") seed_used = gr.Number(label="Seed Used", interactive=False) with gr.Column(): output_image = gr.Image(label="Generated Image", type="pil") generate_btn.click( fn=generate_image, inputs=[model_type, prompt, resolution, seed], outputs=[output_image, seed_used], ) if __name__ == "__main__": demo.launch()