HiDream-ai-full / app-full.py
blanchon's picture
up
113fc70
import gradio as gr
import PIL
import spaces
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
# Constants
MODEL_PREFIX: str = "HiDream-ai"
LLAMA_MODEL_NAME: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
MODEL_PATH = "HiDream-ai/HiDream-I1-full"
MODEL_CONFIGS = {
"guidance_scale": 5.0,
"num_inference_steps": 50,
"shift": 3.0,
"scheduler": FlowUniPCMultistepScheduler,
}
# Supported image sizes
RESOLUTION_OPTIONS: list[str] = [
"1024 x 1024",
"768 x 1360",
"1360 x 768",
"880 x 1168",
"1168 x 880",
"1248 x 832",
"832 x 1248",
]
quant_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
text_encoder = AutoModelForCausalLM.from_pretrained(
LLAMA_MODEL_NAME,
output_hidden_states=True,
output_attentions=True,
low_cpu_mem_usage=True,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
quant_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
)
transformer = HiDreamImageTransformer2DModel.from_pretrained(
MODEL_PATH,
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
scheduler = MODEL_CONFIGS["scheduler"](
num_train_timesteps=1000,
shift=MODEL_CONFIGS["shift"],
use_dynamic_shifting=False,
)
pipe = HiDreamImagePipeline.from_pretrained(
MODEL_PATH,
transformer=transformer,
scheduler=scheduler,
tokenizer_4=tokenizer,
text_encoder_4=text_encoder,
device_map="balanced",
torch_dtype=torch.bfloat16,
)
@spaces.GPU(duration=120)
def generate_image(
prompt: str,
resolution: str,
seed: int,
progress=gr.Progress(track_tqdm=True), # noqa: ARG001, B008
) -> tuple[PIL.Image.Image, int]:
gr.Info(
"This Spaces is an unofficial quantized version of HiDream-ai-full. It is not as good as the full version, but it is faster and uses less memory."
)
if seed == -1:
seed = torch.randint(0, 1_000_000, (1,)).item()
height, width = tuple(map(int, resolution.replace(" ", "").split("x")))
generator = torch.Generator("cuda").manual_seed(seed)
image = pipe(
prompt=prompt,
height=height,
width=width,
guidance_scale=MODEL_CONFIGS["guidance_scale"],
num_inference_steps=MODEL_CONFIGS["num_inference_steps"],
generator=generator,
).images[0]
torch.cuda.empty_cache()
return image, seed
# Gradio UI
with gr.Blocks(title="HiDream Image Generator Full") as demo:
gr.Markdown("## 🌈 HiDream Image Generator Full")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="e.g. A futuristic city with floating cars at sunset",
lines=3,
)
resolution = gr.Radio(
choices=RESOLUTION_OPTIONS,
value=RESOLUTION_OPTIONS[0],
label="Resolution",
)
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
generate_btn = gr.Button("Generate Image", variant="primary")
seed_used = gr.Number(label="Seed Used", interactive=False)
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
generate_btn.click(
fn=generate_image,
inputs=[prompt, resolution, seed],
outputs=[output_image, seed_used],
)
if __name__ == "__main__":
demo.launch()