Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import json | |
import time | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
import gradio as gr | |
from safetensors.torch import save_file | |
from src.pipeline import FluxPipeline | |
from src.transformer_flux import FluxTransformer2DModel | |
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora | |
# Initialize the image processor | |
base_path = "black-forest-labs/FLUX.1-dev" | |
lora_base_path = "./models" | |
# System prompt that will be hidden from users but automatically added to their input | |
SYSTEM_PROMPT = "Ghibli Studio style, Charming hand-drawn anime-style illustration" | |
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16) | |
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16) | |
pipe.transformer = transformer | |
pipe.to("cuda") | |
def clear_cache(transformer): | |
for name, attn_processor in transformer.attn_processors.items(): | |
attn_processor.bank_kv.clear() | |
# Define the Gradio interface | |
def single_condition_generate_image(user_prompt, spatial_img, height, width, seed): | |
# Combine the system prompt with user prompt | |
full_prompt = f"{SYSTEM_PROMPT}, {user_prompt}" if user_prompt else SYSTEM_PROMPT | |
# Set the Ghibli LoRA | |
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors") | |
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512) | |
# Process the image | |
spatial_imgs = [spatial_img] if spatial_img else [] | |
image = pipe( | |
full_prompt, | |
height=int(height), | |
width=int(width), | |
guidance_scale=3.5, | |
num_inference_steps=25, | |
max_sequence_length=512, | |
generator=torch.Generator("cpu").manual_seed(seed), | |
subject_images=[], | |
spatial_images=spatial_imgs, | |
cond_size=512, | |
).images[0] | |
clear_cache(pipe.transformer) | |
return image | |
# Load example images | |
def load_examples(): | |
examples = [] | |
test_img_dir = "./test_imgs" | |
example_prompts = [ | |
"a cat sitting by the window", | |
"a peaceful mountain village", | |
"a young girl with flowers in her hair", | |
"a magical forest with spirits", | |
"a flying castle in the clouds", | |
"a serene river with boats", | |
"a cozy cottage in the countryside", | |
"a bustling market in a small town" | |
] | |
for i, filename in enumerate(["00.png", "02.png", "03.png", "04.png", "06.png", "07.png", "08.png", "09.png"]): | |
img_path = os.path.join(test_img_dir, filename) | |
if os.path.exists(img_path): | |
# Use dimensions from original code for each specific example | |
if filename == "00.png": | |
height, width = 680, 1024 | |
elif filename == "02.png": | |
height, width = 560, 1024 | |
elif filename == "03.png": | |
height, width = 568, 1024 | |
elif filename == "04.png": | |
height, width = 768, 672 | |
elif filename == "06.png": | |
height, width = 896, 1024 | |
elif filename == "07.png": | |
height, width = 528, 800 | |
elif filename == "08.png": | |
height, width = 696, 1024 | |
elif filename == "09.png": | |
height, width = 896, 1024 | |
else: | |
height, width = 768, 768 | |
examples.append([ | |
example_prompts[i % len(example_prompts)], # User prompt (without system prompt) | |
Image.open(img_path), # Reference image | |
height, # Height | |
width, # Width | |
i + 1 # Seed | |
]) | |
return examples | |
# CSS for improved UI | |
css = """ | |
:root { | |
--primary-color: #4a6670; | |
--accent-color: #ff8a65; | |
--background-color: #f5f5f5; | |
--card-background: #ffffff; | |
--text-color: #333333; | |
--border-radius: 10px; | |
--shadow: 0 4px 6px rgba(0,0,0,0.1); | |
} | |
body { | |
background-color: var(--background-color); | |
color: var(--text-color); | |
font-family: 'Helvetica Neue', Arial, sans-serif; | |
} | |
.container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.gr-header { | |
background: linear-gradient(135deg, #668796 0%, #4a6670 100%); | |
padding: 24px; | |
border-radius: var(--border-radius); | |
margin-bottom: 24px; | |
box-shadow: var(--shadow); | |
text-align: center; | |
} | |
.gr-header h1 { | |
color: white; | |
font-size: 2.5rem; | |
margin: 0; | |
font-weight: 700; | |
} | |
.gr-header p { | |
color: rgba(255, 255, 255, 0.9); | |
font-size: 1.1rem; | |
margin-top: 8px; | |
} | |
.gr-panel { | |
background-color: var(--card-background); | |
border-radius: var(--border-radius); | |
padding: 16px; | |
box-shadow: var(--shadow); | |
} | |
.gr-button { | |
background-color: var(--accent-color); | |
border: none; | |
color: white; | |
padding: 10px 20px; | |
border-radius: 5px; | |
font-size: 16px; | |
font-weight: bold; | |
cursor: pointer; | |
transition: transform 0.1s, background-color 0.3s; | |
} | |
.gr-button:hover { | |
background-color: #ff7043; | |
transform: translateY(-2px); | |
} | |
.gr-input, .gr-select { | |
border-radius: 5px; | |
border: 1px solid #ddd; | |
padding: 10px; | |
width: 100%; | |
} | |
.gr-form { | |
display: grid; | |
gap: 16px; | |
} | |
.gr-box { | |
background-color: var(--card-background); | |
border-radius: var(--border-radius); | |
padding: 20px; | |
box-shadow: var(--shadow); | |
margin-bottom: 20px; | |
} | |
.gr-gallery { | |
display: grid; | |
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); | |
gap: 16px; | |
} | |
.gr-gallery-item { | |
overflow: hidden; | |
border-radius: var(--border-radius); | |
box-shadow: var(--shadow); | |
transition: transform 0.3s; | |
} | |
.gr-gallery-item:hover { | |
transform: scale(1.02); | |
} | |
.gr-image { | |
width: 100%; | |
height: auto; | |
object-fit: cover; | |
} | |
.gr-footer { | |
text-align: center; | |
margin-top: 40px; | |
padding: 20px; | |
color: #666; | |
font-size: 14px; | |
} | |
.gr-examples-gallery { | |
margin-top: 20px; | |
} | |
/* Responsive adjustments */ | |
@media (max-width: 768px) { | |
.gr-header h1 { | |
font-size: 1.8rem; | |
} | |
.gr-panel { | |
padding: 12px; | |
} | |
} | |
/* Ghibli-inspired accent colors */ | |
.gr-accent-1 { | |
background-color: #95ccd9; | |
} | |
.gr-accent-2 { | |
background-color: #74ad8c; | |
} | |
.gr-accent-3 { | |
background-color: #f9c06b; | |
} | |
""" | |
# Create the Gradio Blocks interface | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(""" | |
<div class="gr-header"> | |
<h1>✨ Ghibli Art Generator ✨</h1> | |
<p>Transform your ideas into magical Ghibli-inspired artwork</p> | |
</div> | |
""") | |
with gr.Tab("Create Ghibli Art"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML(""" | |
<div class="gr-box"> | |
<h3>🎨 Your Creative Input</h3> | |
<p>Describe what you want to see in your Ghibli-inspired image</p> | |
</div> | |
""") | |
user_prompt = gr.Textbox( | |
label="Your description", | |
placeholder="Describe what you want to see (e.g., a cat sitting by the window)", | |
lines=2 | |
) | |
spatial_img = gr.Image( | |
label="Reference Image (Optional)", | |
type="pil", | |
elem_classes="gr-image-upload" | |
) | |
with gr.Group(): | |
with gr.Row(): | |
height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768) | |
width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768) | |
seed = gr.Slider(minimum=1, maximum=9999, step=1, label="Seed", value=42, | |
info="Change for different variations") | |
generate_btn = gr.Button("✨ Generate Ghibli Art", elem_classes="gr-button") | |
with gr.Column(scale=1): | |
gr.HTML(""" | |
<div class="gr-box"> | |
<h3>✨ Your Magical Creation</h3> | |
<p>Your Ghibli-inspired artwork will appear here</p> | |
</div> | |
""") | |
output_image = gr.Image(label="Generated Image", elem_classes="gr-output-image") | |
gr.HTML(""" | |
<div class="gr-box gr-examples-gallery"> | |
<h3>✨ Inspiration Gallery</h3> | |
<p>Click on any example to try it out</p> | |
</div> | |
""") | |
# Add examples | |
examples = load_examples() | |
gr.Examples( | |
examples=examples, | |
inputs=[user_prompt, spatial_img, height, width, seed], | |
outputs=output_image, | |
fn=single_condition_generate_image, | |
cache_examples=False, | |
examples_per_page=4 | |
) | |
gr.HTML(""" | |
<div class="gr-footer"> | |
<p>Powered by FLUX.1 and Ghibli LoRA • Created with ❤️</p> | |
</div> | |
""") | |
# Link the button to the function | |
generate_btn.click( | |
single_condition_generate_image, | |
inputs=[user_prompt, spatial_img, height, width, seed], | |
outputs=output_image | |
) | |
# Launch the Gradio app | |
demo.queue().launch() |