Zoro-chi commited on
Commit
647c8f0
·
1 Parent(s): c0b9c7a

cleaned code

Browse files
Files changed (3) hide show
  1. app/llm/client.py +18 -6
  2. app/llm/model.py +79 -602
  3. app/llm/service.py +62 -151
app/llm/client.py CHANGED
@@ -38,6 +38,7 @@ class LLMClient:
38
  """
39
  Client for interacting with the LLM service or direct model access.
40
  Provides methods to generate text and expand creative prompts.
 
41
  """
42
 
43
  def __init__(self, base_url: str = None):
@@ -56,7 +57,9 @@ class LLMClient:
56
  if self.spaces_mode or not self.base_url:
57
  if MODEL_SUPPORT:
58
  try:
59
- logger.info("Running in Spaces mode, initializing local model...")
 
 
60
  self.local_model = get_llm_instance()
61
  logger.info(f"Local model initialized successfully")
62
  except Exception as e:
@@ -122,14 +125,16 @@ class LLMClient:
122
  try:
123
  response = self.session.post(f"{self.base_url}/generate", json=payload)
124
  response.raise_for_status()
125
- return response.json()["text"]
 
 
126
  except requests.RequestException as e:
127
  logger.error(f"Failed to generate text: {str(e)}")
128
  return prompt
129
 
130
  def expand_prompt(self, prompt: str) -> str:
131
  """
132
- Expand a creative prompt with rich details.
133
 
134
  Args:
135
  prompt: The user's original prompt
@@ -152,10 +157,13 @@ class LLMClient:
152
 
153
  try:
154
  response = self.session.post(
155
- f"{self.base_url}/expand", json={"prompt": prompt}
 
156
  )
157
  response.raise_for_status()
158
- return response.json()["text"]
 
 
159
  except requests.RequestException as e:
160
  logger.error(f"Failed to expand prompt: {str(e)}")
161
  return prompt
@@ -168,7 +176,11 @@ class LLMClient:
168
  Health status information
169
  """
170
  if self.local_model:
171
- return {"status": "healthy", "mode": "direct_model"}
 
 
 
 
172
 
173
  if not self.base_url:
174
  return {"status": "unavailable", "reason": "no_service_url"}
 
38
  """
39
  Client for interacting with the LLM service or direct model access.
40
  Provides methods to generate text and expand creative prompts.
41
+ Uses TinyLlama model for efficient prompt expansion.
42
  """
43
 
44
  def __init__(self, base_url: str = None):
 
57
  if self.spaces_mode or not self.base_url:
58
  if MODEL_SUPPORT:
59
  try:
60
+ logger.info(
61
+ "Running in Spaces mode, initializing TinyLlama model..."
62
+ )
63
  self.local_model = get_llm_instance()
64
  logger.info(f"Local model initialized successfully")
65
  except Exception as e:
 
125
  try:
126
  response = self.session.post(f"{self.base_url}/generate", json=payload)
127
  response.raise_for_status()
128
+ return response.json()[
129
+ "result"
130
+ ] # Updated to match service.py response format
131
  except requests.RequestException as e:
132
  logger.error(f"Failed to generate text: {str(e)}")
133
  return prompt
134
 
135
  def expand_prompt(self, prompt: str) -> str:
136
  """
137
+ Expand a creative prompt with rich details using TinyLlama.
138
 
139
  Args:
140
  prompt: The user's original prompt
 
157
 
158
  try:
159
  response = self.session.post(
160
+ f"{self.base_url}/expand-prompt",
161
+ json={"prompt": prompt}, # Updated endpoint to match service.py
162
  )
163
  response.raise_for_status()
164
+ return response.json()[
165
+ "expanded_prompt"
166
+ ] # Updated to match service.py response format
167
  except requests.RequestException as e:
168
  logger.error(f"Failed to expand prompt: {str(e)}")
169
  return prompt
 
176
  Health status information
177
  """
178
  if self.local_model:
179
+ return {
180
+ "status": "healthy",
181
+ "mode": "direct_model",
182
+ "model": "TinyLlama-1.1B-Chat-v1.0",
183
+ }
184
 
185
  if not self.base_url:
186
  return {"status": "unavailable", "reason": "no_service_url"}
app/llm/model.py CHANGED
@@ -1,251 +1,62 @@
1
  import os
2
  import logging
3
- import torch
4
- import re
5
- from typing import Dict, List, Optional, Union
6
  from pathlib import Path
7
- import json
8
- import tempfile
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
- # Try to import transformers and ctransformers
13
- try:
14
- from transformers import (
15
- AutoTokenizer,
16
- AutoModelForCausalLM,
17
- AutoModelForSeq2SeqLM,
18
- pipeline,
19
- AutoConfig,
20
- )
21
 
22
- HAS_TRANSFORMERS = True
23
- except ImportError:
24
- HAS_TRANSFORMERS = False
25
- logger.warning(
26
- "Transformers library not found. Standard models won't be available."
27
- )
28
 
29
- # Try to import ctransformers for GGUF support
30
- try:
31
- from ctransformers import AutoModelForCausalLM as CTAutoModelForCausalLM
32
 
33
- HAS_CTRANSFORMERS = True
34
- except ImportError:
35
- HAS_CTRANSFORMERS = False
36
- logger.warning("CTransformers library not found. GGUF models won't be available.")
37
 
38
 
39
  class LocalLLM:
40
  """
41
- A wrapper for running local LLMs using either Hugging Face Transformers or CTransformers.
42
- Optimized for creative prompt expansion and interpretation.
43
  """
44
 
