Spaces:
Running
Running
File size: 6,346 Bytes
1d21f23 e366aec f1cc7f4 1d21f23 e366aec 1d21f23 f1cc7f4 e366aec f1cc7f4 2547144 e366aec f1cc7f4 1d21f23 2547144 1d21f23 e366aec 1d21f23 647c8f0 1d21f23 e366aec 1d21f23 e366aec 1d21f23 e366aec 2547144 e366aec 647c8f0 e366aec 2547144 1d21f23 e366aec 2547144 e366aec 2547144 e366aec 2547144 1d21f23 647c8f0 1d21f23 e366aec 1d21f23 647c8f0 1d21f23 e366aec 2547144 e366aec 2547144 1d21f23 647c8f0 1d21f23 647c8f0 1d21f23 e366aec 1d21f23 e366aec 647c8f0 2547144 e366aec 2547144 1d21f23 e366aec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import requests
import logging
import os
import sys
from typing import Optional, Dict, Any
from pathlib import Path
logger = logging.getLogger(__name__)
# For Hugging Face Spaces, we need better import handling
MODEL_SUPPORT = False
try:
# First try direct import
from app.llm.model import get_llm_instance
MODEL_SUPPORT = True
except ImportError:
try:
# Then try relative import
from ..llm.model import get_llm_instance
MODEL_SUPPORT = True
except ImportError:
try:
# Then try path-based import as a fallback
current_dir = Path(__file__).parent
sys.path.append(str(current_dir.parent))
from llm.model import get_llm_instance
MODEL_SUPPORT = True
except ImportError:
logger.warning(
"Failed to import local LLM model - direct model usage disabled"
)
class LLMClient:
"""
Client for interacting with the LLM service or direct model access.
Provides methods to generate text and expand creative prompts.
Uses TinyLlama model for efficient prompt expansion.
"""
def __init__(self, base_url: str = None):
"""
Initialize the LLM client.
Args:
base_url: Base URL of the LLM service (optional)
"""
self.base_url = base_url
self.session = requests.Session()
self.local_model = None
self.spaces_mode = os.environ.get("HF_SPACES", "0") == "1"
# For Hugging Face Spaces, we'll use the model directly instead of a service
if self.spaces_mode or not self.base_url:
if MODEL_SUPPORT:
try:
logger.info(
"Running in Spaces mode, initializing TinyLlama model..."
)
self.local_model = get_llm_instance()
logger.info(f"Local model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize local model: {str(e)}")
else:
logger.warning(
"No LLM service URL provided and direct model access disabled"
)
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
) -> str:
"""
Generate text based on a prompt.
Args:
prompt: The user prompt to generate from
system_prompt: Optional system prompt to guide the generation
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (higher = more creative)
top_p: Top-p sampling parameter
Returns:
The generated text
Raises:
Exception: If the request fails
"""
# Use local model if available (Spaces mode)
if self.local_model:
try:
return self.local_model.generate(
prompt=prompt,
system_prompt=system_prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
except Exception as e:
logger.error(f"Local model generation failed: {str(e)}")
return prompt
# Fall back to service if base_url is provided
if not self.base_url:
logger.warning("No LLM service URL and no local model available")
return prompt
payload = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
if system_prompt:
payload["system_prompt"] = system_prompt
try:
response = self.session.post(f"{self.base_url}/generate", json=payload)
response.raise_for_status()
return response.json()[
"result"
] # Updated to match service.py response format
except requests.RequestException as e:
logger.error(f"Failed to generate text: {str(e)}")
return prompt
def expand_prompt(self, prompt: str) -> str:
"""
Expand a creative prompt with rich details using TinyLlama.
Args:
prompt: The user's original prompt
Returns:
An expanded, detailed creative prompt
"""
# Use local model if available (Spaces mode)
if self.local_model:
try:
return self.local_model.expand_creative_prompt(prompt)
except Exception as e:
logger.error(f"Local model prompt expansion failed: {str(e)}")
return prompt
# Fall back to service if base_url is provided
if not self.base_url:
logger.warning("No LLM service URL and no local model available")
return prompt
try:
response = self.session.post(
f"{self.base_url}/expand-prompt",
json={"prompt": prompt}, # Updated endpoint to match service.py
)
response.raise_for_status()
return response.json()[
"expanded_prompt"
] # Updated to match service.py response format
except requests.RequestException as e:
logger.error(f"Failed to expand prompt: {str(e)}")
return prompt
def health_check(self) -> Dict[str, Any]:
"""
Check if the LLM service is healthy.
Returns:
Health status information
"""
if self.local_model:
return {
"status": "healthy",
"mode": "direct_model",
"model": "TinyLlama-1.1B-Chat-v1.0",
}
if not self.base_url:
return {"status": "unavailable", "reason": "no_service_url"}
try:
response = self.session.get(f"{self.base_url}/health")
response.raise_for_status()
return response.json()
except requests.RequestException as e:
logger.error(f"Health check failed: {str(e)}")
return {"status": "unhealthy", "reason": str(e)}
|