Spaces:
Running
Running
import logging | |
import os | |
import sys | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import Dict, Any, List, Optional | |
from pathlib import Path | |
import time | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
handlers=[ | |
logging.StreamHandler(sys.stdout), | |
logging.FileHandler("llm_service.log", mode="a"), | |
], | |
) | |
logger = logging.getLogger(__name__) | |
# Import the model | |
from .model import get_llm_instance | |
# Initialize model | |
llm = get_llm_instance() | |
# Create FastAPI app | |
app = FastAPI( | |
title="LLM Service API", | |
description="API for interacting with the local LLM", | |
version="1.0.0", | |
) | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, specify actual origins | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class GenerateRequest(BaseModel): | |
prompt: str | |
system_prompt: Optional[str] = None | |
max_tokens: Optional[int] = 512 | |
temperature: Optional[float] = 0.7 | |
top_p: Optional[float] = 0.9 | |
class ExpandPromptRequest(BaseModel): | |
prompt: str | |
def read_root(): | |
return {"status": "ok", "message": "LLM Service is running"} | |
def health_check(): | |
"""Health check endpoint""" | |
return {"status": "healthy", "model": "TinyLlama-1.1B-Chat-v1.0"} | |
def generate_text(request: GenerateRequest): | |
"""Generate text from a prompt""" | |
try: | |
start_time = time.time() | |
logger.info(f"Generating text for prompt: {request.prompt[:50]}...") | |
result = llm.generate( | |
prompt=request.prompt, | |
system_prompt=request.system_prompt, | |
max_tokens=request.max_tokens, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
) | |
elapsed = time.time() - start_time | |
logger.info(f"Generation completed in {elapsed:.2f} seconds") | |
return { | |
"result": result, | |
"elapsed_seconds": elapsed, | |
} | |
except Exception as e: | |
logger.error(f"Error during text generation: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
def expand_prompt(request: ExpandPromptRequest): | |
"""Expand a creative prompt with more detail""" | |
try: | |
start_time = time.time() | |
logger.info(f"Expanding prompt: {request.prompt[:50]}...") | |
expanded_prompt = llm.expand_creative_prompt(request.prompt) | |
elapsed = time.time() - start_time | |
logger.info(f"Prompt expansion completed in {elapsed:.2f} seconds") | |
return { | |
"original_prompt": request.prompt, | |
"expanded_prompt": expanded_prompt, | |
"elapsed_seconds": elapsed, | |
} | |
except Exception as e: | |
logger.error(f"Error during prompt expansion: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Run the service when executed directly | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 8000)) | |
host = os.environ.get("HOST", "0.0.0.0") | |
logger.info(f"Starting LLM service on {host}:{port}") | |
uvicorn.run(app, host=host, port=port) | |