45
- def __init__(
46
- self,
47
- model_path: str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
48
- model_file: str = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
49
- model_type: str = "gguf",
50
- model_architecture: str = "causal",
51
- device_map: str = "auto",
52
- torch_dtype=None,
53
- use_quantization: bool = False,
54
- ):
55
  """
56
  Initialize the local LLM.
57
 
58
  Args:
59
- model_path: Path to model or HuggingFace model ID
60
- model_file: Specific model file to load (for GGUF models)
61
- model_type: Type of model ('transformers' or 'gguf')
62
- model_architecture: Architecture type ('causal' or 'seq2seq')
63
- device_map: Device mapping strategy (default: "auto")
64
- torch_dtype: Torch data type (default: float16)
65
- use_quantization: Whether to use 8-bit quantization to reduce memory usage
66
  """
67
- self.model_path = model_path
68
- self.model_file = model_file
69
- self.model_type = model_type.lower()
70
- self.model_architecture = model_architecture.lower()
71
- self.device_map = device_map
72
- self.use_quantization = use_quantization
73
- self.pipe = None
74
- self.model = None
75
- self.tokenizer = None
76
-
77
- # Set torch dtype if using transformers models
78
- if torch_dtype is None and self.model_type != "gguf":
79
- self.torch_dtype = torch.float16
80
- else:
81
- self.torch_dtype = torch_dtype
82
-
83
- logger.info(f"Loading LLM from {model_path}")
84
- logger.info(
85
- f"Model type: {model_type}, architecture: {model_architecture}, model file: {model_file}"
86
- )
87
-
88
- # Various loading strategies based on model type
89
- if self.model_type == "gguf":
90
- self._load_gguf_model()
91
- else:
92
- self._load_transformers_model()
93
-
94
- def _load_gguf_model(self):
95
- """Load a GGUF model using CTransformers"""
96
- if not HAS_CTRANSFORMERS:
97
- raise ImportError(
98
- "CTransformers library not found but required for GGUF models"
99
  )
100
 
101
- try:
102
- # Handle spaces and CPU constraints
103
- spaces_mode = os.environ.get("HF_SPACES", "0") == "1"
104
- threads = int(os.environ.get("MODEL_CPU_THREADS", "4"))
105
-
106
- # Very optimized settings for spaces
107
- if spaces_mode:
108
- logger.info("Using optimized settings for Spaces environment")
109
- context_length = 512 # Use shorter context for speed
110
- else:
111
- context_length = 2048 # Standard context length
112
-
113
- logger.info(f"Using context length: {context_length}, threads: {threads}")
114
-
115
- # If we have a model file specified, use it directly
116
- if self.model_file and "/" in self.model_path:
117
- logger.info(
118
- f"Loading GGUF model from Hugging Face: {self.model_path}/{self.model_file}"
119
- )
120
-
121
- # Using the exact pattern from the example
122
- self.model = CTAutoModelForCausalLM.from_pretrained(
123
- self.model_path,
124
- model_file=self.model_file,
125
- model_type="llama", # required for Llama/TinyLlama models
126
- max_new_tokens=256,
127
- context_length=context_length,
128
- temperature=0.7,
129
- top_p=0.95,
130
- repetition_penalty=1.1,
131
- threads=threads, # CPU threads
132
- )
133
- else:
134
- # Local path with model
135
- logger.info(f"Loading local GGUF model: {self.model_path}")
136
- self.model = CTAutoModelForCausalLM.from_pretrained(
137
- self.model_path,
138
- model_type="llama",
139
- )
140
-
141
- logger.info("GGUF model loaded successfully")
142
-
143
- except Exception as e:
144
- logger.error(f"Failed to load GGUF model: {str(e)}")
145
- raise
146
-
147
- def _load_transformers_model(self):
148
- """Load a model using Hugging Face transformers"""
149
- if not HAS_TRANSFORMERS:
150
- raise ImportError(
151
- "Transformers library not found but required for standard models"
152
- )
153
-
154
- try:
155
- # When running in Spaces, we need more conservative settings
156
- spaces_mode = os.environ.get("HF_SPACES", "0") == "1"
157
-
158
- # Prepare model loading arguments
159
- load_kwargs = {
160
- "torch_dtype": self.torch_dtype,
161
- }
162
-
163
- # Add quantization for memory savings if requested
164
- if self.use_quantization:
165
- logger.info("Using 8-bit quantization for memory efficiency")
166
- load_kwargs.update(
167
- {
168
- "load_in_8bit": True,
169
- "device_map": "auto", # Force auto when using quantization
170
- }
171
- )
172
- else:
173
- load_kwargs["device_map"] = self.device_map
174
-
175
- # In Spaces, use more conservative loading options
176
- if spaces_mode:
177
- logger.info(
178
- "Running in Hugging Face Spaces, using minimal memory settings"
179
- )
180
- load_kwargs.update(
181
- {
182
- "low_cpu_mem_usage": True,
183
- }
184
- )
185
-
186
- # Load the tokenizer first - common to both architectures
187
- tokenizer = AutoTokenizer.from_pretrained(self.model_path)
188
-
189
- # Load the model based on architecture
190
- if self.model_architecture == "seq2seq":
191
- logger.info("Loading sequence-to-sequence model architecture")
192
- model = AutoModelForSeq2SeqLM.from_pretrained(
193
- self.model_path, **load_kwargs
194
- )
195
- self.pipe = pipeline(
196
- "text2text-generation",
197
- model=model,
198
- tokenizer=tokenizer,
199
- framework="pt",
200
- )
201
- else:
202
- # Standard causal language model
203
- logger.info("Loading causal language model architecture")
204
- # Skip the custom config handling for Spaces mode or small models
205
- if (
206
- spaces_mode
207
- or "phi" in self.model_path.lower()
208
- or "tiny" in self.model_path.lower()
209
- ):
210
- model = AutoModelForCausalLM.from_pretrained(
211
- self.model_path, **load_kwargs
212
- )
213
- else:
214
- # Standard local loading with our custom config handling
215
- config = AutoConfig.from_pretrained(self.model_path)
216
-
217
- # Fix the rope_scaling issue for Llama models
218
- if hasattr(config, "rope_scaling") and isinstance(
219
- config.rope_scaling, dict
220
- ):
221
- config.rope_scaling["type"] = "linear"
222
- logger.info("Fixed rope_scaling configuration with type=linear")
223
- elif (
224
- not hasattr(config, "rope_scaling")
225
- and "llama" in self.model_path.lower()
226
- ):
227
- config.rope_scaling = {"type": "linear", "factor": 1.0}
228
- logger.info("Added default rope_scaling configuration")
229
-
230
- # Load the model with our fixed config
231
- model = AutoModelForCausalLM.from_pretrained(
232
- self.model_path, config=config, **load_kwargs
233
- )
234
-
235
- # Create text generation pipeline for causal LM
236
- self.pipe = pipeline(
237
- "text-generation", model=model, tokenizer=tokenizer, framework="pt"
238
- )
239
-
240
- # Store the model and tokenizer reference
241
- self.model = model
242
- self.tokenizer = tokenizer
243
-
244
- logger.info("Transformers model loaded successfully")
245
 
