set-detector / app.py
Oamitai's picture
Upload 10 files
02667ce verified
import io
import torch
import cv2
import base64
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import AutoImageProcessor
from datasets import Image as HFImage
def overlay_boxes(image, boxes, scores, labels, class_names, conf_threshold=0.5):
"""Draw bounding boxes on the image with labels and scores"""
img = np.array(image.copy())
colors = {0: (0, 255, 0)} # Green for card class
for box, score, label in zip(boxes, scores, labels):
if score >= conf_threshold:
x1, y1, x2, y2 = map(int, box)
label_text = f"{class_names[int(label)]}: {score:.2f}"
color = colors.get(int(label), (255, 0, 0))
cv2_image = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.rectangle(cv2_image, (x1, y1), (x2, y2), color, 2)
cv2.putText(cv2_image, label_text, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
img = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
return Image.fromarray(img)
class YOLOv9CardDetector:
def __init__(self):
self.model = YOLO("best.pt")
self.config = {
"class_names": ["card"],
"conf_threshold": 0.5
}
self.image_processor = AutoImageProcessor.from_pretrained(".")
def __call__(self, inputs):
"""Process input for the Hugging Face inference API"""
if isinstance(inputs, HFImage):
image = inputs.convert("RGB")
else:
if isinstance(inputs, dict) and "image" in inputs:
# Handle API input format
image = inputs["image"]
if isinstance(image, str):
image = Image.open(io.BytesIO(base64.b64decode(image)))
else:
image = inputs
# Get predictions from YOLOv9 model
with torch.no_grad():
results = self.model(image)
# Process results
result = results[0]
boxes = result.boxes.xyxy.cpu().numpy()
scores = result.boxes.conf.cpu().numpy()
labels = result.boxes.cls.cpu().numpy()
# Format the output for Hugging Face inference API
output = {
"boxes": boxes.tolist(),
"scores": scores.tolist(),
"labels": labels.tolist(),
"class_names": self.config["class_names"]
}
# If image is provided, also return an annotated image
try:
import cv2
annotated_image = overlay_boxes(
image,
boxes,
scores,
labels,
self.config["class_names"],
self.config["conf_threshold"]
)
buffered = io.BytesIO()
annotated_image.save(buffered, format="JPEG")
output["annotated_image"] = buffered.getvalue()
except ImportError:
# If cv2 is not available, skip image annotation
pass
return output
# Initialize the model at module level for faster inference
detector = YOLOv9CardDetector()
def run_inference(inputs):
"""Entry point for the model"""
return detector(inputs)