Spaces:
Sleeping
Sleeping
File size: 2,971 Bytes
6ba63c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import gradio as gr
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from modeling.BaseModel import BaseModel
from modeling import build_model
from utilities.distributed import init_distributed
from utilities.arguments import load_opt_from_config_files
from utilities.constants import BIOMED_CLASSES
from inference_utils.inference import interactive_infer_image
def overlay_masks(image, masks, colors):
overlay = image.copy()
overlay = np.array(overlay, dtype=np.uint8)
for mask, color in zip(masks, colors):
overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
np.uint8
)
return Image.fromarray(overlay)
def generate_colors(n):
cmap = plt.get_cmap("tab10")
colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
return colors
def init_model():
# Download model
model_file = hf_hub_download(
repo_id="microsoft/BiomedParse",
filename="biomedparse_v1.pt",
token=os.getenv("HF_TOKEN"),
)
# Initialize model
conf_files = "configs/biomedparse_inference.yaml"
opt = load_opt_from_config_files([conf_files])
opt = init_distributed(opt)
model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
with torch.no_grad():
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
BIOMED_CLASSES + ["background"], is_eval=True
)
return model
def predict(image, prompts):
if not prompts:
return None
# Convert string input to list
prompts = [p.strip() for p in prompts.split(",")]
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Get predictions
pred_mask = interactive_infer_image(model, image, prompts)
# Generate visualization
colors = generate_colors(len(prompts))
pred_overlay = overlay_masks(
image, [1 * (pred_mask[i] > 0.5) for i in range(len(prompts))], colors
)
return pred_overlay
def run():
global model
model = init_model()
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(
label="Prompts",
placeholder="Enter prompts separated by commas (e.g., neoplastic cells, inflammatory cells)",
),
],
outputs=gr.Image(type="pil", label="Prediction"),
title="BiomedParse Demo",
description="Upload a biomedical image and enter prompts (separated by commas) to detect specific features.",
examples=[
[
"examples/Part_1_516_pathology_breast.png",
"neoplastic cells, inflammatory cells",
]
],
)
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
run()
|