Zoro-chi commited on
Commit
34d21a4
·
1 Parent(s): bbce178

Add support for google/flan-t5-small with proper sequence-to-sequence model handling

Browse files
Files changed (2) hide show
  1. app.py +7 -2
  2. app/llm/model.py +98 -36
app.py CHANGED
@@ -28,10 +28,15 @@ os.environ["HF_SPACES"] = "1" # Flag to indicate we're running in Spaces
28
 
29
  # Set model environment variables explicitly for Hugging Face Spaces
30
  # These will override any variables loaded from .env.spaces
31
- os.environ["MODEL_ID"] = "microsoft/phi-2" # Use phi-1.5 model
32
  os.environ["USE_LOCAL_MODEL"] = "true"
33
  os.environ["MODEL_TYPE"] = "transformers"
34
- os.environ["MODEL_QUANTIZED"] = "false"
 
 
 
 
 
35
 
36
  # Import UI module directly
37
  try:
 
28
 
29
  # Set model environment variables explicitly for Hugging Face Spaces
30
  # These will override any variables loaded from .env.spaces
31
+ os.environ["MODEL_ID"] = "google/flan-t5-small" # Use flan-t5-small model
32
  os.environ["USE_LOCAL_MODEL"] = "true"
33
  os.environ["MODEL_TYPE"] = "transformers"
34
+ os.environ["MODEL_QUANTIZED"] = (
35
+ "false" # Disable quantization to avoid bitsandbytes dependency
36
+ )
37
+ os.environ["MODEL_ARCHITECTURE"] = (
38
+ "seq2seq" # T5 models are sequence-to-sequence, not causal LM
39
+ )
40
 
41
  # Import UI module directly
42
  try:
app/llm/model.py CHANGED
@@ -10,7 +10,13 @@ logger = logging.getLogger(__name__)
10
 
11
  # Try to import transformers and ctransformers
12
  try:
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoConfig
 
 
 
 
 
 
14
 
15
  HAS_TRANSFORMERS = True
16
  except ImportError:
@@ -40,6 +46,7 @@ class LocalLLM:
40
  model_path: str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
41
  model_file: str = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
42
  model_type: str = "gguf",
 
43
  device_map: str = "auto",
44
  torch_dtype=None,
45
  use_quantization: bool = False,
@@ -51,6 +58,7 @@ class LocalLLM:
51
  model_path: Path to model or HuggingFace model ID
52
  model_file: Specific model file to load (for GGUF models)
53
  model_type: Type of model ('transformers' or 'gguf')
 
54
  device_map: Device mapping strategy (default: "auto")
55
  torch_dtype: Torch data type (default: float16)
56
  use_quantization: Whether to use 8-bit quantization to reduce memory usage
@@ -58,6 +66,7 @@ class LocalLLM:
58
  self.model_path = model_path
59
  self.model_file = model_file
60
  self.model_type = model_type.lower()
 
61
  self.device_map = device_map
62
  self.use_quantization = use_quantization
63
  self.pipe = None
@@ -71,7 +80,9 @@ class LocalLLM:
71
  self.torch_dtype = torch_dtype
72
 
73
  logger.info(f"Loading LLM from {model_path}")
74
- logger.info(f"Model type: {model_type}, model file: {model_file}")
 
 
75
 
76
  # Various loading strategies based on model type
77
  if self.model_type == "gguf":
@@ -184,50 +195,64 @@ class LocalLLM:
184
  load_kwargs.update(
185
  {
186
  "low_cpu_mem_usage": True,
187
- "offload_folder": "/tmp/offload",
188
- "offload_state_dict": True,
189
  }
190
  )
191
 
192
- # Skip the custom config handling for Spaces mode or small models
193
- if (
194
- spaces_mode
195
- or "phi" in self.model_path.lower()
196
- or "tiny" in self.model_path.lower()
197
- ):
198
- model = AutoModelForCausalLM.from_pretrained(
199
  self.model_path, **load_kwargs
200
  )
201
- tokenizer = AutoTokenizer.from_pretrained(self.model_path)
 
 
 
 
 
202
  else:
