multimodalart HF Staff commited on
Commit
faa0ef1
·
verified ·
1 Parent(s): 9acd353

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -3
app.py CHANGED
@@ -113,6 +113,22 @@ def grounded_segmentation(
113
  return np.array(image), detections
114
 
115
  def segment_image(image, object_name, detector, segmentator, seg_processor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  image_array, detections = grounded_segmentation(detector, segmentator, seg_processor, image, [object_name])
117
  if not detections or detections[0].mask is None:
118
  raise gr.Error(f"Could not segment the subject '{object_name}' in the image. Please try a clearer image or a more specific subject name.")
@@ -122,6 +138,15 @@ def segment_image(image, object_name, detector, segmentator, seg_processor):
122
  return Image.fromarray(segment_result.astype(np.uint8))
123
 
124
  def make_diptych(image):
 
 
 
 
 
 
 
 
 
125
  ref_image_np = np.array(image)
126
  diptych_np = np.concatenate([ref_image_np, np.zeros_like(ref_image_np)], axis=1)
127
  return Image.fromarray(diptych_np)
@@ -227,6 +252,29 @@ def get_duration(
227
  randomize_seed: bool,
228
  progress=gr.Progress(track_tqdm=True)
229
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  if width > 768 or height > 768:
231
  return 210
232
  else:
@@ -250,6 +298,37 @@ def run_diptych_prompting(
250
  randomize_seed: bool,
251
  progress=gr.Progress(track_tqdm=True)
252
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  if randomize_seed:
254
  actual_seed = random.randint(0, 9223372036854775807)
255
  else:
@@ -362,14 +441,33 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
362
  # --- UI Event Handlers ---
363
 
364
  def toggle_mode_visibility(mode_choice):
365
- """Hides/shows the relevant input textboxes based on mode."""
 
 
 
 
 
 
 
 
366
  if mode_choice == "Subject-Driven":
367
  return gr.update(visible=True), gr.update(visible=False)
368
  else:
369
  return gr.update(visible=False), gr.update(visible=True)
370
 
371
  def update_derived_fields(mode_choice, subject, style_desc, target):
372
- """Updates the full prompt and segmentation checkbox based on other inputs."""
 
 
 
 
 
 
 
 
 
 
 
373
  if mode_choice == "Subject-Driven":
374
  prompt = f"A diptych with two side-by-side images of same {subject}. On the left, a photo of {subject}. On the right, replicate this {subject} exactly but as {target}"
375
  return gr.update(value=prompt), gr.update(value=True)
@@ -406,6 +504,17 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
406
  outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used, seed]
407
  )
408
  def run_subject_driven_example(input_image, subject_name, target_prompt):
 
 
 
 
 
 
 
 
 
 
 
409
  # Construct the full prompt for subject-driven mode
410
  full_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, a photo of {subject_name}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
411
 
@@ -439,4 +548,4 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
439
  )
440
 
441
  if __name__ == "__main__":
442
- demo.launch(share=True, debug=True)
 
113
  return np.array(image), detections
114
 
115
  def segment_image(image, object_name, detector, segmentator, seg_processor):
116
+ """
117
+ Segments a specific object from an image and returns the segmented object on a white background.
118
+
119
+ Args:
120
+ image (PIL.Image.Image): The input image.
121
+ object_name (str): The name of the object to segment.
122
+ detector: The object detection pipeline.
123
+ segmentator: The mask generation model.
124
+ seg_processor: The processor for the mask generation model.
125
+
126
+ Returns:
127
+ PIL.Image.Image: The image with the segmented object on a white background.
128
+
129
+ Raises:
130
+ gr.Error: If the object cannot be segmented.
131
+ """
132
  image_array, detections = grounded_segmentation(detector, segmentator, seg_processor, image, [object_name])
133
  if not detections or detections[0].mask is None:
134
  raise gr.Error(f"Could not segment the subject '{object_name}' in the image. Please try a clearer image or a more specific subject name.")
 
138
  return Image.fromarray(segment_result.astype(np.uint8))
139
 
140
  def make_diptych(image):
141
+ """
142
+ Creates a diptych image by concatenating the input image with a black image of the same size.
143
+
144
+ Args:
145
+ image (PIL.Image.Image): The input image.
146
+
147
+ Returns:
148
+ PIL.Image.Image: The diptych image.
149
+ """
150
  ref_image_np = np.array(image)
151
  diptych_np = np.concatenate([ref_image_np, np.zeros_like(ref_image_np)], axis=1)
152
  return Image.fromarray(diptych_np)
 
252
  randomize_seed: bool,
253
  progress=gr.Progress(track_tqdm=True)
254
  ):
