WIP: Save current work before CHORUS rebrand
- Agent roles integration progress - Various backend and frontend updates - Storybook cache cleanup 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
522
backend/app/services/cluster_registration_service.py
Normal file
522
backend/app/services/cluster_registration_service.py
Normal file
@@ -0,0 +1,522 @@
|
||||
"""
|
||||
Cluster Registration Service
|
||||
Handles registration-based cluster management for Hive-Bzzz integration.
|
||||
"""
|
||||
import asyncpg
|
||||
import secrets
|
||||
import json
|
||||
import socket
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Network, IPv6Network, ip_address
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ClusterToken:
|
||||
id: int
|
||||
token: str
|
||||
description: str
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime]
|
||||
is_active: bool
|
||||
max_registrations: Optional[int]
|
||||
current_registrations: int
|
||||
allowed_ip_ranges: Optional[List[str]]
|
||||
|
||||
@dataclass
|
||||
class ClusterNode:
|
||||
id: int
|
||||
node_id: str
|
||||
hostname: str
|
||||
ip_address: str
|
||||
registration_token: str
|
||||
cpu_info: Optional[Dict[str, Any]]
|
||||
memory_info: Optional[Dict[str, Any]]
|
||||
gpu_info: Optional[Dict[str, Any]]
|
||||
disk_info: Optional[Dict[str, Any]]
|
||||
os_info: Optional[Dict[str, Any]]
|
||||
platform_info: Optional[Dict[str, Any]]
|
||||
status: str
|
||||
last_heartbeat: datetime
|
||||
first_registered: datetime
|
||||
services: Optional[Dict[str, Any]]
|
||||
capabilities: Optional[Dict[str, Any]]
|
||||
ports: Optional[Dict[str, Any]]
|
||||
client_version: Optional[str]
|
||||
registration_metadata: Optional[Dict[str, Any]]
|
||||
|
||||
@dataclass
|
||||
class RegistrationRequest:
|
||||
token: str
|
||||
node_id: str
|
||||
hostname: str
|
||||
ip_address: str
|
||||
system_info: Dict[str, Any]
|
||||
client_version: Optional[str] = None
|
||||
services: Optional[Dict[str, Any]] = None
|
||||
capabilities: Optional[Dict[str, Any]] = None
|
||||
ports: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@dataclass
|
||||
class HeartbeatRequest:
|
||||
node_id: str
|
||||
status: str = "online"
|
||||
cpu_usage: Optional[float] = None
|
||||
memory_usage: Optional[float] = None
|
||||
disk_usage: Optional[float] = None
|
||||
gpu_usage: Optional[float] = None
|
||||
services_status: Optional[Dict[str, Any]] = None
|
||||
network_metrics: Optional[Dict[str, Any]] = None
|
||||
custom_metrics: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ClusterRegistrationService:
|
||||
def __init__(self, database_url: str):
|
||||
self.database_url = database_url
|
||||
self._conn_cache = None
|
||||
|
||||
async def get_connection(self) -> asyncpg.Connection:
|
||||
"""Get database connection with caching."""
|
||||
if not self._conn_cache or self._conn_cache.is_closed():
|
||||
try:
|
||||
self._conn_cache = await asyncpg.connect(self.database_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to database: {e}")
|
||||
raise
|
||||
return self._conn_cache
|
||||
|
||||
async def close_connection(self):
|
||||
"""Close database connection."""
|
||||
if self._conn_cache and not self._conn_cache.is_closed():
|
||||
await self._conn_cache.close()
|
||||
|
||||
# Token Management
|
||||
async def generate_cluster_token(
|
||||
self,
|
||||
description: str,
|
||||
created_by_user_id: str,
|
||||
expires_in_days: Optional[int] = None,
|
||||
max_registrations: Optional[int] = None,
|
||||
allowed_ip_ranges: Optional[List[str]] = None
|
||||
) -> ClusterToken:
|
||||
"""Generate a new cluster registration token."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
# Generate secure token
|
||||
token = f"hive_cluster_{secrets.token_urlsafe(32)}"
|
||||
expires_at = datetime.now() + timedelta(days=expires_in_days) if expires_in_days else None
|
||||
|
||||
try:
|
||||
result = await conn.fetchrow("""
|
||||
INSERT INTO cluster_tokens (
|
||||
token, description, created_by, expires_at,
|
||||
max_registrations, allowed_ip_ranges
|
||||
) VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, token, description, created_at, expires_at,
|
||||
is_active, max_registrations, current_registrations, allowed_ip_ranges
|
||||
""", token, description, created_by_user_id, expires_at, max_registrations, allowed_ip_ranges)
|
||||
|
||||
return ClusterToken(**dict(result))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate cluster token: {e}")
|
||||
raise
|
||||
|
||||
async def validate_token(self, token: str, client_ip: str) -> Optional[ClusterToken]:
|
||||
"""Validate a cluster registration token."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
result = await conn.fetchrow("""
|
||||
SELECT id, token, description, created_at, expires_at,
|
||||
is_active, max_registrations, current_registrations, allowed_ip_ranges
|
||||
FROM cluster_tokens
|
||||
WHERE token = $1 AND is_active = true
|
||||
""", token)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
cluster_token = ClusterToken(**dict(result))
|
||||
|
||||
# Check expiration
|
||||
if cluster_token.expires_at and datetime.now() > cluster_token.expires_at:
|
||||
logger.warning(f"Token {token[:20]}... has expired")
|
||||
return None
|
||||
|
||||
# Check registration limit
|
||||
if (cluster_token.max_registrations and
|
||||
cluster_token.current_registrations >= cluster_token.max_registrations):
|
||||
logger.warning(f"Token {token[:20]}... has reached registration limit")
|
||||
return None
|
||||
|
||||
# Check IP restrictions
|
||||
if cluster_token.allowed_ip_ranges:
|
||||
client_ip_obj = ip_address(client_ip)
|
||||
allowed = False
|
||||
for ip_range in cluster_token.allowed_ip_ranges:
|
||||
try:
|
||||
network = IPv4Network(ip_range, strict=False) if ':' not in ip_range else IPv6Network(ip_range, strict=False)
|
||||
if client_ip_obj in network:
|
||||
allowed = True
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid IP range {ip_range}: {e}")
|
||||
|
||||
if not allowed:
|
||||
logger.warning(f"IP {client_ip} not allowed for token {token[:20]}...")
|
||||
return None
|
||||
|
||||
return cluster_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate token: {e}")
|
||||
return None
|
||||
|
||||
async def list_tokens(self) -> List[ClusterToken]:
|
||||
"""List all cluster tokens."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
results = await conn.fetch("""
|
||||
SELECT id, token, description, created_at, expires_at,
|
||||
is_active, max_registrations, current_registrations, allowed_ip_ranges
|
||||
FROM cluster_tokens
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
return [ClusterToken(**dict(result)) for result in results]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list tokens: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_token(self, token: str) -> bool:
|
||||
"""Revoke a cluster token."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
result = await conn.execute("""
|
||||
UPDATE cluster_tokens
|
||||
SET is_active = false
|
||||
WHERE token = $1
|
||||
""", token)
|
||||
|
||||
return result != "UPDATE 0"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke token: {e}")
|
||||
return False
|
||||
|
||||
# Node Registration
|
||||
async def register_node(self, request: RegistrationRequest, client_ip: str) -> Dict[str, Any]:
|
||||
"""Register a new cluster node."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
# Log registration attempt
|
||||
await self._log_registration_attempt(
|
||||
client_ip, request.token, request.node_id,
|
||||
request.hostname, True, None, request.metadata
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate token
|
||||
token_info = await self.validate_token(request.token, client_ip)
|
||||
if not token_info:
|
||||
await self._log_registration_attempt(
|
||||
client_ip, request.token, request.node_id,
|
||||
request.hostname, False, "Invalid or expired token", request.metadata
|
||||
)
|
||||
raise ValueError("Invalid or expired registration token")
|
||||
|
||||
# Extract system info components
|
||||
system_info = request.system_info or {}
|
||||
cpu_info = system_info.get('cpu', {})
|
||||
memory_info = system_info.get('memory', {})
|
||||
gpu_info = system_info.get('gpu', {})
|
||||
disk_info = system_info.get('disk', {})
|
||||
os_info = system_info.get('os', {})
|
||||
platform_info = system_info.get('platform', {})
|
||||
|
||||
# Register or update node
|
||||
result = await conn.fetchrow("""
|
||||
INSERT INTO cluster_nodes (
|
||||
node_id, hostname, ip_address, registration_token,
|
||||
cpu_info, memory_info, gpu_info, disk_info, os_info, platform_info,
|
||||
services, capabilities, ports, client_version, registration_metadata
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
|
||||
ON CONFLICT (node_id) DO UPDATE SET
|
||||
hostname = EXCLUDED.hostname,
|
||||
ip_address = EXCLUDED.ip_address,
|
||||
cpu_info = EXCLUDED.cpu_info,
|
||||
memory_info = EXCLUDED.memory_info,
|
||||
gpu_info = EXCLUDED.gpu_info,
|
||||
disk_info = EXCLUDED.disk_info,
|
||||
os_info = EXCLUDED.os_info,
|
||||
platform_info = EXCLUDED.platform_info,
|
||||
services = EXCLUDED.services,
|
||||
capabilities = EXCLUDED.capabilities,
|
||||
ports = EXCLUDED.ports,
|
||||
client_version = EXCLUDED.client_version,
|
||||
registration_metadata = EXCLUDED.registration_metadata,
|
||||
status = 'online',
|
||||
last_heartbeat = NOW()
|
||||
RETURNING id, node_id, hostname, ip_address, first_registered
|
||||
""",
|
||||
request.node_id, request.hostname, request.ip_address, request.token,
|
||||
json.dumps(cpu_info) if cpu_info else None,
|
||||
json.dumps(memory_info) if memory_info else None,
|
||||
json.dumps(gpu_info) if gpu_info else None,
|
||||
json.dumps(disk_info) if disk_info else None,
|
||||
json.dumps(os_info) if os_info else None,
|
||||
json.dumps(platform_info) if platform_info else None,
|
||||
json.dumps(request.services) if request.services else None,
|
||||
json.dumps(request.capabilities) if request.capabilities else None,
|
||||
json.dumps(request.ports) if request.ports else None,
|
||||
request.client_version,
|
||||
json.dumps(request.metadata) if request.metadata else None
|
||||
)
|
||||
|
||||
logger.info(f"Node {request.node_id} registered successfully from {client_ip}")
|
||||
|
||||
return {
|
||||
"node_id": result["node_id"],
|
||||
"registration_status": "success",
|
||||
"heartbeat_interval": 30, # seconds
|
||||
"registered_at": result["first_registered"].isoformat(),
|
||||
"cluster_info": {
|
||||
"coordinator_version": "1.0.0",
|
||||
"features": ["heartbeat", "dynamic_scaling", "service_discovery"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register node {request.node_id}: {e}")
|
||||
await self._log_registration_attempt(
|
||||
client_ip, request.token, request.node_id,
|
||||
request.hostname, False, str(e), request.metadata
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_heartbeat(self, request: HeartbeatRequest) -> Dict[str, Any]:
|
||||
"""Update node heartbeat and metrics."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
# Update node status and heartbeat
|
||||
result = await conn.fetchrow("""
|
||||
UPDATE cluster_nodes
|
||||
SET status = $2, last_heartbeat = NOW()
|
||||
WHERE node_id = $1
|
||||
RETURNING node_id, status, last_heartbeat
|
||||
""", request.node_id, request.status)
|
||||
|
||||
if not result:
|
||||
raise ValueError(f"Node {request.node_id} not found")
|
||||
|
||||
# Record heartbeat metrics
|
||||
await conn.execute("""
|
||||
INSERT INTO node_heartbeats (
|
||||
node_id, cpu_usage, memory_usage, disk_usage, gpu_usage,
|
||||
services_status, network_metrics, custom_metrics
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""",
|
||||
request.node_id, request.cpu_usage, request.memory_usage,
|
||||
request.disk_usage, request.gpu_usage,
|
||||
json.dumps(request.services_status) if request.services_status else None,
|
||||
json.dumps(request.network_metrics) if request.network_metrics else None,
|
||||
json.dumps(request.custom_metrics) if request.custom_metrics else None
|
||||
)
|
||||
|
||||
return {
|
||||
"node_id": result["node_id"],
|
||||
"status": result["status"],
|
||||
"heartbeat_received": result["last_heartbeat"].isoformat(),
|
||||
"next_heartbeat_in": 30, # seconds
|
||||
"commands": [] # Future: cluster management commands
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update heartbeat for {request.node_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_registered_nodes(self, include_offline: bool = True) -> List[ClusterNode]:
|
||||
"""Get all registered cluster nodes."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
query = """
|
||||
SELECT id, node_id, hostname, ip_address, registration_token,
|
||||
cpu_info, memory_info, gpu_info, disk_info, os_info, platform_info,
|
||||
status, last_heartbeat, first_registered,
|
||||
services, capabilities, ports, client_version, registration_metadata
|
||||
FROM cluster_nodes
|
||||
"""
|
||||
|
||||
if not include_offline:
|
||||
query += " WHERE status != 'offline'"
|
||||
|
||||
query += " ORDER BY first_registered DESC"
|
||||
|
||||
results = await conn.fetch(query)
|
||||
|
||||
nodes = []
|
||||
for result in results:
|
||||
node_dict = dict(result)
|
||||
# Parse JSON fields
|
||||
for json_field in ['cpu_info', 'memory_info', 'gpu_info', 'disk_info',
|
||||
'os_info', 'platform_info', 'services', 'capabilities',
|
||||
'ports', 'registration_metadata']:
|
||||
if node_dict[json_field]:
|
||||
try:
|
||||
node_dict[json_field] = json.loads(node_dict[json_field])
|
||||
except json.JSONDecodeError:
|
||||
node_dict[json_field] = None
|
||||
|
||||
nodes.append(ClusterNode(**node_dict))
|
||||
|
||||
return nodes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get registered nodes: {e}")
|
||||
raise
|
||||
|
||||
async def get_node_details(self, node_id: str) -> Optional[ClusterNode]:
|
||||
"""Get detailed information about a specific node."""
|
||||
nodes = await self.get_registered_nodes()
|
||||
return next((node for node in nodes if node.node_id == node_id), None)
|
||||
|
||||
async def remove_node(self, node_id: str) -> bool:
|
||||
"""Remove a node from the cluster."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM cluster_nodes WHERE node_id = $1
|
||||
""", node_id)
|
||||
|
||||
if result != "DELETE 0":
|
||||
logger.info(f"Node {node_id} removed from cluster")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove node {node_id}: {e}")
|
||||
return False
|
||||
|
||||
# Maintenance and Monitoring
|
||||
async def cleanup_offline_nodes(self, offline_threshold_minutes: int = 10) -> int:
|
||||
"""Mark nodes as offline if they haven't sent heartbeats."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
result = await conn.execute("""
|
||||
UPDATE cluster_nodes
|
||||
SET status = 'offline'
|
||||
WHERE status = 'online'
|
||||
AND last_heartbeat < NOW() - INTERVAL '%s minutes'
|
||||
""" % offline_threshold_minutes)
|
||||
|
||||
# Extract number from result like "UPDATE 3"
|
||||
count = int(result.split()[-1]) if result.split()[-1].isdigit() else 0
|
||||
if count > 0:
|
||||
logger.info(f"Marked {count} nodes as offline due to missing heartbeats")
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup offline nodes: {e}")
|
||||
return 0
|
||||
|
||||
async def cleanup_old_heartbeats(self, retention_days: int = 30) -> int:
|
||||
"""Remove old heartbeat data for storage management."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM node_heartbeats
|
||||
WHERE heartbeat_time < NOW() - INTERVAL '%s days'
|
||||
""" % retention_days)
|
||||
|
||||
count = int(result.split()[-1]) if result.split()[-1].isdigit() else 0
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} old heartbeat records")
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old heartbeats: {e}")
|
||||
return 0
|
||||
|
||||
async def _log_registration_attempt(
|
||||
self,
|
||||
ip_address: str,
|
||||
token: str,
|
||||
node_id: str,
|
||||
hostname: str,
|
||||
success: bool,
|
||||
failure_reason: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log registration attempts for security monitoring."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
await conn.execute("""
|
||||
INSERT INTO node_registration_attempts (
|
||||
ip_address, token_used, node_id, hostname,
|
||||
success, failure_reason, request_metadata
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""", ip_address, token, node_id, hostname, success, failure_reason,
|
||||
json.dumps(metadata) if metadata else None)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log registration attempt: {e}")
|
||||
|
||||
async def get_cluster_statistics(self) -> Dict[str, Any]:
|
||||
"""Get cluster statistics and health metrics."""
|
||||
conn = await self.get_connection()
|
||||
|
||||
try:
|
||||
# Node statistics
|
||||
node_stats = await conn.fetchrow("""
|
||||
SELECT
|
||||
COUNT(*) as total_nodes,
|
||||
COUNT(*) FILTER (WHERE status = 'online') as online_nodes,
|
||||
COUNT(*) FILTER (WHERE status = 'offline') as offline_nodes,
|
||||
COUNT(*) FILTER (WHERE status = 'maintenance') as maintenance_nodes
|
||||
FROM cluster_nodes
|
||||
""")
|
||||
|
||||
# Token statistics
|
||||
token_stats = await conn.fetchrow("""
|
||||
SELECT
|
||||
COUNT(*) as total_tokens,
|
||||
COUNT(*) FILTER (WHERE is_active = true) as active_tokens,
|
||||
COUNT(*) FILTER (WHERE expires_at IS NOT NULL AND expires_at < NOW()) as expired_tokens
|
||||
FROM cluster_tokens
|
||||
""")
|
||||
|
||||
return {
|
||||
"cluster_health": {
|
||||
"total_nodes": node_stats["total_nodes"],
|
||||
"online_nodes": node_stats["online_nodes"],
|
||||
"offline_nodes": node_stats["offline_nodes"],
|
||||
"maintenance_nodes": node_stats["maintenance_nodes"],
|
||||
"health_percentage": (node_stats["online_nodes"] / max(node_stats["total_nodes"], 1)) * 100
|
||||
},
|
||||
"token_management": {
|
||||
"total_tokens": token_stats["total_tokens"],
|
||||
"active_tokens": token_stats["active_tokens"],
|
||||
"expired_tokens": token_stats["expired_tokens"]
|
||||
},
|
||||
"last_updated": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cluster statistics: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"last_updated": datetime.now().isoformat()
|
||||
}
|
||||
Reference in New Issue
Block a user