gnumanth commited on
Commit
a5da8f2
·
verified ·
1 Parent(s): 184ee6c

chore: device optimization

Browse files
Files changed (1) hide show
  1. app.py +109 -10
app.py CHANGED
@@ -24,13 +24,21 @@ class MedGemmaSymptomAnalyzer:
24
  logger.info("Initializing MedGemma Symptom Analyzer...")
25
 
26
  def load_model(self):
27
- """Load MedGemma model with optimizations for deployment"""
28
  if self.model_loaded:
29
  return True
30
 
31
  model_name = "google/medgemma-4b-it"
32
  logger.info(f"Loading model: {model_name}")
33
 
 
 
 
 
 
 
 
 
34
  try:
35
  # Get HF token from environment (set in Hugging Face Spaces secrets)
36
  hf_token = os.getenv("HF_TOKEN")
@@ -39,33 +47,124 @@ class MedGemmaSymptomAnalyzer:
39
  else:
40
  logger.warning("HF_TOKEN not found in environment variables")
41
 
42
- # First try without quantization for CPU compatibility
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  logger.info("Loading tokenizer...")
44
  self.tokenizer = AutoTokenizer.from_pretrained(
45
  model_name,
46
- token=hf_token
 
47
  )
48
 
49
- logger.info("Loading model...")
50
- # Simplified loading for CPU/compatibility
51
  self.model = AutoModelForCausalLM.from_pretrained(
52
  model_name,
53
- torch_dtype=torch.float32, # Use float32 for CPU
54
- device_map="cpu", # Force CPU for compatibility
55
- low_cpu_mem_usage=True,
56
- token=hf_token
57
  )
58
 
 
59
  if self.tokenizer.pad_token is None:
60
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
 
 
61
 
62
  self.model_loaded = True
63
- logger.info("Model loaded successfully!")
64
  return True
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  except Exception as e:
67
  logger.error(f"Failed to load model {model_name}: {str(e)}", exc_info=True)
68
  logger.warning("Falling back to demo mode due to model loading failure")
 
 
 
 
 
 
 
 
69
  self.model = None
70
  self.tokenizer = None
71
  self.model_loaded = False
 
24
  logger.info("Initializing MedGemma Symptom Analyzer...")
25
 
26
  def load_model(self):
27
+ """Load MedGemma model with optimizations for deployment and CPU compatibility"""
28
  if self.model_loaded:
29
  return True
30
 
31
  model_name = "google/medgemma-4b-it"
32
  logger.info(f"Loading model: {model_name}")
33
 
34
+ # Detect available device and log system info
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ logger.info(f"Device detected: {device}")
37
+ if device == "cpu":
38
+ logger.info(f"CPU threads available: {torch.get_num_threads()}")
39
+ else:
40
+ logger.info(f"CUDA device: {torch.cuda.get_device_name()}")
41
+
42
  try:
43
  # Get HF token from environment (set in Hugging Face Spaces secrets)
44
  hf_token = os.getenv("HF_TOKEN")
 
47
  else:
48
  logger.warning("HF_TOKEN not found in environment variables")
49
 