203
- # Standard local loading with our custom config handling
204
- config = AutoConfig.from_pretrained(self.model_path)
205
-
206
- # Fix the rope_scaling issue for Llama models
207
- if hasattr(config, "rope_scaling") and isinstance(
208
- config.rope_scaling, dict
209
- ):
210
- config.rope_scaling["type"] = "linear"
211
- logger.info("Fixed rope_scaling configuration with type=linear")
212
- elif (
213
- not hasattr(config, "rope_scaling")
214
- and "llama" in self.model_path.lower()
215
  ):
216
- config.rope_scaling = {"type": "linear", "factor": 1.0}
217
- logger.info("Added default rope_scaling configuration")
218
-
219
- # Load the tokenizer
220
- tokenizer = AutoTokenizer.from_pretrained(self.model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- # Load the model with our fixed config
223
- model = AutoModelForCausalLM.from_pretrained(
224
- self.model_path, config=config, **load_kwargs
225
  )
226
 
227
- # Create the pipeline with our pre-loaded model and tokenizer
228
- self.pipe = pipeline(
229
- "text-generation", model=model, tokenizer=tokenizer, framework="pt"
230
- )
231
  self.model = model
232
  self.tokenizer = tokenizer
233
 
@@ -320,6 +345,39 @@ class LocalLLM:
320
  ) -> str:
321
  """Generate text using transformers pipeline"""
322
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  # Check if the model can handle chat templates
324
  has_chat_template = (
325
  hasattr(self.tokenizer, "chat_template")
@@ -443,6 +501,9 @@ def get_llm_instance(model_path: Optional[str] = None) -> Optional[LocalLLM]:
443
  # Get model file for GGUF models
444
  model_file = os.environ.get("MODEL_FILENAME")
445
 
 
 
 
446
  # Check model type - prefer GGUF for speed in resource-constrained environments
447
  model_type = os.environ.get("MODEL_TYPE", "transformers").lower()
448
 
@@ -472,6 +533,7 @@ def get_llm_instance(model_path: Optional[str] = None) -> Optional[LocalLLM]:
472
  model_path=model_path,
473
  model_file=model_file,
474
  model_type=model_type,
 
475
  device_map=device_map,
476
  torch_dtype=torch_dtype,
477
  use_quantization=use_quantization,
 
10
 
11
  # Try to import transformers and ctransformers
12
  try:
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModelForCausalLM,
16
+ AutoModelForSeq2SeqLM,
17
+ pipeline,
18
+ AutoConfig,
19
+ )
20
 
21
  HAS_TRANSFORMERS = True
22
  except ImportError:
 
46
  model_path: str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
47
  model_file: str = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
48
  model_type: str = "gguf",
49
+ model_architecture: str = "causal",
50
  device_map: str = "auto",
51
  torch_dtype=None,
52
  use_quantization: bool = False,
 
58
  model_path: Path to model or HuggingFace model ID
59
  model_file: Specific model file to load (for GGUF models)
60
  model_type: Type of model ('transformers' or 'gguf')
61
+ model_architecture: Architecture type ('causal' or 'seq2seq')
62
  device_map: Device mapping strategy (default: "auto")
63
  torch_dtype: Torch data type (default: float16)
64
  use_quantization: Whether to use 8-bit quantization to reduce memory usage
 
66
  self.model_path = model_path
67
  self.model_file = model_file
68
  self.model_type = model_type.lower()
69
+ self.model_architecture = model_architecture.lower()
70
  self.device_map = device_map
71
  self.use_quantization = use_quantization
72
  self.pipe = None
 
80
  self.torch_dtype = torch_dtype
81
 
82
  logger.info(f"Loading LLM from {model_path}")
83
+ logger.info(
84
+ f"Model type: {model_type}, architecture: {model_architecture}, model file: {model_file}"
85
+ )
86
 
87
  # Various loading strategies based on model type
88
  if self.model_type == "gguf":
 
195
  load_kwargs.update(
196
  {
197
  "low_cpu_mem_usage": True,
 
 
198
  }
199
  )
200
 
201
+ # Load the tokenizer first - common to both architectures
202
+ tokenizer = AutoTokenizer.from_pretrained(self.model_path)
203
+
204
+ # Load the model based on architecture
205
+ if self.model_architecture == "seq2seq":
206
+ logger.info("Loading sequence-to-sequence model architecture")
207
+ model = AutoModelForSeq2SeqLM.from_pretrained(
208
  self.model_path, **load_kwargs
209
  )
