Ubuntu commited on
Commit
2eefbf0
·
1 Parent(s): b2000f8
Files changed (1) hide show
  1. model_handler.py +40 -16
model_handler.py CHANGED
@@ -14,23 +14,47 @@ def load_model_and_tokenizer():
14
  offload_dir = "offload_dir"
15
  os.makedirs(offload_dir, exist_ok=True)
16
 
17
- # Load base model with 8-bit quantization to reduce memory usage
18
- base_model = AutoModelForCausalLM.from_pretrained(
19
- base_model_name,
20
- torch_dtype=torch.float32, # Use float32 for CPU
21
- device_map="auto",
22
- offload_folder=offload_dir, # Add offload directory
23
- load_in_8bit=True, # Use 8-bit quantization
24
- low_cpu_mem_usage=True # Optimize for low memory
25
- )
26
 
27
- # Load adapter weights
28
- model = PeftModel.from_pretrained(
29
- base_model,
30
- "phi2-grpo-qlora-final",
31
- device_map="auto",
32
- offload_folder=offload_dir # Add offload directory
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # Set to evaluation mode
36
  model.eval()
 
14
  offload_dir = "offload_dir"
15
  os.makedirs(offload_dir, exist_ok=True)
16
 
17
+ # Check if CUDA is available
18
+ use_cuda = torch.cuda.is_available()
 
 
 
 
 
 
 
19
 
20
+ try:
21
+ # First try loading with quantization if CUDA is available
22
+ if use_cuda:
23
+ base_model = AutoModelForCausalLM.from_pretrained(
24
+ base_model_name,
25
+ torch_dtype=torch.float16,
26
+ device_map="auto",
27
+ offload_folder=offload_dir,
28
+ load_in_8bit=True,
29
+ low_cpu_mem_usage=True
30
+ )
31
+ else:
32
+ # CPU-only loading without quantization
33
+ base_model = AutoModelForCausalLM.from_pretrained(
34
+ base_model_name,
35
+ torch_dtype=torch.float32,
36
+ device_map="auto",
37
+ offload_folder=offload_dir,
38
+ low_cpu_mem_usage=True
39
+ )
40
+
41
+ # Load adapter weights
42
+ model = PeftModel.from_pretrained(
43
+ base_model,
44
+ "phi2-grpo-qlora-final",
45
+ device_map="auto",
46
+ offload_folder=offload_dir
47
+ )
48
+ except Exception as e:
49
+ print(f"Error loading with adapter: {e}")
50
+ print("Falling back to base model only...")
51
+ # Fallback to just the base model if adapter loading fails
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ base_model_name,
54
+ torch_dtype=torch.float32,
55
+ device_map="auto",
56
+ low_cpu_mem_usage=True
57
+ )
58
 
59
  # Set to evaluation mode
60
  model.eval()