246
- except Exception as e:
247
- logger.error(f"Failed to load transformers model: {str(e)}")
248
- raise
249
 
250
  def generate(
251
  self,
@@ -256,7 +67,7 @@ class LocalLLM:
256
  top_p: float = 0.9,
257
  ) -> str:
258
  """
259
- Generate text based on a prompt with the local LLM.
260
 
261
  Args:
262
  prompt: The user prompt to generate from
@@ -268,400 +79,66 @@ class LocalLLM:
268
  Returns:
269
  The generated text
270
  """
271
- # Different handling based on model type
272
- if self.model_type == "gguf":
273
- return self._generate_with_gguf(
274
- prompt, system_prompt, max_tokens, temperature, top_p
275
- )
276
- else:
277
- return self._generate_with_transformers(
278
- prompt, system_prompt, max_tokens, temperature, top_p
279
- )
280
-
281
- def _generate_with_gguf(
282
- self,
283
- prompt: str,
284
- system_prompt: Optional[str] = None,
285
- max_tokens: int = 512,
286
- temperature: float = 0.7,
287
- top_p: float = 0.9,
288
- ) -> str:
289
- """Generate text using GGUF model"""
290
- try:
291
- # Format prompt for chat completion
292
- formatted_prompt = prompt
293
- if system_prompt:
294
- # Format system and user prompts for chat
295
- formatted_prompt = (
296
- f"<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n"
297
- )
298
-
299
- # Generate from the GGUF model
300
- # Use a slightly more conservative max_new_tokens for spaces
301
- spaces_mode = os.environ.get("HF_SPACES", "0") == "1"
302
- if spaces_mode:
303
- max_tokens = min(max_tokens, 256) # Cap at 256 for faster responses
304
-
305
- start_time = os.times().user
306
- response = self.model(
307
- formatted_prompt,
308
- max_new_tokens=max_tokens,
309
- temperature=temperature,
310
- top_p=top_p,
311
- stop=["<|user|>", "<|system|>", "<|end|>"],
312
- )
313
- end_time = os.times().user
314
- generation_time = end_time - start_time
315
- logger.info(f"GGUF generation completed in {generation_time:.2f}s")
316
-
317
- return response
318
-
319
- except Exception as e:
320
- logger.error(f"Error during GGUF generation: {str(e)}")
321
- return ""
322
-
323
- def _generate_with_transformers(
324
- self,
325
- prompt: str,
326
- system_prompt: Optional[str] = None,
327
- max_tokens: int = 512,
328
- temperature: float = 0.7,
329
- top_p: float = 0.9,
330
- ) -> str:
331
- """Generate text using transformers pipeline"""
332
- try:
333
- # Handle seq2seq models (like T5)
334
- if self.model_architecture == "seq2seq":
335
- logger.debug(f"Generating with seq2seq model: {self.model_path}")
336
-
337
- # Format prompt for seq2seq models
338
- formatted_prompt = prompt
339
- if system_prompt:
340
- formatted_prompt = f"{system_prompt}\n\nQuery: {prompt}"
341
-
342
- # T5 models work best with specific task prefixes
343
- if (
344
- "flan" in self.model_path.lower()
345
- and not formatted_prompt.startswith("enhance:")
346
- ):
347
- formatted_prompt = f"enhance: {formatted_prompt}"
348
 
349
- # Generate with seq2seq model
350
- outputs = self.pipe(
351
- formatted_prompt,
352
- max_length=max_tokens,
353
- temperature=temperature,
354
- top_p=top_p,
355
- do_sample=True,
356
- )
357
 
358
- # Extract the generated text
359
- if isinstance(outputs, list) and len(outputs) > 0:
360
- if "generated_text" in outputs[0]:
361
- return outputs[0]["generated_text"].strip()
362
 
