rrnoa's picture
lcm
c794c9c
raw
history blame
1.62 kB
import diffusers
import torch
from fastapi import FastAPI, UploadFile, HTTPException
from PIL import Image
app = FastAPI()
# Inicializa el pipeline al arrancar el servidor
@app.on_event("startup")
async def startup_event():
global pipe
print("[DEBUG] Cargando modelo Marigold...")
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
).to("cuda")
print("[DEBUG] Modelo Marigold cargado exitosamente.")
@app.post("/predict-depth/")
async def predict_depth(file: UploadFile):
try:
# Verifica si el archivo es una imagen v谩lida
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="El archivo subido no es una imagen.")
# Carga la imagen desde el archivo subido
image = Image.open(file.file).convert("RGB")
# Realiza la predicci贸n de profundidad
print("[DEBUG] Realizando predicci贸n de profundidad...")
depth = pipe(image)
# Visualiza la profundidad
vis = pipe.image_processor.visualize_depth(depth.prediction)
output_path = "predicted_depth.png"
vis[0].save(output_path)
return {"message": "Predicci贸n completada", "output_file": output_path}
except Exception as e:
print(f"[ERROR] {str(e)}")
raise HTTPException(status_code=500, detail="Error procesando la imagen.")
@app.get("/")
async def root():
return {"message": "API de generaci贸n de mapas de profundidad con Marigold"}