Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved. | |
# | |
# This work is licensed under the LICENSE file | |
# located at the root directory. | |
import os | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import tempfile | |
import gc | |
from datetime import datetime | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
from addit_flux_pipeline import AdditFluxPipeline | |
from addit_flux_transformer import AdditFluxTransformer2DModel | |
from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler | |
from addit_methods import add_object_generated, add_object_real | |
# Global variables for model | |
pipe = None | |
device = None | |
original_image_size = None | |
# Initialize model at startup | |
print("Initializing ADDIT model...") | |
try: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load transformer | |
my_transformer = AdditFluxTransformer2DModel.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="transformer", | |
torch_dtype=torch.bfloat16 | |
) | |
# Load pipeline | |
pipe = AdditFluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
transformer=my_transformer, | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
# Set scheduler | |
pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
print("Model initialized successfully!") | |
print("Initialization SAM model:") | |
sam = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") | |
except Exception as e: | |
print(f"Error initializing model: {str(e)}") | |
print("The application will start but model functionality will be unavailable.") | |
def validate_inputs(prompt_source, prompt_target, subject_token): | |
"""Validate user inputs""" | |
if not prompt_source.strip(): | |
return "Source prompt cannot be empty" | |
if not prompt_target.strip(): | |
return "Target prompt cannot be empty" | |
if not subject_token.strip(): | |
return "Subject token cannot be empty" | |
if subject_token not in prompt_target: | |
return f"Subject token '{subject_token}' must appear in the target prompt" | |
return None | |
def resize_and_crop_image(image): | |
""" | |
Resize and center crop image to 1024x1024. | |
Returns the processed image, a message about what was done, and original size info. | |
""" | |
if image is None: | |
return None, "", None | |
original_width, original_height = image.size | |
original_size = (original_width, original_height) | |
# If already 1024x1024, no processing needed | |
if original_width == 1024 and original_height == 1024: | |
return image, "", original_size | |
# Calculate scaling to make smaller dimension 1024 | |
scale = 1024 / min(original_width, original_height) | |
new_width = int(original_width * scale) | |
new_height = int(original_height * scale) | |
# Resize image | |
resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
# Center crop to 1024x1024 | |
left = (new_width - 1024) // 2 | |
top = (new_height - 1024) // 2 | |
right = left + 1024 | |
bottom = top + 1024 | |
cropped_image = resized_image.crop((left, top, right, bottom)) | |
# Create status message | |
if new_width == 1024 and new_height == 1024: | |
message = f"<div style='background-color: #e8f5e8; border: 1px solid #4caf50; border-radius: 5px; padding: 8px; margin-bottom: 10px;'><span style='color: #2e7d32; font-weight: bold;'>✅ Image resized to 1024×1024</span></div>" | |
else: | |
message = f"<div style='background-color: #e8f5e8; border: 1px solid #4caf50; border-radius: 5px; padding: 8px; margin-bottom: 10px;'><span style='color: #2e7d32; font-weight: bold;'>✅ Image resized and center cropped to 1024×1024</span></div>" | |
return cropped_image, message, original_size | |
def handle_image_upload(image): | |
""" | |
Handle image upload and preprocessing for the Gradio interface. | |
This function is called when a user uploads an image to the real images tab. | |
It stores the original image size globally and processes the image to the required dimensions. | |
Args: | |
image: PIL.Image object uploaded by the user, or None if no image is uploaded. | |
Returns: | |
Tuple containing: | |
- processed_image: PIL.Image object resized and cropped to 1024x1024, or None if no image. | |
- message: HTML-formatted string indicating the processing status, or empty string. | |
""" | |
global original_image_size | |
if image is None: | |
original_image_size = None | |
return None, "" | |
# Store original size | |
original_image_size = image.size | |
# Process image | |
processed_image, message, _ = resize_and_crop_image(image) | |
return processed_image, message | |
def process_generated_image( | |
prompt_source, | |
prompt_target, | |
subject_token, | |
seed_src, | |
seed_obj, | |
extended_scale, | |
structure_transfer_step, | |
blend_steps, | |
localization_model, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
""" | |
Process and generate images using ADDIT for the generated images workflow. | |
This function generates a source image from a text prompt and then adds an object to it | |
based on the target prompt and subject token using the ADDIT pipeline. | |
Args: | |
prompt_source: String describing the source scene without the object to be added. | |
prompt_target: String describing the target scene including the object to be added. | |
subject_token: String token representing the object to add (must appear in target prompt). | |
seed_src: Integer seed for generating the source image. | |
seed_obj: Integer seed for generating the object. | |
extended_scale: Float value (1.0-1.3) controlling the extended attention scale. | |
structure_transfer_step: Integer (0-10) controlling structure transfer strength. | |
blend_steps: String of comma-separated integers for blending steps, or empty string. | |
localization_model: String specifying the localization model to use. | |
progress: Gradio progress tracker for displaying progress updates. | |
Returns: | |
Tuple containing: | |
- src_image: PIL.Image of the generated source image, or None if error. | |
- edited_image: PIL.Image with the added object, or None if error. | |
- status_message: String describing the result or error message. | |
""" | |
global pipe | |
if pipe is None: | |
return None, None, "Model not initialized. Please restart the application." | |
# Validate inputs | |
error_msg = validate_inputs(prompt_source, prompt_target, subject_token) | |
if error_msg: | |
return None, None, error_msg | |
# Print current time and input information | |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
print(f"\n[{current_time}] Starting Generated Image Processing") | |
print(f"Source Prompt: '{prompt_source}'") | |
print(f"Target Prompt: '{prompt_target}'") | |
print(f"Subject Token: '{subject_token}'") | |
print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") | |
print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") | |
print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") | |
try: | |
# Parse blend steps | |
if blend_steps.strip(): | |
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] | |
else: | |
blend_steps_list = [] | |
# Generate images | |
src_image, edited_image = add_object_generated( | |
pipe=pipe, | |
prompt_source=prompt_source, | |
prompt_object=prompt_target, | |
subject_token=subject_token, | |
seed_src=seed_src, | |
seed_obj=seed_obj, | |
show_attention=False, | |
extended_scale=extended_scale, | |
structure_transfer_step=structure_transfer_step, | |
blend_steps=blend_steps_list, | |
localization_model=localization_model, | |
display_output=False | |
) | |
return src_image, edited_image, "Images generated successfully!" | |
except Exception as e: | |
error_msg = f"Error generating images: {str(e)}" | |
print(error_msg) | |
return None, None, error_msg | |
def process_real_image( | |
source_image, | |
prompt_source, | |
prompt_target, | |
subject_token, | |
seed_src, | |
seed_obj, | |
extended_scale, | |
structure_transfer_step, | |
blend_steps, | |
localization_model, | |
use_offset, | |
disable_inversion, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
""" | |
Process and edit a real uploaded image using ADDIT to add objects. | |
This function takes an uploaded image and adds an object to it based on the target prompt | |
and subject token using the ADDIT pipeline with optional inversion and offset techniques. | |
Args: | |
source_image: PIL.Image object of the uploaded source image to edit. | |
prompt_source: String describing the source image content. | |
prompt_target: String describing the desired result including the object to add. | |
subject_token: String token representing the object to add (must appear in target prompt). | |
seed_src: Integer seed for source image processing. | |
seed_obj: Integer seed for object generation. | |
extended_scale: Float value (1.0-1.3) controlling the extended attention scale. | |
structure_transfer_step: Integer (0-10) controlling structure transfer strength. | |
blend_steps: String of comma-separated integers for blending steps, or empty string. | |
localization_model: String specifying the localization model to use. | |
use_offset: Boolean indicating whether to use offset technique. | |
disable_inversion: Boolean indicating whether to disable DDIM inversion. | |
progress: Gradio progress tracker for displaying progress updates. | |
Returns: | |
Tuple containing: | |
- src_image: PIL.Image of the processed source image, or None if error. | |
- edited_image: PIL.Image with the added object, or None if error. | |
- status_message: String describing the result or error message. | |
""" | |
global pipe | |
if pipe is None: | |
return None, None, "Model not initialized. Please restart the application." | |
if source_image is None: | |
return None, None, "Please upload a source image" | |
# Validate inputs | |
error_msg = validate_inputs(prompt_source, prompt_target, subject_token) | |
if error_msg: | |
return None, None, error_msg | |
# Print current time and input information | |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
print(f"\n[{current_time}] Starting Real Image Processing") | |
if original_image_size: | |
print(f"Original uploaded image size: {original_image_size[0]}×{original_image_size[1]}") | |
print(f"Source Image Size: {source_image.size}") | |
print(f"Source Prompt: '{prompt_source}'") | |
print(f"Target Prompt: '{prompt_target}'") | |
print(f"Subject Token: '{subject_token}'") | |
print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") | |
print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") | |
print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") | |
print(f"Use Offset: {use_offset}, Disable Inversion: {disable_inversion}") | |
try: | |
# Resize source image | |
source_image = source_image.resize((1024, 1024)) | |
# Parse blend steps | |
if blend_steps.strip(): | |
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] | |
else: | |
blend_steps_list = [] | |
# Process image | |
src_image, edited_image = add_object_real( | |
pipe=pipe, | |
source_image=source_image, | |
prompt_source=prompt_source, | |
prompt_object=prompt_target, | |
subject_token=subject_token, | |
seed_src=seed_src, | |
seed_obj=seed_obj, | |
extended_scale=extended_scale, | |
structure_transfer_step=structure_transfer_step, | |
blend_steps=blend_steps_list, | |
localization_model=localization_model, | |
use_offset=use_offset, | |
show_attention=False, | |
use_inversion=not disable_inversion, | |
display_output=False | |
) | |
return src_image, edited_image, "Image edited successfully!" | |
except Exception as e: | |
error_msg = f"Error processing image: {str(e)}" | |
print(error_msg) | |
return None, None, error_msg | |
def create_interface(): | |
"""Create the Gradio interface""" | |
# Show model status in the interface | |
model_status = "Model ready!" if pipe is not None else "Model initialization failed - functionality unavailable" | |
with gr.Blocks(title="🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models", theme=gr.themes.Soft()) as demo: | |
gr.HTML(f""" | |
<div style="text-align: center; margin-bottom: 20px;"> | |
<h1>🎨 Add-it: Training-Free Object Insertion</h1> | |
<p>Add objects to images using pretrained diffusion models</p> | |
<p><a href="https://research.nvidia.com/labs/par/addit/" target="_blank">🌐 Project Website</a> | | |
<a href="https://arxiv.org/abs/2411.07232" target="_blank">📄 Paper</a> | | |
<a href="https://github.com/NVlabs/addit" target="_blank">💻 Code</a></p> | |
<p style="color: {'green' if pipe is not None else 'red'}; font-weight: bold;">Status: {model_status}</p> | |
</div> | |
""") | |
# Main interface | |
with gr.Tabs(): | |
# Generated Images Tab | |
with gr.TabItem("🎭 Generated Images"): | |
gr.Markdown("### Generate a base image and add objects to it") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gen_prompt_source = gr.Textbox( | |
label="Source Prompt", | |
placeholder="A photo of a cat sitting on the couch", | |
value="A photo of a cat sitting on the couch" | |
) | |
gen_prompt_target = gr.Textbox( | |
label="Target Prompt", | |
placeholder="A photo of a cat wearing a blue hat sitting on the couch", | |
value="A photo of a cat wearing a blue hat sitting on the couch" | |
) | |
gen_subject_token = gr.Textbox( | |
label="Subject Token", | |
placeholder="hat", | |
value="hat", | |
info="Single token representing the object to add **(must appear in target prompt)**" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
gen_seed_src = gr.Number(label="Source Seed", value=1, precision=0) | |
gen_seed_obj = gr.Number(label="Object Seed", value=42, precision=0) | |
gen_extended_scale = gr.Slider( | |
label="Extended Scale", | |
minimum=1.0, | |
maximum=1.3, | |
value=1.05, | |
step=0.01 | |
) | |
gen_structure_transfer_step = gr.Slider( | |
label="Structure Transfer Step", | |
minimum=0, | |
maximum=10, | |
value=2, | |
step=1 | |
) | |
gen_blend_steps = gr.Textbox( | |
label="Blend Steps", | |
value="15", | |
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" | |
) | |
gen_localization_model = gr.Dropdown( | |
label="Localization Model", | |
choices=[ | |
"attention_points_sam", | |
"attention", | |
"attention_box_sam", | |
"attention_mask_sam", | |
"grounding_sam" | |
], | |
value="attention_points_sam" | |
) | |
gen_submit_btn = gr.Button("🎨 Generate & Edit", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Row(): | |
gen_src_output = gr.Image(label="Generated Source Image", type="pil") | |
gen_edited_output = gr.Image(label="Edited Image", type="pil") | |
gen_status = gr.Textbox(label="Status", interactive=False) | |
gen_submit_btn.click( | |
fn=process_generated_image, | |
inputs=[ | |
gen_prompt_source, gen_prompt_target, gen_subject_token, | |
gen_seed_src, gen_seed_obj, gen_extended_scale, | |
gen_structure_transfer_step, gen_blend_steps, | |
gen_localization_model | |
], | |
outputs=[gen_src_output, gen_edited_output, gen_status] | |
) | |
# Examples for generated images | |
gr.Examples( | |
examples=[ | |
["An empty throne", "A king sitting on a throne", "king"], | |
["A photo of a man sitting on a bench", "A photo of a man sitting on a bench with a dog", "dog"], | |
["A photo of a cat sitting on the couch", "A photo of a cat wearing a blue hat sitting on the couch", "hat"], | |
["A car driving through an empty street", "A pink car driving through an empty street", "car"] | |
], | |
inputs=[ | |
gen_prompt_source, gen_prompt_target, gen_subject_token | |
], | |
label="Example Prompts" | |
) | |
# Real Images Tab | |
with gr.TabItem("📸 Real Images"): | |
gr.Markdown("### Upload an image and add objects to it") | |
gr.HTML("<p style='color: orange; font-weight: bold; margin: -15px -10px;'>Note: Images will be automatically resized and center cropped to 1024×1024 pixels.</p>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
real_image_status = gr.HTML(visible=False) | |
real_source_image = gr.Image(label="Source Image", type="pil") | |
real_prompt_source = gr.Textbox( | |
label="Source Prompt", | |
placeholder="A photo of a bed in a dark room", | |
value="A photo of a bed in a dark room" | |
) | |
real_prompt_target = gr.Textbox( | |
label="Target Prompt", | |
placeholder="A photo of a dog lying on a bed in a dark room", | |
value="A photo of a dog lying on a bed in a dark room" | |
) | |
real_subject_token = gr.Textbox( | |
label="Subject Token", | |
placeholder="dog", | |
value="dog", | |
info="Single token representing the object to add **(must appear in target prompt)**" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
real_seed_src = gr.Number(label="Source Seed", value=1, precision=0) | |
real_seed_obj = gr.Number(label="Object Seed", value=0, precision=0) | |
real_extended_scale = gr.Slider( | |
label="Extended Scale", | |
minimum=1.0, | |
maximum=1.3, | |
value=1.1, | |
step=0.01 | |
) | |
real_structure_transfer_step = gr.Slider( | |
label="Structure Transfer Step", | |
minimum=0, | |
maximum=10, | |
value=4, | |
step=1 | |
) | |
real_blend_steps = gr.Textbox( | |
label="Blend Steps", | |
value="18", | |
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" | |
) | |
real_localization_model = gr.Dropdown( | |
label="Localization Model", | |
choices=[ | |
"attention", | |
"attention_points_sam", | |
"attention_box_sam", | |
"attention_mask_sam", | |
"grounding_sam" | |
], | |
value="attention" | |
) | |
real_use_offset = gr.Checkbox(label="Use Offset", value=False) | |
real_disable_inversion = gr.Checkbox(label="Disable Inversion", value=False) | |
real_submit_btn = gr.Button("🎨 Edit Image", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Row(): | |
real_src_output = gr.Image(label="Source Image", type="pil") | |
real_edited_output = gr.Image(label="Edited Image", type="pil") | |
real_status = gr.Textbox(label="Status", interactive=False) | |
# Handle image upload and preprocessing | |
real_source_image.upload( | |
fn=handle_image_upload, | |
inputs=[real_source_image], | |
outputs=[real_source_image, real_image_status] | |
).then( | |
fn=lambda status: gr.update(visible=bool(status.strip()), value=status), | |
inputs=[real_image_status], | |
outputs=[real_image_status] | |
) | |
real_submit_btn.click( | |
fn=process_real_image, | |
inputs=[ | |
real_source_image, real_prompt_source, real_prompt_target, real_subject_token, | |
real_seed_src, real_seed_obj, real_extended_scale, | |
real_structure_transfer_step, real_blend_steps, | |
real_localization_model, real_use_offset, | |
real_disable_inversion | |
], | |
outputs=[real_src_output, real_edited_output, real_status] | |
) | |
# Examples for real images | |
gr.Examples( | |
examples=[ | |
[ | |
"images/bed_dark_room.jpg", | |
"A photo of a bed in a dark room", | |
"A photo of a dog lying on a bed in a dark room", | |
"dog" | |
], | |
[ | |
"images/flower.jpg", | |
"A photo of a flower", | |
"A bee standing on a flower", | |
"bee" | |
] | |
], | |
inputs=[ | |
real_source_image, real_prompt_source, real_prompt_target, real_subject_token | |
], | |
label="Example Images & Prompts" | |
) | |
# Tips | |
with gr.Accordion("💡 Tips for Better Results", open=False): | |
gr.Markdown(""" | |
- **Prompt Design**: The Target Prompt should be similar to the Source Prompt, but include a description of the new object to insert | |
- **Seed Variation**: Try different values for Object Seed - some prompts may require a few attempts to get satisfying results | |
- **Localization Models**: The most effective options are `attention_points_sam` and `attention`. Use Show Attention to visualize localization performance | |
- **Object Placement Issues**: If the object is not added to the image: | |
- Try **decreasing** Structure Transfer Step | |
- Try **increasing** Extended Scale | |
- **Flexibility**: To allow more flexibility in modifying the source image, leave Blend Steps empty to send an empty list | |
""") | |
return demo | |
demo = create_interface() | |
# demo.launch( | |
# server_name="0.0.0.0", | |
# server_port=7860, | |
# share=True, | |
# mcp_server=False | |
# ) | |
demo.launch(mcp_server=True) | |