255
+ """
256
+ Calculates the estimated duration for the Spaces GPU based on image dimensions.
257
+
258
+ Args:
259
+ input_image (PIL.Image.Image): The input reference image.
260
+ subject_name (str): Name of the subject for segmentation.
261
+ do_segmentation (bool): Whether to perform segmentation.
262
+ full_prompt (str): The full text prompt.
263
+ attn_enforce (float): Attention enforcement value.
264
+ ctrl_scale (float): ControlNet conditioning scale.
265
+ width (int): Target width of the generated image.
266
+ height (int): Target height of the generated image.
267
+ pixel_offset (int): Padding offset in pixels.
268
+ num_steps (int): Number of inference steps.
269
+ guidance (float): Distilled guidance scale.
270
+ real_guidance (float): Real guidance scale.
271
+ seed (int): Random seed.
272
+ randomize_seed (bool): Whether to randomize the seed.
273
+ progress (gr.Progress): Gradio progress tracker.
274
+
275
+ Returns:
276
+ int: Estimated duration in seconds.
277
+ """
278
  if width > 768 or height > 768:
279
  return 210
280
  else:
 
298
  randomize_seed: bool,
299
  progress=gr.Progress(track_tqdm=True)
300
  ):
301
+ """
302
+ Runs the diptych prompting image generation process.
303
+
304
+ Args:
305
+ input_image (PIL.Image.Image): The input reference image.
306
+ subject_name (str): The name of the subject for segmentation (if `do_segmentation` is True).
307
+ do_segmentation (bool): If True, the subject will be segmented from the reference image.
308
+ full_prompt (str): The complete text prompt used for image generation.
309
+ attn_enforce (float): Controls the attention enforcement in the custom attention processor.
310
+ ctrl_scale (float): The conditioning scale for ControlNet.
311
+ width (int): The desired width of the final generated image.
312
+ height (int): The desired height of the final generated image.
313
+ pixel_offset (int): Padding added around the image during diptych creation.
314
+ num_steps (int): The number of inference steps for the diffusion process.
315
+ guidance (float): The distilled guidance scale for the diffusion process.
316
+ real_guidance (float): The real guidance scale for the diffusion process.
317
+ seed (int): The random seed for reproducibility.
318
+ randomize_seed (bool): If True, a random seed will be used instead of the provided `seed`.
319
+ progress (gr.Progress): Gradio progress tracker to update UI during execution.
320
+
321
+ Returns:
322
+ tuple: A tuple containing:
323
+ - PIL.Image.Image: The final generated image.
324
+ - PIL.Image.Image: The processed reference image (left panel of the diptych).
325
+ - PIL.Image.Image: The full diptych image generated by the pipeline.
326
+ - str: The final prompt used.
327
+ - int: The actual seed used for generation.
328
+
329
+ Raises:
330
+ gr.Error: If a reference image is not uploaded, prompts are empty, or segmentation fails.
331
+ """
332
  if randomize_seed:
333
  actual_seed = random.randint(0, 9223372036854775807)
334
  else:
 
441
  # --- UI Event Handlers ---
442
 
443
  def toggle_mode_visibility(mode_choice):
444
+ """
445
+ Hides/shows the relevant input textboxes based on the selected mode.
446
+
447
+ Args:
448
+ mode_choice (str): The selected generation mode ("Subject-Driven" or "Style-Driven").
449
+
450
+ Returns:
451
+ tuple: Gradio update objects for `subject_driven_group` and `style_driven_group` visibility.
452
+ """
453
  if mode_choice == "Subject-Driven":
454
  return gr.update(visible=True), gr.update(visible=False)
455
  else:
456
  return gr.update(visible=False), gr.update(visible=True)
457
 
458
  def update_derived_fields(mode_choice, subject, style_desc, target):
459
+ """
460
+ Updates the full prompt and segmentation checkbox based on other inputs.
461
+
462
+ Args:
463
+ mode_choice (str): The selected generation mode ("Subject-Driven" or "Style-Driven").
464
+ subject (str): The subject name (relevant for "Subject-Driven" mode).
465
+ style_desc (str): The original style description (relevant for "Style-Driven" mode).
466
+ target (str): The target prompt.
467
+
468
+ Returns:
469
+ tuple: Gradio update objects for `full_prompt` value and `do_segmentation` checkbox value.
470
+ """
471
  if mode_choice == "Subject-Driven":
472
  prompt = f"A diptych with two side-by-side images of same {subject}. On the left, a photo of {subject}. On the right, replicate this {subject} exactly but as {target}"
473
  return gr.update(value=prompt), gr.update(value=True)
 
504
  outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used, seed]
505
  )
506
  def run_subject_driven_example(input_image, subject_name, target_prompt):
507
+ """
508
+ Helper function to run an example for the subject-driven mode.
509
+
510
+ Args:
511
+ input_image (PIL.Image.Image): The input reference image for the example.
512
+ subject_name (str): The subject name for the example.
513
+ target_prompt (str): The target prompt for the example.
514
+
515
+ Returns:
516
+ tuple: The outputs from `run_diptych_prompting`.
517
+ """
518
  # Construct the full prompt for subject-driven mode
519
  full_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, a photo of {subject_name}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
520
 
 
548
  )
549
 
550
  if __name__ == "__main__":
551
+ demo.launch(share=True, debug=True, mcp_server=True)