363
- # Fallback extraction
364
- return str(outputs).strip()
365
-
366
- # Check if the model can handle chat templates
367
- has_chat_template = (
368
- hasattr(self.tokenizer, "chat_template")
369
- and self.tokenizer.chat_template is not None
370
- )
371
-
372
- # For models that support chat templates
373
- if has_chat_template:
374
- # Format messages for chat-style models
375
- messages = []
376
- # Add system prompt if provided
377
- if system_prompt:
378
- messages.append({"role": "system", "content": system_prompt})
379
- # Add user prompt
380
- messages.append({"role": "user", "content": prompt})
381
-
382
- logger.debug(f"Generating with chat messages: {prompt[:100]}...")
383
-
384
- # Generate response using the pipeline
385
- outputs = self.pipe(
386
- messages,
387
- max_new_tokens=max_tokens,
388
- temperature=temperature,
389
- top_p=top_p,
390
- do_sample=True,
391
- )
392
-
393
- # Extract the assistant's response
394
- response = outputs[0]["generated_text"][-1]["content"]
395
- return response
396
-
397
- # For non-chat models (like DistilGPT2)
398
- else:
399
- logger.debug(f"Using non-chat model format for: {self.model_path}")
400
-
401
- # Format prompt directly for non-chat models
402
- formatted_prompt = prompt
403
- if system_prompt:
404
- formatted_prompt = (
405
- f"{system_prompt}\n\nUser: {prompt}\n\nAssistant:"
406
- )
407
-
408
- outputs = self.pipe(
409
- formatted_prompt,
410
- max_new_tokens=max_tokens,
411
- temperature=temperature,
412
- top_p=top_p,
413
- do_sample=True,
414
- return_full_text=False, # Only return the generated text, not the prompt
415
- )
416
 
417
- # Extract just the generated text
418
- if isinstance(outputs, list) and len(outputs) > 0:
419
- if "generated_text" in outputs[0]:
420
- return outputs[0]["generated_text"].strip()
 
 
 
 
 
421
 
422
- # Fallback - return whatever we got
423
- return str(outputs).strip()
424
 
425
- except Exception as e:
426
- logger.error(f"Error during transformers generation: {str(e)}")
427
- # Return the original prompt on error as a fallback (non-empty result)
428
- return prompt
429
 
430
- def expand_creative_prompt(self, prompt: str) -> str:
431
  """
432
- Specifically designed to expand a user prompt into a more detailed,
433
- creative description suitable for image generation.
434
 
435
  Args:
436
- prompt: The user's original prompt
437
 
438
  Returns:
439
  An expanded, detailed creative prompt
440
  """
441
- # For seq2seq models like T5, use a template-first approach
442
- if self.model_architecture == "seq2seq":
443
- # Get some standard enhancement phrases
444
- expansions = [
445
- "cinematic lighting",
446
- "professional photography",
447
- "8k resolution",
448
- "dramatic angle",
449
- "photorealistic",
450
- "highly detailed",
451
- "vivid colors",
452
- ]
453
- import random
454
-
455
- # Select 2-3 random expansions
456
- selected = random.sample(expansions, k=min(3, len(expansions)))
457
- return f"{prompt}, {', '.join(selected)}"
458
-
459
- # For GGUF models like TinyLlama, use a very specific chat format
460
- if self.model_type == "gguf":
461
- # This system prompt is now much more direct and explicit
462
- system_prompt = "You enhance image generation prompts by adding style and quality descriptors."
463
-
464
- user_prompt = f'Transform: "{prompt}" into "{prompt}, [artistic style], [quality details]". Max 40 words. No explanations.'
465
-
466
- # Generate the expanded prompt
467
- expanded = self.generate(
468
- prompt=user_prompt,
469
- system_prompt=system_prompt,
470
- max_tokens=100,
471
- temperature=0.7,
472
- )
473
-
474
- # Post-process the response
475
- expanded = self._clean_expansion(prompt, expanded)
476
- return expanded
477
-
478
- # For Transformers models
479
- else:
480
- # Standard approach for causal LMs like GPT-2 or Llama
481
- system_prompt = (
482
- "You are a prompt engineer that enhances image generation prompts."
483
- )
484
-
485
- user_prompt = f'Enhance this prompt for image generation: "{prompt}"\n\nOutput format: "{prompt}, [style], [quality]"\n\nKeep it under 40 words.'
486
-
487
- # Generate the expanded prompt
488
- expanded = self.generate(
489
- prompt=user_prompt,
490
- system_prompt=system_prompt,
491
- max_tokens=100,
492
- temperature=0.7,
493
- )
494
-
495
- # Post-process the response
496
- expanded = self._clean_expansion(prompt, expanded)
497
- return expanded
498
-
499
- def _clean_expansion(self, original_prompt: str, expanded_text: str) -> str:
500
- """
501
- Clean up the expanded prompt text to ensure proper formatting.
502
-
503
- Args:
504
- original_prompt: The original prompt for reference
505
- expanded_text: The raw expanded text from the model
506
-
507
- Returns:
508
- Cleaned and properly formatted prompt
509
- """
510
- import re
511
-
512
- # First, handle the common case where TinyLlama outputs multiple variations
513
- # Split by instances of the original prompt
514
- if expanded_text.lower().count(original_prompt.lower()) > 1:
515
- # Multiple variations detected - just use the first one
516
- parts = expanded_text.lower().split(original_prompt.lower(), 1)
517
- if len(parts) > 1:
518
- # Take just the first expansion
519
- expanded_text = original_prompt + parts[1]
520
- # Find the next occurrence of the prompt and cut everything after it
521
- next_prompt_pos = expanded_text.lower().find(
522
- original_prompt.lower(), len(original_prompt)
523
- )
524
- if next_prompt_pos > 0:
525
- expanded_text = expanded_text[:next_prompt_pos].strip()
526
-
527
- # First pass: remove obvious instruction text
528
- patterns_to_remove = [
529
- r"(?i)^\s*(?:output|enhanced prompt|result):\s*", # Remove prefixes like "Output:" or "Enhanced prompt:"
530
- r"(?i)\b(?:original prompt|start with|add|use|format|rule|follow|example)\b.*$", # Remove instructions
531
- r"^\s*\d+\.?\s*", # Remove numbered list markers
532
- r'^["\'](.*)["\']$', # Remove quotes surrounding the entire text
533
- ]
534
-
535
- for pattern in patterns_to_remove:
536
- expanded_text = re.sub(pattern, "", expanded_text, flags=re.MULTILINE)
537
-
538
- # Normalize whitespace
539
- expanded_text = " ".join(expanded_text.split())
540
-
541
- # If the expansion doesn't start with the original prompt, add it
542
- if not expanded_text.lower().startswith(original_prompt.lower()):
543
- if "," in expanded_text and not expanded_text.startswith(","):
544
- # Try to find where the original prompt might appear
545
- parts = expanded_text.split(",", 1)
546
- if original_prompt.lower() in parts[0].lower():
547
- # The first part contains the original prompt but with modifications
548
- expanded_text = f"{original_prompt}, {parts[1].strip()}"
549
- else:
550
- expanded_text = f"{original_prompt}, {expanded_text}"
551
- else:
552
- expanded_text = f"{original_prompt}, {expanded_text}"
553
-
554
- # Remove any duplicated commas
555
- expanded_text = re.sub(r",\s*,", ",", expanded_text)
556
-
557
- # Strict length control - limit expansion to approximately 40 words
558
- # Count words in the expansion
559
- words = expanded_text.split()
560
- if len(words) > 40:
561
- # Keep original prompt and just enough words to stay under 40
562
- prompt_words = len(original_prompt.split())
563
- # We need to keep the original prompt and stay under 40 total words
564
- allowed_extra_words = 40 - prompt_words
565
- # Join the original prompt with the allowed number of additional words
566
- expanded_text = " ".join(words[: prompt_words + allowed_extra_words])
567
-
568
- # Check if the expansion still contains instruction-like text or is too repetitive
569
- instruction_indicators = [
570
- "original prompt",
571
- "add only",
572
- "rule",
573
- "format as",
574
- "example",
575
- "enhancement:",
576
- ]
577
- if any(
578
- indicator in expanded_text.lower() for indicator in instruction_indicators
579
- ):
580
- # Emergency fallback - use hardcoded expansion phrases
581
- expansions = [
582
- "cinematic lighting",
583
- "professional photography",
584
- "8k resolution",
585
- "dramatic angle",
586
- "photorealistic",
587
- "highly detailed",
588
- "vivid colors",
589
- "stunning detail",
590
- "artistically composed",
591
- "sharp focus",
592
- ]
593
- import random
594
-
595
- # Select 2-3 random expansions
596
- selected = random.sample(expansions, k=min(3, len(expansions)))
597
- expanded_text = f"{original_prompt}, {', '.join(selected)}"
598
 
