Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -113,6 +113,22 @@ def grounded_segmentation(
|
|
113 |
return np.array(image), detections
|
114 |
|
115 |
def segment_image(image, object_name, detector, segmentator, seg_processor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
image_array, detections = grounded_segmentation(detector, segmentator, seg_processor, image, [object_name])
|
117 |
if not detections or detections[0].mask is None:
|
118 |
raise gr.Error(f"Could not segment the subject '{object_name}' in the image. Please try a clearer image or a more specific subject name.")
|
@@ -122,6 +138,15 @@ def segment_image(image, object_name, detector, segmentator, seg_processor):
|
|
122 |
return Image.fromarray(segment_result.astype(np.uint8))
|
123 |
|
124 |
def make_diptych(image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
ref_image_np = np.array(image)
|
126 |
diptych_np = np.concatenate([ref_image_np, np.zeros_like(ref_image_np)], axis=1)
|
127 |
return Image.fromarray(diptych_np)
|
@@ -227,6 +252,29 @@ def get_duration(
|
|
227 |
randomize_seed: bool,
|
228 |
progress=gr.Progress(track_tqdm=True)
|
229 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
if width > 768 or height > 768:
|
231 |
return 210
|
232 |
else:
|
@@ -250,6 +298,37 @@ def run_diptych_prompting(
|
|
250 |
randomize_seed: bool,
|
251 |
progress=gr.Progress(track_tqdm=True)
|
252 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
if randomize_seed:
|
254 |
actual_seed = random.randint(0, 9223372036854775807)
|
255 |
else:
|
@@ -362,14 +441,33 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
362 |
# --- UI Event Handlers ---
|
363 |
|
364 |
def toggle_mode_visibility(mode_choice):
|
365 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
if mode_choice == "Subject-Driven":
|
367 |
return gr.update(visible=True), gr.update(visible=False)
|
368 |
else:
|
369 |
return gr.update(visible=False), gr.update(visible=True)
|
370 |
|
371 |
def update_derived_fields(mode_choice, subject, style_desc, target):
|
372 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
if mode_choice == "Subject-Driven":
|
374 |
prompt = f"A diptych with two side-by-side images of same {subject}. On the left, a photo of {subject}. On the right, replicate this {subject} exactly but as {target}"
|
375 |
return gr.update(value=prompt), gr.update(value=True)
|
@@ -406,6 +504,17 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
406 |
outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used, seed]
|
407 |
)
|
408 |
def run_subject_driven_example(input_image, subject_name, target_prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
# Construct the full prompt for subject-driven mode
|
410 |
full_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, a photo of {subject_name}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
|
411 |
|
@@ -439,4 +548,4 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
439 |
)
|
440 |
|
441 |
if __name__ == "__main__":
|
442 |
-
demo.launch(share=True, debug=True)
|
|
|
113 |
return np.array(image), detections
|
114 |
|
115 |
def segment_image(image, object_name, detector, segmentator, seg_processor):
|
116 |
+
"""
|
117 |
+
Segments a specific object from an image and returns the segmented object on a white background.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
image (PIL.Image.Image): The input image.
|
121 |
+
object_name (str): The name of the object to segment.
|
122 |
+
detector: The object detection pipeline.
|
123 |
+
segmentator: The mask generation model.
|
124 |
+
seg_processor: The processor for the mask generation model.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
PIL.Image.Image: The image with the segmented object on a white background.
|
128 |
+
|
129 |
+
Raises:
|
130 |
+
gr.Error: If the object cannot be segmented.
|
131 |
+
"""
|
132 |
image_array, detections = grounded_segmentation(detector, segmentator, seg_processor, image, [object_name])
|
133 |
if not detections or detections[0].mask is None:
|
134 |
raise gr.Error(f"Could not segment the subject '{object_name}' in the image. Please try a clearer image or a more specific subject name.")
|
|
|
138 |
return Image.fromarray(segment_result.astype(np.uint8))
|
139 |
|
140 |
def make_diptych(image):
|
141 |
+
"""
|
142 |
+
Creates a diptych image by concatenating the input image with a black image of the same size.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
image (PIL.Image.Image): The input image.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
PIL.Image.Image: The diptych image.
|
149 |
+
"""
|
150 |
ref_image_np = np.array(image)
|
151 |
diptych_np = np.concatenate([ref_image_np, np.zeros_like(ref_image_np)], axis=1)
|
152 |
return Image.fromarray(diptych_np)
|
|
|
252 |
randomize_seed: bool,
|
253 |
progress=gr.Progress(track_tqdm=True)
|
254 |
):
|
255 |
+
"""
|
256 |
+
Calculates the estimated duration for the Spaces GPU based on image dimensions.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
input_image (PIL.Image.Image): The input reference image.
|
260 |
+
subject_name (str): Name of the subject for segmentation.
|
261 |
+
do_segmentation (bool): Whether to perform segmentation.
|
262 |
+
full_prompt (str): The full text prompt.
|
263 |
+
attn_enforce (float): Attention enforcement value.
|
264 |
+
ctrl_scale (float): ControlNet conditioning scale.
|
265 |
+
width (int): Target width of the generated image.
|
266 |
+
height (int): Target height of the generated image.
|
267 |
+
pixel_offset (int): Padding offset in pixels.
|
268 |
+
num_steps (int): Number of inference steps.
|
269 |
+
guidance (float): Distilled guidance scale.
|
270 |
+
real_guidance (float): Real guidance scale.
|
271 |
+
seed (int): Random seed.
|
272 |
+
randomize_seed (bool): Whether to randomize the seed.
|
273 |
+
progress (gr.Progress): Gradio progress tracker.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
int: Estimated duration in seconds.
|
277 |
+
"""
|
278 |
if width > 768 or height > 768:
|
279 |
return 210
|
280 |
else:
|
|
|
298 |
randomize_seed: bool,
|
299 |
progress=gr.Progress(track_tqdm=True)
|
300 |
):
|
301 |
+
"""
|
302 |
+
Runs the diptych prompting image generation process.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
input_image (PIL.Image.Image): The input reference image.
|
306 |
+
subject_name (str): The name of the subject for segmentation (if `do_segmentation` is True).
|
307 |
+
do_segmentation (bool): If True, the subject will be segmented from the reference image.
|
308 |
+
full_prompt (str): The complete text prompt used for image generation.
|
309 |
+
attn_enforce (float): Controls the attention enforcement in the custom attention processor.
|
310 |
+
ctrl_scale (float): The conditioning scale for ControlNet.
|
311 |
+
width (int): The desired width of the final generated image.
|
312 |
+
height (int): The desired height of the final generated image.
|
313 |
+
pixel_offset (int): Padding added around the image during diptych creation.
|
314 |
+
num_steps (int): The number of inference steps for the diffusion process.
|
315 |
+
guidance (float): The distilled guidance scale for the diffusion process.
|
316 |
+
real_guidance (float): The real guidance scale for the diffusion process.
|
317 |
+
seed (int): The random seed for reproducibility.
|
318 |
+
randomize_seed (bool): If True, a random seed will be used instead of the provided `seed`.
|
319 |
+
progress (gr.Progress): Gradio progress tracker to update UI during execution.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
tuple: A tuple containing:
|
323 |
+
- PIL.Image.Image: The final generated image.
|
324 |
+
- PIL.Image.Image: The processed reference image (left panel of the diptych).
|
325 |
+
- PIL.Image.Image: The full diptych image generated by the pipeline.
|
326 |
+
- str: The final prompt used.
|
327 |
+
- int: The actual seed used for generation.
|
328 |
+
|
329 |
+
Raises:
|
330 |
+
gr.Error: If a reference image is not uploaded, prompts are empty, or segmentation fails.
|
331 |
+
"""
|
332 |
if randomize_seed:
|
333 |
actual_seed = random.randint(0, 9223372036854775807)
|
334 |
else:
|
|
|
441 |
# --- UI Event Handlers ---
|
442 |
|
443 |
def toggle_mode_visibility(mode_choice):
|
444 |
+
"""
|
445 |
+
Hides/shows the relevant input textboxes based on the selected mode.
|
446 |
+
|
447 |
+
Args:
|
448 |
+
mode_choice (str): The selected generation mode ("Subject-Driven" or "Style-Driven").
|
449 |
+
|
450 |
+
Returns:
|
451 |
+
tuple: Gradio update objects for `subject_driven_group` and `style_driven_group` visibility.
|
452 |
+
"""
|
453 |
if mode_choice == "Subject-Driven":
|
454 |
return gr.update(visible=True), gr.update(visible=False)
|
455 |
else:
|
456 |
return gr.update(visible=False), gr.update(visible=True)
|
457 |
|
458 |
def update_derived_fields(mode_choice, subject, style_desc, target):
|
459 |
+
"""
|
460 |
+
Updates the full prompt and segmentation checkbox based on other inputs.
|
461 |
+
|
462 |
+
Args:
|
463 |
+
mode_choice (str): The selected generation mode ("Subject-Driven" or "Style-Driven").
|
464 |
+
subject (str): The subject name (relevant for "Subject-Driven" mode).
|
465 |
+
style_desc (str): The original style description (relevant for "Style-Driven" mode).
|
466 |
+
target (str): The target prompt.
|
467 |
+
|
468 |
+
Returns:
|
469 |
+
tuple: Gradio update objects for `full_prompt` value and `do_segmentation` checkbox value.
|
470 |
+
"""
|
471 |
if mode_choice == "Subject-Driven":
|
472 |
prompt = f"A diptych with two side-by-side images of same {subject}. On the left, a photo of {subject}. On the right, replicate this {subject} exactly but as {target}"
|
473 |
return gr.update(value=prompt), gr.update(value=True)
|
|
|
504 |
outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used, seed]
|
505 |
)
|
506 |
def run_subject_driven_example(input_image, subject_name, target_prompt):
|
507 |
+
"""
|
508 |
+
Helper function to run an example for the subject-driven mode.
|
509 |
+
|
510 |
+
Args:
|
511 |
+
input_image (PIL.Image.Image): The input reference image for the example.
|
512 |
+
subject_name (str): The subject name for the example.
|
513 |
+
target_prompt (str): The target prompt for the example.
|
514 |
+
|
515 |
+
Returns:
|
516 |
+
tuple: The outputs from `run_diptych_prompting`.
|
517 |
+
"""
|
518 |
# Construct the full prompt for subject-driven mode
|
519 |
full_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, a photo of {subject_name}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
|
520 |
|
|
|
548 |
)
|
549 |
|
550 |
if __name__ == "__main__":
|
551 |
+
demo.launch(share=True, debug=True, mcp_server=True)
|