#!/usr/bin/env python import os import gradio as gr import numpy as np import PIL.Image import spaces import torch from transformers import VitMatteForImageMatting, VitMatteImageProcessor DESCRIPTION = """\ # [ViTMatte](https://github.com/hustvl/ViTMatte) This is a demo of [ViTMatte](https://github.com/hustvl/ViTMatte), an image matting method that uses Vision Transformers (ViT) to accurately extract the foreground from an image. It predicts a soft alpha matte to help separate the subject from the background — even tricky areas like hair and fur! You've got two ways to get started: ### 🖼️ Option 1: Upload Image & Trimap - Upload your original image. - Upload a **trimap**: a helper image that labels regions as **foreground (white)**, **background (black)**, and **unknown (gray)**. - The trimap must be a **grayscale image** containing only three pixel values: - `0` for **background** - `128` for **unknown** - `255` for **foreground** - The model will use this trimap to generate the alpha matte and extract the foreground. ### ✏️ Option 2: Draw Your Own Trimap - Upload just your image. - Go to the **"Draw Trimap"** tab to start drawing masks. - Use the tools to mark: - **Foreground** (e.g. the subject), - **Unknown** (areas where the boundary is unclear). - Once you're done, click the **"Generate Trimap"** button to generate the trimap from your drawing. ### ✨ Optional: Replace Background Want to swap the background? Just check the **"Replace Background"** option and choose a new background image. The app will blend your extracted subject with the new background seamlessly! """ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1500")) MODEL_ID = os.getenv("MODEL_ID", "hustvl/vitmatte-small-distinctions-646") processor = VitMatteImageProcessor.from_pretrained(MODEL_ID) model = VitMatteForImageMatting.from_pretrained(MODEL_ID).to(device) def resize_input_image(image: PIL.Image.Image | None) -> PIL.Image.Image: if image is None: return None if max(image.size) > MAX_IMAGE_SIZE: w, h = image.size scale = MAX_IMAGE_SIZE / max(w, h) new_w = int(w * scale) new_h = int(h * scale) gr.Info( f"The uploaded image exceeded the maximum resolution limit of {MAX_IMAGE_SIZE}px. It has been resized to {new_w}x{new_h}." ) return image.resize((new_w, new_h)) return image def binarize_mask(mask: np.ndarray) -> np.ndarray: mask[mask > 0] = 1 return mask def update_trimap(foreground_mask_editor: dict, unknown_mask_editor: dict) -> np.ndarray: foreground = foreground_mask_editor["layers"][0] foreground = binarize_mask(foreground) unknown = unknown_mask_editor["layers"][0] unknown = binarize_mask(unknown) trimap = np.zeros_like(foreground) trimap[unknown > 0] = 128 trimap[foreground > 0] = 255 return trimap def adjust_background_image(background_image: PIL.Image.Image, target_size: tuple[int, int]) -> PIL.Image.Image: target_w, target_h = target_size bg_w, bg_h = background_image.size scale = max(target_w / bg_w, target_h / bg_h) new_bg_w = int(bg_w * scale) new_bg_h = int(bg_h * scale) background_image = background_image.resize((new_bg_w, new_bg_h)) left = (new_bg_w - target_w) // 2 top = (new_bg_h - target_h) // 2 right = left + target_w bottom = top + target_h return background_image.crop((left, top, right, bottom)) def replace_background( image: PIL.Image.Image, alpha: np.ndarray, background_image: PIL.Image.Image | None ) -> PIL.Image.Image | None: if background_image is None: return None if image.mode != "RGB": raise gr.Error("Image must be RGB.") background_image = background_image.convert("RGB") background_image = adjust_background_image(background_image, image.size) image = np.array(image).astype(float) / 255 background_image = np.array(background_image).astype(float) / 255 result = image * alpha[:, :, None] + background_image * (1 - alpha[:, :, None]) return (result * 255).astype(np.uint8) @spaces.GPU @torch.inference_mode() def run( image: PIL.Image.Image, trimap: PIL.Image.Image, apply_background_replacement: bool, background_image: PIL.Image.Image | None, ) -> tuple[np.ndarray, PIL.Image.Image, PIL.Image.Image | None]: if image.size != trimap.size: raise gr.Error("Image and trimap must have the same size.") if max(image.size) > MAX_IMAGE_SIZE: error_message = f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels." raise gr.Error(error_message) if image.mode != "RGB": raise gr.Error("Image must be RGB.") if trimap.mode != "L": raise gr.Error("Trimap must be grayscale.") pixel_values = processor(images=image, trimaps=trimap, return_tensors="pt").to(device).pixel_values out = model(pixel_values=pixel_values) alpha = out.alphas[0, 0].to("cpu").numpy() w, h = image.size alpha = alpha[:h, :w] foreground = np.array(image).astype(float) / 255 * alpha[:, :, None] + (1 - alpha[:, :, None]) foreground = (foreground * 255).astype(np.uint8) foreground = PIL.Image.fromarray(foreground) res_bg_replacement = replace_background(image, alpha, background_image) if apply_background_replacement else None return alpha, foreground, res_bg_replacement with gr.Blocks(css_paths="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): with gr.Group(): image = gr.Image(label="Input image", type="pil") with gr.Tabs(): with gr.Tab(label="Trimap"): trimap = gr.Image(label="Trimap", type="pil", image_mode="L") with gr.Tab(label="Draw trimap"): foreground_mask = gr.ImageEditor( label="Foreground", type="numpy", sources=("upload",), transforms=(), image_mode="L", height=500, brush=gr.Brush(default_color=("#00ff00", 0.6)), layers=gr.LayerOptions(allow_additional_layers=False, layers=["Foreground mask"]), ) unknown_mask = gr.ImageEditor( label="Unknown", type="numpy", sources=("upload",), transforms=(), image_mode="L", height=500, brush=gr.Brush(default_color=("#00ff00", 0.6)), layers=gr.LayerOptions(allow_additional_layers=False, layers=["Unknown mask"]), ) generate_trimap_button = gr.Button("Generate trimap") apply_background_replacement = gr.Checkbox(label="Replace background", value=False) background_image = gr.Image(label="Background image", type="pil", visible=False) run_button = gr.Button("Run") with gr.Column(): with gr.Group(): out_alpha = gr.Image(label="Alpha") out_foreground = gr.Image(label="Foreground") out_background_replacement = gr.Image(label="Background replacement", visible=False) inputs = [ image, trimap, apply_background_replacement, background_image, ] outputs = [ out_alpha, out_foreground, out_background_replacement, ] gr.Examples( examples=[ ["assets/retriever_rgb.png", "assets/retriever_trimap.png", False, None], ["assets/bulb_rgb.png", "assets/bulb_trimap.png", True, "assets/new_bg.jpg"], ], inputs=inputs, outputs=outputs, fn=run, cache_examples=False, ) image.input( fn=resize_input_image, inputs=image, outputs=image, api_name=False, ).then( fn=lambda image: (image, image), inputs=image, outputs=[foreground_mask, unknown_mask], api_name=False, ) generate_trimap_button.click( fn=update_trimap, inputs=[foreground_mask, unknown_mask], outputs=trimap, api_name=False, ) apply_background_replacement.change( fn=lambda checked: (gr.Image(visible=checked), gr.Image(visible=checked)), inputs=apply_background_replacement, outputs=[background_image, out_background_replacement], api_name=False, ) run_button.click( fn=run, inputs=inputs, outputs=outputs, ) if __name__ == "__main__": demo.launch()