Ubuntu
commited on
Commit
·
2eefbf0
1
Parent(s):
b2000f8
CUDA fix
Browse files- model_handler.py +40 -16
model_handler.py
CHANGED
@@ -14,23 +14,47 @@ def load_model_and_tokenizer():
|
|
14 |
offload_dir = "offload_dir"
|
15 |
os.makedirs(offload_dir, exist_ok=True)
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
base_model_name,
|
20 |
-
torch_dtype=torch.float32, # Use float32 for CPU
|
21 |
-
device_map="auto",
|
22 |
-
offload_folder=offload_dir, # Add offload directory
|
23 |
-
load_in_8bit=True, # Use 8-bit quantization
|
24 |
-
low_cpu_mem_usage=True # Optimize for low memory
|
25 |
-
)
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
# Set to evaluation mode
|
36 |
model.eval()
|
|
|
14 |
offload_dir = "offload_dir"
|
15 |
os.makedirs(offload_dir, exist_ok=True)
|
16 |
|
17 |
+
# Check if CUDA is available
|
18 |
+
use_cuda = torch.cuda.is_available()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
try:
|
21 |
+
# First try loading with quantization if CUDA is available
|
22 |
+
if use_cuda:
|
23 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
24 |
+
base_model_name,
|
25 |
+
torch_dtype=torch.float16,
|
26 |
+
device_map="auto",
|
27 |
+
offload_folder=offload_dir,
|
28 |
+
load_in_8bit=True,
|
29 |
+
low_cpu_mem_usage=True
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
# CPU-only loading without quantization
|
33 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
34 |
+
base_model_name,
|
35 |
+
torch_dtype=torch.float32,
|
36 |
+
device_map="auto",
|
37 |
+
offload_folder=offload_dir,
|
38 |
+
low_cpu_mem_usage=True
|
39 |
+
)
|
40 |
+
|
41 |
+
# Load adapter weights
|
42 |
+
model = PeftModel.from_pretrained(
|
43 |
+
base_model,
|
44 |
+
"phi2-grpo-qlora-final",
|
45 |
+
device_map="auto",
|
46 |
+
offload_folder=offload_dir
|
47 |
+
)
|
48 |
+
except Exception as e:
|
49 |
+
print(f"Error loading with adapter: {e}")
|
50 |
+
print("Falling back to base model only...")
|
51 |
+
# Fallback to just the base model if adapter loading fails
|
52 |
+
model = AutoModelForCausalLM.from_pretrained(
|
53 |
+
base_model_name,
|
54 |
+
torch_dtype=torch.float32,
|
55 |
+
device_map="auto",
|
56 |
+
low_cpu_mem_usage=True
|
57 |
+
)
|
58 |
|
59 |
# Set to evaluation mode
|
60 |
model.eval()
|