599
- logger.info(f"Expanded prompt: {expanded_text[:100]}...")
600
- return expanded_text
601
 
602
 
603
- def get_llm_instance(model_path: Optional[str] = None) -> Optional[LocalLLM]:
604
  """
605
- Factory function to get a LocalLLM instance with default settings.
606
- Returns None if model loading fails, allowing graceful fallback.
607
-
608
- Args:
609
- model_path: Optional path to model or HuggingFace model ID
610
 
611
  Returns:
612
- A LocalLLM instance or None if model loading fails
613
  """
614
- use_local_model = os.environ.get("USE_LOCAL_MODEL", "true").lower() != "false"
615
- if not use_local_model:
616
- logger.info("Local model usage is disabled by environment setting")
617
- return None
618
-
619
- # Default to environment settings with fallbacks
620
- if not model_path:
621
- model_path = os.environ.get("MODEL_PATH") or os.environ.get(
622
- "MODEL_ID", "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
623
- )
624
-
625
- # Get model file for GGUF models
626
- model_file = os.environ.get("MODEL_FILENAME")
627
-
628
- # Check model architecture - T5 models use seq2seq, others use causal LM
629
- model_architecture = os.environ.get("MODEL_ARCHITECTURE", "causal").lower()
630
-
631
- # Check model type - prefer GGUF for speed in resource-constrained environments
632
- model_type = os.environ.get("MODEL_TYPE", "transformers").lower()
633
-
634
- # Check if quantization is enabled
635
- use_quantization = os.environ.get("MODEL_QUANTIZED", "false").lower() == "true"
636
-
637
- try:
638
- # Check if the provided path is a local directory
639
- if os.path.isdir(model_path):
640
- logger.info(f"Using local model directory: {model_path}")
641
- else:
642
- logger.info(f"Using model ID from Hugging Face: {model_path}")
643
-
644
- # Check available device backends
645
- device_map = "auto"
646
- torch_dtype = None
647
-
648
- # For Hugging Face Spaces, be more careful about memory usage
649
- spaces_mode = os.environ.get("HF_SPACES", "0") == "1"
650
- if spaces_mode and model_type != "gguf":
651
- logger.info("Running in Hugging Face Spaces, using CPU for stability")
652
- # Force CPU for Spaces with transformers models
653
- device_map = "cpu" if not use_quantization else "auto"
654
-
655
- # Create the LLM instance with appropriate settings
656
- return LocalLLM(
657
- model_path=model_path,
658
- model_file=model_file,
659
- model_type=model_type,
660
- model_architecture=model_architecture,
661
- device_map=device_map,
662
- torch_dtype=torch_dtype,
663
- use_quantization=use_quantization,
664
- )
665
- except Exception as e:
666
- logger.error(f"Failed to create LLM instance: {e}")
667
- return None
 
