Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved. | |
# | |
# This work is licensed under the LICENSE file | |
# located at the root directory. | |
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
from skimage import filters | |
import matplotlib.pyplot as plt | |
from scipy.ndimage import maximum_filter, label, find_objects | |
def dilate_mask(latents_mask, k, latents_dtype): | |
# Reshape the mask to 2D (64x64) | |
mask_2d = latents_mask.view(64, 64) | |
# Create a square kernel for dilation | |
kernel = torch.ones(2*k+1, 2*k+1, device=mask_2d.device, dtype=mask_2d.dtype) | |
# Add two dimensions to make it compatible with conv2d | |
mask_4d = mask_2d.unsqueeze(0).unsqueeze(0) | |
# Perform dilation using conv2d | |
dilated_mask = F.conv2d(mask_4d, kernel.unsqueeze(0).unsqueeze(0), padding=k) | |
# Threshold the result to get a binary mask | |
dilated_mask = (dilated_mask > 0).to(mask_2d.dtype) | |
# Reshape back to the original shape and convert to the desired dtype | |
dilated_mask = dilated_mask.view(4096, 1).to(latents_dtype) | |
return dilated_mask | |
def clipseg_predict(model, processor, image, text, device): | |
inputs = processor(text=text, images=image, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
preds = outputs.logits.unsqueeze(1) | |
preds = torch.sigmoid(preds) | |
otsu_thr = filters.threshold_otsu(preds.cpu().numpy()) | |
subject_mask = (preds > otsu_thr).float() | |
return subject_mask | |
def grounding_sam_predict(model, processor, sam_predictor, image, text, device): | |
inputs = processor(images=image, text=text, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
results = processor.post_process_grounded_object_detection( | |
outputs, | |
inputs.input_ids, | |
box_threshold=0.4, | |
text_threshold=0.3, | |
target_sizes=[image.size[::-1]] | |
) | |
input_boxes = results[0]["boxes"].cpu().numpy() | |
if input_boxes.shape[0] == 0: | |
return torch.ones((64, 64), device=device) | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
sam_predictor.set_image(image) | |
masks, scores, logits = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_boxes, | |
multimask_output=False, | |
) | |
subject_mask = torch.tensor(masks[0], device=device) | |
return subject_mask | |
def mask_to_box_sam_predict(mask, sam_predictor, image, text, device): | |
H, W = image.size | |
# Resize clipseg mask to image size | |
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W) | |
mask_indices = torch.nonzero(mask) | |
top_left = mask_indices.min(dim=0)[0] | |
bottom_right = mask_indices.max(dim=0)[0] | |
# numpy shape [1,4] | |
input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]]) | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
sam_predictor.set_image(image) | |
masks, scores, logits = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_boxes, | |
multimask_output=True, | |
) | |
# subject_mask = torch.tensor(masks[0], device=device) | |
subject_mask = torch.tensor(np.max(masks, axis=0), device=device) | |
return subject_mask, input_boxes[0] | |
def mask_to_mask_sam_predict(mask, sam_predictor, image, text, device): | |
H, W = (256, 256) | |
# Resize clipseg mask to image size | |
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(1, H, W) | |
mask_input = mask.float().cpu().numpy() | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
sam_predictor.set_image(image) | |
masks, scores, logits = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
mask_input=mask_input, | |
multimask_output=False, | |
) | |
subject_mask = torch.tensor(masks[0], device=device) | |
return subject_mask | |
def mask_to_points_sam_predict(mask, sam_predictor, image, text, device): | |
H, W = image.size | |
# Resize clipseg mask to image size | |
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W) | |
mask_indices = torch.nonzero(mask) | |
# Randomly sample 10 points from the mask | |
n_points = 2 | |
point_coords = mask_indices[torch.randperm(mask_indices.shape[0])[:n_points]].float().cpu().numpy() | |
point_labels = torch.ones((n_points,)).float().cpu().numpy() | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
sam_predictor.set_image(image) | |
masks, scores, logits = sam_predictor.predict( | |
point_coords=point_coords, | |
point_labels=point_labels, | |
multimask_output=False, | |
) | |
subject_mask = torch.tensor(masks[0], device=device) | |
return subject_mask | |
def attention_to_points_sam_predict(subject_attention, subject_mask, sam_predictor, image, text, device): | |
H, W = image.size | |
# Resize clipseg mask to image size | |
subject_attention = F.interpolate(subject_attention.view(1, 1, subject_attention.shape[-2], subject_attention.shape[-1]), size=(H, W), mode='bilinear').view(H, W) | |
subject_mask = F.interpolate(subject_mask.view(1, 1, subject_mask.shape[-2], subject_mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W) | |
# Get mask_bbox | |
subject_mask_indices = torch.nonzero(subject_mask) | |
top_left = subject_mask_indices.min(dim=0)[0] | |
bottom_right = subject_mask_indices.max(dim=0)[0] | |
box_width = bottom_right[1] - top_left[1] | |
box_height = bottom_right[0] - top_left[0] | |
# Define the number of points and minimum distance between points | |
n_points = 3 | |
max_thr = 0.35 | |
max_attention = torch.max(subject_attention) | |
min_distance = max(box_width, box_height) // (n_points + 1) # Adjust this value to control spread | |
# min_distance = max(min_distance, 75) | |
# Initialize list to store selected points | |
selected_points = [] | |
# Create a copy of the attention map | |
remaining_attention = subject_attention.clone() | |
for _ in range(n_points): | |
if remaining_attention.max() < max_thr * max_attention: | |
break | |
# Find the highest attention point | |
point = torch.argmax(remaining_attention) | |
y, x = torch.unravel_index(point, remaining_attention.shape) | |
y, x = y.item(), x.item() | |
# Add the point to our list | |
selected_points.append((x, y)) | |
# Zero out the area around the selected point | |
y_min = max(0, y - min_distance) | |
y_max = min(H, y + min_distance + 1) | |
x_min = max(0, x - min_distance) | |
x_max = min(W, x + min_distance + 1) | |
remaining_attention[y_min:y_max, x_min:x_max] = 0 | |
# Convert selected points to numpy array | |
point_coords = np.array(selected_points) | |
point_labels = np.ones(point_coords.shape[0], dtype=int) | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
sam_predictor.set_image(image) | |
masks, scores, logits = sam_predictor.predict( | |
point_coords=point_coords, | |
point_labels=point_labels, | |
multimask_output=False, | |
) | |
subject_mask = torch.tensor(masks[0], device=device) | |
return subject_mask, point_coords | |
def sam_refine_step(mask, sam_predictor, image, device): | |
mask_indices = torch.nonzero(mask) | |
top_left = mask_indices.min(dim=0)[0] | |
bottom_right = mask_indices.max(dim=0)[0] | |
# numpy shape [1,4] | |
input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]]) | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
sam_predictor.set_image(image) | |
masks, scores, logits = sam_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_boxes, | |
multimask_output=True, | |
) | |
# subject_mask = torch.tensor(masks[0], device=device) | |
subject_mask = torch.tensor(np.max(masks, axis=0), device=device) | |
return subject_mask, input_boxes[0] | |