import os import spaces import gradio as gr import torch import logging from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, BitsAndBytesConfig from transformer_hidream_image import HiDreamImageTransformer2DModel from pipeline_hidream_image import HiDreamImagePipeline from schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler from schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler import subprocess print(f"Is CUDA available: {torch.cuda.is_available()}") print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") try: print(subprocess.check_output(["nvcc", "--version"]).decode("utf-8")) except: print("nvcc version check error") # subprocess.run('python -m pip install flash-attn --no-build-isolation', shell=True) def log_vram(msg: str): print(f"{msg} (used {torch.cuda.memory_allocated() / 1024**2:.2f} MB VRAM)\n") # from nf4 import * # Resolution options RESOLUTION_OPTIONS = [ "1024 × 1024 (Square)", "768 × 1360 (Portrait)", "1360 × 768 (Landscape)", "880 × 1168 (Portrait)", "1168 × 880 (Landscape)", "1248 × 832 (Landscape)", "832 × 1248 (Portrait)" ] # quantization_config = BitsAndBytesConfig(load_in_4bit=True) MODEL_PREFIX = "azaneko" LLAMA_MODEL_NAME = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4" FAST_CONFIG = { "path": "azaneko/HiDream-I1-Fast-nf4", "guidance_scale": 0.0, "num_inference_steps": 16, "shift": 3.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler } tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME) log_vram("✅ Tokenizer loaded!") text_encoder_4 = LlamaForCausalLM.from_pretrained( LLAMA_MODEL_NAME, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, torch_dtype=torch.bfloat16, device_map="auto", ) log_vram("✅ Text encoder loaded!") transformer = HiDreamImageTransformer2DModel.from_pretrained( "azaneko/HiDream-I1-Fast-nf4", subfolder="transformer", torch_dtype=torch.bfloat16 ) log_vram("✅ Transformer loaded!") pipe = HiDreamImagePipeline.from_pretrained( "azaneko/HiDream-I1-Fast-nf4", scheduler=FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=3.0, use_dynamic_shifting=False), tokenizer_4=tokenizer_4, text_encoder_4=text_encoder_4, torch_dtype=torch.bfloat16, # quantization_config=quantization_config ) pipe.transformer = transformer log_vram("✅ Pipeline loaded!") pipe.enable_sequential_cpu_offload() # Model configurations MODEL_CONFIGS = { "dev": { "path": f"{MODEL_PREFIX}/HiDream-I1-Dev-nf4", "guidance_scale": 0.0, "num_inference_steps": 28, "shift": 6.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler }, "full": { "path": f"{MODEL_PREFIX}/HiDream-I1-Full-nf4", "guidance_scale": 5.0, "num_inference_steps": 50, "shift": 3.0, "scheduler": FlowUniPCMultistepScheduler }, "fast": { "path": f"{MODEL_PREFIX}/HiDream-I1-Fast-nf4", "guidance_scale": 0.0, "num_inference_steps": 16, "shift": 3.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler } } # Parse resolution string to get height and width def parse_resolution(resolution_str): return tuple(map(int, resolution_str.split("(")[0].strip().split(" × "))) # def load_models(model_type: str): # config = MODEL_CONFIGS[model_type] # tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME) # log_vram("✅ Tokenizer loaded!") # text_encoder_4 = LlamaForCausalLM.from_pretrained( # LLAMA_MODEL_NAME, # output_hidden_states=True, # output_attentions=True, # return_dict_in_generate=True, # torch_dtype=torch.bfloat16, # device_map="auto", # ) # log_vram("✅ Text encoder loaded!") # transformer = HiDreamImageTransformer2DModel.from_pretrained( # config["path"], # subfolder="transformer", # torch_dtype=torch.bfloat16 # ) # log_vram("✅ Transformer loaded!") # pipe = HiDreamImagePipeline.from_pretrained( # config["path"], # scheduler=FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False), # tokenizer_4=tokenizer_4, # text_encoder_4=text_encoder_4, # torch_dtype=torch.bfloat16, # ) # pipe.transformer = transformer # log_vram("✅ Pipeline loaded!") # pipe.enable_sequential_cpu_offload() # return pipe, config #@torch.inference_mode() @spaces.GPU() def generate_image(pipe: HiDreamImagePipeline, model_type: str, prompt: str, resolution: tuple[int, int], seed: int): # Get configuration for current model # config = MODEL_CONFIGS[model_type] guidance_scale = 0.0 num_inference_steps = 16 # Parse resolution width, height = resolution # Handle seed if seed == -1: seed = torch.randint(0, 1000000, (1,)).item() generator = torch.Generator("cuda").manual_seed(seed) images = pipe( prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, num_images_per_prompt=1, generator=generator ).images return images[0], seed @spaces.GPU() def gen_img_helper(prompt, res, seed): global pipe, current_model # 1. Check if the model matches loaded model, load the model if not # if model != current_model: # print(f"Unloading model {current_model}...") # del pipe # torch.cuda.empty_cache() # print(f"Loading model {model}...") # pipe, _ = load_models(model) # current_model = model # print("Model loaded successfully!") # 2. Generate image res = parse_resolution(res) return generate_image(pipe, model, prompt, res, seed) if __name__ == "__main__": logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) # Initialize with default model # print("Loading default model (fast)...") # current_model = "fast" # pipe, _ = load_models(current_model) # print("Model loaded successfully!") # Create Gradio interface with gr.Blocks(title="HiDream-I1-nf4 Dashboard") as demo: gr.Markdown("# HiDream-I1-nf4 Dashboard") with gr.Row(): with gr.Column(): model_type = gr.Radio( choices=list(MODEL_CONFIGS.keys()), value="fast", label="Model Type", info="Select model variant" ) prompt = gr.Textbox( label="Prompt", placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".", lines=3 ) resolution = gr.Radio( choices=RESOLUTION_OPTIONS, value=RESOLUTION_OPTIONS[0], label="Resolution", info="Select image resolution" ) seed = gr.Number( label="Seed (use -1 for random)", value=-1, precision=0 ) generate_btn = gr.Button("Generate Image") seed_used = gr.Number(label="Seed Used", interactive=False) with gr.Column(): output_image = gr.Image(label="Generated Image", type="pil") generate_btn.click( fn=gen_img_helper, inputs=[prompt, resolution, seed], outputs=[output_image, seed_used] ) demo.launch()