File size: 9,533 Bytes
754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 45309a1 fc5fa78 45309a1 fc5fa78 45309a1 fc5fa78 45309a1 fc5fa78 45309a1 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec 45309a1 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec 45309a1 fc5fa78 754afec fc5fa78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
"""coordinator.py module."""
import tensorflow as tf
from typing import List, Dict, Any, Optional
import numpy as np
from collections import defaultdict
import logging
import time
import threading
from .aggregator import FederatedAggregator
class FederatedCoordinator:
def __init__(self, config: Dict):
"""Initialize the federated learning coordinator."""
logger = logging.getLogger(__name__)
logger.debug(f"Initializing FederatedCoordinator with config: {config}")
self.config = config
self.clients = {}
self.client_updates = {} # Store updates for current round
self.global_model_weights = None
self.current_round = 0
self.training_active = False
# Extract federated learning parameters
self.min_clients = config.get('federated', {}).get('min_clients', 2)
self.rounds = config.get('federated', {}).get('rounds', 10)
# Debug: log config structure
logger.debug(f"Coordinator received config: {config}")
# Robustly extract aggregation config
agg_config = None
if 'aggregation' in config:
agg_config = config
elif 'server' in config and 'aggregation' in config['server']:
agg_config = config['server']
else:
logger.error(f"No 'aggregation' key found in config for FederatedAggregator: {config}")
raise ValueError("'aggregation' config section is required for FederatedAggregator")
logger.debug(f"Passing aggregation config to FederatedAggregator: {agg_config}")
try:
self.aggregator = FederatedAggregator(agg_config)
except Exception as e:
logger.error(f"Error initializing FederatedAggregator: {e}")
raise
# Initialize global model weights with random values
self._initialize_global_model()
self.lock = threading.Lock() # Thread safety for concurrent API calls
logger.info("FederatedCoordinator initialized.")
def _initialize_global_model(self):
"""Initialize global model weights with random values."""
logger = logging.getLogger(__name__)
try:
# Build a simple model to get initial weights
input_dim = self.config.get('model', {}).get('input_dim', 32)
hidden_layers = self.config.get('model', {}).get('hidden_layers', [128, 64])
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(input_dim,)),
tf.keras.layers.Dense(hidden_layers[0], activation='relu'),
tf.keras.layers.Dense(hidden_layers[1], activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
self.global_model_weights = model.get_weights()
logger.info(f"Global model initialized with {len(self.global_model_weights)} weight layers")
except Exception as e:
logger.error(f"Error initializing global model: {e}")
# Fallback to simple random weights
self.global_model_weights = [
np.random.randn(32, 128).astype(np.float32),
np.random.randn(128).astype(np.float32),
np.random.randn(128, 64).astype(np.float32),
np.random.randn(64).astype(np.float32),
np.random.randn(64, 1).astype(np.float32),
np.random.randn(1).astype(np.float32)
]
logger.info("Using fallback random weights for global model")
def register_client(self, client_id: str, client_info: Dict[str, Any] = None) -> bool:
"""Register a new client."""
with self.lock:
if client_id in self.clients:
logging.getLogger(__name__).warning(f"Client {client_id} already registered")
return True
self.clients[client_id] = {
'info': client_info or {},
'last_seen': time.time(),
'metrics': defaultdict(list)
}
logging.getLogger(__name__).info(f"Client {client_id} registered successfully")
return True
def get_client_config(self) -> Dict[str, Any]:
"""Get configuration to send to clients"""
return {
'model_config': self.config.get('model', {}),
'training_config': self.config.get('training', {}),
'current_round': self.current_round,
'total_rounds': self.rounds
}
def get_global_model(self) -> Optional[List]:
"""Get the current global model weights"""
with self.lock:
return self.global_model_weights
def receive_model_update(self, client_id: str, model_weights: List, metrics: Dict[str, Any]):
"""Receive a model update from a client"""
with self.lock:
if client_id not in self.clients:
raise ValueError(f"Client {client_id} not registered")
self.client_updates[client_id] = {
'weights': model_weights,
'metrics': metrics,
'timestamp': time.time()
}
self.clients[client_id]['last_seen'] = time.time()
logger = logging.getLogger(__name__)
logger.info(f"Received update from client {client_id}")
# Check if we have enough updates for aggregation
if len(self.client_updates) >= self.min_clients:
self._aggregate_models()
def _aggregate_models(self):
"""Aggregate models from all client updates"""
try:
logger = logging.getLogger(__name__)
logger.info(f"Aggregating models from {len(self.client_updates)} clients")
# Prepare updates for aggregation
updates = []
for client_id, update in self.client_updates.items():
client_size = update['metrics'].get('dataset_size', 100) # Default size
updates.append({
'client_id': client_id,
'weights': update['weights'],
'size': client_size
})
# Aggregate using FedAvg
self.global_model_weights = self.aggregator.federated_averaging(updates)
# Clear updates for next round
self.client_updates.clear()
self.current_round += 1
logger.info(f"Model aggregation completed for round {self.current_round}")
except Exception as e:
logger = logging.getLogger(__name__)
logger.error(f"Error during model aggregation: {str(e)}")
def _count_active_clients(self) -> int:
"""Count active clients (seen in last 60 seconds)"""
current_time = time.time()
active_count = sum(1 for client in self.clients.values()
if current_time - client['last_seen'] < 60)
return active_count
def start(self):
"""Start the federated learning process with API server"""
logger = logging.getLogger(__name__)
# Print server startup information
logger.info("\n" + "=" * 60)
logger.info(f"{'Federated Learning Server Starting':^60}")
logger.info("=" * 60)
# Print configuration details
logger.info("\nServer Configuration:")
logger.info("-" * 30)
logger.info(f"Minimum clients required: {self.min_clients}")
logger.info(f"Total rounds planned: {self.rounds}")
active_clients_count = self._count_active_clients()
logger.info(f"Current active clients: {active_clients_count}")
logger.info("-" * 30 + "\n")
self.training_active = True
# Import and start API server
try:
from ..api.server import FederatedAPI
api_config = self.config.get('api', {})
host = api_config.get('host', '0.0.0.0')
port = api_config.get('port', 8080)
api_server = FederatedAPI(self, host, port)
api_thread = api_server.run_threaded()
logger.info(f"API server started on {host}:{port}")
# Keep server running
try:
while self.training_active and self.current_round < self.rounds:
time.sleep(1) # Keep main thread alive
# Log progress periodically
active_clients_count = self._count_active_clients()
if active_clients_count > 0:
logger.debug(f"Round {self.current_round}/{self.rounds}, "
f"Active Clients: {active_clients_count}, "
f"Updates: {len(self.client_updates)}")
logger.info("Federated learning completed successfully")
except KeyboardInterrupt:
logger.info("Server shutdown requested")
self.training_active = False
except ImportError as e:
logger.error(f"Failed to start API server: {str(e)}")
# Fallback to original behavior
# ...existing code...
|