1
  import os
2
  import logging
 
 
 
3
  from pathlib import Path
4
+ import torch
5
+ from typing import Optional, Dict, Any, Union, List
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ # Constants for TinyLlama model
11
+ MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
 
 
 
 
 
 
12
 
13
+ # Default system prompt for creative expansion
14
+ DEFAULT_CREATIVE_SYSTEM_PROMPT = """You are an expert assistant that helps expand creative prompts for image generation.
15
+ Your goal is to enrich the original prompt with vivid visual details, artistic style suggestions, and composition elements.
16
+ Focus on visual enhancement only. Keep your response concise (under 100 words) and focus entirely on the expanded prompt.
17
+ Do not include explanations or comments - only return the enhanced prompt text."""
 
18
 
19
+ # Default user template for creative expansion
20
+ DEFAULT_CREATIVE_USER_TEMPLATE = """Original prompt: {original_prompt}
 
21
 
22
+ Please expand this into a detailed, vivid prompt for image generation with rich visual elements, mood, and style."""
 
 
 
23
 
24
 
25
  class LocalLLM:
26
  """
27
+ A local LLM implementation using TinyLlama-1.1B-Chat.
28
+ Provides methods to generate text and expand creative prompts.
29
  """
30
 
31
+ def __init__(self, model_id: str = MODEL_ID):
 
 
 
 
 
 
 
 
 
32
  """
33
  Initialize the local LLM.
34
 
35
  Args:
36
+ model_id: The model ID to load
 
 
 
 
 
 
37
  """
