File size: 2,971 Bytes
6ba63c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import gradio as gr
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from modeling.BaseModel import BaseModel
from modeling import build_model
from utilities.distributed import init_distributed
from utilities.arguments import load_opt_from_config_files
from utilities.constants import BIOMED_CLASSES
from inference_utils.inference import interactive_infer_image


def overlay_masks(image, masks, colors):
    overlay = image.copy()
    overlay = np.array(overlay, dtype=np.uint8)
    for mask, color in zip(masks, colors):
        overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
            np.uint8
        )
    return Image.fromarray(overlay)


def generate_colors(n):
    cmap = plt.get_cmap("tab10")
    colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
    return colors


def init_model():
    # Download model
    model_file = hf_hub_download(
        repo_id="microsoft/BiomedParse",
        filename="biomedparse_v1.pt",
        token=os.getenv("HF_TOKEN"),
    )

    # Initialize model
    conf_files = "configs/biomedparse_inference.yaml"
    opt = load_opt_from_config_files([conf_files])
    opt = init_distributed(opt)

    model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
    with torch.no_grad():
        model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
            BIOMED_CLASSES + ["background"], is_eval=True
        )

    return model


def predict(image, prompts):
    if not prompts:
        return None

    # Convert string input to list
    prompts = [p.strip() for p in prompts.split(",")]

    # Convert to RGB if needed
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Get predictions
    pred_mask = interactive_infer_image(model, image, prompts)

    # Generate visualization
    colors = generate_colors(len(prompts))
    pred_overlay = overlay_masks(
        image, [1 * (pred_mask[i] > 0.5) for i in range(len(prompts))], colors
    )

    return pred_overlay


def run():
    global model
    model = init_model()

    demo = gr.Interface(
        fn=predict,
        inputs=[
            gr.Image(type="pil", label="Input Image"),
            gr.Textbox(
                label="Prompts",
                placeholder="Enter prompts separated by commas (e.g., neoplastic cells, inflammatory cells)",
            ),
        ],
        outputs=gr.Image(type="pil", label="Prediction"),
        title="BiomedParse Demo",
        description="Upload a biomedical image and enter prompts (separated by commas) to detect specific features.",
        examples=[
            [
                "examples/Part_1_516_pathology_breast.png",
                "neoplastic cells, inflammatory cells",
            ]
        ],
    )

    demo.launch(server_name="0.0.0.0", server_port=7860)


if __name__ == "__main__":
    run()