Zoro-chi's picture
cleaned code
647c8f0
import requests
import logging
import os
import sys
from typing import Optional, Dict, Any
from pathlib import Path
logger = logging.getLogger(__name__)
# For Hugging Face Spaces, we need better import handling
MODEL_SUPPORT = False
try:
# First try direct import
from app.llm.model import get_llm_instance
MODEL_SUPPORT = True
except ImportError:
try:
# Then try relative import
from ..llm.model import get_llm_instance
MODEL_SUPPORT = True
except ImportError:
try:
# Then try path-based import as a fallback
current_dir = Path(__file__).parent
sys.path.append(str(current_dir.parent))
from llm.model import get_llm_instance
MODEL_SUPPORT = True
except ImportError:
logger.warning(
"Failed to import local LLM model - direct model usage disabled"
)
class LLMClient:
"""
Client for interacting with the LLM service or direct model access.
Provides methods to generate text and expand creative prompts.
Uses TinyLlama model for efficient prompt expansion.
"""
def __init__(self, base_url: str = None):
"""
Initialize the LLM client.
Args:
base_url: Base URL of the LLM service (optional)
"""
self.base_url = base_url
self.session = requests.Session()
self.local_model = None
self.spaces_mode = os.environ.get("HF_SPACES", "0") == "1"
# For Hugging Face Spaces, we'll use the model directly instead of a service
if self.spaces_mode or not self.base_url:
if MODEL_SUPPORT:
try:
logger.info(
"Running in Spaces mode, initializing TinyLlama model..."
)
self.local_model = get_llm_instance()
logger.info(f"Local model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize local model: {str(e)}")
else:
logger.warning(
"No LLM service URL provided and direct model access disabled"
)
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
) -> str:
"""
Generate text based on a prompt.
Args:
prompt: The user prompt to generate from
system_prompt: Optional system prompt to guide the generation
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (higher = more creative)
top_p: Top-p sampling parameter
Returns:
The generated text
Raises:
Exception: If the request fails
"""
# Use local model if available (Spaces mode)
if self.local_model:
try:
return self.local_model.generate(
prompt=prompt,
system_prompt=system_prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
except Exception as e:
logger.error(f"Local model generation failed: {str(e)}")
return prompt
# Fall back to service if base_url is provided
if not self.base_url:
logger.warning("No LLM service URL and no local model available")
return prompt
payload = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
if system_prompt:
payload["system_prompt"] = system_prompt
try:
response = self.session.post(f"{self.base_url}/generate", json=payload)
response.raise_for_status()
return response.json()[
"result"
] # Updated to match service.py response format
except requests.RequestException as e:
logger.error(f"Failed to generate text: {str(e)}")
return prompt
def expand_prompt(self, prompt: str) -> str:
"""
Expand a creative prompt with rich details using TinyLlama.
Args:
prompt: The user's original prompt
Returns:
An expanded, detailed creative prompt
"""
# Use local model if available (Spaces mode)
if self.local_model:
try:
return self.local_model.expand_creative_prompt(prompt)
except Exception as e:
logger.error(f"Local model prompt expansion failed: {str(e)}")
return prompt
# Fall back to service if base_url is provided
if not self.base_url:
logger.warning("No LLM service URL and no local model available")
return prompt
try:
response = self.session.post(
f"{self.base_url}/expand-prompt",
json={"prompt": prompt}, # Updated endpoint to match service.py
)
response.raise_for_status()
return response.json()[
"expanded_prompt"
] # Updated to match service.py response format
except requests.RequestException as e:
logger.error(f"Failed to expand prompt: {str(e)}")
return prompt
def health_check(self) -> Dict[str, Any]:
"""
Check if the LLM service is healthy.
Returns:
Health status information
"""
if self.local_model:
return {
"status": "healthy",
"mode": "direct_model",
"model": "TinyLlama-1.1B-Chat-v1.0",
}
if not self.base_url:
return {"status": "unavailable", "reason": "no_service_url"}
try:
response = self.session.get(f"{self.base_url}/health")
response.raise_for_status()
return response.json()
except requests.RequestException as e:
logger.error(f"Health check failed: {str(e)}")
return {"status": "unhealthy", "reason": str(e)}