|
from datasets import load_dataset |
|
from transformers import CLIPProcessor, CLIPModel |
|
import torch, numpy as np, os |
|
from collections import defaultdict |
|
|
|
rename_qsn = { |
|
"Are there any abnormalities in the image? Check all that are present.": "𧬠Abnorm", |
|
"Are there any anatomical landmarks in the image? Check all that are present.": "π Landmark", |
|
"Are there any instruments in the image? Check all that are present.": "π οΈ Instrum", |
|
"Have all polyps been removed?": "β Polyps_Removed", |
|
"Is this finding easy to detect?": "π Easy_Detect", |
|
"Is there a green/black box artefact?": "π© Box_Artifact", |
|
"Is there text?": "π€ Has_Text", |
|
"What type of polyp is present?": "π¬ Polyp_Type", |
|
"What type of procedure is the image taken from?": "π₯ Proc_Type", |
|
"What is the size of the polyp?": "π Polyp_Size", |
|
"How many findings are present?": "π§Ύ Find_Count", |
|
"How many polyps are in the image?": "π’ Polyp_Count", |
|
"Where in the image is the instrument?": "π Instrum_Loc", |
|
"Where in the image is the abnormality?": "π Abnorm_Loc", |
|
"Where in the image is the anatomical landmark?": "π Landmark_Loc", |
|
"How many instrumnets are in the image?": "π’ Instrum_Count", |
|
"What color is the abnormality? If more than one separate with ;": "π¨ Abnorm_Color", |
|
"What color is the anatomical landmark? If more than one separate with ;": "π¨ Landmark_Color", |
|
"Does this image contain any finding?": "πΈ Has_Finding", |
|
"none": "π« Nan", |
|
} |
|
|
|
ds = load_dataset("SimulaMet-HOST/Kvasir-VQA")["raw"] |
|
qas = defaultdict(dict) |
|
for q, a, img_id in zip(ds["question"], ds["answer"], ds["img_id"]): |
|
qas[img_id][rename_qsn[q]] = a |
|
|
|
|
|
|
|
|
|
|
|
log_dir = "logs/projector" |
|
os.makedirs(log_dir, exist_ok=True) |
|
|
|
import math |
|
import numpy as np |
|
from PIL import Image |
|
|
|
def create_sprite_image(dataset, save_path='sprite.png', image_column='image', size=(100, 100), max_images=6500): |
|
imgs = [] |
|
for i, x in enumerate(dataset): |
|
if i >= max_images: |
|
break |
|
img = x[image_column].resize(size).convert('RGB') |
|
imgs.append(np.asarray(img) / 255.0) |
|
|
|
imgs = np.array(imgs) |
|
n = math.ceil(math.sqrt(len(imgs))) |
|
pad = ((0, n**2 - len(imgs)), (0,0), (0,0), (0,0)) |
|
imgs = np.pad(imgs, pad, constant_values=1) |
|
imgs = imgs.reshape((n, n, size[1], size[0], 3)).transpose(0,2,1,3,4).reshape(n*size[1], n*size[0], 3) |
|
Image.fromarray((imgs * 255).astype(np.uint8)).save(save_path) |
|
|
|
dsx = ds.select({v: k for k, v in enumerate(ds['img_id'])}.values()) |
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
def get_emb(batch): |
|
inputs = processor(images=batch["image"], return_tensors="pt", padding=True).to(device) |
|
with torch.no_grad(): |
|
feats = model.get_image_features(**inputs) |
|
return {"emb": (feats / feats.norm(p=2, dim=-1, keepdim=True)).cpu().numpy()} |
|
|
|
dsx = dsx.map(get_emb, batched=True, batch_size=512) |
|
|
|
np.savez_compressed("all_embeddings.npz", |
|
embeddings=np.array(dsx["emb"]), |
|
metadata=np.array(list(zip(dsx["img_id"], dsx["source"], dsx["question"], dsx["answer"])))) |
|
np.savetxt(os.path.join(log_dir, "vectors.tsv"), np.array(dsx["emb"]), delimiter="\t") |
|
|
|
|
|
import tensorflow as tf |
|
|
|
|
|
embeddings_np = np.array(dsx["emb"]) |
|
embedding_tensor = tf.Variable(embeddings_np, name="image_embeddings") |
|
checkpoint = tf.train.Checkpoint(embedding=embedding_tensor) |
|
checkpoint.save(os.path.join(log_dir, "embedding.ckpt")) |
|
|
|
|
|
metadata_path = os.path.join(log_dir, "metadata.tsv") |
|
with open(metadata_path, "w", encoding="utf-8") as f: |
|
f.write("source\tQ/A\timg_hash\n") |
|
for img_id, source, question, answer in zip(dsx["img_id"], dsx["source"], dsx["question"], dsx["answer"]): |
|
img_hash = str(img_id).replace("\t", " ").replace("\n", " ") |
|
img_id = " | ".join(f"{k}: {v}" for k, v in qas.get(img_id, {}).items()) |
|
source = str(source).replace("\t", " ").replace("\n", " ") |
|
question = str(question).replace("\t", " ").replace("\n", " ") |
|
answer = str(answer).replace("\t", " ").replace("\n", " ") |
|
f.write(f"{source}\t{img_id}\t{img_hash}\n") |
|
|
|
from tensorboard.plugins import projector |
|
|
|
config = projector.ProjectorConfig() |
|
embedding = config.embeddings.add() |
|
embedding.tensor_name = embedding_tensor.name |
|
embedding.metadata_path = "metadata.tsv" |
|
embedding.sprite.image_path = "openai__clip-vit-large-patch14-336_sprite.png" |
|
embedding.sprite.single_image_dim.extend([100, 100]) |
|
projector.visualize_embeddings(log_dir, config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("β
All done! Launch TensorBoard using:") |
|
print(f"tensorboard --logdir={log_dir}") |