import spaces import os import json import time import torch from PIL import Image from tqdm import tqdm import gradio as gr from safetensors.torch import save_file from src.pipeline import FluxPipeline from src.transformer_flux import FluxTransformer2DModel from src.lora_helper import set_single_lora, set_multi_lora, unset_lora # Initialize the image processor base_path = "black-forest-labs/FLUX.1-dev" lora_base_path = "./models" # System prompt that will be hidden from users but automatically added to their input SYSTEM_PROMPT = "Ghibli Studio style, Charming hand-drawn anime-style illustration" pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16) transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16) pipe.transformer = transformer pipe.to("cuda") def clear_cache(transformer): for name, attn_processor in transformer.attn_processors.items(): attn_processor.bank_kv.clear() # Define the Gradio interface @spaces.GPU() def single_condition_generate_image(user_prompt, spatial_img, height, width, seed): # Combine the system prompt with user prompt full_prompt = f"{SYSTEM_PROMPT}, {user_prompt}" if user_prompt else SYSTEM_PROMPT # Set the Ghibli LoRA lora_path = os.path.join(lora_base_path, "Ghibli.safetensors") set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512) # Process the image spatial_imgs = [spatial_img] if spatial_img else [] image = pipe( full_prompt, height=int(height), width=int(width), guidance_scale=3.5, num_inference_steps=25, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(seed), subject_images=[], spatial_images=spatial_imgs, cond_size=512, ).images[0] clear_cache(pipe.transformer) return image # Load example images def load_examples(): examples = [] test_img_dir = "./test_imgs" example_prompts = [ "a cat sitting by the window", "a peaceful mountain village", "a young girl with flowers in her hair", "a magical forest with spirits", "a flying castle in the clouds", "a serene river with boats", "a cozy cottage in the countryside", "a bustling market in a small town" ] for i, filename in enumerate(["00.png", "02.png", "03.png", "04.png", "06.png", "07.png", "08.png", "09.png"]): img_path = os.path.join(test_img_dir, filename) if os.path.exists(img_path): # Use dimensions from original code for each specific example if filename == "00.png": height, width = 680, 1024 elif filename == "02.png": height, width = 560, 1024 elif filename == "03.png": height, width = 568, 1024 elif filename == "04.png": height, width = 768, 672 elif filename == "06.png": height, width = 896, 1024 elif filename == "07.png": height, width = 528, 800 elif filename == "08.png": height, width = 696, 1024 elif filename == "09.png": height, width = 896, 1024 else: height, width = 768, 768 examples.append([ example_prompts[i % len(example_prompts)], # User prompt (without system prompt) Image.open(img_path), # Reference image height, # Height width, # Width i + 1 # Seed ]) return examples # CSS for improved UI css = """ :root { --primary-color: #4a6670; --accent-color: #ff8a65; --background-color: #f5f5f5; --card-background: #ffffff; --text-color: #333333; --border-radius: 10px; --shadow: 0 4px 6px rgba(0,0,0,0.1); } body { background-color: var(--background-color); color: var(--text-color); font-family: 'Helvetica Neue', Arial, sans-serif; } .container { max-width: 1200px; margin: 0 auto; padding: 20px; } .gr-header { background: linear-gradient(135deg, #668796 0%, #4a6670 100%); padding: 24px; border-radius: var(--border-radius); margin-bottom: 24px; box-shadow: var(--shadow); text-align: center; } .gr-header h1 { color: white; font-size: 2.5rem; margin: 0; font-weight: 700; } .gr-header p { color: rgba(255, 255, 255, 0.9); font-size: 1.1rem; margin-top: 8px; } .gr-panel { background-color: var(--card-background); border-radius: var(--border-radius); padding: 16px; box-shadow: var(--shadow); } .gr-button { background-color: var(--accent-color); border: none; color: white; padding: 10px 20px; border-radius: 5px; font-size: 16px; font-weight: bold; cursor: pointer; transition: transform 0.1s, background-color 0.3s; } .gr-button:hover { background-color: #ff7043; transform: translateY(-2px); } .gr-input, .gr-select { border-radius: 5px; border: 1px solid #ddd; padding: 10px; width: 100%; } .gr-form { display: grid; gap: 16px; } .gr-box { background-color: var(--card-background); border-radius: var(--border-radius); padding: 20px; box-shadow: var(--shadow); margin-bottom: 20px; } .gr-gallery { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 16px; } .gr-gallery-item { overflow: hidden; border-radius: var(--border-radius); box-shadow: var(--shadow); transition: transform 0.3s; } .gr-gallery-item:hover { transform: scale(1.02); } .gr-image { width: 100%; height: auto; object-fit: cover; } .gr-footer { text-align: center; margin-top: 40px; padding: 20px; color: #666; font-size: 14px; } .gr-examples-gallery { margin-top: 20px; } /* Responsive adjustments */ @media (max-width: 768px) { .gr-header h1 { font-size: 1.8rem; } .gr-panel { padding: 12px; } } /* Ghibli-inspired accent colors */ .gr-accent-1 { background-color: #95ccd9; } .gr-accent-2 { background-color: #74ad8c; } .gr-accent-3 { background-color: #f9c06b; } """ # Create the Gradio Blocks interface with gr.Blocks(css=css) as demo: gr.HTML("""
Transform your ideas into magical Ghibli-inspired artwork
Describe what you want to see in your Ghibli-inspired image
Your Ghibli-inspired artwork will appear here
Click on any example to try it out