38
+ self.model_id = model_id
39
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
40
+
41
+ # Configure quantization for efficient memory usage
42
+ quantization_config = None
43
+ if self.device == "cuda":
44
+ quantization_config = BitsAndBytesConfig(
45
+ load_in_4bit=True,
46
+ bnb_4bit_compute_dtype=torch.float16,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
 
49
+ # Load model and tokenizer
50
+ logger.info(f"Loading TinyLlama model on {self.device}...")
51
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
52
+ self.model = AutoModelForCausalLM.from_pretrained(
53
+ self.model_id,
54
+ quantization_config=quantization_config,
55
+ device_map="auto" if self.device == "cuda" else None,
56
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
57
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ logger.info(f"TinyLlama model loaded successfully")
 
 
60
 
61
  def generate(
62
  self,
 
67
  top_p: float = 0.9,
68
  ) -> str:
69
  """
70
+ Generate text based on a prompt.
71
 
72
  Args:
73
  prompt: The user prompt to generate from
 
79
  Returns:
80
  The generated text
81
  """
82
+ messages = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Add system message if provided
85
+ if system_prompt:
86
+ messages.append({"role": "system", "content": system_prompt})
 
 
 
 
 
87
 
88
+ # Add user message
89
+ messages.append({"role": "user", "content": prompt})
 
 
90
 
91
+ # Format messages for the model
92
+ prompt_text = self.tokenizer.apply_chat_template(
93
+ messages, tokenize=False, add_generation_prompt=True
94
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Generate response
97
+ inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
98
+ outputs = self.model.generate(
99
+ **inputs,
100
+ max_new_tokens=max_tokens,
101
+ do_sample=True,
102
+ temperature=temperature,
103
+ top_p=top_p,
104
+ )
105
 
106
+ # Decode the response and extract the assistant's message
107
+ full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
108
 
109
+ # Extract just the assistant's response
110
+ assistant_response = full_output[len(prompt_text) :].strip()
111
+ return assistant_response
 
112
 
113
+ def expand_creative_prompt(self, original_prompt: str) -> str:
114
  """
115
+ Expand a creative prompt with rich details for better image generation.
 
116
 
117
  Args:
118
+ original_prompt: The user's original prompt
119
 
120
  Returns:
121
  An expanded, detailed creative prompt
122
  """
123
+ # Format the prompt for creative expansion
124
+ prompt = DEFAULT_CREATIVE_USER_TEMPLATE.format(original_prompt=original_prompt)
125
+
126
+ # Generate expanded prompt
127
+ expanded = self.generate(
128
+ prompt=prompt,
129
+ system_prompt=DEFAULT_CREATIVE_SYSTEM_PROMPT,
130
+ max_tokens=150, # Limit to ensure concise responses
131
+ temperature=0.7, # Balanced creativity
132
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ return expanded
 
135
 
136
 
137
+ def get_llm_instance() -> LocalLLM:
138
  """
139
+ Get an instance of the local LLM.
 
 
 
 
140
 
141
  Returns:
142
+ An initialized LocalLLM instance
143
  """
144
+ return LocalLLM(model_id=MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/llm/service.py CHANGED
@@ -1,139 +1,78 @@
1
- import os
2
  import logging
3
- import time
4
  import sys
5
- from fastapi import FastAPI, HTTPException, Request
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
- from typing import Optional
9
- import psutil
10
- import uvicorn
11
- from dotenv import load_dotenv
12
  from pathlib import Path
13
- from model import LocalLLM, get_llm_instance
14
 
15
- # Configure logging first
16
  logging.basicConfig(
17
  level=logging.INFO,
18
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
19
  handlers=[
20
  logging.StreamHandler(sys.stdout),
21
- logging.FileHandler(os.path.join(os.path.dirname(__file__), "llm_service.log")),
22
  ],
23
  )
24
- logger = logging.getLogger("llm_service")
25
 
26
- # Try to load .env file from project root
27
- env_path = Path(__file__).parents[2] / ".env"
28
- if env_path.exists():
29
- load_dotenv(dotenv_path=env_path)
30
- logger.info(f"Loaded environment variables from {env_path}")
31
 
 
 
32
 
33
- # Initialize FastAPI app
34
- app = FastAPI(title="Local LLM Service", description="API for local LLM interaction")
 
 
 
 
35
 
36
- # Add CORS middleware
37
  app.add_middleware(
38
  CORSMiddleware,
39
- allow_origins=["*"],
40
  allow_credentials=True,
41
  allow_methods=["*"],
42
  allow_headers=["*"],
43
  )
44
 
45
 
46
- # Request timing middleware
47
- @app.middleware("http")
48
- async def log_requests(request: Request, call_next):
49
- start_time = time.time()
50
- logger.info(f"Request started: {request.method} {request.url.path}")
51
-
52
- response = await call_next(request)
53
-
54
- process_time = time.time() - start_time
55
- logger.info(
56
- f"Request completed: {request.method} {request.url.path} - Status: {response.status_code} - Duration: {process_time:.4f}s"
57
- )
58
-
59
- return response
60
-
61
-
62
- # Model request and response classes
63
- class PromptRequest(BaseModel):
64
  prompt: str
65
  system_prompt: Optional[str] = None
66
- max_tokens: int = 512
67
- temperature: float = 0.7
68
- top_p: float = 0.9
69
 
70
 
71
- class ExpandRequest(BaseModel):
72
  prompt: str
73
 
74
 
75
- class LLMResponse(BaseModel):
76
- text: str
77
-
78
 
79
- # Global LLM instance
80
- llm = None
81
 
 
 
 
 
82
 
83
- @app.on_event("startup")
84
- async def startup_event():
85
- """Initialize the LLM on startup"""
86
- global llm
87
- logger.info("Starting LLM service initialization...")
88
-
89
- # First check for MODEL_PATH (local model), then fall back to MODEL_ID
90
- model_path = os.environ.get("MODEL_PATH")
91
- if model_path and os.path.isdir(model_path):
92
- logger.info(f"Using local model from MODEL_PATH: {model_path}")
93
- else:
94
- # Fall back to MODEL_ID if MODEL_PATH isn't set or doesn't exist
95
- model_path = os.environ.get("MODEL_ID", "meta-llama/Llama-3.2-3B-Instruct")
96
- logger.info(f"Using model ID from Hugging Face: {model_path}")
97
-
98
- try:
99
- start_time = time.time()
100
- llm = get_llm_instance(model_path)
101
- init_time = time.time() - start_time
102
-
103
- logger.info(
104
- f"LLM initialized successfully with model: {model_path} in {init_time:.2f} seconds"
105
- )
106
-
107
- memory = psutil.virtual_memory()
108
- logger.info(
109
- f"System memory: {memory.percent}% used ({memory.used / (1024**3):.1f}GB / {memory.total / (1024**3):.1f}GB)"
110
- )
111
-
112
- except Exception as e:
113
- logger.error(f"Failed to initialize LLM: {str(e)}", exc_info=True)
114
- raise
115
-
116
-
117
- @app.post("/generate", response_model=LLMResponse)
118
- async def generate_text(request: PromptRequest):
119
- """Generate text based on a prompt"""
120
- logger.info(
121
- f"Received text generation request, prompt length: {len(request.prompt)} chars"
122
- )
123
- logger.debug(f"Prompt: {request.prompt[:50]}...")
124
-
125
- if not llm:
126
- logger.error("LLM service not initialized when generate endpoint was called")
127
- raise HTTPException(status_code=503, detail="LLM service not initialized")
128
 
 
 
 
129
  try:
130
  start_time = time.time()
 
131
 
132
- logger.info(
133
- f"Generation parameters: max_tokens={request.max_tokens}, temperature={request.temperature}, top_p={request.top_p}"
134
- )
135
-
136
- response = llm.generate(
137
  prompt=request.prompt,
138
  system_prompt=request.system_prompt,
139
  max_tokens=request.max_tokens,
@@ -141,73 +80,45 @@ async def generate_text(request: PromptRequest):
141
  top_p=request.top_p,
142
  )
143
 
144
- generation_time = time.time() - start_time
145
- response_length = len(response)
146
 
147
- logger.info(
148
- f"Text generation completed in {generation_time:.2f}s, response length: {response_length} chars"
149
- )
150
- logger.debug(f"Generated response: {response[:50]}...")
151
-
152
- return LLMResponse(text=response)
153
  except Exception as e:
154
- logger.error(f"Error generating text: {str(e)}", exc_info=True)
155
  raise HTTPException(status_code=500, detail=str(e))
156
 
157
 
158
- @app.post("/expand", response_model=LLMResponse)
159
- async def expand_prompt(request: ExpandRequest):
160
- """Expand a creative prompt with rich details"""
161
- logger.info(f"Received prompt expansion request, prompt: '{request.prompt}'")
162
-
163
- if not llm:
164
- logger.error("LLM service not initialized when expand endpoint was called")
165
- raise HTTPException(status_code=503, detail="LLM service not initialized")
166
-
167
  try:
168
  start_time = time.time()
 
169
 
170
- expanded = llm.expand_creative_prompt(request.prompt)
171
 
172
- expansion_time = time.time() - start_time
173
- expanded_length = len(expanded)
174
 
175
- logger.info(
176
- f"Prompt expansion completed in {expansion_time:.2f}s, original length: {len(request.prompt)}, expanded length: {expanded_length}"
177
- )
178
- logger.debug(f"Original: '{request.prompt}'")
179
- logger.debug(f"Expanded: '{expanded}'")
180
-
181
- return LLMResponse(text=expanded)
182
  except Exception as e:
183
- logger.error(f"Error expanding prompt: {str(e)}", exc_info=True)
184
  raise HTTPException(status_code=500, detail=str(e))
185
 
186
 
187
- @app.get("/health")
188
- async def health_check():
189
- """Health check endpoint"""
190
- logger.debug("Health check endpoint called")
191
-
192
- if llm:
193
- logger.info(f"Health check: LLM service is healthy, model: {llm.model_path}")
194
- return {"status": "healthy", "model": llm.model_path}
195
-
196
- logger.warning("Health check: LLM service is still initializing")
197
- return {"status": "initializing"}
198
-
199
-
200
- # Start the service if run directly
201
  if __name__ == "__main__":
 
202
 
203
- # Check for psutil dependency
204
- try:
205
- import psutil
206
- except ImportError:
207
- logger.warning(
208
- "psutil not installed. Some system resource metrics will not be available."
209
- )
210
- logger.warning("Install with: pip install psutil")
211
-
212
- logger.info("Starting LLM service server")
213
- uvicorn.run(app, host="0.0.0.0", port=8001)
 
 
1
  import logging
2
+ import os
3
  import sys
4
+ from fastapi import FastAPI, HTTPException
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
+ from typing import Dict, Any, List, Optional
 
 
 
8
  from pathlib import Path
9
+ import time
10
 
11
+ # Configure logging
12
  logging.basicConfig(
13
  level=logging.INFO,
14
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
15
  handlers=[
16
  logging.StreamHandler(sys.stdout),
17
+ logging.FileHandler("llm_service.log", mode="a"),
18
  ],
19
  )
20
+ logger = logging.getLogger(__name__)
21
 
22
+ # Import the model
23
+ from .model import get_llm_instance
 
 
 
24
 
25
+ # Initialize model
26
+ llm = get_llm_instance()
27
 
28
+ # Create FastAPI app
29
+ app = FastAPI(
30
+ title="LLM Service API",
31
+ description="API for interacting with the local LLM",
32
+ version="1.0.0",
33
+ )
34
 
35
+ # Configure CORS
36
  app.add_middleware(
37
  CORSMiddleware,
38
+ allow_origins=["*"], # In production, specify actual origins
39
  allow_credentials=True,
40
  allow_methods=["*"],
41
  allow_headers=["*"],
42
  )
43
 
44
 
45
+ class GenerateRequest(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  prompt: str
47
  system_prompt: Optional[str] = None
48
+ max_tokens: Optional[int] = 512
49
+ temperature: Optional[float] = 0.7
50
+ top_p: Optional[float] = 0.9
51
 
52
 
53
+ class ExpandPromptRequest(BaseModel):
54
  prompt: str
55
 
56
 
57
+ @app.get("/")
58
+ def read_root():
59
+ return {"status": "ok", "message": "LLM Service is running"}
60
 
 
 
61
 
62
+ @app.get("/health")
63
+ def health_check():
64
+ """Health check endpoint"""
65
+ return {"status": "healthy", "model": "TinyLlama-1.1B-Chat-v1.0"}
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ @app.post("/generate")
69
+ def generate_text(request: GenerateRequest):
70
+ """Generate text from a prompt"""
71
  try:
72
  start_time = time.time()
73
+ logger.info(f"Generating text for prompt: {request.prompt[:50]}...")
74
 
75
+ result = llm.generate(
 
 
 
 
76
  prompt=request.prompt,
77
  system_prompt=request.system_prompt,
78
  max_tokens=request.max_tokens,
 
80
  top_p=request.top_p,
81
  )
82
 
83
+ elapsed = time.time() - start_time
84
+ logger.info(f"Generation completed in {elapsed:.2f} seconds")
85
 
86
+ return {
87
+ "result": result,
88
+ "elapsed_seconds": elapsed,
89
+ }
 
 
90
  except Exception as e:
91
+ logger.error(f"Error during text generation: {str(e)}")
92
  raise HTTPException(status_code=500, detail=str(e))
93
 
94
 
95
+ @app.post("/expand-prompt")
96
+ def expand_prompt(request: ExpandPromptRequest):
97
+ """Expand a creative prompt with more detail"""
 
 
 
 
 
 
98
  try:
99
  start_time = time.time()
100
+ logger.info(f"Expanding prompt: {request.prompt[:50]}...")
101
 
102
+ expanded_prompt = llm.expand_creative_prompt(request.prompt)
103
 
104
+ elapsed = time.time() - start_time
105
+ logger.info(f"Prompt expansion completed in {elapsed:.2f} seconds")
106
 
107
+ return {
108
+ "original_prompt": request.prompt,
109
+ "expanded_prompt": expanded_prompt,
110
+ "elapsed_seconds": elapsed,
111
+ }
 
 
112
  except Exception as e:
113
+ logger.error(f"Error during prompt expansion: {str(e)}")
114
  raise HTTPException(status_code=500, detail=str(e))
115
 
116
 
117
+ # Run the service when executed directly
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if __name__ == "__main__":
119
+ import uvicorn
120
 
121
+ port = int(os.environ.get("PORT", 8000))
122
+ host = os.environ.get("HOST", "0.0.0.0")
123
+ logger.info(f"Starting LLM service on {host}:{port}")
124
+ uvicorn.run(app, host=host, port=port)