50
+ # Optimize settings based on device
51
+ if device == "cpu":
52
+ logger.info("Configuring for CPU-optimized loading...")
53
+ torch_dtype = torch.float32
54
+ device_map = "cpu"
55
+ # Set optimal number of threads for CPU inference
56
+ torch.set_num_threads(max(1, torch.get_num_threads() // 2))
57
+
58
+ # Additional CPU optimizations
59
+ import psutil
60
+ available_memory_gb = psutil.virtual_memory().available / (1024**3)
61
+ logger.info(f"Available memory: {available_memory_gb:.1f} GB")
62
+
63
+ # Enable memory-efficient loading for low-memory systems
64
+ cpu_loading_kwargs = {
65
+ "low_cpu_mem_usage": True,
66
+ "torch_dtype": torch_dtype,
67
+ "device_map": device_map
68
+ }
69
+
70
+ # Use offloading for very low memory systems (< 8GB available)
71
+ if available_memory_gb < 8:
72
+ logger.warning("Low memory detected, enabling aggressive memory optimizations")
73
+ cpu_loading_kwargs.update({
74
+ "offload_folder": "/tmp/model_offload",
75
+ "offload_state_dict": True
76
+ })
77
+ else:
78
+ logger.info("Configuring for GPU loading...")
79
+ torch_dtype = torch.float16
80
+ device_map = "auto"
81
+ cpu_loading_kwargs = {
82
+ "torch_dtype": torch_dtype,
83
+ "device_map": device_map,
84
+ "low_cpu_mem_usage": True
85
+ }
86
+
87
  logger.info("Loading tokenizer...")
88
  self.tokenizer = AutoTokenizer.from_pretrained(
89
  model_name,
90
+ token=hf_token,
91
+ use_fast=True # Use fast tokenizer for better performance
92
  )
93
 
94
+ logger.info(f"Loading model with dtype={torch_dtype}, device_map={device_map}...")
 
95
  self.model = AutoModelForCausalLM.from_pretrained(
96
  model_name,
97
+ token=hf_token,
98
+ trust_remote_code=False, # Security best practice
99
+ **cpu_loading_kwargs
 
100
  )
101
 
102
+ # Ensure pad token is set
103
  if self.tokenizer.pad_token is None:
104
  self.tokenizer.pad_token = self.tokenizer.eos_token
105
+
106
+ # Move model to appropriate device if needed
107
+ if device == "cpu" and hasattr(self.model, 'to'):
108
+ self.model = self.model.to('cpu')
109
+ logger.info("Model moved to CPU")
110
 
111
  self.model_loaded = True
112
+ logger.info(f"Model loaded successfully on {device}!")
113
  return True
114
 
115
+ except torch.cuda.OutOfMemoryError as e:
116
+ logger.error(f"GPU out of memory: {str(e)}")
117
+ logger.info("Attempting CPU fallback due to GPU memory constraints...")
118
+ try:
119
+ # Force CPU loading if GPU fails
120
+ self.model = AutoModelForCausalLM.from_pretrained(
121
+ model_name,
122
+ token=hf_token,
123
+ trust_remote_code=False,
124
+ torch_dtype=torch.float32,
125
+ device_map="cpu",
126
+ low_cpu_mem_usage=True
127
+ )
128
+ self.model_loaded = True
129
+ logger.info("Model loaded successfully on CPU after GPU failure!")
130
+ return True
131
+ except Exception as fallback_e:
132
+ logger.error(f"CPU fallback also failed: {str(fallback_e)}")
133
+ self.model = None
134
+ self.tokenizer = None
135
+ self.model_loaded = False
136
+ return False
137
+ except ImportError as e:
138
+ logger.error(f"Missing dependency for model loading: {str(e)}")
139
+ logger.info("Please ensure all required packages are installed: pip install -r requirements.txt")
140
+ self.model = None
141
+ self.tokenizer = None
142
+ self.model_loaded = False
143
+ return False
144
+ except OSError as e:
145
+ if "disk quota exceeded" in str(e).lower() or "no space left" in str(e).lower():
146
+ logger.error("Insufficient disk space for model loading")
147
+ logger.info("Please free up disk space and try again")
148
+ elif "connection" in str(e).lower() or "timeout" in str(e).lower():
149
+ logger.error("Network connection issue during model download")
150
+ logger.info("Please check your internet connection and try again")
151
+ else:
152
+ logger.error(f"OS error during model loading: {str(e)}")
153
+ self.model = None
154
+ self.tokenizer = None
155
+ self.model_loaded = False
156
+ return False
157
  except Exception as e:
158
  logger.error(f"Failed to load model {model_name}: {str(e)}", exc_info=True)
159
  logger.warning("Falling back to demo mode due to model loading failure")
160
+
161
+ # Provide helpful troubleshooting info
162
+ if device == "cpu":
163
+ logger.info("CPU loading troubleshooting tips:")
164
+ logger.info("- Ensure sufficient RAM (minimum 8GB recommended)")
165
+ logger.info("- Check that PyTorch CPU version is installed")
166
+ logger.info("- Verify HuggingFace token is valid")
167
+
168
  self.model = None
169
  self.tokenizer = None
170
  self.model_loaded = False