Spaces:
Running
Running
Add support for google/flan-t5-small with proper sequence-to-sequence model handling
Browse files- app.py +7 -2
- app/llm/model.py +98 -36
app.py
CHANGED
@@ -28,10 +28,15 @@ os.environ["HF_SPACES"] = "1" # Flag to indicate we're running in Spaces
|
|
28 |
|
29 |
# Set model environment variables explicitly for Hugging Face Spaces
|
30 |
# These will override any variables loaded from .env.spaces
|
31 |
-
os.environ["MODEL_ID"] = "
|
32 |
os.environ["USE_LOCAL_MODEL"] = "true"
|
33 |
os.environ["MODEL_TYPE"] = "transformers"
|
34 |
-
os.environ["MODEL_QUANTIZED"] =
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
# Import UI module directly
|
37 |
try:
|
|
|
28 |
|
29 |
# Set model environment variables explicitly for Hugging Face Spaces
|
30 |
# These will override any variables loaded from .env.spaces
|
31 |
+
os.environ["MODEL_ID"] = "google/flan-t5-small" # Use flan-t5-small model
|
32 |
os.environ["USE_LOCAL_MODEL"] = "true"
|
33 |
os.environ["MODEL_TYPE"] = "transformers"
|
34 |
+
os.environ["MODEL_QUANTIZED"] = (
|
35 |
+
"false" # Disable quantization to avoid bitsandbytes dependency
|
36 |
+
)
|
37 |
+
os.environ["MODEL_ARCHITECTURE"] = (
|
38 |
+
"seq2seq" # T5 models are sequence-to-sequence, not causal LM
|
39 |
+
)
|
40 |
|
41 |
# Import UI module directly
|
42 |
try:
|
app/llm/model.py
CHANGED
@@ -10,7 +10,13 @@ logger = logging.getLogger(__name__)
|
|
10 |
|
11 |
# Try to import transformers and ctransformers
|
12 |
try:
|
13 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
HAS_TRANSFORMERS = True
|
16 |
except ImportError:
|
@@ -40,6 +46,7 @@ class LocalLLM:
|
|
40 |
model_path: str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
41 |
model_file: str = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
|
42 |
model_type: str = "gguf",
|
|
|
43 |
device_map: str = "auto",
|
44 |
torch_dtype=None,
|
45 |
use_quantization: bool = False,
|
@@ -51,6 +58,7 @@ class LocalLLM:
|
|
51 |
model_path: Path to model or HuggingFace model ID
|
52 |
model_file: Specific model file to load (for GGUF models)
|
53 |
model_type: Type of model ('transformers' or 'gguf')
|
|
|
54 |
device_map: Device mapping strategy (default: "auto")
|
55 |
torch_dtype: Torch data type (default: float16)
|
56 |
use_quantization: Whether to use 8-bit quantization to reduce memory usage
|
@@ -58,6 +66,7 @@ class LocalLLM:
|
|
58 |
self.model_path = model_path
|
59 |
self.model_file = model_file
|
60 |
self.model_type = model_type.lower()
|
|
|
61 |
self.device_map = device_map
|
62 |
self.use_quantization = use_quantization
|
63 |
self.pipe = None
|
@@ -71,7 +80,9 @@ class LocalLLM:
|
|
71 |
self.torch_dtype = torch_dtype
|
72 |
|
73 |
logger.info(f"Loading LLM from {model_path}")
|
74 |
-
logger.info(
|
|
|
|
|
75 |
|
76 |
# Various loading strategies based on model type
|
77 |
if self.model_type == "gguf":
|
@@ -184,50 +195,64 @@ class LocalLLM:
|
|
184 |
load_kwargs.update(
|
185 |
{
|
186 |
"low_cpu_mem_usage": True,
|
187 |
-
"offload_folder": "/tmp/offload",
|
188 |
-
"offload_state_dict": True,
|
189 |
}
|
190 |
)
|
191 |
|
192 |
-
#
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
model =
|
199 |
self.model_path, **load_kwargs
|
200 |
)
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
202 |
else:
|
203 |
-
# Standard
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
config.rope_scaling["type"] = "linear"
|
211 |
-
logger.info("Fixed rope_scaling configuration with type=linear")
|
212 |
-
elif (
|
213 |
-
not hasattr(config, "rope_scaling")
|
214 |
-
and "llama" in self.model_path.lower()
|
215 |
):
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
-
#
|
223 |
-
|
224 |
-
|
225 |
)
|
226 |
|
227 |
-
#
|
228 |
-
self.pipe = pipeline(
|
229 |
-
"text-generation", model=model, tokenizer=tokenizer, framework="pt"
|
230 |
-
)
|
231 |
self.model = model
|
232 |
self.tokenizer = tokenizer
|
233 |
|
@@ -320,6 +345,39 @@ class LocalLLM:
|
|
320 |
) -> str:
|
321 |
"""Generate text using transformers pipeline"""
|
322 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
# Check if the model can handle chat templates
|
324 |
has_chat_template = (
|
325 |
hasattr(self.tokenizer, "chat_template")
|
@@ -443,6 +501,9 @@ def get_llm_instance(model_path: Optional[str] = None) -> Optional[LocalLLM]:
|
|
443 |
# Get model file for GGUF models
|
444 |
model_file = os.environ.get("MODEL_FILENAME")
|
445 |
|
|
|
|
|
|
|
446 |
# Check model type - prefer GGUF for speed in resource-constrained environments
|
447 |
model_type = os.environ.get("MODEL_TYPE", "transformers").lower()
|
448 |
|
@@ -472,6 +533,7 @@ def get_llm_instance(model_path: Optional[str] = None) -> Optional[LocalLLM]:
|
|
472 |
model_path=model_path,
|
473 |
model_file=model_file,
|
474 |
model_type=model_type,
|
|
|
475 |
device_map=device_map,
|
476 |
torch_dtype=torch_dtype,
|
477 |
use_quantization=use_quantization,
|
|
|
10 |
|
11 |
# Try to import transformers and ctransformers
|
12 |
try:
|
13 |
+
from transformers import (
|
14 |
+
AutoTokenizer,
|
15 |
+
AutoModelForCausalLM,
|
16 |
+
AutoModelForSeq2SeqLM,
|
17 |
+
pipeline,
|
18 |
+
AutoConfig,
|
19 |
+
)
|
20 |
|
21 |
HAS_TRANSFORMERS = True
|
22 |
except ImportError:
|
|
|
46 |
model_path: str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
47 |
model_file: str = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
|
48 |
model_type: str = "gguf",
|
49 |
+
model_architecture: str = "causal",
|
50 |
device_map: str = "auto",
|
51 |
torch_dtype=None,
|
52 |
use_quantization: bool = False,
|
|
|
58 |
model_path: Path to model or HuggingFace model ID
|
59 |
model_file: Specific model file to load (for GGUF models)
|
60 |
model_type: Type of model ('transformers' or 'gguf')
|
61 |
+
model_architecture: Architecture type ('causal' or 'seq2seq')
|
62 |
device_map: Device mapping strategy (default: "auto")
|
63 |
torch_dtype: Torch data type (default: float16)
|
64 |
use_quantization: Whether to use 8-bit quantization to reduce memory usage
|
|
|
66 |
self.model_path = model_path
|
67 |
self.model_file = model_file
|
68 |
self.model_type = model_type.lower()
|
69 |
+
self.model_architecture = model_architecture.lower()
|
70 |
self.device_map = device_map
|
71 |
self.use_quantization = use_quantization
|
72 |
self.pipe = None
|
|
|
80 |
self.torch_dtype = torch_dtype
|
81 |
|
82 |
logger.info(f"Loading LLM from {model_path}")
|
83 |
+
logger.info(
|
84 |
+
f"Model type: {model_type}, architecture: {model_architecture}, model file: {model_file}"
|
85 |
+
)
|
86 |
|
87 |
# Various loading strategies based on model type
|
88 |
if self.model_type == "gguf":
|
|
|
195 |
load_kwargs.update(
|
196 |
{
|
197 |
"low_cpu_mem_usage": True,
|
|
|
|
|
198 |
}
|
199 |
)
|
200 |
|
201 |
+
# Load the tokenizer first - common to both architectures
|
202 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
203 |
+
|
204 |
+
# Load the model based on architecture
|
205 |
+
if self.model_architecture == "seq2seq":
|
206 |
+
logger.info("Loading sequence-to-sequence model architecture")
|
207 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
208 |
self.model_path, **load_kwargs
|
209 |
)
|
210 |
+
self.pipe = pipeline(
|
211 |
+
"text2text-generation",
|
212 |
+
model=model,
|
213 |
+
tokenizer=tokenizer,
|
214 |
+
framework="pt",
|
215 |
+
)
|
216 |
else:
|
217 |
+
# Standard causal language model
|
218 |
+
logger.info("Loading causal language model architecture")
|
219 |
+
# Skip the custom config handling for Spaces mode or small models
|
220 |
+
if (
|
221 |
+
spaces_mode
|
222 |
+
or "phi" in self.model_path.lower()
|
223 |
+
or "tiny" in self.model_path.lower()
|
|
|
|
|
|
|
|
|
|
|
224 |
):
|
225 |
+
model = AutoModelForCausalLM.from_pretrained(
|
226 |
+
self.model_path, **load_kwargs
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
# Standard local loading with our custom config handling
|
230 |
+
config = AutoConfig.from_pretrained(self.model_path)
|
231 |
+
|
232 |
+
# Fix the rope_scaling issue for Llama models
|
233 |
+
if hasattr(config, "rope_scaling") and isinstance(
|
234 |
+
config.rope_scaling, dict
|
235 |
+
):
|
236 |
+
config.rope_scaling["type"] = "linear"
|
237 |
+
logger.info("Fixed rope_scaling configuration with type=linear")
|
238 |
+
elif (
|
239 |
+
not hasattr(config, "rope_scaling")
|
240 |
+
and "llama" in self.model_path.lower()
|
241 |
+
):
|
242 |
+
config.rope_scaling = {"type": "linear", "factor": 1.0}
|
243 |
+
logger.info("Added default rope_scaling configuration")
|
244 |
+
|
245 |
+
# Load the model with our fixed config
|
246 |
+
model = AutoModelForCausalLM.from_pretrained(
|
247 |
+
self.model_path, config=config, **load_kwargs
|
248 |
+
)
|
249 |
|
250 |
+
# Create text generation pipeline for causal LM
|
251 |
+
self.pipe = pipeline(
|
252 |
+
"text-generation", model=model, tokenizer=tokenizer, framework="pt"
|
253 |
)
|
254 |
|
255 |
+
# Store the model and tokenizer reference
|
|
|
|
|
|
|
256 |
self.model = model
|
257 |
self.tokenizer = tokenizer
|
258 |
|
|
|
345 |
) -> str:
|
346 |
"""Generate text using transformers pipeline"""
|
347 |
try:
|
348 |
+
# Handle seq2seq models (like T5)
|
349 |
+
if self.model_architecture == "seq2seq":
|
350 |
+
logger.debug(f"Generating with seq2seq model: {self.model_path}")
|
351 |
+
|
352 |
+
# Format prompt for seq2seq models
|
353 |
+
formatted_prompt = prompt
|
354 |
+
if system_prompt:
|
355 |
+
formatted_prompt = f"{system_prompt}\n\nQuery: {prompt}"
|
356 |
+
|
357 |
+
# T5 models work best with specific task prefixes
|
358 |
+
if (
|
359 |
+
"flan" in self.model_path.lower()
|
360 |
+
and not formatted_prompt.startswith("enhance:")
|
361 |
+
):
|
362 |
+
formatted_prompt = f"enhance: {formatted_prompt}"
|
363 |
+
|
364 |
+
# Generate with seq2seq model
|
365 |
+
outputs = self.pipe(
|
366 |
+
formatted_prompt,
|
367 |
+
max_length=max_tokens,
|
368 |
+
temperature=temperature,
|
369 |
+
top_p=top_p,
|
370 |
+
do_sample=True,
|
371 |
+
)
|
372 |
+
|
373 |
+
# Extract the generated text
|
374 |
+
if isinstance(outputs, list) and len(outputs) > 0:
|
375 |
+
if "generated_text" in outputs[0]:
|
376 |
+
return outputs[0]["generated_text"].strip()
|
377 |
+
|
378 |
+
# Fallback extraction
|
379 |
+
return str(outputs).strip()
|
380 |
+
|
381 |
# Check if the model can handle chat templates
|
382 |
has_chat_template = (
|
383 |
hasattr(self.tokenizer, "chat_template")
|
|
|
501 |
# Get model file for GGUF models
|
502 |
model_file = os.environ.get("MODEL_FILENAME")
|
503 |
|
504 |
+
# Check model architecture - T5 models use seq2seq, others use causal LM
|
505 |
+
model_architecture = os.environ.get("MODEL_ARCHITECTURE", "causal").lower()
|
506 |
+
|
507 |
# Check model type - prefer GGUF for speed in resource-constrained environments
|
508 |
model_type = os.environ.get("MODEL_TYPE", "transformers").lower()
|
509 |
|
|
|
533 |
model_path=model_path,
|
534 |
model_file=model_file,
|
535 |
model_type=model_type,
|
536 |
+
model_architecture=model_architecture,
|
537 |
device_map=device_map,
|
538 |
torch_dtype=torch_dtype,
|
539 |
use_quantization=use_quantization,
|