- Agent roles integration progress - Various backend and frontend updates - Storybook cache cleanup 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
522 lines
21 KiB
Python
522 lines
21 KiB
Python
"""
|
|
Cluster Registration Service
|
|
Handles registration-based cluster management for Hive-Bzzz integration.
|
|
"""
|
|
import asyncpg
|
|
import secrets
|
|
import json
|
|
import socket
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass
|
|
from ipaddress import IPv4Network, IPv6Network, ip_address
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class ClusterToken:
|
|
id: int
|
|
token: str
|
|
description: str
|
|
created_at: datetime
|
|
expires_at: Optional[datetime]
|
|
is_active: bool
|
|
max_registrations: Optional[int]
|
|
current_registrations: int
|
|
allowed_ip_ranges: Optional[List[str]]
|
|
|
|
@dataclass
|
|
class ClusterNode:
|
|
id: int
|
|
node_id: str
|
|
hostname: str
|
|
ip_address: str
|
|
registration_token: str
|
|
cpu_info: Optional[Dict[str, Any]]
|
|
memory_info: Optional[Dict[str, Any]]
|
|
gpu_info: Optional[Dict[str, Any]]
|
|
disk_info: Optional[Dict[str, Any]]
|
|
os_info: Optional[Dict[str, Any]]
|
|
platform_info: Optional[Dict[str, Any]]
|
|
status: str
|
|
last_heartbeat: datetime
|
|
first_registered: datetime
|
|
services: Optional[Dict[str, Any]]
|
|
capabilities: Optional[Dict[str, Any]]
|
|
ports: Optional[Dict[str, Any]]
|
|
client_version: Optional[str]
|
|
registration_metadata: Optional[Dict[str, Any]]
|
|
|
|
@dataclass
|
|
class RegistrationRequest:
|
|
token: str
|
|
node_id: str
|
|
hostname: str
|
|
ip_address: str
|
|
system_info: Dict[str, Any]
|
|
client_version: Optional[str] = None
|
|
services: Optional[Dict[str, Any]] = None
|
|
capabilities: Optional[Dict[str, Any]] = None
|
|
ports: Optional[Dict[str, Any]] = None
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
@dataclass
|
|
class HeartbeatRequest:
|
|
node_id: str
|
|
status: str = "online"
|
|
cpu_usage: Optional[float] = None
|
|
memory_usage: Optional[float] = None
|
|
disk_usage: Optional[float] = None
|
|
gpu_usage: Optional[float] = None
|
|
services_status: Optional[Dict[str, Any]] = None
|
|
network_metrics: Optional[Dict[str, Any]] = None
|
|
custom_metrics: Optional[Dict[str, Any]] = None
|
|
|
|
class ClusterRegistrationService:
|
|
def __init__(self, database_url: str):
|
|
self.database_url = database_url
|
|
self._conn_cache = None
|
|
|
|
async def get_connection(self) -> asyncpg.Connection:
|
|
"""Get database connection with caching."""
|
|
if not self._conn_cache or self._conn_cache.is_closed():
|
|
try:
|
|
self._conn_cache = await asyncpg.connect(self.database_url)
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to database: {e}")
|
|
raise
|
|
return self._conn_cache
|
|
|
|
async def close_connection(self):
|
|
"""Close database connection."""
|
|
if self._conn_cache and not self._conn_cache.is_closed():
|
|
await self._conn_cache.close()
|
|
|
|
# Token Management
|
|
async def generate_cluster_token(
|
|
self,
|
|
description: str,
|
|
created_by_user_id: str,
|
|
expires_in_days: Optional[int] = None,
|
|
max_registrations: Optional[int] = None,
|
|
allowed_ip_ranges: Optional[List[str]] = None
|
|
) -> ClusterToken:
|
|
"""Generate a new cluster registration token."""
|
|
conn = await self.get_connection()
|
|
|
|
# Generate secure token
|
|
token = f"hive_cluster_{secrets.token_urlsafe(32)}"
|
|
expires_at = datetime.now() + timedelta(days=expires_in_days) if expires_in_days else None
|
|
|
|
try:
|
|
result = await conn.fetchrow("""
|
|
INSERT INTO cluster_tokens (
|
|
token, description, created_by, expires_at,
|
|
max_registrations, allowed_ip_ranges
|
|
) VALUES ($1, $2, $3, $4, $5, $6)
|
|
RETURNING id, token, description, created_at, expires_at,
|
|
is_active, max_registrations, current_registrations, allowed_ip_ranges
|
|
""", token, description, created_by_user_id, expires_at, max_registrations, allowed_ip_ranges)
|
|
|
|
return ClusterToken(**dict(result))
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate cluster token: {e}")
|
|
raise
|
|
|
|
async def validate_token(self, token: str, client_ip: str) -> Optional[ClusterToken]:
|
|
"""Validate a cluster registration token."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
result = await conn.fetchrow("""
|
|
SELECT id, token, description, created_at, expires_at,
|
|
is_active, max_registrations, current_registrations, allowed_ip_ranges
|
|
FROM cluster_tokens
|
|
WHERE token = $1 AND is_active = true
|
|
""", token)
|
|
|
|
if not result:
|
|
return None
|
|
|
|
cluster_token = ClusterToken(**dict(result))
|
|
|
|
# Check expiration
|
|
if cluster_token.expires_at and datetime.now() > cluster_token.expires_at:
|
|
logger.warning(f"Token {token[:20]}... has expired")
|
|
return None
|
|
|
|
# Check registration limit
|
|
if (cluster_token.max_registrations and
|
|
cluster_token.current_registrations >= cluster_token.max_registrations):
|
|
logger.warning(f"Token {token[:20]}... has reached registration limit")
|
|
return None
|
|
|
|
# Check IP restrictions
|
|
if cluster_token.allowed_ip_ranges:
|
|
client_ip_obj = ip_address(client_ip)
|
|
allowed = False
|
|
for ip_range in cluster_token.allowed_ip_ranges:
|
|
try:
|
|
network = IPv4Network(ip_range, strict=False) if ':' not in ip_range else IPv6Network(ip_range, strict=False)
|
|
if client_ip_obj in network:
|
|
allowed = True
|
|
break
|
|
except Exception as e:
|
|
logger.warning(f"Invalid IP range {ip_range}: {e}")
|
|
|
|
if not allowed:
|
|
logger.warning(f"IP {client_ip} not allowed for token {token[:20]}...")
|
|
return None
|
|
|
|
return cluster_token
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to validate token: {e}")
|
|
return None
|
|
|
|
async def list_tokens(self) -> List[ClusterToken]:
|
|
"""List all cluster tokens."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
results = await conn.fetch("""
|
|
SELECT id, token, description, created_at, expires_at,
|
|
is_active, max_registrations, current_registrations, allowed_ip_ranges
|
|
FROM cluster_tokens
|
|
ORDER BY created_at DESC
|
|
""")
|
|
|
|
return [ClusterToken(**dict(result)) for result in results]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to list tokens: {e}")
|
|
raise
|
|
|
|
async def revoke_token(self, token: str) -> bool:
|
|
"""Revoke a cluster token."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
result = await conn.execute("""
|
|
UPDATE cluster_tokens
|
|
SET is_active = false
|
|
WHERE token = $1
|
|
""", token)
|
|
|
|
return result != "UPDATE 0"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to revoke token: {e}")
|
|
return False
|
|
|
|
# Node Registration
|
|
async def register_node(self, request: RegistrationRequest, client_ip: str) -> Dict[str, Any]:
|
|
"""Register a new cluster node."""
|
|
conn = await self.get_connection()
|
|
|
|
# Log registration attempt
|
|
await self._log_registration_attempt(
|
|
client_ip, request.token, request.node_id,
|
|
request.hostname, True, None, request.metadata
|
|
)
|
|
|
|
try:
|
|
# Validate token
|
|
token_info = await self.validate_token(request.token, client_ip)
|
|
if not token_info:
|
|
await self._log_registration_attempt(
|
|
client_ip, request.token, request.node_id,
|
|
request.hostname, False, "Invalid or expired token", request.metadata
|
|
)
|
|
raise ValueError("Invalid or expired registration token")
|
|
|
|
# Extract system info components
|
|
system_info = request.system_info or {}
|
|
cpu_info = system_info.get('cpu', {})
|
|
memory_info = system_info.get('memory', {})
|
|
gpu_info = system_info.get('gpu', {})
|
|
disk_info = system_info.get('disk', {})
|
|
os_info = system_info.get('os', {})
|
|
platform_info = system_info.get('platform', {})
|
|
|
|
# Register or update node
|
|
result = await conn.fetchrow("""
|
|
INSERT INTO cluster_nodes (
|
|
node_id, hostname, ip_address, registration_token,
|
|
cpu_info, memory_info, gpu_info, disk_info, os_info, platform_info,
|
|
services, capabilities, ports, client_version, registration_metadata
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
|
|
ON CONFLICT (node_id) DO UPDATE SET
|
|
hostname = EXCLUDED.hostname,
|
|
ip_address = EXCLUDED.ip_address,
|
|
cpu_info = EXCLUDED.cpu_info,
|
|
memory_info = EXCLUDED.memory_info,
|
|
gpu_info = EXCLUDED.gpu_info,
|
|
disk_info = EXCLUDED.disk_info,
|
|
os_info = EXCLUDED.os_info,
|
|
platform_info = EXCLUDED.platform_info,
|
|
services = EXCLUDED.services,
|
|
capabilities = EXCLUDED.capabilities,
|
|
ports = EXCLUDED.ports,
|
|
client_version = EXCLUDED.client_version,
|
|
registration_metadata = EXCLUDED.registration_metadata,
|
|
status = 'online',
|
|
last_heartbeat = NOW()
|
|
RETURNING id, node_id, hostname, ip_address, first_registered
|
|
""",
|
|
request.node_id, request.hostname, request.ip_address, request.token,
|
|
json.dumps(cpu_info) if cpu_info else None,
|
|
json.dumps(memory_info) if memory_info else None,
|
|
json.dumps(gpu_info) if gpu_info else None,
|
|
json.dumps(disk_info) if disk_info else None,
|
|
json.dumps(os_info) if os_info else None,
|
|
json.dumps(platform_info) if platform_info else None,
|
|
json.dumps(request.services) if request.services else None,
|
|
json.dumps(request.capabilities) if request.capabilities else None,
|
|
json.dumps(request.ports) if request.ports else None,
|
|
request.client_version,
|
|
json.dumps(request.metadata) if request.metadata else None
|
|
)
|
|
|
|
logger.info(f"Node {request.node_id} registered successfully from {client_ip}")
|
|
|
|
return {
|
|
"node_id": result["node_id"],
|
|
"registration_status": "success",
|
|
"heartbeat_interval": 30, # seconds
|
|
"registered_at": result["first_registered"].isoformat(),
|
|
"cluster_info": {
|
|
"coordinator_version": "1.0.0",
|
|
"features": ["heartbeat", "dynamic_scaling", "service_discovery"]
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to register node {request.node_id}: {e}")
|
|
await self._log_registration_attempt(
|
|
client_ip, request.token, request.node_id,
|
|
request.hostname, False, str(e), request.metadata
|
|
)
|
|
raise
|
|
|
|
async def update_heartbeat(self, request: HeartbeatRequest) -> Dict[str, Any]:
|
|
"""Update node heartbeat and metrics."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
# Update node status and heartbeat
|
|
result = await conn.fetchrow("""
|
|
UPDATE cluster_nodes
|
|
SET status = $2, last_heartbeat = NOW()
|
|
WHERE node_id = $1
|
|
RETURNING node_id, status, last_heartbeat
|
|
""", request.node_id, request.status)
|
|
|
|
if not result:
|
|
raise ValueError(f"Node {request.node_id} not found")
|
|
|
|
# Record heartbeat metrics
|
|
await conn.execute("""
|
|
INSERT INTO node_heartbeats (
|
|
node_id, cpu_usage, memory_usage, disk_usage, gpu_usage,
|
|
services_status, network_metrics, custom_metrics
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
""",
|
|
request.node_id, request.cpu_usage, request.memory_usage,
|
|
request.disk_usage, request.gpu_usage,
|
|
json.dumps(request.services_status) if request.services_status else None,
|
|
json.dumps(request.network_metrics) if request.network_metrics else None,
|
|
json.dumps(request.custom_metrics) if request.custom_metrics else None
|
|
)
|
|
|
|
return {
|
|
"node_id": result["node_id"],
|
|
"status": result["status"],
|
|
"heartbeat_received": result["last_heartbeat"].isoformat(),
|
|
"next_heartbeat_in": 30, # seconds
|
|
"commands": [] # Future: cluster management commands
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update heartbeat for {request.node_id}: {e}")
|
|
raise
|
|
|
|
async def get_registered_nodes(self, include_offline: bool = True) -> List[ClusterNode]:
|
|
"""Get all registered cluster nodes."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
query = """
|
|
SELECT id, node_id, hostname, ip_address, registration_token,
|
|
cpu_info, memory_info, gpu_info, disk_info, os_info, platform_info,
|
|
status, last_heartbeat, first_registered,
|
|
services, capabilities, ports, client_version, registration_metadata
|
|
FROM cluster_nodes
|
|
"""
|
|
|
|
if not include_offline:
|
|
query += " WHERE status != 'offline'"
|
|
|
|
query += " ORDER BY first_registered DESC"
|
|
|
|
results = await conn.fetch(query)
|
|
|
|
nodes = []
|
|
for result in results:
|
|
node_dict = dict(result)
|
|
# Parse JSON fields
|
|
for json_field in ['cpu_info', 'memory_info', 'gpu_info', 'disk_info',
|
|
'os_info', 'platform_info', 'services', 'capabilities',
|
|
'ports', 'registration_metadata']:
|
|
if node_dict[json_field]:
|
|
try:
|
|
node_dict[json_field] = json.loads(node_dict[json_field])
|
|
except json.JSONDecodeError:
|
|
node_dict[json_field] = None
|
|
|
|
nodes.append(ClusterNode(**node_dict))
|
|
|
|
return nodes
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get registered nodes: {e}")
|
|
raise
|
|
|
|
async def get_node_details(self, node_id: str) -> Optional[ClusterNode]:
|
|
"""Get detailed information about a specific node."""
|
|
nodes = await self.get_registered_nodes()
|
|
return next((node for node in nodes if node.node_id == node_id), None)
|
|
|
|
async def remove_node(self, node_id: str) -> bool:
|
|
"""Remove a node from the cluster."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
result = await conn.execute("""
|
|
DELETE FROM cluster_nodes WHERE node_id = $1
|
|
""", node_id)
|
|
|
|
if result != "DELETE 0":
|
|
logger.info(f"Node {node_id} removed from cluster")
|
|
return True
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to remove node {node_id}: {e}")
|
|
return False
|
|
|
|
# Maintenance and Monitoring
|
|
async def cleanup_offline_nodes(self, offline_threshold_minutes: int = 10) -> int:
|
|
"""Mark nodes as offline if they haven't sent heartbeats."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
result = await conn.execute("""
|
|
UPDATE cluster_nodes
|
|
SET status = 'offline'
|
|
WHERE status = 'online'
|
|
AND last_heartbeat < NOW() - INTERVAL '%s minutes'
|
|
""" % offline_threshold_minutes)
|
|
|
|
# Extract number from result like "UPDATE 3"
|
|
count = int(result.split()[-1]) if result.split()[-1].isdigit() else 0
|
|
if count > 0:
|
|
logger.info(f"Marked {count} nodes as offline due to missing heartbeats")
|
|
|
|
return count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to cleanup offline nodes: {e}")
|
|
return 0
|
|
|
|
async def cleanup_old_heartbeats(self, retention_days: int = 30) -> int:
|
|
"""Remove old heartbeat data for storage management."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
result = await conn.execute("""
|
|
DELETE FROM node_heartbeats
|
|
WHERE heartbeat_time < NOW() - INTERVAL '%s days'
|
|
""" % retention_days)
|
|
|
|
count = int(result.split()[-1]) if result.split()[-1].isdigit() else 0
|
|
if count > 0:
|
|
logger.info(f"Cleaned up {count} old heartbeat records")
|
|
|
|
return count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to cleanup old heartbeats: {e}")
|
|
return 0
|
|
|
|
async def _log_registration_attempt(
|
|
self,
|
|
ip_address: str,
|
|
token: str,
|
|
node_id: str,
|
|
hostname: str,
|
|
success: bool,
|
|
failure_reason: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
):
|
|
"""Log registration attempts for security monitoring."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
await conn.execute("""
|
|
INSERT INTO node_registration_attempts (
|
|
ip_address, token_used, node_id, hostname,
|
|
success, failure_reason, request_metadata
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
""", ip_address, token, node_id, hostname, success, failure_reason,
|
|
json.dumps(metadata) if metadata else None)
|
|
except Exception as e:
|
|
logger.error(f"Failed to log registration attempt: {e}")
|
|
|
|
async def get_cluster_statistics(self) -> Dict[str, Any]:
|
|
"""Get cluster statistics and health metrics."""
|
|
conn = await self.get_connection()
|
|
|
|
try:
|
|
# Node statistics
|
|
node_stats = await conn.fetchrow("""
|
|
SELECT
|
|
COUNT(*) as total_nodes,
|
|
COUNT(*) FILTER (WHERE status = 'online') as online_nodes,
|
|
COUNT(*) FILTER (WHERE status = 'offline') as offline_nodes,
|
|
COUNT(*) FILTER (WHERE status = 'maintenance') as maintenance_nodes
|
|
FROM cluster_nodes
|
|
""")
|
|
|
|
# Token statistics
|
|
token_stats = await conn.fetchrow("""
|
|
SELECT
|
|
COUNT(*) as total_tokens,
|
|
COUNT(*) FILTER (WHERE is_active = true) as active_tokens,
|
|
COUNT(*) FILTER (WHERE expires_at IS NOT NULL AND expires_at < NOW()) as expired_tokens
|
|
FROM cluster_tokens
|
|
""")
|
|
|
|
return {
|
|
"cluster_health": {
|
|
"total_nodes": node_stats["total_nodes"],
|
|
"online_nodes": node_stats["online_nodes"],
|
|
"offline_nodes": node_stats["offline_nodes"],
|
|
"maintenance_nodes": node_stats["maintenance_nodes"],
|
|
"health_percentage": (node_stats["online_nodes"] / max(node_stats["total_nodes"], 1)) * 100
|
|
},
|
|
"token_management": {
|
|
"total_tokens": token_stats["total_tokens"],
|
|
"active_tokens": token_stats["active_tokens"],
|
|
"expired_tokens": token_stats["expired_tokens"]
|
|
},
|
|
"last_updated": datetime.now().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get cluster statistics: {e}")
|
|
return {
|
|
"error": str(e),
|
|
"last_updated": datetime.now().isoformat()
|
|
} |