""" Cluster Registration API endpoints Handles registration-based cluster management for WHOOSH-Bzzz integration. """ from fastapi import APIRouter, HTTPException, Request, Depends from pydantic import BaseModel, Field from typing import Dict, Any, List, Optional import logging import os from ..services.cluster_registration_service import ( ClusterRegistrationService, RegistrationRequest, HeartbeatRequest ) logger = logging.getLogger(__name__) router = APIRouter() # Initialize service DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://whoosh:whooshpass@localhost:5432/whoosh") cluster_registration_service = ClusterRegistrationService(DATABASE_URL) # Pydantic models for API class NodeRegistrationRequest(BaseModel): token: str = Field(..., description="Cluster registration token") node_id: str = Field(..., description="Unique node identifier") hostname: str = Field(..., description="Node hostname") system_info: Dict[str, Any] = Field(..., description="System hardware and OS information") client_version: Optional[str] = Field(None, description="Bzzz client version") services: Optional[Dict[str, Any]] = Field(None, description="Available services") capabilities: Optional[Dict[str, Any]] = Field(None, description="Node capabilities") ports: Optional[Dict[str, Any]] = Field(None, description="Service ports") metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") class NodeHeartbeatRequest(BaseModel): node_id: str = Field(..., description="Node identifier") status: str = Field("online", description="Node status") cpu_usage: Optional[float] = Field(None, ge=0, le=100, description="CPU usage percentage") memory_usage: Optional[float] = Field(None, ge=0, le=100, description="Memory usage percentage") disk_usage: Optional[float] = Field(None, ge=0, le=100, description="Disk usage percentage") gpu_usage: Optional[float] = Field(None, ge=0, le=100, description="GPU usage percentage") services_status: Optional[Dict[str, Any]] = Field(None, description="Service status information") network_metrics: Optional[Dict[str, Any]] = Field(None, description="Network metrics") custom_metrics: Optional[Dict[str, Any]] = Field(None, description="Custom node metrics") class TokenCreateRequest(BaseModel): description: str = Field(..., description="Token description") expires_in_days: Optional[int] = Field(None, gt=0, description="Token expiration in days") max_registrations: Optional[int] = Field(None, gt=0, description="Maximum number of registrations") allowed_ip_ranges: Optional[List[str]] = Field(None, description="Allowed IP CIDR ranges") # Helper function to get client IP def get_client_ip(request: Request) -> str: """Extract client IP address from request.""" # Check for X-Forwarded-For header (proxy/load balancer) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: # Take the first IP in the chain (original client) return forwarded_for.split(",")[0].strip() # Check for X-Real-IP header (nginx) real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip.strip() # Fall back to direct connection IP return request.client.host if request.client else "unknown" # Registration endpoints @router.post("/cluster/register") async def register_node( registration: NodeRegistrationRequest, request: Request ) -> Dict[str, Any]: """ Register a new node in the cluster. This endpoint allows Bzzz clients to register themselves with the WHOOSH coordinator using a valid cluster token. Similar to `docker swarm join`. """ try: client_ip = get_client_ip(request) logger.info(f"Node registration attempt: {registration.node_id} from {client_ip}") # Convert to service request reg_request = RegistrationRequest( token=registration.token, node_id=registration.node_id, hostname=registration.hostname, ip_address=client_ip, system_info=registration.system_info, client_version=registration.client_version, services=registration.services, capabilities=registration.capabilities, ports=registration.ports, metadata=registration.metadata ) result = await cluster_registration_service.register_node(reg_request, client_ip) logger.info(f"Node {registration.node_id} registered successfully") return result except ValueError as e: logger.warning(f"Registration failed for {registration.node_id}: {e}") raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Registration error for {registration.node_id}: {e}") raise HTTPException(status_code=500, detail="Registration failed") @router.post("/cluster/heartbeat") async def node_heartbeat(heartbeat: NodeHeartbeatRequest) -> Dict[str, Any]: """ Update node heartbeat and status. Registered nodes should call this endpoint periodically (every 30 seconds) to maintain their registration and report current status/metrics. """ try: heartbeat_request = HeartbeatRequest( node_id=heartbeat.node_id, status=heartbeat.status, cpu_usage=heartbeat.cpu_usage, memory_usage=heartbeat.memory_usage, disk_usage=heartbeat.disk_usage, gpu_usage=heartbeat.gpu_usage, services_status=heartbeat.services_status, network_metrics=heartbeat.network_metrics, custom_metrics=heartbeat.custom_metrics ) result = await cluster_registration_service.update_heartbeat(heartbeat_request) return result except ValueError as e: logger.warning(f"Heartbeat failed for {heartbeat.node_id}: {e}") raise HTTPException(status_code=404, detail=str(e)) except Exception as e: logger.error(f"Heartbeat error for {heartbeat.node_id}: {e}") raise HTTPException(status_code=500, detail="Heartbeat update failed") # Node management endpoints @router.get("/cluster/nodes/registered") async def get_registered_nodes(include_offline: bool = True) -> Dict[str, Any]: """ Get all registered cluster nodes. Returns detailed information about all nodes that have registered with the cluster, including their hardware specs and current status. """ try: nodes = await cluster_registration_service.get_registered_nodes(include_offline) # Convert to API response format nodes_data = [] for node in nodes: # Convert dataclass to dict and handle datetime serialization node_dict = { "id": node.id, "node_id": node.node_id, "hostname": node.hostname, "ip_address": node.ip_address, "status": node.status, "hardware": { "cpu": node.cpu_info or {}, "memory": node.memory_info or {}, "gpu": node.gpu_info or {}, "disk": node.disk_info or {}, "os": node.os_info or {}, "platform": node.platform_info or {} }, "services": node.services or {}, "capabilities": node.capabilities or {}, "ports": node.ports or {}, "client_version": node.client_version, "first_registered": node.first_registered.isoformat(), "last_heartbeat": node.last_heartbeat.isoformat(), "registration_metadata": node.registration_metadata or {} } nodes_data.append(node_dict) return { "nodes": nodes_data, "total_count": len(nodes_data), "online_count": len([n for n in nodes if n.status == "online"]), "offline_count": len([n for n in nodes if n.status == "offline"]) } except Exception as e: logger.error(f"Failed to get registered nodes: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve registered nodes") @router.get("/cluster/nodes/{node_id}") async def get_node_details(node_id: str) -> Dict[str, Any]: """Get detailed information about a specific registered node.""" try: node = await cluster_registration_service.get_node_details(node_id) if not node: raise HTTPException(status_code=404, detail="Node not found") return { "id": node.id, "node_id": node.node_id, "hostname": node.hostname, "ip_address": node.ip_address, "status": node.status, "hardware": { "cpu": node.cpu_info or {}, "memory": node.memory_info or {}, "gpu": node.gpu_info or {}, "disk": node.disk_info or {}, "os": node.os_info or {}, "platform": node.platform_info or {} }, "services": node.services or {}, "capabilities": node.capabilities or {}, "ports": node.ports or {}, "client_version": node.client_version, "first_registered": node.first_registered.isoformat(), "last_heartbeat": node.last_heartbeat.isoformat(), "registration_metadata": node.registration_metadata or {} } except HTTPException: raise except Exception as e: logger.error(f"Failed to get node details for {node_id}: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve node details") @router.delete("/cluster/nodes/{node_id}") async def remove_node(node_id: str) -> Dict[str, Any]: """ Remove a node from the cluster. This will unregister the node and stop accepting its heartbeats. The node will need to re-register to rejoin the cluster. """ try: success = await cluster_registration_service.remove_node(node_id) if not success: raise HTTPException(status_code=404, detail="Node not found") return { "node_id": node_id, "status": "removed", "message": "Node successfully removed from cluster" } except HTTPException: raise except Exception as e: logger.error(f"Failed to remove node {node_id}: {e}") raise HTTPException(status_code=500, detail="Failed to remove node") # Token management endpoints @router.post("/cluster/tokens") async def create_cluster_token(token_request: TokenCreateRequest) -> Dict[str, Any]: """ Create a new cluster registration token. Tokens are used by Bzzz clients to authenticate and register with the cluster. Only administrators should have access to this endpoint. """ try: # For now, use a default admin user ID # TODO: Extract from JWT token or session admin_user_id = "admin" # This should come from authentication token = await cluster_registration_service.generate_cluster_token( description=token_request.description, created_by_user_id=admin_user_id, expires_in_days=token_request.expires_in_days, max_registrations=token_request.max_registrations, allowed_ip_ranges=token_request.allowed_ip_ranges ) return { "id": token.id, "token": token.token, "description": token.description, "created_at": token.created_at.isoformat(), "expires_at": token.expires_at.isoformat() if token.expires_at else None, "is_active": token.is_active, "max_registrations": token.max_registrations, "current_registrations": token.current_registrations, "allowed_ip_ranges": token.allowed_ip_ranges } except Exception as e: logger.error(f"Failed to create cluster token: {e}") raise HTTPException(status_code=500, detail="Failed to create token") @router.get("/cluster/tokens") async def list_cluster_tokens() -> Dict[str, Any]: """ List all cluster registration tokens. Returns information about all tokens including their usage statistics. Only administrators should have access to this endpoint. """ try: tokens = await cluster_registration_service.list_tokens() tokens_data = [] for token in tokens: tokens_data.append({ "id": token.id, "token": token.token[:20] + "..." if len(token.token) > 20 else token.token, # Partial token for security "description": token.description, "created_at": token.created_at.isoformat(), "expires_at": token.expires_at.isoformat() if token.expires_at else None, "is_active": token.is_active, "max_registrations": token.max_registrations, "current_registrations": token.current_registrations, "allowed_ip_ranges": token.allowed_ip_ranges }) return { "tokens": tokens_data, "total_count": len(tokens_data) } except Exception as e: logger.error(f"Failed to list cluster tokens: {e}") raise HTTPException(status_code=500, detail="Failed to list tokens") @router.delete("/cluster/tokens/{token}") async def revoke_cluster_token(token: str) -> Dict[str, Any]: """ Revoke a cluster registration token. This will prevent new registrations using this token, but won't affect nodes that are already registered. """ try: success = await cluster_registration_service.revoke_token(token) if not success: raise HTTPException(status_code=404, detail="Token not found") return { "token": token[:20] + "..." if len(token) > 20 else token, "status": "revoked", "message": "Token successfully revoked" } except HTTPException: raise except Exception as e: logger.error(f"Failed to revoke token {token}: {e}") raise HTTPException(status_code=500, detail="Failed to revoke token") # Cluster statistics and monitoring @router.get("/cluster/statistics") async def get_cluster_statistics() -> Dict[str, Any]: """ Get cluster health and usage statistics. Returns information about node counts, token usage, and overall cluster health. """ try: stats = await cluster_registration_service.get_cluster_statistics() return stats except Exception as e: logger.error(f"Failed to get cluster statistics: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve cluster statistics") # Maintenance endpoints @router.post("/cluster/maintenance/cleanup-offline") async def cleanup_offline_nodes(offline_threshold_minutes: int = 10) -> Dict[str, Any]: """ Mark nodes as offline if they haven't sent heartbeats recently. This maintenance endpoint should be called periodically to keep the cluster status accurate. """ try: count = await cluster_registration_service.cleanup_offline_nodes(offline_threshold_minutes) return { "nodes_marked_offline": count, "threshold_minutes": offline_threshold_minutes, "message": f"Marked {count} nodes as offline" } except Exception as e: logger.error(f"Failed to cleanup offline nodes: {e}") raise HTTPException(status_code=500, detail="Failed to cleanup offline nodes") @router.post("/cluster/maintenance/cleanup-heartbeats") async def cleanup_old_heartbeats(retention_days: int = 30) -> Dict[str, Any]: """ Remove old heartbeat data to manage database size. This maintenance endpoint should be called periodically to prevent the heartbeat table from growing too large. """ try: count = await cluster_registration_service.cleanup_old_heartbeats(retention_days) return { "heartbeats_deleted": count, "retention_days": retention_days, "message": f"Deleted {count} old heartbeat records" } except Exception as e: logger.error(f"Failed to cleanup old heartbeats: {e}") raise HTTPException(status_code=500, detail="Failed to cleanup old heartbeats") # Health check endpoint @router.get("/cluster/health") async def cluster_registration_health() -> Dict[str, Any]: """ Health check for the cluster registration system. """ try: # Test database connection stats = await cluster_registration_service.get_cluster_statistics() return { "status": "healthy", "database_connected": True, "cluster_health": stats.get("cluster_health", {}), "timestamp": stats.get("last_updated") } except Exception as e: logger.error(f"Cluster registration health check failed: {e}") return { "status": "unhealthy", "database_connected": False, "error": str(e), "timestamp": None }