import os from typing import List, Tuple, Optional from fastapi import FastAPI from pydantic import BaseModel from huggingface_hub import hf_hub_download from llama_cpp import Llama from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType from llama_cpp_agent.providers import LlamaCppPythonProvider from llama_cpp_agent.chat_history import BasicChatHistory from llama_cpp_agent.chat_history.messages import Roles # Suppress warnings import warnings warnings.filterwarnings("ignore") # Ensure models directory exists MODEL_DIR = "./models" os.makedirs(MODEL_DIR, exist_ok=True) # Model info for download MODELS_INFO = [ { "repo_id": "bartowski/Dolphin3.0-Llama3.2-1B-GGUF", "filename": "Dolphin3.0-Llama3.2-1B-Q4_K_M.gguf" }, { "repo_id": "bartowski/Dolphin3.0-Qwen2.5-0.5B-GGUF", "filename": "Dolphin3.0-Qwen2.5-0.5B-Q6_K.gguf" }, { "repo_id": "bartowski/Qwen2.5-Coder-14B-Instruct-GGUF", "filename": "Qwen2.5-Coder-14B-Instruct-Q6_K.gguf" } ] # Download all models if not present for model_info in MODELS_INFO: model_path = os.path.join(MODEL_DIR, model_info["filename"]) if not os.path.exists(model_path): print(f"Downloading {model_info['filename']} from {model_info['repo_id']}...") try: hf_hub_download( repo_id=model_info["repo_id"], filename=model_info["filename"], local_dir=MODEL_DIR ) print(f"Downloaded {model_info['filename']}") except Exception as e: print(f"Error downloading {model_info['filename']}: {e}") # Available model keys (used in API) AVAILABLE_MODELS = { "qwen": "Dolphin3.0-Qwen2.5-0.5B-Q6_K.gguf", "llama": "Dolphin3.0-Llama3.2-1B-Q4_K_M.gguf", "coder": "Qwen2.5-Coder-14B-Instruct-Q6_K.gguf" } # Global LLM instance llm = None llm_model = None def load_model(model_key: str): global llm, llm_model model_name = AVAILABLE_MODELS.get(model_key) if not model_name: raise ValueError(f"Invalid model key: {model_key}") model_path = os.path.join(MODEL_DIR, model_name) if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found at {model_path}") if llm is None or llm_model != model_name: llm = Llama( model_path=model_path, flash_attn=False, n_gpu_layers=0, n_batch=8, n_ctx=2048, n_threads=8, n_threads_batch=8, ) llm_model = model_name return llm class ChatRequest(BaseModel): message: str # Required history: Optional[List[Tuple[str, str]]] = [] # Default: empty list model: Optional[str] = "qwen" # Default model key system_prompt: Optional[str] = "You are Dolphin, a helpful AI assistant." max_tokens: Optional[int] = 1024 temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.95 top_k: Optional[int] = 40 repeat_penalty: Optional[float] = 1.1 class ChatResponse(BaseModel): response: str class ModelInfoResponse(BaseModel): models: List[str] app = FastAPI( title="Dolphin 3.0 LLM API", description="REST API for Dolphin 3.0 models using Llama.cpp backend.", version="1.0", docs_url="/docs", # Only Swagger docs redoc_url=None # Disable ReDoc ) @app.get("/models", response_model=ModelInfoResponse) def get_available_models(): """Returns the list of supported models.""" return {"models": list(AVAILABLE_MODELS.keys())} @app.post("/chat", response_model=ChatResponse) def chat(request: ChatRequest): try: # Load model load_model(request.model) provider = LlamaCppPythonProvider(llm) agent = LlamaCppAgent( provider, system_prompt=request.system_prompt, predefined_messages_formatter_type=MessagesFormatterType.CHATML, ) settings = provider.get_provider_default_settings() settings.temperature = request.temperature settings.top_k = request.top_k settings.top_p = request.top_p settings.max_tokens = request.max_tokens settings.repeat_penalty = request.repeat_penalty messages = BasicChatHistory() # Add history for user_msg, assistant_msg in request.history: messages.add_message({"role": Roles.user, "content": user_msg}) messages.add_message({"role": Roles.assistant, "content": assistant_msg}) # Get response response = agent.get_chat_response( request.message, llm_sampling_settings=settings, chat_history=messages, print_output=False, ) return {"response": response} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)