Spaces:
Running
Running
cleaned code
Browse files- app/llm/client.py +18 -6
- app/llm/model.py +79 -602
- app/llm/service.py +62 -151
app/llm/client.py
CHANGED
@@ -38,6 +38,7 @@ class LLMClient:
|
|
38 |
"""
|
39 |
Client for interacting with the LLM service or direct model access.
|
40 |
Provides methods to generate text and expand creative prompts.
|
|
|
41 |
"""
|
42 |
|
43 |
def __init__(self, base_url: str = None):
|
@@ -56,7 +57,9 @@ class LLMClient:
|
|
56 |
if self.spaces_mode or not self.base_url:
|
57 |
if MODEL_SUPPORT:
|
58 |
try:
|
59 |
-
logger.info(
|
|
|
|
|
60 |
self.local_model = get_llm_instance()
|
61 |
logger.info(f"Local model initialized successfully")
|
62 |
except Exception as e:
|
@@ -122,14 +125,16 @@ class LLMClient:
|
|
122 |
try:
|
123 |
response = self.session.post(f"{self.base_url}/generate", json=payload)
|
124 |
response.raise_for_status()
|
125 |
-
return response.json()[
|
|
|
|
|
126 |
except requests.RequestException as e:
|
127 |
logger.error(f"Failed to generate text: {str(e)}")
|
128 |
return prompt
|
129 |
|
130 |
def expand_prompt(self, prompt: str) -> str:
|
131 |
"""
|
132 |
-
Expand a creative prompt with rich details.
|
133 |
|
134 |
Args:
|
135 |
prompt: The user's original prompt
|
@@ -152,10 +157,13 @@ class LLMClient:
|
|
152 |
|
153 |
try:
|
154 |
response = self.session.post(
|
155 |
-
f"{self.base_url}/expand",
|
|
|
156 |
)
|
157 |
response.raise_for_status()
|
158 |
-
return response.json()[
|
|
|
|
|
159 |
except requests.RequestException as e:
|
160 |
logger.error(f"Failed to expand prompt: {str(e)}")
|
161 |
return prompt
|
@@ -168,7 +176,11 @@ class LLMClient:
|
|
168 |
Health status information
|
169 |
"""
|
170 |
if self.local_model:
|
171 |
-
return {
|
|
|
|
|
|
|
|
|
172 |
|
173 |
if not self.base_url:
|
174 |
return {"status": "unavailable", "reason": "no_service_url"}
|
|
|
38 |
"""
|
39 |
Client for interacting with the LLM service or direct model access.
|
40 |
Provides methods to generate text and expand creative prompts.
|
41 |
+
Uses TinyLlama model for efficient prompt expansion.
|
42 |
"""
|
43 |
|
44 |
def __init__(self, base_url: str = None):
|
|
|
57 |
if self.spaces_mode or not self.base_url:
|
58 |
if MODEL_SUPPORT:
|
59 |
try:
|
60 |
+
logger.info(
|
61 |
+
"Running in Spaces mode, initializing TinyLlama model..."
|
62 |
+
)
|
63 |
self.local_model = get_llm_instance()
|
64 |
logger.info(f"Local model initialized successfully")
|
65 |
except Exception as e:
|
|
|
125 |
try:
|
126 |
response = self.session.post(f"{self.base_url}/generate", json=payload)
|
127 |
response.raise_for_status()
|
128 |
+
return response.json()[
|
129 |
+
"result"
|
130 |
+
] # Updated to match service.py response format
|
131 |
except requests.RequestException as e:
|
132 |
logger.error(f"Failed to generate text: {str(e)}")
|
133 |
return prompt
|
134 |
|
135 |
def expand_prompt(self, prompt: str) -> str:
|
136 |
"""
|
137 |
+
Expand a creative prompt with rich details using TinyLlama.
|
138 |
|
139 |
Args:
|
140 |
prompt: The user's original prompt
|
|
|
157 |
|
158 |
try:
|
159 |
response = self.session.post(
|
160 |
+
f"{self.base_url}/expand-prompt",
|
161 |
+
json={"prompt": prompt}, # Updated endpoint to match service.py
|
162 |
)
|
163 |
response.raise_for_status()
|
164 |
+
return response.json()[
|
165 |
+
"expanded_prompt"
|
166 |
+
] # Updated to match service.py response format
|
167 |
except requests.RequestException as e:
|
168 |
logger.error(f"Failed to expand prompt: {str(e)}")
|
169 |
return prompt
|
|
|
176 |
Health status information
|
177 |
"""
|
178 |
if self.local_model:
|
179 |
+
return {
|
180 |
+
"status": "healthy",
|
181 |
+
"mode": "direct_model",
|
182 |
+
"model": "TinyLlama-1.1B-Chat-v1.0",
|
183 |
+
}
|
184 |
|
185 |
if not self.base_url:
|
186 |
return {"status": "unavailable", "reason": "no_service_url"}
|
app/llm/model.py
CHANGED
@@ -1,251 +1,62 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
-
import torch
|
4 |
-
import re
|
5 |
-
from typing import Dict, List, Optional, Union
|
6 |
from pathlib import Path
|
7 |
-
import
|
8 |
-
import
|
|
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
from transformers import (
|
15 |
-
AutoTokenizer,
|
16 |
-
AutoModelForCausalLM,
|
17 |
-
AutoModelForSeq2SeqLM,
|
18 |
-
pipeline,
|
19 |
-
AutoConfig,
|
20 |
-
)
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
)
|
28 |
|
29 |
-
#
|
30 |
-
|
31 |
-
from ctransformers import AutoModelForCausalLM as CTAutoModelForCausalLM
|
32 |
|
33 |
-
|
34 |
-
except ImportError:
|
35 |
-
HAS_CTRANSFORMERS = False
|
36 |
-
logger.warning("CTransformers library not found. GGUF models won't be available.")
|
37 |
|
38 |
|
39 |
class LocalLLM:
|
40 |
"""
|
41 |
-
A
|
42 |
-
|
43 |
"""
|
44 |
|
45 |
-
def __init__(
|
46 |
-
self,
|
47 |
-
model_path: str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
48 |
-
model_file: str = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
|
49 |
-
model_type: str = "gguf",
|
50 |
-
model_architecture: str = "causal",
|
51 |
-
device_map: str = "auto",
|
52 |
-
torch_dtype=None,
|
53 |
-
use_quantization: bool = False,
|
54 |
-
):
|
55 |
"""
|
56 |
Initialize the local LLM.
|
57 |
|
58 |
Args:
|
59 |
-
|
60 |
-
model_file: Specific model file to load (for GGUF models)
|
61 |
-
model_type: Type of model ('transformers' or 'gguf')
|
62 |
-
model_architecture: Architecture type ('causal' or 'seq2seq')
|
63 |
-
device_map: Device mapping strategy (default: "auto")
|
64 |
-
torch_dtype: Torch data type (default: float16)
|
65 |
-
use_quantization: Whether to use 8-bit quantization to reduce memory usage
|
66 |
"""
|
67 |
-
self.
|
68 |
-
self.
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
self.
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
# Set torch dtype if using transformers models
|
78 |
-
if torch_dtype is None and self.model_type != "gguf":
|
79 |
-
self.torch_dtype = torch.float16
|
80 |
-
else:
|
81 |
-
self.torch_dtype = torch_dtype
|
82 |
-
|
83 |
-
logger.info(f"Loading LLM from {model_path}")
|
84 |
-
logger.info(
|
85 |
-
f"Model type: {model_type}, architecture: {model_architecture}, model file: {model_file}"
|
86 |
-
)
|
87 |
-
|
88 |
-
# Various loading strategies based on model type
|
89 |
-
if self.model_type == "gguf":
|
90 |
-
self._load_gguf_model()
|
91 |
-
else:
|
92 |
-
self._load_transformers_model()
|
93 |
-
|
94 |
-
def _load_gguf_model(self):
|
95 |
-
"""Load a GGUF model using CTransformers"""
|
96 |
-
if not HAS_CTRANSFORMERS:
|
97 |
-
raise ImportError(
|
98 |
-
"CTransformers library not found but required for GGUF models"
|
99 |
)
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
if
|
108 |
-
|
109 |
-
|
110 |
-
else:
|
111 |
-
context_length = 2048 # Standard context length
|
112 |
-
|
113 |
-
logger.info(f"Using context length: {context_length}, threads: {threads}")
|
114 |
-
|
115 |
-
# If we have a model file specified, use it directly
|
116 |
-
if self.model_file and "/" in self.model_path:
|
117 |
-
logger.info(
|
118 |
-
f"Loading GGUF model from Hugging Face: {self.model_path}/{self.model_file}"
|
119 |
-
)
|
120 |
-
|
121 |
-
# Using the exact pattern from the example
|
122 |
-
self.model = CTAutoModelForCausalLM.from_pretrained(
|
123 |
-
self.model_path,
|
124 |
-
model_file=self.model_file,
|
125 |
-
model_type="llama", # required for Llama/TinyLlama models
|
126 |
-
max_new_tokens=256,
|
127 |
-
context_length=context_length,
|
128 |
-
temperature=0.7,
|
129 |
-
top_p=0.95,
|
130 |
-
repetition_penalty=1.1,
|
131 |
-
threads=threads, # CPU threads
|
132 |
-
)
|
133 |
-
else:
|
134 |
-
# Local path with model
|
135 |
-
logger.info(f"Loading local GGUF model: {self.model_path}")
|
136 |
-
self.model = CTAutoModelForCausalLM.from_pretrained(
|
137 |
-
self.model_path,
|
138 |
-
model_type="llama",
|
139 |
-
)
|
140 |
-
|
141 |
-
logger.info("GGUF model loaded successfully")
|
142 |
-
|
143 |
-
except Exception as e:
|
144 |
-
logger.error(f"Failed to load GGUF model: {str(e)}")
|
145 |
-
raise
|
146 |
-
|
147 |
-
def _load_transformers_model(self):
|
148 |
-
"""Load a model using Hugging Face transformers"""
|
149 |
-
if not HAS_TRANSFORMERS:
|
150 |
-
raise ImportError(
|
151 |
-
"Transformers library not found but required for standard models"
|
152 |
-
)
|
153 |
-
|
154 |
-
try:
|
155 |
-
# When running in Spaces, we need more conservative settings
|
156 |
-
spaces_mode = os.environ.get("HF_SPACES", "0") == "1"
|
157 |
-
|
158 |
-
# Prepare model loading arguments
|
159 |
-
load_kwargs = {
|
160 |
-
"torch_dtype": self.torch_dtype,
|
161 |
-
}
|
162 |
-
|
163 |
-
# Add quantization for memory savings if requested
|
164 |
-
if self.use_quantization:
|
165 |
-
logger.info("Using 8-bit quantization for memory efficiency")
|
166 |
-
load_kwargs.update(
|
167 |
-
{
|
168 |
-
"load_in_8bit": True,
|
169 |
-
"device_map": "auto", # Force auto when using quantization
|
170 |
-
}
|
171 |
-
)
|
172 |
-
else:
|
173 |
-
load_kwargs["device_map"] = self.device_map
|
174 |
-
|
175 |
-
# In Spaces, use more conservative loading options
|
176 |
-
if spaces_mode:
|
177 |
-
logger.info(
|
178 |
-
"Running in Hugging Face Spaces, using minimal memory settings"
|
179 |
-
)
|
180 |
-
load_kwargs.update(
|
181 |
-
{
|
182 |
-
"low_cpu_mem_usage": True,
|
183 |
-
}
|
184 |
-
)
|
185 |
-
|
186 |
-
# Load the tokenizer first - common to both architectures
|
187 |
-
tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
188 |
-
|
189 |
-
# Load the model based on architecture
|
190 |
-
if self.model_architecture == "seq2seq":
|
191 |
-
logger.info("Loading sequence-to-sequence model architecture")
|
192 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
193 |
-
self.model_path, **load_kwargs
|
194 |
-
)
|
195 |
-
self.pipe = pipeline(
|
196 |
-
"text2text-generation",
|
197 |
-
model=model,
|
198 |
-
tokenizer=tokenizer,
|
199 |
-
framework="pt",
|
200 |
-
)
|
201 |
-
else:
|
202 |
-
# Standard causal language model
|
203 |
-
logger.info("Loading causal language model architecture")
|
204 |
-
# Skip the custom config handling for Spaces mode or small models
|
205 |
-
if (
|
206 |
-
spaces_mode
|
207 |
-
or "phi" in self.model_path.lower()
|
208 |
-
or "tiny" in self.model_path.lower()
|
209 |
-
):
|
210 |
-
model = AutoModelForCausalLM.from_pretrained(
|
211 |
-
self.model_path, **load_kwargs
|
212 |
-
)
|
213 |
-
else:
|
214 |
-
# Standard local loading with our custom config handling
|
215 |
-
config = AutoConfig.from_pretrained(self.model_path)
|
216 |
-
|
217 |
-
# Fix the rope_scaling issue for Llama models
|
218 |
-
if hasattr(config, "rope_scaling") and isinstance(
|
219 |
-
config.rope_scaling, dict
|
220 |
-
):
|
221 |
-
config.rope_scaling["type"] = "linear"
|
222 |
-
logger.info("Fixed rope_scaling configuration with type=linear")
|
223 |
-
elif (
|
224 |
-
not hasattr(config, "rope_scaling")
|
225 |
-
and "llama" in self.model_path.lower()
|
226 |
-
):
|
227 |
-
config.rope_scaling = {"type": "linear", "factor": 1.0}
|
228 |
-
logger.info("Added default rope_scaling configuration")
|
229 |
-
|
230 |
-
# Load the model with our fixed config
|
231 |
-
model = AutoModelForCausalLM.from_pretrained(
|
232 |
-
self.model_path, config=config, **load_kwargs
|
233 |
-
)
|
234 |
-
|
235 |
-
# Create text generation pipeline for causal LM
|
236 |
-
self.pipe = pipeline(
|
237 |
-
"text-generation", model=model, tokenizer=tokenizer, framework="pt"
|
238 |
-
)
|
239 |
-
|
240 |
-
# Store the model and tokenizer reference
|
241 |
-
self.model = model
|
242 |
-
self.tokenizer = tokenizer
|
243 |
-
|
244 |
-
logger.info("Transformers model loaded successfully")
|
245 |
|
246 |
-
|
247 |
-
logger.error(f"Failed to load transformers model: {str(e)}")
|
248 |
-
raise
|
249 |
|
250 |
def generate(
|
251 |
self,
|
@@ -256,7 +67,7 @@ class LocalLLM:
|
|
256 |
top_p: float = 0.9,
|
257 |
) -> str:
|
258 |
"""
|
259 |
-
Generate text based on a prompt
|
260 |
|
261 |
Args:
|
262 |
prompt: The user prompt to generate from
|
@@ -268,400 +79,66 @@ class LocalLLM:
|
|
268 |
Returns:
|
269 |
The generated text
|
270 |
"""
|
271 |
-
|
272 |
-
if self.model_type == "gguf":
|
273 |
-
return self._generate_with_gguf(
|
274 |
-
prompt, system_prompt, max_tokens, temperature, top_p
|
275 |
-
)
|
276 |
-
else:
|
277 |
-
return self._generate_with_transformers(
|
278 |
-
prompt, system_prompt, max_tokens, temperature, top_p
|
279 |
-
)
|
280 |
-
|
281 |
-
def _generate_with_gguf(
|
282 |
-
self,
|
283 |
-
prompt: str,
|
284 |
-
system_prompt: Optional[str] = None,
|
285 |
-
max_tokens: int = 512,
|
286 |
-
temperature: float = 0.7,
|
287 |
-
top_p: float = 0.9,
|
288 |
-
) -> str:
|
289 |
-
"""Generate text using GGUF model"""
|
290 |
-
try:
|
291 |
-
# Format prompt for chat completion
|
292 |
-
formatted_prompt = prompt
|
293 |
-
if system_prompt:
|
294 |
-
# Format system and user prompts for chat
|
295 |
-
formatted_prompt = (
|
296 |
-
f"<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n"
|
297 |
-
)
|
298 |
-
|
299 |
-
# Generate from the GGUF model
|
300 |
-
# Use a slightly more conservative max_new_tokens for spaces
|
301 |
-
spaces_mode = os.environ.get("HF_SPACES", "0") == "1"
|
302 |
-
if spaces_mode:
|
303 |
-
max_tokens = min(max_tokens, 256) # Cap at 256 for faster responses
|
304 |
-
|
305 |
-
start_time = os.times().user
|
306 |
-
response = self.model(
|
307 |
-
formatted_prompt,
|
308 |
-
max_new_tokens=max_tokens,
|
309 |
-
temperature=temperature,
|
310 |
-
top_p=top_p,
|
311 |
-
stop=["<|user|>", "<|system|>", "<|end|>"],
|
312 |
-
)
|
313 |
-
end_time = os.times().user
|
314 |
-
generation_time = end_time - start_time
|
315 |
-
logger.info(f"GGUF generation completed in {generation_time:.2f}s")
|
316 |
-
|
317 |
-
return response
|
318 |
-
|
319 |
-
except Exception as e:
|
320 |
-
logger.error(f"Error during GGUF generation: {str(e)}")
|
321 |
-
return ""
|
322 |
-
|
323 |
-
def _generate_with_transformers(
|
324 |
-
self,
|
325 |
-
prompt: str,
|
326 |
-
system_prompt: Optional[str] = None,
|
327 |
-
max_tokens: int = 512,
|
328 |
-
temperature: float = 0.7,
|
329 |
-
top_p: float = 0.9,
|
330 |
-
) -> str:
|
331 |
-
"""Generate text using transformers pipeline"""
|
332 |
-
try:
|
333 |
-
# Handle seq2seq models (like T5)
|
334 |
-
if self.model_architecture == "seq2seq":
|
335 |
-
logger.debug(f"Generating with seq2seq model: {self.model_path}")
|
336 |
-
|
337 |
-
# Format prompt for seq2seq models
|
338 |
-
formatted_prompt = prompt
|
339 |
-
if system_prompt:
|
340 |
-
formatted_prompt = f"{system_prompt}\n\nQuery: {prompt}"
|
341 |
-
|
342 |
-
# T5 models work best with specific task prefixes
|
343 |
-
if (
|
344 |
-
"flan" in self.model_path.lower()
|
345 |
-
and not formatted_prompt.startswith("enhance:")
|
346 |
-
):
|
347 |
-
formatted_prompt = f"enhance: {formatted_prompt}"
|
348 |
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
max_length=max_tokens,
|
353 |
-
temperature=temperature,
|
354 |
-
top_p=top_p,
|
355 |
-
do_sample=True,
|
356 |
-
)
|
357 |
|
358 |
-
|
359 |
-
|
360 |
-
if "generated_text" in outputs[0]:
|
361 |
-
return outputs[0]["generated_text"].strip()
|
362 |
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
has_chat_template = (
|
368 |
-
hasattr(self.tokenizer, "chat_template")
|
369 |
-
and self.tokenizer.chat_template is not None
|
370 |
-
)
|
371 |
-
|
372 |
-
# For models that support chat templates
|
373 |
-
if has_chat_template:
|
374 |
-
# Format messages for chat-style models
|
375 |
-
messages = []
|
376 |
-
# Add system prompt if provided
|
377 |
-
if system_prompt:
|
378 |
-
messages.append({"role": "system", "content": system_prompt})
|
379 |
-
# Add user prompt
|
380 |
-
messages.append({"role": "user", "content": prompt})
|
381 |
-
|
382 |
-
logger.debug(f"Generating with chat messages: {prompt[:100]}...")
|
383 |
-
|
384 |
-
# Generate response using the pipeline
|
385 |
-
outputs = self.pipe(
|
386 |
-
messages,
|
387 |
-
max_new_tokens=max_tokens,
|
388 |
-
temperature=temperature,
|
389 |
-
top_p=top_p,
|
390 |
-
do_sample=True,
|
391 |
-
)
|
392 |
-
|
393 |
-
# Extract the assistant's response
|
394 |
-
response = outputs[0]["generated_text"][-1]["content"]
|
395 |
-
return response
|
396 |
-
|
397 |
-
# For non-chat models (like DistilGPT2)
|
398 |
-
else:
|
399 |
-
logger.debug(f"Using non-chat model format for: {self.model_path}")
|
400 |
-
|
401 |
-
# Format prompt directly for non-chat models
|
402 |
-
formatted_prompt = prompt
|
403 |
-
if system_prompt:
|
404 |
-
formatted_prompt = (
|
405 |
-
f"{system_prompt}\n\nUser: {prompt}\n\nAssistant:"
|
406 |
-
)
|
407 |
-
|
408 |
-
outputs = self.pipe(
|
409 |
-
formatted_prompt,
|
410 |
-
max_new_tokens=max_tokens,
|
411 |
-
temperature=temperature,
|
412 |
-
top_p=top_p,
|
413 |
-
do_sample=True,
|
414 |
-
return_full_text=False, # Only return the generated text, not the prompt
|
415 |
-
)
|
416 |
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
|
|
|
|
|
|
|
|
|
|
421 |
|
422 |
-
|
423 |
-
|
424 |
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
return prompt
|
429 |
|
430 |
-
def expand_creative_prompt(self,
|
431 |
"""
|
432 |
-
|
433 |
-
creative description suitable for image generation.
|
434 |
|
435 |
Args:
|
436 |
-
|
437 |
|
438 |
Returns:
|
439 |
An expanded, detailed creative prompt
|
440 |
"""
|
441 |
-
#
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
"vivid colors",
|
452 |
-
]
|
453 |
-
import random
|
454 |
-
|
455 |
-
# Select 2-3 random expansions
|
456 |
-
selected = random.sample(expansions, k=min(3, len(expansions)))
|
457 |
-
return f"{prompt}, {', '.join(selected)}"
|
458 |
-
|
459 |
-
# For GGUF models like TinyLlama, use a very specific chat format
|
460 |
-
if self.model_type == "gguf":
|
461 |
-
# This system prompt is now much more direct and explicit
|
462 |
-
system_prompt = "You enhance image generation prompts by adding style and quality descriptors."
|
463 |
-
|
464 |
-
user_prompt = f'Transform: "{prompt}" into "{prompt}, [artistic style], [quality details]". Max 40 words. No explanations.'
|
465 |
-
|
466 |
-
# Generate the expanded prompt
|
467 |
-
expanded = self.generate(
|
468 |
-
prompt=user_prompt,
|
469 |
-
system_prompt=system_prompt,
|
470 |
-
max_tokens=100,
|
471 |
-
temperature=0.7,
|
472 |
-
)
|
473 |
-
|
474 |
-
# Post-process the response
|
475 |
-
expanded = self._clean_expansion(prompt, expanded)
|
476 |
-
return expanded
|
477 |
-
|
478 |
-
# For Transformers models
|
479 |
-
else:
|
480 |
-
# Standard approach for causal LMs like GPT-2 or Llama
|
481 |
-
system_prompt = (
|
482 |
-
"You are a prompt engineer that enhances image generation prompts."
|
483 |
-
)
|
484 |
-
|
485 |
-
user_prompt = f'Enhance this prompt for image generation: "{prompt}"\n\nOutput format: "{prompt}, [style], [quality]"\n\nKeep it under 40 words.'
|
486 |
-
|
487 |
-
# Generate the expanded prompt
|
488 |
-
expanded = self.generate(
|
489 |
-
prompt=user_prompt,
|
490 |
-
system_prompt=system_prompt,
|
491 |
-
max_tokens=100,
|
492 |
-
temperature=0.7,
|
493 |
-
)
|
494 |
-
|
495 |
-
# Post-process the response
|
496 |
-
expanded = self._clean_expansion(prompt, expanded)
|
497 |
-
return expanded
|
498 |
-
|
499 |
-
def _clean_expansion(self, original_prompt: str, expanded_text: str) -> str:
|
500 |
-
"""
|
501 |
-
Clean up the expanded prompt text to ensure proper formatting.
|
502 |
-
|
503 |
-
Args:
|
504 |
-
original_prompt: The original prompt for reference
|
505 |
-
expanded_text: The raw expanded text from the model
|
506 |
-
|
507 |
-
Returns:
|
508 |
-
Cleaned and properly formatted prompt
|
509 |
-
"""
|
510 |
-
import re
|
511 |
-
|
512 |
-
# First, handle the common case where TinyLlama outputs multiple variations
|
513 |
-
# Split by instances of the original prompt
|
514 |
-
if expanded_text.lower().count(original_prompt.lower()) > 1:
|
515 |
-
# Multiple variations detected - just use the first one
|
516 |
-
parts = expanded_text.lower().split(original_prompt.lower(), 1)
|
517 |
-
if len(parts) > 1:
|
518 |
-
# Take just the first expansion
|
519 |
-
expanded_text = original_prompt + parts[1]
|
520 |
-
# Find the next occurrence of the prompt and cut everything after it
|
521 |
-
next_prompt_pos = expanded_text.lower().find(
|
522 |
-
original_prompt.lower(), len(original_prompt)
|
523 |
-
)
|
524 |
-
if next_prompt_pos > 0:
|
525 |
-
expanded_text = expanded_text[:next_prompt_pos].strip()
|
526 |
-
|
527 |
-
# First pass: remove obvious instruction text
|
528 |
-
patterns_to_remove = [
|
529 |
-
r"(?i)^\s*(?:output|enhanced prompt|result):\s*", # Remove prefixes like "Output:" or "Enhanced prompt:"
|
530 |
-
r"(?i)\b(?:original prompt|start with|add|use|format|rule|follow|example)\b.*$", # Remove instructions
|
531 |
-
r"^\s*\d+\.?\s*", # Remove numbered list markers
|
532 |
-
r'^["\'](.*)["\']$', # Remove quotes surrounding the entire text
|
533 |
-
]
|
534 |
-
|
535 |
-
for pattern in patterns_to_remove:
|
536 |
-
expanded_text = re.sub(pattern, "", expanded_text, flags=re.MULTILINE)
|
537 |
-
|
538 |
-
# Normalize whitespace
|
539 |
-
expanded_text = " ".join(expanded_text.split())
|
540 |
-
|
541 |
-
# If the expansion doesn't start with the original prompt, add it
|
542 |
-
if not expanded_text.lower().startswith(original_prompt.lower()):
|
543 |
-
if "," in expanded_text and not expanded_text.startswith(","):
|
544 |
-
# Try to find where the original prompt might appear
|
545 |
-
parts = expanded_text.split(",", 1)
|
546 |
-
if original_prompt.lower() in parts[0].lower():
|
547 |
-
# The first part contains the original prompt but with modifications
|
548 |
-
expanded_text = f"{original_prompt}, {parts[1].strip()}"
|
549 |
-
else:
|
550 |
-
expanded_text = f"{original_prompt}, {expanded_text}"
|
551 |
-
else:
|
552 |
-
expanded_text = f"{original_prompt}, {expanded_text}"
|
553 |
-
|
554 |
-
# Remove any duplicated commas
|
555 |
-
expanded_text = re.sub(r",\s*,", ",", expanded_text)
|
556 |
-
|
557 |
-
# Strict length control - limit expansion to approximately 40 words
|
558 |
-
# Count words in the expansion
|
559 |
-
words = expanded_text.split()
|
560 |
-
if len(words) > 40:
|
561 |
-
# Keep original prompt and just enough words to stay under 40
|
562 |
-
prompt_words = len(original_prompt.split())
|
563 |
-
# We need to keep the original prompt and stay under 40 total words
|
564 |
-
allowed_extra_words = 40 - prompt_words
|
565 |
-
# Join the original prompt with the allowed number of additional words
|
566 |
-
expanded_text = " ".join(words[: prompt_words + allowed_extra_words])
|
567 |
-
|
568 |
-
# Check if the expansion still contains instruction-like text or is too repetitive
|
569 |
-
instruction_indicators = [
|
570 |
-
"original prompt",
|
571 |
-
"add only",
|
572 |
-
"rule",
|
573 |
-
"format as",
|
574 |
-
"example",
|
575 |
-
"enhancement:",
|
576 |
-
]
|
577 |
-
if any(
|
578 |
-
indicator in expanded_text.lower() for indicator in instruction_indicators
|
579 |
-
):
|
580 |
-
# Emergency fallback - use hardcoded expansion phrases
|
581 |
-
expansions = [
|
582 |
-
"cinematic lighting",
|
583 |
-
"professional photography",
|
584 |
-
"8k resolution",
|
585 |
-
"dramatic angle",
|
586 |
-
"photorealistic",
|
587 |
-
"highly detailed",
|
588 |
-
"vivid colors",
|
589 |
-
"stunning detail",
|
590 |
-
"artistically composed",
|
591 |
-
"sharp focus",
|
592 |
-
]
|
593 |
-
import random
|
594 |
-
|
595 |
-
# Select 2-3 random expansions
|
596 |
-
selected = random.sample(expansions, k=min(3, len(expansions)))
|
597 |
-
expanded_text = f"{original_prompt}, {', '.join(selected)}"
|
598 |
|
599 |
-
|
600 |
-
return expanded_text
|
601 |
|
602 |
|
603 |
-
def get_llm_instance(
|
604 |
"""
|
605 |
-
|
606 |
-
Returns None if model loading fails, allowing graceful fallback.
|
607 |
-
|
608 |
-
Args:
|
609 |
-
model_path: Optional path to model or HuggingFace model ID
|
610 |
|
611 |
Returns:
|
612 |
-
|
613 |
"""
|
614 |
-
|
615 |
-
if not use_local_model:
|
616 |
-
logger.info("Local model usage is disabled by environment setting")
|
617 |
-
return None
|
618 |
-
|
619 |
-
# Default to environment settings with fallbacks
|
620 |
-
if not model_path:
|
621 |
-
model_path = os.environ.get("MODEL_PATH") or os.environ.get(
|
622 |
-
"MODEL_ID", "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
623 |
-
)
|
624 |
-
|
625 |
-
# Get model file for GGUF models
|
626 |
-
model_file = os.environ.get("MODEL_FILENAME")
|
627 |
-
|
628 |
-
# Check model architecture - T5 models use seq2seq, others use causal LM
|
629 |
-
model_architecture = os.environ.get("MODEL_ARCHITECTURE", "causal").lower()
|
630 |
-
|
631 |
-
# Check model type - prefer GGUF for speed in resource-constrained environments
|
632 |
-
model_type = os.environ.get("MODEL_TYPE", "transformers").lower()
|
633 |
-
|
634 |
-
# Check if quantization is enabled
|
635 |
-
use_quantization = os.environ.get("MODEL_QUANTIZED", "false").lower() == "true"
|
636 |
-
|
637 |
-
try:
|
638 |
-
# Check if the provided path is a local directory
|
639 |
-
if os.path.isdir(model_path):
|
640 |
-
logger.info(f"Using local model directory: {model_path}")
|
641 |
-
else:
|
642 |
-
logger.info(f"Using model ID from Hugging Face: {model_path}")
|
643 |
-
|
644 |
-
# Check available device backends
|
645 |
-
device_map = "auto"
|
646 |
-
torch_dtype = None
|
647 |
-
|
648 |
-
# For Hugging Face Spaces, be more careful about memory usage
|
649 |
-
spaces_mode = os.environ.get("HF_SPACES", "0") == "1"
|
650 |
-
if spaces_mode and model_type != "gguf":
|
651 |
-
logger.info("Running in Hugging Face Spaces, using CPU for stability")
|
652 |
-
# Force CPU for Spaces with transformers models
|
653 |
-
device_map = "cpu" if not use_quantization else "auto"
|
654 |
-
|
655 |
-
# Create the LLM instance with appropriate settings
|
656 |
-
return LocalLLM(
|
657 |
-
model_path=model_path,
|
658 |
-
model_file=model_file,
|
659 |
-
model_type=model_type,
|
660 |
-
model_architecture=model_architecture,
|
661 |
-
device_map=device_map,
|
662 |
-
torch_dtype=torch_dtype,
|
663 |
-
use_quantization=use_quantization,
|
664 |
-
)
|
665 |
-
except Exception as e:
|
666 |
-
logger.error(f"Failed to create LLM instance: {e}")
|
667 |
-
return None
|
|
|
1 |
import os
|
2 |
import logging
|
|
|
|
|
|
|
3 |
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
from typing import Optional, Dict, Any, Union, List
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
+
# Constants for TinyLlama model
|
11 |
+
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# Default system prompt for creative expansion
|
14 |
+
DEFAULT_CREATIVE_SYSTEM_PROMPT = """You are an expert assistant that helps expand creative prompts for image generation.
|
15 |
+
Your goal is to enrich the original prompt with vivid visual details, artistic style suggestions, and composition elements.
|
16 |
+
Focus on visual enhancement only. Keep your response concise (under 100 words) and focus entirely on the expanded prompt.
|
17 |
+
Do not include explanations or comments - only return the enhanced prompt text."""
|
|
|
18 |
|
19 |
+
# Default user template for creative expansion
|
20 |
+
DEFAULT_CREATIVE_USER_TEMPLATE = """Original prompt: {original_prompt}
|
|
|
21 |
|
22 |
+
Please expand this into a detailed, vivid prompt for image generation with rich visual elements, mood, and style."""
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
class LocalLLM:
|
26 |
"""
|
27 |
+
A local LLM implementation using TinyLlama-1.1B-Chat.
|
28 |
+
Provides methods to generate text and expand creative prompts.
|
29 |
"""
|
30 |
|
31 |
+
def __init__(self, model_id: str = MODEL_ID):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
"""
|
33 |
Initialize the local LLM.
|
34 |
|
35 |
Args:
|
36 |
+
model_id: The model ID to load
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
"""
|
38 |
+
self.model_id = model_id
|
39 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
+
|
41 |
+
# Configure quantization for efficient memory usage
|
42 |
+
quantization_config = None
|
43 |
+
if self.device == "cuda":
|
44 |
+
quantization_config = BitsAndBytesConfig(
|
45 |
+
load_in_4bit=True,
|
46 |
+
bnb_4bit_compute_dtype=torch.float16,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
)
|
48 |
|
49 |
+
# Load model and tokenizer
|
50 |
+
logger.info(f"Loading TinyLlama model on {self.device}...")
|
51 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
52 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
53 |
+
self.model_id,
|
54 |
+
quantization_config=quantization_config,
|
55 |
+
device_map="auto" if self.device == "cuda" else None,
|
56 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
57 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
logger.info(f"TinyLlama model loaded successfully")
|
|
|
|
|
60 |
|
61 |
def generate(
|
62 |
self,
|
|
|
67 |
top_p: float = 0.9,
|
68 |
) -> str:
|
69 |
"""
|
70 |
+
Generate text based on a prompt.
|
71 |
|
72 |
Args:
|
73 |
prompt: The user prompt to generate from
|
|
|
79 |
Returns:
|
80 |
The generated text
|
81 |
"""
|
82 |
+
messages = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
+
# Add system message if provided
|
85 |
+
if system_prompt:
|
86 |
+
messages.append({"role": "system", "content": system_prompt})
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
+
# Add user message
|
89 |
+
messages.append({"role": "user", "content": prompt})
|
|
|
|
|
90 |
|
91 |
+
# Format messages for the model
|
92 |
+
prompt_text = self.tokenizer.apply_chat_template(
|
93 |
+
messages, tokenize=False, add_generation_prompt=True
|
94 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
# Generate response
|
97 |
+
inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
|
98 |
+
outputs = self.model.generate(
|
99 |
+
**inputs,
|
100 |
+
max_new_tokens=max_tokens,
|
101 |
+
do_sample=True,
|
102 |
+
temperature=temperature,
|
103 |
+
top_p=top_p,
|
104 |
+
)
|
105 |
|
106 |
+
# Decode the response and extract the assistant's message
|
107 |
+
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
108 |
|
109 |
+
# Extract just the assistant's response
|
110 |
+
assistant_response = full_output[len(prompt_text) :].strip()
|
111 |
+
return assistant_response
|
|
|
112 |
|
113 |
+
def expand_creative_prompt(self, original_prompt: str) -> str:
|
114 |
"""
|
115 |
+
Expand a creative prompt with rich details for better image generation.
|
|
|
116 |
|
117 |
Args:
|
118 |
+
original_prompt: The user's original prompt
|
119 |
|
120 |
Returns:
|
121 |
An expanded, detailed creative prompt
|
122 |
"""
|
123 |
+
# Format the prompt for creative expansion
|
124 |
+
prompt = DEFAULT_CREATIVE_USER_TEMPLATE.format(original_prompt=original_prompt)
|
125 |
+
|
126 |
+
# Generate expanded prompt
|
127 |
+
expanded = self.generate(
|
128 |
+
prompt=prompt,
|
129 |
+
system_prompt=DEFAULT_CREATIVE_SYSTEM_PROMPT,
|
130 |
+
max_tokens=150, # Limit to ensure concise responses
|
131 |
+
temperature=0.7, # Balanced creativity
|
132 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
return expanded
|
|
|
135 |
|
136 |
|
137 |
+
def get_llm_instance() -> LocalLLM:
|
138 |
"""
|
139 |
+
Get an instance of the local LLM.
|
|
|
|
|
|
|
|
|
140 |
|
141 |
Returns:
|
142 |
+
An initialized LocalLLM instance
|
143 |
"""
|
144 |
+
return LocalLLM(model_id=MODEL_ID)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/llm/service.py
CHANGED
@@ -1,139 +1,78 @@
|
|
1 |
-
import os
|
2 |
import logging
|
3 |
-
import
|
4 |
import sys
|
5 |
-
from fastapi import FastAPI, HTTPException
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
from pydantic import BaseModel
|
8 |
-
from typing import Optional
|
9 |
-
import psutil
|
10 |
-
import uvicorn
|
11 |
-
from dotenv import load_dotenv
|
12 |
from pathlib import Path
|
13 |
-
|
14 |
|
15 |
-
# Configure logging
|
16 |
logging.basicConfig(
|
17 |
level=logging.INFO,
|
18 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
19 |
handlers=[
|
20 |
logging.StreamHandler(sys.stdout),
|
21 |
-
logging.FileHandler(
|
22 |
],
|
23 |
)
|
24 |
-
logger = logging.getLogger(
|
25 |
|
26 |
-
#
|
27 |
-
|
28 |
-
if env_path.exists():
|
29 |
-
load_dotenv(dotenv_path=env_path)
|
30 |
-
logger.info(f"Loaded environment variables from {env_path}")
|
31 |
|
|
|
|
|
32 |
|
33 |
-
#
|
34 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
#
|
37 |
app.add_middleware(
|
38 |
CORSMiddleware,
|
39 |
-
allow_origins=["*"],
|
40 |
allow_credentials=True,
|
41 |
allow_methods=["*"],
|
42 |
allow_headers=["*"],
|
43 |
)
|
44 |
|
45 |
|
46 |
-
|
47 |
-
@app.middleware("http")
|
48 |
-
async def log_requests(request: Request, call_next):
|
49 |
-
start_time = time.time()
|
50 |
-
logger.info(f"Request started: {request.method} {request.url.path}")
|
51 |
-
|
52 |
-
response = await call_next(request)
|
53 |
-
|
54 |
-
process_time = time.time() - start_time
|
55 |
-
logger.info(
|
56 |
-
f"Request completed: {request.method} {request.url.path} - Status: {response.status_code} - Duration: {process_time:.4f}s"
|
57 |
-
)
|
58 |
-
|
59 |
-
return response
|
60 |
-
|
61 |
-
|
62 |
-
# Model request and response classes
|
63 |
-
class PromptRequest(BaseModel):
|
64 |
prompt: str
|
65 |
system_prompt: Optional[str] = None
|
66 |
-
max_tokens: int = 512
|
67 |
-
temperature: float = 0.7
|
68 |
-
top_p: float = 0.9
|
69 |
|
70 |
|
71 |
-
class
|
72 |
prompt: str
|
73 |
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
|
79 |
-
# Global LLM instance
|
80 |
-
llm = None
|
81 |
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
@app.on_event("startup")
|
84 |
-
async def startup_event():
|
85 |
-
"""Initialize the LLM on startup"""
|
86 |
-
global llm
|
87 |
-
logger.info("Starting LLM service initialization...")
|
88 |
-
|
89 |
-
# First check for MODEL_PATH (local model), then fall back to MODEL_ID
|
90 |
-
model_path = os.environ.get("MODEL_PATH")
|
91 |
-
if model_path and os.path.isdir(model_path):
|
92 |
-
logger.info(f"Using local model from MODEL_PATH: {model_path}")
|
93 |
-
else:
|
94 |
-
# Fall back to MODEL_ID if MODEL_PATH isn't set or doesn't exist
|
95 |
-
model_path = os.environ.get("MODEL_ID", "meta-llama/Llama-3.2-3B-Instruct")
|
96 |
-
logger.info(f"Using model ID from Hugging Face: {model_path}")
|
97 |
-
|
98 |
-
try:
|
99 |
-
start_time = time.time()
|
100 |
-
llm = get_llm_instance(model_path)
|
101 |
-
init_time = time.time() - start_time
|
102 |
-
|
103 |
-
logger.info(
|
104 |
-
f"LLM initialized successfully with model: {model_path} in {init_time:.2f} seconds"
|
105 |
-
)
|
106 |
-
|
107 |
-
memory = psutil.virtual_memory()
|
108 |
-
logger.info(
|
109 |
-
f"System memory: {memory.percent}% used ({memory.used / (1024**3):.1f}GB / {memory.total / (1024**3):.1f}GB)"
|
110 |
-
)
|
111 |
-
|
112 |
-
except Exception as e:
|
113 |
-
logger.error(f"Failed to initialize LLM: {str(e)}", exc_info=True)
|
114 |
-
raise
|
115 |
-
|
116 |
-
|
117 |
-
@app.post("/generate", response_model=LLMResponse)
|
118 |
-
async def generate_text(request: PromptRequest):
|
119 |
-
"""Generate text based on a prompt"""
|
120 |
-
logger.info(
|
121 |
-
f"Received text generation request, prompt length: {len(request.prompt)} chars"
|
122 |
-
)
|
123 |
-
logger.debug(f"Prompt: {request.prompt[:50]}...")
|
124 |
-
|
125 |
-
if not llm:
|
126 |
-
logger.error("LLM service not initialized when generate endpoint was called")
|
127 |
-
raise HTTPException(status_code=503, detail="LLM service not initialized")
|
128 |
|
|
|
|
|
|
|
129 |
try:
|
130 |
start_time = time.time()
|
|
|
131 |
|
132 |
-
|
133 |
-
f"Generation parameters: max_tokens={request.max_tokens}, temperature={request.temperature}, top_p={request.top_p}"
|
134 |
-
)
|
135 |
-
|
136 |
-
response = llm.generate(
|
137 |
prompt=request.prompt,
|
138 |
system_prompt=request.system_prompt,
|
139 |
max_tokens=request.max_tokens,
|
@@ -141,73 +80,45 @@ async def generate_text(request: PromptRequest):
|
|
141 |
top_p=request.top_p,
|
142 |
)
|
143 |
|
144 |
-
|
145 |
-
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
return LLMResponse(text=response)
|
153 |
except Exception as e:
|
154 |
-
logger.error(f"Error
|
155 |
raise HTTPException(status_code=500, detail=str(e))
|
156 |
|
157 |
|
158 |
-
@app.post("/expand"
|
159 |
-
|
160 |
-
"""Expand a creative prompt with
|
161 |
-
logger.info(f"Received prompt expansion request, prompt: '{request.prompt}'")
|
162 |
-
|
163 |
-
if not llm:
|
164 |
-
logger.error("LLM service not initialized when expand endpoint was called")
|
165 |
-
raise HTTPException(status_code=503, detail="LLM service not initialized")
|
166 |
-
|
167 |
try:
|
168 |
start_time = time.time()
|
|
|
169 |
|
170 |
-
|
171 |
|
172 |
-
|
173 |
-
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
return LLMResponse(text=expanded)
|
182 |
except Exception as e:
|
183 |
-
logger.error(f"Error
|
184 |
raise HTTPException(status_code=500, detail=str(e))
|
185 |
|
186 |
|
187 |
-
|
188 |
-
async def health_check():
|
189 |
-
"""Health check endpoint"""
|
190 |
-
logger.debug("Health check endpoint called")
|
191 |
-
|
192 |
-
if llm:
|
193 |
-
logger.info(f"Health check: LLM service is healthy, model: {llm.model_path}")
|
194 |
-
return {"status": "healthy", "model": llm.model_path}
|
195 |
-
|
196 |
-
logger.warning("Health check: LLM service is still initializing")
|
197 |
-
return {"status": "initializing"}
|
198 |
-
|
199 |
-
|
200 |
-
# Start the service if run directly
|
201 |
if __name__ == "__main__":
|
|
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
logger.warning(
|
208 |
-
"psutil not installed. Some system resource metrics will not be available."
|
209 |
-
)
|
210 |
-
logger.warning("Install with: pip install psutil")
|
211 |
-
|
212 |
-
logger.info("Starting LLM service server")
|
213 |
-
uvicorn.run(app, host="0.0.0.0", port=8001)
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
import sys
|
4 |
+
from fastapi import FastAPI, HTTPException
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from pydantic import BaseModel
|
7 |
+
from typing import Dict, Any, List, Optional
|
|
|
|
|
|
|
8 |
from pathlib import Path
|
9 |
+
import time
|
10 |
|
11 |
+
# Configure logging
|
12 |
logging.basicConfig(
|
13 |
level=logging.INFO,
|
14 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
15 |
handlers=[
|
16 |
logging.StreamHandler(sys.stdout),
|
17 |
+
logging.FileHandler("llm_service.log", mode="a"),
|
18 |
],
|
19 |
)
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
|
22 |
+
# Import the model
|
23 |
+
from .model import get_llm_instance
|
|
|
|
|
|
|
24 |
|
25 |
+
# Initialize model
|
26 |
+
llm = get_llm_instance()
|
27 |
|
28 |
+
# Create FastAPI app
|
29 |
+
app = FastAPI(
|
30 |
+
title="LLM Service API",
|
31 |
+
description="API for interacting with the local LLM",
|
32 |
+
version="1.0.0",
|
33 |
+
)
|
34 |
|
35 |
+
# Configure CORS
|
36 |
app.add_middleware(
|
37 |
CORSMiddleware,
|
38 |
+
allow_origins=["*"], # In production, specify actual origins
|
39 |
allow_credentials=True,
|
40 |
allow_methods=["*"],
|
41 |
allow_headers=["*"],
|
42 |
)
|
43 |
|
44 |
|
45 |
+
class GenerateRequest(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
prompt: str
|
47 |
system_prompt: Optional[str] = None
|
48 |
+
max_tokens: Optional[int] = 512
|
49 |
+
temperature: Optional[float] = 0.7
|
50 |
+
top_p: Optional[float] = 0.9
|
51 |
|
52 |
|
53 |
+
class ExpandPromptRequest(BaseModel):
|
54 |
prompt: str
|
55 |
|
56 |
|
57 |
+
@app.get("/")
|
58 |
+
def read_root():
|
59 |
+
return {"status": "ok", "message": "LLM Service is running"}
|
60 |
|
|
|
|
|
61 |
|
62 |
+
@app.get("/health")
|
63 |
+
def health_check():
|
64 |
+
"""Health check endpoint"""
|
65 |
+
return {"status": "healthy", "model": "TinyLlama-1.1B-Chat-v1.0"}
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
@app.post("/generate")
|
69 |
+
def generate_text(request: GenerateRequest):
|
70 |
+
"""Generate text from a prompt"""
|
71 |
try:
|
72 |
start_time = time.time()
|
73 |
+
logger.info(f"Generating text for prompt: {request.prompt[:50]}...")
|
74 |
|
75 |
+
result = llm.generate(
|
|
|
|
|
|
|
|
|
76 |
prompt=request.prompt,
|
77 |
system_prompt=request.system_prompt,
|
78 |
max_tokens=request.max_tokens,
|
|
|
80 |
top_p=request.top_p,
|
81 |
)
|
82 |
|
83 |
+
elapsed = time.time() - start_time
|
84 |
+
logger.info(f"Generation completed in {elapsed:.2f} seconds")
|
85 |
|
86 |
+
return {
|
87 |
+
"result": result,
|
88 |
+
"elapsed_seconds": elapsed,
|
89 |
+
}
|
|
|
|
|
90 |
except Exception as e:
|
91 |
+
logger.error(f"Error during text generation: {str(e)}")
|
92 |
raise HTTPException(status_code=500, detail=str(e))
|
93 |
|
94 |
|
95 |
+
@app.post("/expand-prompt")
|
96 |
+
def expand_prompt(request: ExpandPromptRequest):
|
97 |
+
"""Expand a creative prompt with more detail"""
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
try:
|
99 |
start_time = time.time()
|
100 |
+
logger.info(f"Expanding prompt: {request.prompt[:50]}...")
|
101 |
|
102 |
+
expanded_prompt = llm.expand_creative_prompt(request.prompt)
|
103 |
|
104 |
+
elapsed = time.time() - start_time
|
105 |
+
logger.info(f"Prompt expansion completed in {elapsed:.2f} seconds")
|
106 |
|
107 |
+
return {
|
108 |
+
"original_prompt": request.prompt,
|
109 |
+
"expanded_prompt": expanded_prompt,
|
110 |
+
"elapsed_seconds": elapsed,
|
111 |
+
}
|
|
|
|
|
112 |
except Exception as e:
|
113 |
+
logger.error(f"Error during prompt expansion: {str(e)}")
|
114 |
raise HTTPException(status_code=500, detail=str(e))
|
115 |
|
116 |
|
117 |
+
# Run the service when executed directly
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
if __name__ == "__main__":
|
119 |
+
import uvicorn
|
120 |
|
121 |
+
port = int(os.environ.get("PORT", 8000))
|
122 |
+
host = os.environ.get("HOST", "0.0.0.0")
|
123 |
+
logger.info(f"Starting LLM service on {host}:{port}")
|
124 |
+
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|