File size: 5,318 Bytes
3287ef8 b53ba4b d51e9f2 b53ba4b d51e9f2 b53ba4b d51e9f2 b53ba4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
---
license: mit
datasets:
- WinterSchool/MedificsDataset
language:
- en
metrics:
- accuracy
base_model:
- microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224
tags:
- medical
- clip
- fine-tuned
- zero-shot
---
This repository contains a fine-tuned version of BiomedCLIP (specifically the PubMedBERT_256-vit_base_patch16_224 variant) using OpenCLIP. The model is trained to recognize and classify various medical images (e.g., chest X-rays, histopathology slides) in a zero-shot manner. It was further adapted on a subset of medical data (e.g., from the WinterSchool/MedificsDataset) to enhance performance on specific image classes.
Model Details
Architecture: Vision Transformer (ViT-B/16) + PubMedBERT-based text encoder, loaded through open_clip.
Training Objective: CLIP-style contrastive learning to align medical text prompts with images.
Fine-Tuned On: Selected medical images and text pairs, including X-rays, histopathology images, etc.
Intended Use:
Zero-shot classification of medical images (e.g., “This is a photo of a chest X-ray”).
Exploratory research or educational demos showcasing multi-modal (image-text) alignment in the medical domain.
Usage
Below is a minimal Python snippet using OpenCLIP. Adjust the labels and text prompts as needed:
python
Copy
import torch
import open_clip
from PIL import Image
# 1) Load the fine-tuned model
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
"hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned",
pretrained=None
)
tokenizer = open_clip.get_tokenizer("hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# 2) Example labels
labels = [
"chest X-ray",
"brain MRI",
"bone X-ray",
"squamous cell carcinoma histopathology",
"adenocarcinoma histopathology",
"immunohistochemistry histopathology"
]
# 3) Load and preprocess an image
image_path = "path/to/your_image.jpg"
image = Image.open(image_path).convert("RGB")
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
# 4) Create text prompts & tokenize
text_prompts = [f"This is a photo of a {label}" for label in labels]
tokens = tokenizer(text_prompts).to(device)
# 5) Forward pass
with torch.no_grad():
image_features = model.encode_image(image_tensor)
text_features = model.encode_text(tokens)
logit_scale = model.logit_scale.exp()
logits = (logit_scale * image_features @ text_features.t()).softmax(dim=-1)
# 6) Get predictions
probs = logits[0].cpu().tolist()
for label, prob in zip(labels, probs):
print(f"{label}: {prob:.4f}")
Example Gradio App
You can also deploy a simple Gradio demo:
python
Copy
import gradio as gr
import torch
import open_clip
from PIL import Image
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
"hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned",
pretrained=None
)
tokenizer = open_clip.get_tokenizer("hf-hub:your-username/OpenCLIP-BiomedCLIP-Finetuned")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
labels = ["chest X-ray", "brain MRI", "histopathology", "etc."]
def classify_image(img):
if img is None:
return {}
image_tensor = preprocess_val(img).unsqueeze(0).to(device)
prompts = [f"This is a photo of a {label}" for label in labels]
tokens = tokenizer(prompts).to(device)
with torch.no_grad():
image_feats = model.encode_image(image_tensor)
text_feats = model.encode_text(tokens)
logit_scale = model.logit_scale.exp()
logits = (logit_scale * image_feats @ text_feats.T).softmax(dim=-1)
probs = logits.squeeze().cpu().numpy().tolist()
return {label: float(prob) for label, prob in zip(labels, probs)}
demo = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs="label")
demo.launch()
Performance
Accuracy: Varies based on your specific dataset. This model can effectively classify medical images like chest X-rays or histopathology slides, but performance depends heavily on fine-tuning data coverage.
Potential Limitations:
Ultrasound, CT, MRI or other modalities might not be recognized if not included in training data.
The model may incorrectly label images that fall outside its known categories.
Limitations & Caveats
Not a Medical Device: This model is not FDA-approved or clinically validated. It’s intended for research and educational purposes only.
Data Bias: If the training dataset lacked certain pathologies or modalities, the model may systematically misclassify them.
Security: This model uses standard PyTorch and open_clip. Be mindful of potential vulnerabilities when loading models or code from untrusted sources.
Privacy: If you use patient data, comply with local regulations (HIPAA, GDPR, etc.).
Citation & Acknowledgements
Base Model: BiomedCLIP by Microsoft
OpenCLIP: GitHub – open_clip
Fine-tuning dataset: WinterSchool/MedificsDataset
If you use this model in your research or demos, please cite the above works accordingly.
License
[Specify your license here—e.g., MIT, Apache 2.0, or a custom license.]
Note: Always include disclaimers that this model is not a substitute for professional medical advice and that it may not generalize to all imaging modalities or patient populations. |