Spaces:
Running
Running
import requests | |
import logging | |
import os | |
import sys | |
from typing import Optional, Dict, Any | |
from pathlib import Path | |
logger = logging.getLogger(__name__) | |
# For Hugging Face Spaces, we need better import handling | |
MODEL_SUPPORT = False | |
try: | |
# First try direct import | |
from app.llm.model import get_llm_instance | |
MODEL_SUPPORT = True | |
except ImportError: | |
try: | |
# Then try relative import | |
from ..llm.model import get_llm_instance | |
MODEL_SUPPORT = True | |
except ImportError: | |
try: | |
# Then try path-based import as a fallback | |
current_dir = Path(__file__).parent | |
sys.path.append(str(current_dir.parent)) | |
from llm.model import get_llm_instance | |
MODEL_SUPPORT = True | |
except ImportError: | |
logger.warning( | |
"Failed to import local LLM model - direct model usage disabled" | |
) | |
class LLMClient: | |
""" | |
Client for interacting with the LLM service or direct model access. | |
Provides methods to generate text and expand creative prompts. | |
Uses TinyLlama model for efficient prompt expansion. | |
""" | |
def __init__(self, base_url: str = None): | |
""" | |
Initialize the LLM client. | |
Args: | |
base_url: Base URL of the LLM service (optional) | |
""" | |
self.base_url = base_url | |
self.session = requests.Session() | |
self.local_model = None | |
self.spaces_mode = os.environ.get("HF_SPACES", "0") == "1" | |
# For Hugging Face Spaces, we'll use the model directly instead of a service | |
if self.spaces_mode or not self.base_url: | |
if MODEL_SUPPORT: | |
try: | |
logger.info( | |
"Running in Spaces mode, initializing TinyLlama model..." | |
) | |
self.local_model = get_llm_instance() | |
logger.info(f"Local model initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize local model: {str(e)}") | |
else: | |
logger.warning( | |
"No LLM service URL provided and direct model access disabled" | |
) | |
def generate( | |
self, | |
prompt: str, | |
system_prompt: Optional[str] = None, | |
max_tokens: int = 512, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
) -> str: | |
""" | |
Generate text based on a prompt. | |
Args: | |
prompt: The user prompt to generate from | |
system_prompt: Optional system prompt to guide the generation | |
max_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature (higher = more creative) | |
top_p: Top-p sampling parameter | |
Returns: | |
The generated text | |
Raises: | |
Exception: If the request fails | |
""" | |
# Use local model if available (Spaces mode) | |
if self.local_model: | |
try: | |
return self.local_model.generate( | |
prompt=prompt, | |
system_prompt=system_prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
except Exception as e: | |
logger.error(f"Local model generation failed: {str(e)}") | |
return prompt | |
# Fall back to service if base_url is provided | |
if not self.base_url: | |
logger.warning("No LLM service URL and no local model available") | |
return prompt | |
payload = { | |
"prompt": prompt, | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
} | |
if system_prompt: | |
payload["system_prompt"] = system_prompt | |
try: | |
response = self.session.post(f"{self.base_url}/generate", json=payload) | |
response.raise_for_status() | |
return response.json()[ | |
"result" | |
] # Updated to match service.py response format | |
except requests.RequestException as e: | |
logger.error(f"Failed to generate text: {str(e)}") | |
return prompt | |
def expand_prompt(self, prompt: str) -> str: | |
""" | |
Expand a creative prompt with rich details using TinyLlama. | |
Args: | |
prompt: The user's original prompt | |
Returns: | |
An expanded, detailed creative prompt | |
""" | |
# Use local model if available (Spaces mode) | |
if self.local_model: | |
try: | |
return self.local_model.expand_creative_prompt(prompt) | |
except Exception as e: | |
logger.error(f"Local model prompt expansion failed: {str(e)}") | |
return prompt | |
# Fall back to service if base_url is provided | |
if not self.base_url: | |
logger.warning("No LLM service URL and no local model available") | |
return prompt | |
try: | |
response = self.session.post( | |
f"{self.base_url}/expand-prompt", | |
json={"prompt": prompt}, # Updated endpoint to match service.py | |
) | |
response.raise_for_status() | |
return response.json()[ | |
"expanded_prompt" | |
] # Updated to match service.py response format | |
except requests.RequestException as e: | |
logger.error(f"Failed to expand prompt: {str(e)}") | |
return prompt | |
def health_check(self) -> Dict[str, Any]: | |
""" | |
Check if the LLM service is healthy. | |
Returns: | |
Health status information | |
""" | |
if self.local_model: | |
return { | |
"status": "healthy", | |
"mode": "direct_model", | |
"model": "TinyLlama-1.1B-Chat-v1.0", | |
} | |
if not self.base_url: | |
return {"status": "unavailable", "reason": "no_service_url"} | |
try: | |
response = self.session.get(f"{self.base_url}/health") | |
response.raise_for_status() | |
return response.json() | |
except requests.RequestException as e: | |
logger.error(f"Health check failed: {str(e)}") | |
return {"status": "unhealthy", "reason": str(e)} | |