210
+ self.pipe = pipeline(
211
+ "text2text-generation",
212
+ model=model,
213
+ tokenizer=tokenizer,
214
+ framework="pt",
215
+ )
216
  else:
217
+ # Standard causal language model
218
+ logger.info("Loading causal language model architecture")
219
+ # Skip the custom config handling for Spaces mode or small models
220
+ if (
221
+ spaces_mode
222
+ or "phi" in self.model_path.lower()
223
+ or "tiny" in self.model_path.lower()
 
 
 
 
 
224
  ):
225
+ model = AutoModelForCausalLM.from_pretrained(
226
+ self.model_path, **load_kwargs
227
+ )
228
+ else:
229
+ # Standard local loading with our custom config handling
230
+ config = AutoConfig.from_pretrained(self.model_path)
231
+
232
+ # Fix the rope_scaling issue for Llama models
233
+ if hasattr(config, "rope_scaling") and isinstance(
234
+ config.rope_scaling, dict
235
+ ):
236
+ config.rope_scaling["type"] = "linear"
237
+ logger.info("Fixed rope_scaling configuration with type=linear")
238
+ elif (
239
+ not hasattr(config, "rope_scaling")
240
+ and "llama" in self.model_path.lower()
241
+ ):
242
+ config.rope_scaling = {"type": "linear", "factor": 1.0}
243
+ logger.info("Added default rope_scaling configuration")
244
+
245
+ # Load the model with our fixed config
246
+ model = AutoModelForCausalLM.from_pretrained(
247
+ self.model_path, config=config, **load_kwargs
248
+ )
249
 
250
+ # Create text generation pipeline for causal LM
251
+ self.pipe = pipeline(
252
+ "text-generation", model=model, tokenizer=tokenizer, framework="pt"
253
  )
254
 
255
+ # Store the model and tokenizer reference
 
 
 
256
  self.model = model
257
  self.tokenizer = tokenizer
258
 
 
345
  ) -> str:
346
  """Generate text using transformers pipeline"""
347
  try:
348
+ # Handle seq2seq models (like T5)
349
+ if self.model_architecture == "seq2seq":
350
+ logger.debug(f"Generating with seq2seq model: {self.model_path}")
351
+
352
+ # Format prompt for seq2seq models
353
+ formatted_prompt = prompt
354
+ if system_prompt:
355
+ formatted_prompt = f"{system_prompt}\n\nQuery: {prompt}"
356
+
357
+ # T5 models work best with specific task prefixes
358
+ if (
359
+ "flan" in self.model_path.lower()
360
+ and not formatted_prompt.startswith("enhance:")
361
+ ):
362
+ formatted_prompt = f"enhance: {formatted_prompt}"
363
+
364
+ # Generate with seq2seq model
365
+ outputs = self.pipe(
366
+ formatted_prompt,
367
+ max_length=max_tokens,
368
+ temperature=temperature,
369
+ top_p=top_p,
370
+ do_sample=True,
371
+ )
372
+
373
+ # Extract the generated text
374
+ if isinstance(outputs, list) and len(outputs) > 0:
375
+ if "generated_text" in outputs[0]:
376
+ return outputs[0]["generated_text"].strip()
377
+
378
+ # Fallback extraction
379
+ return str(outputs).strip()
380
+
381
  # Check if the model can handle chat templates
382
  has_chat_template = (
383
  hasattr(self.tokenizer, "chat_template")
 
501
  # Get model file for GGUF models
502
  model_file = os.environ.get("MODEL_FILENAME")
503
 
504
+ # Check model architecture - T5 models use seq2seq, others use causal LM
505
+ model_architecture = os.environ.get("MODEL_ARCHITECTURE", "causal").lower()
506
+
507
  # Check model type - prefer GGUF for speed in resource-constrained environments
508
  model_type = os.environ.get("MODEL_TYPE", "transformers").lower()
509
 
 
533
  model_path=model_path,
534
  model_file=model_file,
535
  model_type=model_type,
536
+ model_architecture=model_architecture,
537
  device_map=device_map,
538
  torch_dtype=torch_dtype,
539
  use_quantization=use_quantization,