""" Cluster Registration Service Handles registration-based cluster management for Hive-Bzzz integration. """ import asyncpg import secrets import json import socket from datetime import datetime, timedelta from typing import Dict, List, Optional, Any from dataclasses import dataclass from ipaddress import IPv4Network, IPv6Network, ip_address import logging logger = logging.getLogger(__name__) @dataclass class ClusterToken: id: int token: str description: str created_at: datetime expires_at: Optional[datetime] is_active: bool max_registrations: Optional[int] current_registrations: int allowed_ip_ranges: Optional[List[str]] @dataclass class ClusterNode: id: int node_id: str hostname: str ip_address: str registration_token: str cpu_info: Optional[Dict[str, Any]] memory_info: Optional[Dict[str, Any]] gpu_info: Optional[Dict[str, Any]] disk_info: Optional[Dict[str, Any]] os_info: Optional[Dict[str, Any]] platform_info: Optional[Dict[str, Any]] status: str last_heartbeat: datetime first_registered: datetime services: Optional[Dict[str, Any]] capabilities: Optional[Dict[str, Any]] ports: Optional[Dict[str, Any]] client_version: Optional[str] registration_metadata: Optional[Dict[str, Any]] @dataclass class RegistrationRequest: token: str node_id: str hostname: str ip_address: str system_info: Dict[str, Any] client_version: Optional[str] = None services: Optional[Dict[str, Any]] = None capabilities: Optional[Dict[str, Any]] = None ports: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None @dataclass class HeartbeatRequest: node_id: str status: str = "online" cpu_usage: Optional[float] = None memory_usage: Optional[float] = None disk_usage: Optional[float] = None gpu_usage: Optional[float] = None services_status: Optional[Dict[str, Any]] = None network_metrics: Optional[Dict[str, Any]] = None custom_metrics: Optional[Dict[str, Any]] = None class ClusterRegistrationService: def __init__(self, database_url: str): self.database_url = database_url self._conn_cache = None async def get_connection(self) -> asyncpg.Connection: """Get database connection with caching.""" if not self._conn_cache or self._conn_cache.is_closed(): try: self._conn_cache = await asyncpg.connect(self.database_url) except Exception as e: logger.error(f"Failed to connect to database: {e}") raise return self._conn_cache async def close_connection(self): """Close database connection.""" if self._conn_cache and not self._conn_cache.is_closed(): await self._conn_cache.close() # Token Management async def generate_cluster_token( self, description: str, created_by_user_id: str, expires_in_days: Optional[int] = None, max_registrations: Optional[int] = None, allowed_ip_ranges: Optional[List[str]] = None ) -> ClusterToken: """Generate a new cluster registration token.""" conn = await self.get_connection() # Generate secure token token = f"hive_cluster_{secrets.token_urlsafe(32)}" expires_at = datetime.now() + timedelta(days=expires_in_days) if expires_in_days else None try: result = await conn.fetchrow(""" INSERT INTO cluster_tokens ( token, description, created_by, expires_at, max_registrations, allowed_ip_ranges ) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, token, description, created_at, expires_at, is_active, max_registrations, current_registrations, allowed_ip_ranges """, token, description, created_by_user_id, expires_at, max_registrations, allowed_ip_ranges) return ClusterToken(**dict(result)) except Exception as e: logger.error(f"Failed to generate cluster token: {e}") raise async def validate_token(self, token: str, client_ip: str) -> Optional[ClusterToken]: """Validate a cluster registration token.""" conn = await self.get_connection() try: result = await conn.fetchrow(""" SELECT id, token, description, created_at, expires_at, is_active, max_registrations, current_registrations, allowed_ip_ranges FROM cluster_tokens WHERE token = $1 AND is_active = true """, token) if not result: return None cluster_token = ClusterToken(**dict(result)) # Check expiration if cluster_token.expires_at and datetime.now() > cluster_token.expires_at: logger.warning(f"Token {token[:20]}... has expired") return None # Check registration limit if (cluster_token.max_registrations and cluster_token.current_registrations >= cluster_token.max_registrations): logger.warning(f"Token {token[:20]}... has reached registration limit") return None # Check IP restrictions if cluster_token.allowed_ip_ranges: client_ip_obj = ip_address(client_ip) allowed = False for ip_range in cluster_token.allowed_ip_ranges: try: network = IPv4Network(ip_range, strict=False) if ':' not in ip_range else IPv6Network(ip_range, strict=False) if client_ip_obj in network: allowed = True break except Exception as e: logger.warning(f"Invalid IP range {ip_range}: {e}") if not allowed: logger.warning(f"IP {client_ip} not allowed for token {token[:20]}...") return None return cluster_token except Exception as e: logger.error(f"Failed to validate token: {e}") return None async def list_tokens(self) -> List[ClusterToken]: """List all cluster tokens.""" conn = await self.get_connection() try: results = await conn.fetch(""" SELECT id, token, description, created_at, expires_at, is_active, max_registrations, current_registrations, allowed_ip_ranges FROM cluster_tokens ORDER BY created_at DESC """) return [ClusterToken(**dict(result)) for result in results] except Exception as e: logger.error(f"Failed to list tokens: {e}") raise async def revoke_token(self, token: str) -> bool: """Revoke a cluster token.""" conn = await self.get_connection() try: result = await conn.execute(""" UPDATE cluster_tokens SET is_active = false WHERE token = $1 """, token) return result != "UPDATE 0" except Exception as e: logger.error(f"Failed to revoke token: {e}") return False # Node Registration async def register_node(self, request: RegistrationRequest, client_ip: str) -> Dict[str, Any]: """Register a new cluster node.""" conn = await self.get_connection() # Log registration attempt await self._log_registration_attempt( client_ip, request.token, request.node_id, request.hostname, True, None, request.metadata ) try: # Validate token token_info = await self.validate_token(request.token, client_ip) if not token_info: await self._log_registration_attempt( client_ip, request.token, request.node_id, request.hostname, False, "Invalid or expired token", request.metadata ) raise ValueError("Invalid or expired registration token") # Extract system info components system_info = request.system_info or {} cpu_info = system_info.get('cpu', {}) memory_info = system_info.get('memory', {}) gpu_info = system_info.get('gpu', {}) disk_info = system_info.get('disk', {}) os_info = system_info.get('os', {}) platform_info = system_info.get('platform', {}) # Register or update node result = await conn.fetchrow(""" INSERT INTO cluster_nodes ( node_id, hostname, ip_address, registration_token, cpu_info, memory_info, gpu_info, disk_info, os_info, platform_info, services, capabilities, ports, client_version, registration_metadata ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) ON CONFLICT (node_id) DO UPDATE SET hostname = EXCLUDED.hostname, ip_address = EXCLUDED.ip_address, cpu_info = EXCLUDED.cpu_info, memory_info = EXCLUDED.memory_info, gpu_info = EXCLUDED.gpu_info, disk_info = EXCLUDED.disk_info, os_info = EXCLUDED.os_info, platform_info = EXCLUDED.platform_info, services = EXCLUDED.services, capabilities = EXCLUDED.capabilities, ports = EXCLUDED.ports, client_version = EXCLUDED.client_version, registration_metadata = EXCLUDED.registration_metadata, status = 'online', last_heartbeat = NOW() RETURNING id, node_id, hostname, ip_address, first_registered """, request.node_id, request.hostname, request.ip_address, request.token, json.dumps(cpu_info) if cpu_info else None, json.dumps(memory_info) if memory_info else None, json.dumps(gpu_info) if gpu_info else None, json.dumps(disk_info) if disk_info else None, json.dumps(os_info) if os_info else None, json.dumps(platform_info) if platform_info else None, json.dumps(request.services) if request.services else None, json.dumps(request.capabilities) if request.capabilities else None, json.dumps(request.ports) if request.ports else None, request.client_version, json.dumps(request.metadata) if request.metadata else None ) logger.info(f"Node {request.node_id} registered successfully from {client_ip}") return { "node_id": result["node_id"], "registration_status": "success", "heartbeat_interval": 30, # seconds "registered_at": result["first_registered"].isoformat(), "cluster_info": { "coordinator_version": "1.0.0", "features": ["heartbeat", "dynamic_scaling", "service_discovery"] } } except Exception as e: logger.error(f"Failed to register node {request.node_id}: {e}") await self._log_registration_attempt( client_ip, request.token, request.node_id, request.hostname, False, str(e), request.metadata ) raise async def update_heartbeat(self, request: HeartbeatRequest) -> Dict[str, Any]: """Update node heartbeat and metrics.""" conn = await self.get_connection() try: # Update node status and heartbeat result = await conn.fetchrow(""" UPDATE cluster_nodes SET status = $2, last_heartbeat = NOW() WHERE node_id = $1 RETURNING node_id, status, last_heartbeat """, request.node_id, request.status) if not result: raise ValueError(f"Node {request.node_id} not found") # Record heartbeat metrics await conn.execute(""" INSERT INTO node_heartbeats ( node_id, cpu_usage, memory_usage, disk_usage, gpu_usage, services_status, network_metrics, custom_metrics ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) """, request.node_id, request.cpu_usage, request.memory_usage, request.disk_usage, request.gpu_usage, json.dumps(request.services_status) if request.services_status else None, json.dumps(request.network_metrics) if request.network_metrics else None, json.dumps(request.custom_metrics) if request.custom_metrics else None ) return { "node_id": result["node_id"], "status": result["status"], "heartbeat_received": result["last_heartbeat"].isoformat(), "next_heartbeat_in": 30, # seconds "commands": [] # Future: cluster management commands } except Exception as e: logger.error(f"Failed to update heartbeat for {request.node_id}: {e}") raise async def get_registered_nodes(self, include_offline: bool = True) -> List[ClusterNode]: """Get all registered cluster nodes.""" conn = await self.get_connection() try: query = """ SELECT id, node_id, hostname, ip_address, registration_token, cpu_info, memory_info, gpu_info, disk_info, os_info, platform_info, status, last_heartbeat, first_registered, services, capabilities, ports, client_version, registration_metadata FROM cluster_nodes """ if not include_offline: query += " WHERE status != 'offline'" query += " ORDER BY first_registered DESC" results = await conn.fetch(query) nodes = [] for result in results: node_dict = dict(result) # Parse JSON fields for json_field in ['cpu_info', 'memory_info', 'gpu_info', 'disk_info', 'os_info', 'platform_info', 'services', 'capabilities', 'ports', 'registration_metadata']: if node_dict[json_field]: try: node_dict[json_field] = json.loads(node_dict[json_field]) except json.JSONDecodeError: node_dict[json_field] = None nodes.append(ClusterNode(**node_dict)) return nodes except Exception as e: logger.error(f"Failed to get registered nodes: {e}") raise async def get_node_details(self, node_id: str) -> Optional[ClusterNode]: """Get detailed information about a specific node.""" nodes = await self.get_registered_nodes() return next((node for node in nodes if node.node_id == node_id), None) async def remove_node(self, node_id: str) -> bool: """Remove a node from the cluster.""" conn = await self.get_connection() try: result = await conn.execute(""" DELETE FROM cluster_nodes WHERE node_id = $1 """, node_id) if result != "DELETE 0": logger.info(f"Node {node_id} removed from cluster") return True return False except Exception as e: logger.error(f"Failed to remove node {node_id}: {e}") return False # Maintenance and Monitoring async def cleanup_offline_nodes(self, offline_threshold_minutes: int = 10) -> int: """Mark nodes as offline if they haven't sent heartbeats.""" conn = await self.get_connection() try: result = await conn.execute(""" UPDATE cluster_nodes SET status = 'offline' WHERE status = 'online' AND last_heartbeat < NOW() - INTERVAL '%s minutes' """ % offline_threshold_minutes) # Extract number from result like "UPDATE 3" count = int(result.split()[-1]) if result.split()[-1].isdigit() else 0 if count > 0: logger.info(f"Marked {count} nodes as offline due to missing heartbeats") return count except Exception as e: logger.error(f"Failed to cleanup offline nodes: {e}") return 0 async def cleanup_old_heartbeats(self, retention_days: int = 30) -> int: """Remove old heartbeat data for storage management.""" conn = await self.get_connection() try: result = await conn.execute(""" DELETE FROM node_heartbeats WHERE heartbeat_time < NOW() - INTERVAL '%s days' """ % retention_days) count = int(result.split()[-1]) if result.split()[-1].isdigit() else 0 if count > 0: logger.info(f"Cleaned up {count} old heartbeat records") return count except Exception as e: logger.error(f"Failed to cleanup old heartbeats: {e}") return 0 async def _log_registration_attempt( self, ip_address: str, token: str, node_id: str, hostname: str, success: bool, failure_reason: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None ): """Log registration attempts for security monitoring.""" conn = await self.get_connection() try: await conn.execute(""" INSERT INTO node_registration_attempts ( ip_address, token_used, node_id, hostname, success, failure_reason, request_metadata ) VALUES ($1, $2, $3, $4, $5, $6, $7) """, ip_address, token, node_id, hostname, success, failure_reason, json.dumps(metadata) if metadata else None) except Exception as e: logger.error(f"Failed to log registration attempt: {e}") async def get_cluster_statistics(self) -> Dict[str, Any]: """Get cluster statistics and health metrics.""" conn = await self.get_connection() try: # Node statistics node_stats = await conn.fetchrow(""" SELECT COUNT(*) as total_nodes, COUNT(*) FILTER (WHERE status = 'online') as online_nodes, COUNT(*) FILTER (WHERE status = 'offline') as offline_nodes, COUNT(*) FILTER (WHERE status = 'maintenance') as maintenance_nodes FROM cluster_nodes """) # Token statistics token_stats = await conn.fetchrow(""" SELECT COUNT(*) as total_tokens, COUNT(*) FILTER (WHERE is_active = true) as active_tokens, COUNT(*) FILTER (WHERE expires_at IS NOT NULL AND expires_at < NOW()) as expired_tokens FROM cluster_tokens """) return { "cluster_health": { "total_nodes": node_stats["total_nodes"], "online_nodes": node_stats["online_nodes"], "offline_nodes": node_stats["offline_nodes"], "maintenance_nodes": node_stats["maintenance_nodes"], "health_percentage": (node_stats["online_nodes"] / max(node_stats["total_nodes"], 1)) * 100 }, "token_management": { "total_tokens": token_stats["total_tokens"], "active_tokens": token_stats["active_tokens"], "expired_tokens": token_stats["expired_tokens"] }, "last_updated": datetime.now().isoformat() } except Exception as e: logger.error(f"Failed to get cluster statistics: {e}") return { "error": str(e), "last_updated": datetime.now().isoformat() }