LPX55's picture
minor: undo
43b5c70
raw
history blame
7.95 kB
import os
import spaces
import gradio as gr
import torch
import logging
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, BitsAndBytesConfig
from transformer_hidream_image import HiDreamImageTransformer2DModel
from pipeline_hidream_image import HiDreamImagePipeline
from schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
import subprocess
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
try:
print(subprocess.check_output(["nvcc", "--version"]).decode("utf-8"))
except:
print("nvcc version check error")
# subprocess.run('python -m pip install flash-attn --no-build-isolation', shell=True)
def log_vram(msg: str):
print(f"{msg} (used {torch.cuda.memory_allocated() / 1024**2:.2f} MB VRAM)\n")
# from nf4 import *
# Resolution options
RESOLUTION_OPTIONS = [
"1024 Γ— 1024 (Square)",
"768 Γ— 1360 (Portrait)",
"1360 Γ— 768 (Landscape)",
"880 Γ— 1168 (Portrait)",
"1168 Γ— 880 (Landscape)",
"1248 Γ— 832 (Landscape)",
"832 Γ— 1248 (Portrait)"
]
# quantization_config = BitsAndBytesConfig(load_in_4bit=True)
MODEL_PREFIX = "azaneko"
LLAMA_MODEL_NAME = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
FAST_CONFIG = {
"path": "azaneko/HiDream-I1-Fast-nf4",
"guidance_scale": 0.0,
"num_inference_steps": 16,
"shift": 3.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
}
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME)
log_vram("βœ… Tokenizer loaded!")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
LLAMA_MODEL_NAME,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
log_vram("βœ… Text encoder loaded!")
transformer = HiDreamImageTransformer2DModel.from_pretrained(
"azaneko/HiDream-I1-Fast-nf4",
subfolder="transformer",
torch_dtype=torch.bfloat16
)
log_vram("βœ… Transformer loaded!")
pipe = HiDreamImagePipeline.from_pretrained(
"azaneko/HiDream-I1-Fast-nf4",
scheduler=FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=3.0, use_dynamic_shifting=False),
tokenizer_4=tokenizer_4,
text_encoder_4=text_encoder_4,
torch_dtype=torch.bfloat16,
# quantization_config=quantization_config
)
pipe.transformer = transformer
log_vram("βœ… Pipeline loaded!")
pipe.enable_sequential_cpu_offload()
# Model configurations
MODEL_CONFIGS = {
"dev": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev-nf4",
"guidance_scale": 0.0,
"num_inference_steps": 28,
"shift": 6.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
},
"full": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Full-nf4",
"guidance_scale": 5.0,
"num_inference_steps": 50,
"shift": 3.0,
"scheduler": FlowUniPCMultistepScheduler
},
"fast": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast-nf4",
"guidance_scale": 0.0,
"num_inference_steps": 16,
"shift": 3.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
}
}
# Parse resolution string to get height and width
def parse_resolution(resolution_str):
return tuple(map(int, resolution_str.split("(")[0].strip().split(" Γ— ")))
# def load_models(model_type: str):
# config = MODEL_CONFIGS[model_type]
# tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME)
# log_vram("βœ… Tokenizer loaded!")
# text_encoder_4 = LlamaForCausalLM.from_pretrained(
# LLAMA_MODEL_NAME,
# output_hidden_states=True,
# output_attentions=True,
# return_dict_in_generate=True,
# torch_dtype=torch.bfloat16,
# device_map="auto",
# )
# log_vram("βœ… Text encoder loaded!")
# transformer = HiDreamImageTransformer2DModel.from_pretrained(
# config["path"],
# subfolder="transformer",
# torch_dtype=torch.bfloat16
# )
# log_vram("βœ… Transformer loaded!")
# pipe = HiDreamImagePipeline.from_pretrained(
# config["path"],
# scheduler=FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False),
# tokenizer_4=tokenizer_4,
# text_encoder_4=text_encoder_4,
# torch_dtype=torch.bfloat16,
# )
# pipe.transformer = transformer
# log_vram("βœ… Pipeline loaded!")
# pipe.enable_sequential_cpu_offload()
# return pipe, config
#@torch.inference_mode()
@spaces.GPU()
def generate_image(pipe: HiDreamImagePipeline, model_type: str, prompt: str, resolution: tuple[int, int], seed: int):
# Get configuration for current model
# config = MODEL_CONFIGS[model_type]
guidance_scale = 0.0
num_inference_steps = 16
# Parse resolution
width, height = resolution
# Handle seed
if seed == -1:
seed = torch.randint(0, 1000000, (1,)).item()
generator = torch.Generator("cuda").manual_seed(seed)
images = pipe(
prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=generator
).images
return images[0], seed
@spaces.GPU()
def gen_img_helper(prompt, res, seed):
global pipe, current_model
# 1. Check if the model matches loaded model, load the model if not
# if model != current_model:
# print(f"Unloading model {current_model}...")
# del pipe
# torch.cuda.empty_cache()
# print(f"Loading model {model}...")
# pipe, _ = load_models(model)
# current_model = model
# print("Model loaded successfully!")
# 2. Generate image
res = parse_resolution(res)
return generate_image(pipe, model, prompt, res, seed)
if __name__ == "__main__":
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
# Initialize with default model
# print("Loading default model (fast)...")
# current_model = "fast"
# pipe, _ = load_models(current_model)
# print("Model loaded successfully!")
# Create Gradio interface
with gr.Blocks(title="HiDream-I1-nf4 Dashboard") as demo:
gr.Markdown("# HiDream-I1-nf4 Dashboard")
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=list(MODEL_CONFIGS.keys()),
value="fast",
label="Model Type",
info="Select model variant"
)
prompt = gr.Textbox(
label="Prompt",
placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
lines=3
)
resolution = gr.Radio(
choices=RESOLUTION_OPTIONS,
value=RESOLUTION_OPTIONS[0],
label="Resolution",
info="Select image resolution"
)
seed = gr.Number(
label="Seed (use -1 for random)",
value=-1,
precision=0
)
generate_btn = gr.Button("Generate Image")
seed_used = gr.Number(label="Seed Used", interactive=False)
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
generate_btn.click(
fn=gen_img_helper,
inputs=[prompt, resolution, seed],
outputs=[output_image, seed_used]
)
demo.launch()