from fastapi import FastAPI, File, UploadFile, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse from fastapi.routing import APIRouter from typing import List import base64 import gdown import io import os import pickle import time import numpy as np from PIL import Image from io import BytesIO segmentationColors = [ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255), (255, 165, 0), (128, 0, 128), (255, 192, 203), (50, 205, 50), (0, 128, 128), (139, 69, 19) ] data_path = os.getenv('DATA_PATH') if not os.path.exists(data_path): url = os.getenv('DATA_URL') file_id = url.split('/')[-2] direct_link = f"https://drive.google.com/uc?id={file_id}" gdown.download(direct_link, data_path, quiet=False) try: with open(data_path, 'rb') as f: data = pickle.load(f) except Exception as e: raise RuntimeError(f"Failed to load data from {data_path}: {e}") app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins or specify your frontend's domain allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Create a router for the API endpoints router = APIRouter() def overlay_mask(base_image, mask_image, color_idx): overlay = np.array(base_image, dtype=np.uint8) mask = np.array(mask_image).astype(bool) if overlay.shape[:2] != mask.shape: raise ValueError("Base image and mask must have the same dimensions.") if not (0 <= color_idx < len(segmentationColors)): raise ValueError(f"Color index {color_idx} is out of bounds.") color = np.array(segmentationColors[color_idx], dtype=np.uint8) overlay[mask] = (overlay[mask] * 0.4 + color * 0.6).astype(np.uint8) return Image.fromarray(overlay) def convert_to_pil(image): if isinstance(image, np.ndarray): return Image.fromarray(image) return image async def return_thumbnails(): thumbnails = [] for item in data: pil_image = convert_to_pil(item['image']) # make a copy thumb_img = pil_image.copy() thumb_img.thumbnail((256, 256)) thumbnails.append(thumb_img) return thumbnails def rgb_to_hex(rgb): return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2]) async def return_state_data(state): image_data = data[state['image_index']] base_image = convert_to_pil(image_data['image']) response = { 'mask_overlayed_image': base_image, 'valid_object_color_tuples': [], 'invalid_objects': [] } mask_data = image_data['mask_data'].get(state['detail_level'], {}) for object_type, mask_info in mask_data.items(): if mask_info['valid']: idx = len(response['valid_object_color_tuples']) if idx in state['object_list']: response['mask_overlayed_image'] = overlay_mask( response['mask_overlayed_image'], mask_info['mask'], idx ) color = segmentationColors[idx] response['valid_object_color_tuples'].append((object_type, rgb_to_hex(color))) else: response['invalid_objects'].append(object_type) buffer = BytesIO() response['mask_overlayed_image'].save(buffer, format="PNG") base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8") response['mask_overlayed_image'] = base64_str return response @router.get("/return_thumbnails") async def return_thumbnails_endpoint(): thumbnails = await return_thumbnails() encoded_images = [] for thumbnail in thumbnails: buffer = BytesIO() thumbnail.save(buffer, format="PNG") base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8") encoded_images.append(base64_str) return JSONResponse(content={"thumbnails": encoded_images}) @router.get("/return_state_data") async def return_state_data_endpoint( image_index: int = Query(...), detail_level: int = Query(...), object_list: str = Query(...) ): if object_list == 'None': object_list = [] else: object_list = [int(x) for x in object_list.split(",")] state = { "image_index": image_index, "detail_level": detail_level, "object_list": object_list, } response = await return_state_data(state) return response # Include the router with a prefix, making endpoints accessible under /api app.include_router(router, prefix="/api") # Serve the React frontend if available frontend_path = "/app/frontend/build" if os.path.exists(frontend_path): app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend") else: print(f"Warning: Frontend build directory '{frontend_path}' does not exist.") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)