WIP: Save current work before CHORUS rebrand
- Agent roles integration progress - Various backend and frontend updates - Storybook cache cleanup 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -15,6 +15,8 @@ Key Features:
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends, status
|
||||
from typing import List, Dict, Any
|
||||
import time
|
||||
import logging
|
||||
from ..models.agent import Agent
|
||||
from ..models.responses import (
|
||||
AgentListResponse,
|
||||
@@ -29,6 +31,9 @@ router = APIRouter()
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.models.agent import Agent as ORMAgent
|
||||
from ..services.agent_service import AgentType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -384,4 +389,244 @@ async def unregister_agent(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to unregister agent: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/heartbeat",
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Agent heartbeat update",
|
||||
description="""
|
||||
Update agent status and maintain registration through periodic heartbeat.
|
||||
|
||||
This endpoint allows agents to:
|
||||
- Confirm they are still online and responsive
|
||||
- Update their current status and metrics
|
||||
- Report any capability or configuration changes
|
||||
- Maintain their registration in the cluster
|
||||
|
||||
Agents should call this endpoint every 30-60 seconds to maintain
|
||||
their active status in the Hive cluster.
|
||||
""",
|
||||
responses={
|
||||
200: {"description": "Heartbeat received successfully"},
|
||||
404: {"model": ErrorResponse, "description": "Agent not registered"},
|
||||
400: {"model": ErrorResponse, "description": "Invalid heartbeat data"}
|
||||
}
|
||||
)
|
||||
async def agent_heartbeat(
|
||||
heartbeat_data: Dict[str, Any],
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
Process agent heartbeat to maintain registration.
|
||||
|
||||
Args:
|
||||
heartbeat_data: Agent status and metrics data
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Success confirmation and any coordinator updates
|
||||
"""
|
||||
agent_id = heartbeat_data.get("agent_id")
|
||||
if not agent_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing agent_id in heartbeat data"
|
||||
)
|
||||
|
||||
# Access coordinator
|
||||
hive_coordinator = getattr(request.app.state, 'hive_coordinator', None)
|
||||
if not hive_coordinator:
|
||||
from ..main import unified_coordinator
|
||||
hive_coordinator = unified_coordinator
|
||||
|
||||
if not hive_coordinator:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Coordinator service unavailable"
|
||||
)
|
||||
|
||||
try:
|
||||
# Update agent heartbeat timestamp
|
||||
agent_service = hive_coordinator.agent_service
|
||||
if agent_service:
|
||||
agent_service.update_agent_heartbeat(agent_id)
|
||||
|
||||
# Update current tasks if provided - use raw SQL to avoid role column
|
||||
if "current_tasks" in heartbeat_data:
|
||||
current_tasks = heartbeat_data["current_tasks"]
|
||||
try:
|
||||
with SessionLocal() as db:
|
||||
from sqlalchemy import text
|
||||
db.execute(text(
|
||||
"UPDATE agents SET current_tasks = :current_tasks, last_seen = NOW() WHERE id = :agent_id"
|
||||
), {
|
||||
"current_tasks": current_tasks,
|
||||
"agent_id": agent_id
|
||||
})
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not update agent tasks: {e}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Heartbeat received from agent '{agent_id}'",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to process heartbeat: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/auto-register",
|
||||
response_model=AgentRegistrationResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Automatic agent registration",
|
||||
description="""
|
||||
Register an agent automatically with capability detection.
|
||||
|
||||
This endpoint is designed for Bzzz agents running as systemd services
|
||||
to automatically register themselves with the Hive coordinator.
|
||||
|
||||
Features:
|
||||
- Automatic capability detection based on available models
|
||||
- Network discovery support
|
||||
- Retry-friendly for service startup scenarios
|
||||
- Health validation before registration
|
||||
""",
|
||||
responses={
|
||||
201: {"description": "Agent auto-registered successfully"},
|
||||
400: {"model": ErrorResponse, "description": "Invalid agent configuration"},
|
||||
409: {"model": ErrorResponse, "description": "Agent already registered"},
|
||||
503: {"model": ErrorResponse, "description": "Agent endpoint unreachable"}
|
||||
}
|
||||
)
|
||||
async def auto_register_agent(
|
||||
agent_data: Dict[str, Any],
|
||||
request: Request
|
||||
) -> AgentRegistrationResponse:
|
||||
"""
|
||||
Automatically register a Bzzz agent with the Hive coordinator.
|
||||
|
||||
Args:
|
||||
agent_data: Agent configuration including endpoint, models, etc.
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
AgentRegistrationResponse: Registration confirmation
|
||||
"""
|
||||
# Extract required fields
|
||||
agent_id = agent_data.get("agent_id")
|
||||
endpoint = agent_data.get("endpoint")
|
||||
hostname = agent_data.get("hostname")
|
||||
|
||||
if not agent_id or not endpoint:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing required fields: agent_id, endpoint"
|
||||
)
|
||||
|
||||
# Access coordinator
|
||||
hive_coordinator = getattr(request.app.state, 'hive_coordinator', None)
|
||||
if not hive_coordinator:
|
||||
from ..main import unified_coordinator
|
||||
hive_coordinator = unified_coordinator
|
||||
|
||||
if not hive_coordinator:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Coordinator service unavailable"
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if agent already exists - use basic query to avoid role column
|
||||
try:
|
||||
with SessionLocal() as db:
|
||||
from sqlalchemy import text
|
||||
existing_agent = db.execute(text(
|
||||
"SELECT id, endpoint FROM agents WHERE id = :agent_id LIMIT 1"
|
||||
), {"agent_id": agent_id}).fetchone()
|
||||
if existing_agent:
|
||||
# Update existing agent
|
||||
db.execute(text(
|
||||
"UPDATE agents SET endpoint = :endpoint, last_seen = NOW() WHERE id = :agent_id"
|
||||
), {"endpoint": endpoint, "agent_id": agent_id})
|
||||
db.commit()
|
||||
|
||||
return AgentRegistrationResponse(
|
||||
agent_id=agent_id,
|
||||
endpoint=endpoint,
|
||||
message=f"Agent '{agent_id}' registration updated successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check existing agent: {e}")
|
||||
|
||||
# Detect capabilities and models
|
||||
models = agent_data.get("models", [])
|
||||
if not models:
|
||||
# Try to detect models from endpoint
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{endpoint}/api/tags", timeout=aiohttp.ClientTimeout(total=5)) as response:
|
||||
if response.status == 200:
|
||||
tags_data = await response.json()
|
||||
models = [model["name"] for model in tags_data.get("models", [])]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not detect models for {agent_id}: {e}")
|
||||
|
||||
# Determine specialty based on models or hostname
|
||||
specialty = AgentType.GENERAL_AI # Default
|
||||
if "codellama" in str(models).lower() or "code" in hostname.lower():
|
||||
specialty = AgentType.KERNEL_DEV
|
||||
elif "gemma" in str(models).lower():
|
||||
specialty = AgentType.PYTORCH_DEV
|
||||
elif any(model for model in models if "llama" in model.lower()):
|
||||
specialty = AgentType.GENERAL_AI
|
||||
|
||||
# Insert agent directly into database
|
||||
try:
|
||||
with SessionLocal() as db:
|
||||
from sqlalchemy import text
|
||||
# Insert new agent using raw SQL to avoid role column issues
|
||||
db.execute(text("""
|
||||
INSERT INTO agents (id, name, endpoint, model, specialty, max_concurrent, current_tasks, status, created_at, last_seen)
|
||||
VALUES (:agent_id, :name, :endpoint, :model, :specialty, :max_concurrent, 0, 'active', NOW(), NOW())
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
endpoint = EXCLUDED.endpoint,
|
||||
model = EXCLUDED.model,
|
||||
specialty = EXCLUDED.specialty,
|
||||
max_concurrent = EXCLUDED.max_concurrent,
|
||||
last_seen = NOW()
|
||||
"""), {
|
||||
"agent_id": agent_id,
|
||||
"name": agent_id, # Use agent_id as name
|
||||
"endpoint": endpoint,
|
||||
"model": models[0] if models else "unknown",
|
||||
"specialty": specialty.value,
|
||||
"max_concurrent": agent_data.get("max_concurrent", 2)
|
||||
})
|
||||
db.commit()
|
||||
|
||||
return AgentRegistrationResponse(
|
||||
agent_id=agent_id,
|
||||
endpoint=endpoint,
|
||||
message=f"Agent '{agent_id}' auto-registered successfully with specialty '{specialty.value}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Database insert failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to register agent in database: {str(e)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to auto-register agent: {str(e)}"
|
||||
)
|
||||
287
backend/app/api/bzzz_logs.py
Normal file
287
backend/app/api/bzzz_logs.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Bzzz hypercore/hyperswarm log streaming API endpoints.
|
||||
Provides real-time access to agent communication logs from the Bzzz network.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List, Optional, Dict, Any
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import httpx
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Keep track of active WebSocket connections
|
||||
active_connections: List[WebSocket] = []
|
||||
|
||||
class BzzzLogEntry:
|
||||
"""Represents a Bzzz hypercore log entry"""
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.index = data.get("index", 0)
|
||||
self.timestamp = data.get("timestamp", "")
|
||||
self.author = data.get("author", "")
|
||||
self.log_type = data.get("type", "")
|
||||
self.message_data = data.get("data", {})
|
||||
self.hash_value = data.get("hash", "")
|
||||
self.prev_hash = data.get("prev_hash", "")
|
||||
|
||||
def to_chat_message(self) -> Dict[str, Any]:
|
||||
"""Convert hypercore log entry to chat message format"""
|
||||
# Extract message details from the log data
|
||||
msg_data = self.message_data
|
||||
|
||||
return {
|
||||
"id": f"log-{self.index}",
|
||||
"senderId": msg_data.get("from_short", self.author),
|
||||
"senderName": msg_data.get("from_short", self.author),
|
||||
"content": self._format_message_content(),
|
||||
"timestamp": self.timestamp,
|
||||
"messageType": self._determine_message_type(),
|
||||
"channel": msg_data.get("topic", "unknown"),
|
||||
"swarmId": f"swarm-{msg_data.get('topic', 'unknown')}",
|
||||
"isDelivered": True,
|
||||
"isRead": True,
|
||||
"logType": self.log_type,
|
||||
"hash": self.hash_value
|
||||
}
|
||||
|
||||
def _format_message_content(self) -> str:
|
||||
"""Format the log entry into a readable message"""
|
||||
msg_data = self.message_data
|
||||
message_type = msg_data.get("message_type", self.log_type)
|
||||
|
||||
if message_type == "availability_broadcast":
|
||||
status = msg_data.get("data", {}).get("status", "unknown")
|
||||
current_tasks = msg_data.get("data", {}).get("current_tasks", 0)
|
||||
max_tasks = msg_data.get("data", {}).get("max_tasks", 0)
|
||||
return f"Status: {status} ({current_tasks}/{max_tasks} tasks)"
|
||||
|
||||
elif message_type == "capability_broadcast":
|
||||
capabilities = msg_data.get("data", {}).get("capabilities", [])
|
||||
models = msg_data.get("data", {}).get("models", [])
|
||||
return f"Updated capabilities: {', '.join(capabilities[:3])}{'...' if len(capabilities) > 3 else ''}"
|
||||
|
||||
elif message_type == "task_announced":
|
||||
task_data = msg_data.get("data", {})
|
||||
return f"Task announced: {task_data.get('title', 'Unknown task')}"
|
||||
|
||||
elif message_type == "task_claimed":
|
||||
task_data = msg_data.get("data", {})
|
||||
return f"Task claimed: {task_data.get('title', 'Unknown task')}"
|
||||
|
||||
elif message_type == "role_announcement":
|
||||
role = msg_data.get("data", {}).get("role", "unknown")
|
||||
return f"Role announcement: {role}"
|
||||
|
||||
elif message_type == "collaboration":
|
||||
return f"Collaboration: {msg_data.get('data', {}).get('content', 'Agent discussion')}"
|
||||
|
||||
elif self.log_type == "peer_joined":
|
||||
return "Agent joined the network"
|
||||
|
||||
elif self.log_type == "peer_left":
|
||||
return "Agent left the network"
|
||||
|
||||
else:
|
||||
# Generic fallback
|
||||
return f"{message_type}: {json.dumps(msg_data.get('data', {}))[:100]}{'...' if len(str(msg_data.get('data', {}))) > 100 else ''}"
|
||||
|
||||
def _determine_message_type(self) -> str:
|
||||
"""Determine if this is a sent, received, or system message"""
|
||||
msg_data = self.message_data
|
||||
|
||||
# System messages
|
||||
if self.log_type in ["peer_joined", "peer_left", "network_event"]:
|
||||
return "system"
|
||||
|
||||
# For now, treat all as received since we're monitoring
|
||||
# In a real implementation, you'd check if the author is the current node
|
||||
return "received"
|
||||
|
||||
class BzzzLogStreamer:
|
||||
"""Manages streaming of Bzzz hypercore logs"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_endpoints = {}
|
||||
self.last_indices = {} # Track last seen index per agent
|
||||
|
||||
async def discover_bzzz_agents(self) -> List[Dict[str, str]]:
|
||||
"""Discover active Bzzz agents from the Hive agents API"""
|
||||
try:
|
||||
# This would typically query the actual agents database
|
||||
# For now, return known endpoints based on cluster nodes
|
||||
return [
|
||||
{"agent_id": "acacia-bzzz", "endpoint": "http://acacia.local:8080"},
|
||||
{"agent_id": "walnut-bzzz", "endpoint": "http://walnut.local:8080"},
|
||||
{"agent_id": "ironwood-bzzz", "endpoint": "http://ironwood.local:8080"},
|
||||
{"agent_id": "rosewood-bzzz", "endpoint": "http://rosewood.local:8080"},
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to discover Bzzz agents: {e}")
|
||||
return []
|
||||
|
||||
async def fetch_agent_logs(self, agent_endpoint: str, since_index: int = 0) -> List[BzzzLogEntry]:
|
||||
"""Fetch hypercore logs from a specific Bzzz agent"""
|
||||
try:
|
||||
# This would call the actual Bzzz agent's HTTP API
|
||||
# For now, return mock data structure that matches hypercore format
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{agent_endpoint}/api/hypercore/logs",
|
||||
params={"since": since_index},
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
logs_data = response.json()
|
||||
return [BzzzLogEntry(log) for log in logs_data.get("entries", [])]
|
||||
else:
|
||||
logger.warning(f"Failed to fetch logs from {agent_endpoint}: {response.status_code}")
|
||||
return []
|
||||
|
||||
except httpx.ConnectError:
|
||||
logger.debug(f"Agent at {agent_endpoint} is not reachable")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching logs from {agent_endpoint}: {e}")
|
||||
return []
|
||||
|
||||
async def get_recent_logs(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get recent logs from all agents"""
|
||||
agents = await self.discover_bzzz_agents()
|
||||
all_messages = []
|
||||
|
||||
for agent in agents:
|
||||
logs = await self.fetch_agent_logs(agent["endpoint"])
|
||||
for log in logs[-limit:]: # Get recent entries
|
||||
message = log.to_chat_message()
|
||||
message["agent_id"] = agent["agent_id"]
|
||||
all_messages.append(message)
|
||||
|
||||
# Sort by timestamp
|
||||
all_messages.sort(key=lambda x: x["timestamp"])
|
||||
return all_messages[-limit:]
|
||||
|
||||
async def stream_new_logs(self):
|
||||
"""Continuously stream new logs from all agents"""
|
||||
while True:
|
||||
try:
|
||||
agents = await self.discover_bzzz_agents()
|
||||
new_messages = []
|
||||
|
||||
for agent in agents:
|
||||
agent_id = agent["agent_id"]
|
||||
last_index = self.last_indices.get(agent_id, 0)
|
||||
|
||||
logs = await self.fetch_agent_logs(agent["endpoint"], last_index)
|
||||
|
||||
for log in logs:
|
||||
if log.index > last_index:
|
||||
message = log.to_chat_message()
|
||||
message["agent_id"] = agent_id
|
||||
new_messages.append(message)
|
||||
self.last_indices[agent_id] = log.index
|
||||
|
||||
# Send new messages to all connected WebSocket clients
|
||||
if new_messages and active_connections:
|
||||
message_data = {
|
||||
"type": "new_messages",
|
||||
"messages": new_messages
|
||||
}
|
||||
|
||||
# Remove disconnected clients
|
||||
disconnected = []
|
||||
for connection in active_connections:
|
||||
try:
|
||||
await connection.send_text(json.dumps(message_data))
|
||||
except:
|
||||
disconnected.append(connection)
|
||||
|
||||
for conn in disconnected:
|
||||
active_connections.remove(conn)
|
||||
|
||||
await asyncio.sleep(2) # Poll every 2 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in log streaming: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Global log streamer instance
|
||||
log_streamer = BzzzLogStreamer()
|
||||
|
||||
@router.get("/bzzz/logs")
|
||||
async def get_bzzz_logs(
|
||||
limit: int = Query(default=100, le=1000),
|
||||
agent_id: Optional[str] = None
|
||||
):
|
||||
"""Get recent Bzzz hypercore logs"""
|
||||
try:
|
||||
logs = await log_streamer.get_recent_logs(limit)
|
||||
|
||||
if agent_id:
|
||||
logs = [log for log in logs if log.get("agent_id") == agent_id]
|
||||
|
||||
return {
|
||||
"logs": logs,
|
||||
"count": len(logs),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Bzzz logs: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/bzzz/agents")
|
||||
async def get_bzzz_agents():
|
||||
"""Get list of discovered Bzzz agents"""
|
||||
try:
|
||||
agents = await log_streamer.discover_bzzz_agents()
|
||||
return {"agents": agents}
|
||||
except Exception as e:
|
||||
logger.error(f"Error discovering Bzzz agents: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.websocket("/bzzz/logs/stream")
|
||||
async def websocket_bzzz_logs(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time Bzzz log streaming"""
|
||||
await websocket.accept()
|
||||
active_connections.append(websocket)
|
||||
|
||||
try:
|
||||
# Send initial recent logs
|
||||
recent_logs = await log_streamer.get_recent_logs(50)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "initial_logs",
|
||||
"messages": recent_logs
|
||||
}))
|
||||
|
||||
# Keep connection alive and handle client messages
|
||||
while True:
|
||||
try:
|
||||
# Wait for client messages (ping, filters, etc.)
|
||||
message = await asyncio.wait_for(websocket.receive_text(), timeout=30)
|
||||
client_data = json.loads(message)
|
||||
|
||||
if client_data.get("type") == "ping":
|
||||
await websocket.send_text(json.dumps({"type": "pong"}))
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Send periodic heartbeat
|
||||
await websocket.send_text(json.dumps({"type": "heartbeat"}))
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.remove(websocket)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
if websocket in active_connections:
|
||||
active_connections.remove(websocket)
|
||||
|
||||
# Start the log streaming background task
|
||||
@router.on_event("startup")
|
||||
async def start_log_streaming():
|
||||
"""Start the background log streaming task"""
|
||||
asyncio.create_task(log_streamer.stream_new_logs())
|
||||
434
backend/app/api/cluster_registration.py
Normal file
434
backend/app/api/cluster_registration.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Cluster Registration API endpoints
|
||||
Handles registration-based cluster management for Hive-Bzzz integration.
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional
|
||||
import logging
|
||||
import os
|
||||
from ..services.cluster_registration_service import (
|
||||
ClusterRegistrationService,
|
||||
RegistrationRequest,
|
||||
HeartbeatRequest
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize service
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://hive:hivepass@localhost:5432/hive")
|
||||
cluster_registration_service = ClusterRegistrationService(DATABASE_URL)
|
||||
|
||||
# Pydantic models for API
|
||||
class NodeRegistrationRequest(BaseModel):
|
||||
token: str = Field(..., description="Cluster registration token")
|
||||
node_id: str = Field(..., description="Unique node identifier")
|
||||
hostname: str = Field(..., description="Node hostname")
|
||||
system_info: Dict[str, Any] = Field(..., description="System hardware and OS information")
|
||||
client_version: Optional[str] = Field(None, description="Bzzz client version")
|
||||
services: Optional[Dict[str, Any]] = Field(None, description="Available services")
|
||||
capabilities: Optional[Dict[str, Any]] = Field(None, description="Node capabilities")
|
||||
ports: Optional[Dict[str, Any]] = Field(None, description="Service ports")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||
|
||||
class NodeHeartbeatRequest(BaseModel):
|
||||
node_id: str = Field(..., description="Node identifier")
|
||||
status: str = Field("online", description="Node status")
|
||||
cpu_usage: Optional[float] = Field(None, ge=0, le=100, description="CPU usage percentage")
|
||||
memory_usage: Optional[float] = Field(None, ge=0, le=100, description="Memory usage percentage")
|
||||
disk_usage: Optional[float] = Field(None, ge=0, le=100, description="Disk usage percentage")
|
||||
gpu_usage: Optional[float] = Field(None, ge=0, le=100, description="GPU usage percentage")
|
||||
services_status: Optional[Dict[str, Any]] = Field(None, description="Service status information")
|
||||
network_metrics: Optional[Dict[str, Any]] = Field(None, description="Network metrics")
|
||||
custom_metrics: Optional[Dict[str, Any]] = Field(None, description="Custom node metrics")
|
||||
|
||||
class TokenCreateRequest(BaseModel):
|
||||
description: str = Field(..., description="Token description")
|
||||
expires_in_days: Optional[int] = Field(None, gt=0, description="Token expiration in days")
|
||||
max_registrations: Optional[int] = Field(None, gt=0, description="Maximum number of registrations")
|
||||
allowed_ip_ranges: Optional[List[str]] = Field(None, description="Allowed IP CIDR ranges")
|
||||
|
||||
# Helper function to get client IP
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""Extract client IP address from request."""
|
||||
# Check for X-Forwarded-For header (proxy/load balancer)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# Take the first IP in the chain (original client)
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check for X-Real-IP header (nginx)
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# Fall back to direct connection IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
# Registration endpoints
|
||||
@router.post("/cluster/register")
|
||||
async def register_node(
|
||||
registration: NodeRegistrationRequest,
|
||||
request: Request
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Register a new node in the cluster.
|
||||
|
||||
This endpoint allows Bzzz clients to register themselves with the Hive coordinator
|
||||
using a valid cluster token. Similar to `docker swarm join`.
|
||||
"""
|
||||
try:
|
||||
client_ip = get_client_ip(request)
|
||||
logger.info(f"Node registration attempt: {registration.node_id} from {client_ip}")
|
||||
|
||||
# Convert to service request
|
||||
reg_request = RegistrationRequest(
|
||||
token=registration.token,
|
||||
node_id=registration.node_id,
|
||||
hostname=registration.hostname,
|
||||
ip_address=client_ip,
|
||||
system_info=registration.system_info,
|
||||
client_version=registration.client_version,
|
||||
services=registration.services,
|
||||
capabilities=registration.capabilities,
|
||||
ports=registration.ports,
|
||||
metadata=registration.metadata
|
||||
)
|
||||
|
||||
result = await cluster_registration_service.register_node(reg_request, client_ip)
|
||||
logger.info(f"Node {registration.node_id} registered successfully")
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Registration failed for {registration.node_id}: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Registration error for {registration.node_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Registration failed")
|
||||
|
||||
@router.post("/cluster/heartbeat")
|
||||
async def node_heartbeat(heartbeat: NodeHeartbeatRequest) -> Dict[str, Any]:
|
||||
"""
|
||||
Update node heartbeat and status.
|
||||
|
||||
Registered nodes should call this endpoint periodically (every 30 seconds)
|
||||
to maintain their registration and report current status/metrics.
|
||||
"""
|
||||
try:
|
||||
heartbeat_request = HeartbeatRequest(
|
||||
node_id=heartbeat.node_id,
|
||||
status=heartbeat.status,
|
||||
cpu_usage=heartbeat.cpu_usage,
|
||||
memory_usage=heartbeat.memory_usage,
|
||||
disk_usage=heartbeat.disk_usage,
|
||||
gpu_usage=heartbeat.gpu_usage,
|
||||
services_status=heartbeat.services_status,
|
||||
network_metrics=heartbeat.network_metrics,
|
||||
custom_metrics=heartbeat.custom_metrics
|
||||
)
|
||||
|
||||
result = await cluster_registration_service.update_heartbeat(heartbeat_request)
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Heartbeat failed for {heartbeat.node_id}: {e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Heartbeat error for {heartbeat.node_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Heartbeat update failed")
|
||||
|
||||
# Node management endpoints
|
||||
@router.get("/cluster/nodes/registered")
|
||||
async def get_registered_nodes(include_offline: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all registered cluster nodes.
|
||||
|
||||
Returns detailed information about all nodes that have registered
|
||||
with the cluster, including their hardware specs and current status.
|
||||
"""
|
||||
try:
|
||||
nodes = await cluster_registration_service.get_registered_nodes(include_offline)
|
||||
|
||||
# Convert to API response format
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
# Convert dataclass to dict and handle datetime serialization
|
||||
node_dict = {
|
||||
"id": node.id,
|
||||
"node_id": node.node_id,
|
||||
"hostname": node.hostname,
|
||||
"ip_address": node.ip_address,
|
||||
"status": node.status,
|
||||
"hardware": {
|
||||
"cpu": node.cpu_info or {},
|
||||
"memory": node.memory_info or {},
|
||||
"gpu": node.gpu_info or {},
|
||||
"disk": node.disk_info or {},
|
||||
"os": node.os_info or {},
|
||||
"platform": node.platform_info or {}
|
||||
},
|
||||
"services": node.services or {},
|
||||
"capabilities": node.capabilities or {},
|
||||
"ports": node.ports or {},
|
||||
"client_version": node.client_version,
|
||||
"first_registered": node.first_registered.isoformat(),
|
||||
"last_heartbeat": node.last_heartbeat.isoformat(),
|
||||
"registration_metadata": node.registration_metadata or {}
|
||||
}
|
||||
nodes_data.append(node_dict)
|
||||
|
||||
return {
|
||||
"nodes": nodes_data,
|
||||
"total_count": len(nodes_data),
|
||||
"online_count": len([n for n in nodes if n.status == "online"]),
|
||||
"offline_count": len([n for n in nodes if n.status == "offline"])
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get registered nodes: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve registered nodes")
|
||||
|
||||
@router.get("/cluster/nodes/{node_id}")
|
||||
async def get_node_details(node_id: str) -> Dict[str, Any]:
|
||||
"""Get detailed information about a specific registered node."""
|
||||
try:
|
||||
node = await cluster_registration_service.get_node_details(node_id)
|
||||
if not node:
|
||||
raise HTTPException(status_code=404, detail="Node not found")
|
||||
|
||||
return {
|
||||
"id": node.id,
|
||||
"node_id": node.node_id,
|
||||
"hostname": node.hostname,
|
||||
"ip_address": node.ip_address,
|
||||
"status": node.status,
|
||||
"hardware": {
|
||||
"cpu": node.cpu_info or {},
|
||||
"memory": node.memory_info or {},
|
||||
"gpu": node.gpu_info or {},
|
||||
"disk": node.disk_info or {},
|
||||
"os": node.os_info or {},
|
||||
"platform": node.platform_info or {}
|
||||
},
|
||||
"services": node.services or {},
|
||||
"capabilities": node.capabilities or {},
|
||||
"ports": node.ports or {},
|
||||
"client_version": node.client_version,
|
||||
"first_registered": node.first_registered.isoformat(),
|
||||
"last_heartbeat": node.last_heartbeat.isoformat(),
|
||||
"registration_metadata": node.registration_metadata or {}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get node details for {node_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve node details")
|
||||
|
||||
@router.delete("/cluster/nodes/{node_id}")
|
||||
async def remove_node(node_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Remove a node from the cluster.
|
||||
|
||||
This will unregister the node and stop accepting its heartbeats.
|
||||
The node will need to re-register to rejoin the cluster.
|
||||
"""
|
||||
try:
|
||||
success = await cluster_registration_service.remove_node(node_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Node not found")
|
||||
|
||||
return {
|
||||
"node_id": node_id,
|
||||
"status": "removed",
|
||||
"message": "Node successfully removed from cluster"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove node {node_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to remove node")
|
||||
|
||||
# Token management endpoints
|
||||
@router.post("/cluster/tokens")
|
||||
async def create_cluster_token(token_request: TokenCreateRequest) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new cluster registration token.
|
||||
|
||||
Tokens are used by Bzzz clients to authenticate and register with the cluster.
|
||||
Only administrators should have access to this endpoint.
|
||||
"""
|
||||
try:
|
||||
# For now, use a default admin user ID
|
||||
# TODO: Extract from JWT token or session
|
||||
admin_user_id = "admin" # This should come from authentication
|
||||
|
||||
token = await cluster_registration_service.generate_cluster_token(
|
||||
description=token_request.description,
|
||||
created_by_user_id=admin_user_id,
|
||||
expires_in_days=token_request.expires_in_days,
|
||||
max_registrations=token_request.max_registrations,
|
||||
allowed_ip_ranges=token_request.allowed_ip_ranges
|
||||
)
|
||||
|
||||
return {
|
||||
"id": token.id,
|
||||
"token": token.token,
|
||||
"description": token.description,
|
||||
"created_at": token.created_at.isoformat(),
|
||||
"expires_at": token.expires_at.isoformat() if token.expires_at else None,
|
||||
"is_active": token.is_active,
|
||||
"max_registrations": token.max_registrations,
|
||||
"current_registrations": token.current_registrations,
|
||||
"allowed_ip_ranges": token.allowed_ip_ranges
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create cluster token: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create token")
|
||||
|
||||
@router.get("/cluster/tokens")
|
||||
async def list_cluster_tokens() -> Dict[str, Any]:
|
||||
"""
|
||||
List all cluster registration tokens.
|
||||
|
||||
Returns information about all tokens including their usage statistics.
|
||||
Only administrators should have access to this endpoint.
|
||||
"""
|
||||
try:
|
||||
tokens = await cluster_registration_service.list_tokens()
|
||||
|
||||
tokens_data = []
|
||||
for token in tokens:
|
||||
tokens_data.append({
|
||||
"id": token.id,
|
||||
"token": token.token[:20] + "..." if len(token.token) > 20 else token.token, # Partial token for security
|
||||
"description": token.description,
|
||||
"created_at": token.created_at.isoformat(),
|
||||
"expires_at": token.expires_at.isoformat() if token.expires_at else None,
|
||||
"is_active": token.is_active,
|
||||
"max_registrations": token.max_registrations,
|
||||
"current_registrations": token.current_registrations,
|
||||
"allowed_ip_ranges": token.allowed_ip_ranges
|
||||
})
|
||||
|
||||
return {
|
||||
"tokens": tokens_data,
|
||||
"total_count": len(tokens_data)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list cluster tokens: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list tokens")
|
||||
|
||||
@router.delete("/cluster/tokens/{token}")
|
||||
async def revoke_cluster_token(token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Revoke a cluster registration token.
|
||||
|
||||
This will prevent new registrations using this token, but won't affect
|
||||
nodes that are already registered.
|
||||
"""
|
||||
try:
|
||||
success = await cluster_registration_service.revoke_token(token)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
|
||||
return {
|
||||
"token": token[:20] + "..." if len(token) > 20 else token,
|
||||
"status": "revoked",
|
||||
"message": "Token successfully revoked"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke token {token}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to revoke token")
|
||||
|
||||
# Cluster statistics and monitoring
|
||||
@router.get("/cluster/statistics")
|
||||
async def get_cluster_statistics() -> Dict[str, Any]:
|
||||
"""
|
||||
Get cluster health and usage statistics.
|
||||
|
||||
Returns information about node counts, token usage, and overall cluster health.
|
||||
"""
|
||||
try:
|
||||
stats = await cluster_registration_service.get_cluster_statistics()
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cluster statistics: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve cluster statistics")
|
||||
|
||||
# Maintenance endpoints
|
||||
@router.post("/cluster/maintenance/cleanup-offline")
|
||||
async def cleanup_offline_nodes(offline_threshold_minutes: int = 10) -> Dict[str, Any]:
|
||||
"""
|
||||
Mark nodes as offline if they haven't sent heartbeats recently.
|
||||
|
||||
This maintenance endpoint should be called periodically to keep
|
||||
the cluster status accurate.
|
||||
"""
|
||||
try:
|
||||
count = await cluster_registration_service.cleanup_offline_nodes(offline_threshold_minutes)
|
||||
return {
|
||||
"nodes_marked_offline": count,
|
||||
"threshold_minutes": offline_threshold_minutes,
|
||||
"message": f"Marked {count} nodes as offline"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup offline nodes: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to cleanup offline nodes")
|
||||
|
||||
@router.post("/cluster/maintenance/cleanup-heartbeats")
|
||||
async def cleanup_old_heartbeats(retention_days: int = 30) -> Dict[str, Any]:
|
||||
"""
|
||||
Remove old heartbeat data to manage database size.
|
||||
|
||||
This maintenance endpoint should be called periodically to prevent
|
||||
the heartbeat table from growing too large.
|
||||
"""
|
||||
try:
|
||||
count = await cluster_registration_service.cleanup_old_heartbeats(retention_days)
|
||||
return {
|
||||
"heartbeats_deleted": count,
|
||||
"retention_days": retention_days,
|
||||
"message": f"Deleted {count} old heartbeat records"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old heartbeats: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to cleanup old heartbeats")
|
||||
|
||||
# Health check endpoint
|
||||
@router.get("/cluster/health")
|
||||
async def cluster_registration_health() -> Dict[str, Any]:
|
||||
"""
|
||||
Health check for the cluster registration system.
|
||||
"""
|
||||
try:
|
||||
# Test database connection
|
||||
stats = await cluster_registration_service.get_cluster_statistics()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"database_connected": True,
|
||||
"cluster_health": stats.get("cluster_health", {}),
|
||||
"timestamp": stats.get("last_updated")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cluster registration health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"database_connected": False,
|
||||
"error": str(e),
|
||||
"timestamp": None
|
||||
}
|
||||
474
backend/app/api/feedback.py
Normal file
474
backend/app/api/feedback.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""
|
||||
Context Feedback API endpoints for RL Context Curator integration
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..core.database import get_db
|
||||
from ..models.context_feedback import ContextFeedback, AgentPermissions, PromotionRuleHistory
|
||||
from ..models.task import Task
|
||||
from ..models.agent import Agent
|
||||
from ..services.auth import get_current_user
|
||||
from ..models.responses import StatusResponse
|
||||
|
||||
router = APIRouter(prefix="/api/feedback", tags=["Context Feedback"])
|
||||
|
||||
|
||||
# Pydantic models for API
|
||||
class ContextFeedbackRequest(BaseModel):
|
||||
"""Request model for context feedback"""
|
||||
context_id: str = Field(..., description="HCFS context ID")
|
||||
feedback_type: str = Field(..., description="Type of feedback: upvote, downvote, forgetfulness, task_success, task_failure")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence in feedback")
|
||||
reason: Optional[str] = Field(None, description="Optional reason for feedback")
|
||||
usage_context: Optional[str] = Field(None, description="Context of usage")
|
||||
directory_scope: Optional[str] = Field(None, description="Directory where context was used")
|
||||
task_type: Optional[str] = Field(None, description="Type of task being performed")
|
||||
|
||||
|
||||
class TaskOutcomeFeedbackRequest(BaseModel):
|
||||
"""Request model for task outcome feedback"""
|
||||
task_id: str = Field(..., description="Task ID")
|
||||
outcome: str = Field(..., description="Task outcome: completed, failed, abandoned")
|
||||
completion_time: Optional[int] = Field(None, description="Time to complete in seconds")
|
||||
errors_encountered: int = Field(0, description="Number of errors during execution")
|
||||
follow_up_questions: int = Field(0, description="Number of follow-up questions")
|
||||
context_used: Optional[List[str]] = Field(None, description="Context IDs used in task")
|
||||
context_relevance_score: Optional[float] = Field(None, ge=0.0, le=1.0, description="Average relevance of used context")
|
||||
outcome_confidence: Optional[float] = Field(None, ge=0.0, le=1.0, description="Confidence in outcome classification")
|
||||
|
||||
|
||||
class AgentPermissionsRequest(BaseModel):
|
||||
"""Request model for agent permissions"""
|
||||
agent_id: str = Field(..., description="Agent ID")
|
||||
role: str = Field(..., description="Agent role")
|
||||
directory_patterns: List[str] = Field(..., description="Directory patterns for this role")
|
||||
task_types: List[str] = Field(..., description="Task types this agent can handle")
|
||||
context_weight: float = Field(1.0, ge=0.1, le=2.0, description="Weight for context relevance")
|
||||
|
||||
|
||||
class ContextFeedbackResponse(BaseModel):
|
||||
"""Response model for context feedback"""
|
||||
id: int
|
||||
context_id: str
|
||||
agent_id: str
|
||||
task_id: Optional[str]
|
||||
feedback_type: str
|
||||
role: str
|
||||
confidence: float
|
||||
reason: Optional[str]
|
||||
usage_context: Optional[str]
|
||||
directory_scope: Optional[str]
|
||||
task_type: Optional[str]
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class FeedbackStatsResponse(BaseModel):
|
||||
"""Response model for feedback statistics"""
|
||||
total_feedback: int
|
||||
feedback_by_type: Dict[str, int]
|
||||
feedback_by_role: Dict[str, int]
|
||||
average_confidence: float
|
||||
recent_feedback_count: int
|
||||
top_contexts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
@router.post("/context/{context_id}", response_model=StatusResponse)
|
||||
async def submit_context_feedback(
|
||||
context_id: str,
|
||||
request: ContextFeedbackRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Submit feedback for a specific context
|
||||
"""
|
||||
try:
|
||||
# Get agent information
|
||||
agent = db.query(Agent).filter(Agent.id == current_user.get("agent_id", "unknown")).first()
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
# Validate feedback type
|
||||
valid_types = ["upvote", "downvote", "forgetfulness", "task_success", "task_failure"]
|
||||
if request.feedback_type not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid feedback type. Must be one of: {valid_types}")
|
||||
|
||||
# Create feedback record
|
||||
feedback = ContextFeedback(
|
||||
context_id=request.context_id,
|
||||
agent_id=agent.id,
|
||||
feedback_type=request.feedback_type,
|
||||
role=agent.role if agent.role else "general",
|
||||
confidence=request.confidence,
|
||||
reason=request.reason,
|
||||
usage_context=request.usage_context,
|
||||
directory_scope=request.directory_scope,
|
||||
task_type=request.task_type
|
||||
)
|
||||
|
||||
db.add(feedback)
|
||||
db.commit()
|
||||
db.refresh(feedback)
|
||||
|
||||
# Send feedback to RL Context Curator in background
|
||||
background_tasks.add_task(
|
||||
send_feedback_to_rl_curator,
|
||||
feedback.id,
|
||||
request.context_id,
|
||||
request.feedback_type,
|
||||
agent.id,
|
||||
agent.role if agent.role else "general",
|
||||
request.confidence
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
status="success",
|
||||
message="Context feedback submitted successfully",
|
||||
data={"feedback_id": feedback.id, "context_id": request.context_id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to submit feedback: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/task-outcome/{task_id}", response_model=StatusResponse)
|
||||
async def submit_task_outcome_feedback(
|
||||
task_id: str,
|
||||
request: TaskOutcomeFeedbackRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Submit task outcome feedback for RL learning
|
||||
"""
|
||||
try:
|
||||
# Get task
|
||||
task = db.query(Task).filter(Task.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Update task with outcome metrics
|
||||
task.task_outcome = request.outcome
|
||||
task.completion_time = request.completion_time
|
||||
task.errors_encountered = request.errors_encountered
|
||||
task.follow_up_questions = request.follow_up_questions
|
||||
task.context_relevance_score = request.context_relevance_score
|
||||
task.outcome_confidence = request.outcome_confidence
|
||||
task.feedback_collected = True
|
||||
|
||||
if request.context_used:
|
||||
task.context_used = request.context_used
|
||||
|
||||
if request.outcome in ["completed", "failed", "abandoned"] and not task.completed_at:
|
||||
task.completed_at = datetime.utcnow()
|
||||
|
||||
# Calculate success rate
|
||||
if request.outcome == "completed":
|
||||
task.success_rate = 1.0 - (request.errors_encountered * 0.1) # Simple calculation
|
||||
task.success_rate = max(0.0, min(1.0, task.success_rate))
|
||||
else:
|
||||
task.success_rate = 0.0
|
||||
|
||||
db.commit()
|
||||
|
||||
# Create feedback events for used contexts
|
||||
if request.context_used and task.assigned_agent_id:
|
||||
agent = db.query(Agent).filter(Agent.id == task.assigned_agent_id).first()
|
||||
if agent:
|
||||
feedback_type = "task_success" if request.outcome == "completed" else "task_failure"
|
||||
|
||||
for context_id in request.context_used:
|
||||
feedback = ContextFeedback(
|
||||
context_id=context_id,
|
||||
agent_id=agent.id,
|
||||
task_id=task.id,
|
||||
feedback_type=feedback_type,
|
||||
role=agent.role if agent.role else "general",
|
||||
confidence=request.outcome_confidence or 0.8,
|
||||
reason=f"Task {request.outcome}",
|
||||
usage_context=f"task_execution_{request.outcome}",
|
||||
task_type=request.task_type
|
||||
)
|
||||
db.add(feedback)
|
||||
|
||||
db.commit()
|
||||
|
||||
return StatusResponse(
|
||||
status="success",
|
||||
message="Task outcome feedback submitted successfully",
|
||||
data={"task_id": task_id, "outcome": request.outcome}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to submit task outcome: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/stats", response_model=FeedbackStatsResponse)
|
||||
async def get_feedback_stats(
|
||||
days: int = 7,
|
||||
role: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get feedback statistics for analysis
|
||||
"""
|
||||
try:
|
||||
# Base query
|
||||
query = db.query(ContextFeedback)
|
||||
|
||||
# Filter by date range
|
||||
if days > 0:
|
||||
since_date = datetime.utcnow() - timedelta(days=days)
|
||||
query = query.filter(ContextFeedback.timestamp >= since_date)
|
||||
|
||||
# Filter by role if specified
|
||||
if role:
|
||||
query = query.filter(ContextFeedback.role == role)
|
||||
|
||||
feedback_records = query.all()
|
||||
|
||||
# Calculate statistics
|
||||
total_feedback = len(feedback_records)
|
||||
|
||||
feedback_by_type = {}
|
||||
feedback_by_role = {}
|
||||
confidence_values = []
|
||||
context_usage = {}
|
||||
|
||||
for feedback in feedback_records:
|
||||
# Count by type
|
||||
feedback_by_type[feedback.feedback_type] = feedback_by_type.get(feedback.feedback_type, 0) + 1
|
||||
|
||||
# Count by role
|
||||
feedback_by_role[feedback.role] = feedback_by_role.get(feedback.role, 0) + 1
|
||||
|
||||
# Collect confidence values
|
||||
confidence_values.append(feedback.confidence)
|
||||
|
||||
# Count context usage
|
||||
context_usage[feedback.context_id] = context_usage.get(feedback.context_id, 0) + 1
|
||||
|
||||
# Calculate average confidence
|
||||
average_confidence = sum(confidence_values) / len(confidence_values) if confidence_values else 0.0
|
||||
|
||||
# Get recent feedback count (last 24 hours)
|
||||
recent_since = datetime.utcnow() - timedelta(days=1)
|
||||
recent_count = db.query(ContextFeedback).filter(
|
||||
ContextFeedback.timestamp >= recent_since
|
||||
).count()
|
||||
|
||||
# Get top contexts by usage
|
||||
top_contexts = [
|
||||
{"context_id": ctx_id, "usage_count": count}
|
||||
for ctx_id, count in sorted(context_usage.items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
]
|
||||
|
||||
return FeedbackStatsResponse(
|
||||
total_feedback=total_feedback,
|
||||
feedback_by_type=feedback_by_type,
|
||||
feedback_by_role=feedback_by_role,
|
||||
average_confidence=average_confidence,
|
||||
recent_feedback_count=recent_count,
|
||||
top_contexts=top_contexts
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get feedback stats: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/recent", response_model=List[ContextFeedbackResponse])
|
||||
async def get_recent_feedback(
|
||||
limit: int = 50,
|
||||
feedback_type: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get recent feedback events
|
||||
"""
|
||||
try:
|
||||
query = db.query(ContextFeedback).order_by(ContextFeedback.timestamp.desc())
|
||||
|
||||
if feedback_type:
|
||||
query = query.filter(ContextFeedback.feedback_type == feedback_type)
|
||||
|
||||
if role:
|
||||
query = query.filter(ContextFeedback.role == role)
|
||||
|
||||
feedback_records = query.limit(limit).all()
|
||||
|
||||
return [
|
||||
ContextFeedbackResponse(
|
||||
id=fb.id,
|
||||
context_id=fb.context_id,
|
||||
agent_id=fb.agent_id,
|
||||
task_id=str(fb.task_id) if fb.task_id else None,
|
||||
feedback_type=fb.feedback_type,
|
||||
role=fb.role,
|
||||
confidence=fb.confidence,
|
||||
reason=fb.reason,
|
||||
usage_context=fb.usage_context,
|
||||
directory_scope=fb.directory_scope,
|
||||
task_type=fb.task_type,
|
||||
timestamp=fb.timestamp
|
||||
)
|
||||
for fb in feedback_records
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get recent feedback: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/agent-permissions", response_model=StatusResponse)
|
||||
async def set_agent_permissions(
|
||||
request: AgentPermissionsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Set or update agent permissions for context filtering
|
||||
"""
|
||||
try:
|
||||
# Check if permissions already exist
|
||||
existing = db.query(AgentPermissions).filter(
|
||||
AgentPermissions.agent_id == request.agent_id,
|
||||
AgentPermissions.role == request.role
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing permissions
|
||||
existing.directory_patterns = ",".join(request.directory_patterns)
|
||||
existing.task_types = ",".join(request.task_types)
|
||||
existing.context_weight = request.context_weight
|
||||
existing.updated_at = datetime.utcnow()
|
||||
else:
|
||||
# Create new permissions
|
||||
permissions = AgentPermissions(
|
||||
agent_id=request.agent_id,
|
||||
role=request.role,
|
||||
directory_patterns=",".join(request.directory_patterns),
|
||||
task_types=",".join(request.task_types),
|
||||
context_weight=request.context_weight
|
||||
)
|
||||
db.add(permissions)
|
||||
|
||||
db.commit()
|
||||
|
||||
return StatusResponse(
|
||||
status="success",
|
||||
message="Agent permissions updated successfully",
|
||||
data={"agent_id": request.agent_id, "role": request.role}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to set agent permissions: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/agent-permissions/{agent_id}")
|
||||
async def get_agent_permissions(
|
||||
agent_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get agent permissions for context filtering
|
||||
"""
|
||||
try:
|
||||
permissions = db.query(AgentPermissions).filter(
|
||||
AgentPermissions.agent_id == agent_id,
|
||||
AgentPermissions.active == "true"
|
||||
).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": perm.id,
|
||||
"agent_id": perm.agent_id,
|
||||
"role": perm.role,
|
||||
"directory_patterns": perm.directory_patterns.split(",") if perm.directory_patterns else [],
|
||||
"task_types": perm.task_types.split(",") if perm.task_types else [],
|
||||
"context_weight": perm.context_weight,
|
||||
"created_at": perm.created_at,
|
||||
"updated_at": perm.updated_at
|
||||
}
|
||||
for perm in permissions
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get agent permissions: {str(e)}")
|
||||
|
||||
|
||||
async def send_feedback_to_rl_curator(
|
||||
feedback_id: int,
|
||||
context_id: str,
|
||||
feedback_type: str,
|
||||
agent_id: str,
|
||||
role: str,
|
||||
confidence: float
|
||||
):
|
||||
"""
|
||||
Background task to send feedback to RL Context Curator
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
# Prepare feedback event in Bzzz format
|
||||
feedback_event = {
|
||||
"bzzz_type": "feedback_event",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"origin": {
|
||||
"node_id": "hive",
|
||||
"agent_id": agent_id,
|
||||
"task_id": f"hive-feedback-{feedback_id}",
|
||||
"workspace": "hive://context-feedback",
|
||||
"directory": "/feedback/"
|
||||
},
|
||||
"feedback": {
|
||||
"type": feedback_type,
|
||||
"category": "general", # Could be enhanced with category detection
|
||||
"role": role,
|
||||
"context_id": context_id,
|
||||
"reason": f"Feedback from Hive agent {agent_id}",
|
||||
"confidence": confidence,
|
||||
"usage_context": "hive_platform"
|
||||
},
|
||||
"task_outcome": {
|
||||
"completed": feedback_type in ["upvote", "task_success"],
|
||||
"completion_time": 0,
|
||||
"errors_encountered": 0,
|
||||
"follow_up_questions": 0
|
||||
}
|
||||
}
|
||||
|
||||
# Send to HCFS RL Tuner Service
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
"http://localhost:8001/api/feedback",
|
||||
json=feedback_event,
|
||||
timeout=10.0
|
||||
)
|
||||
if response.status_code == 200:
|
||||
print(f"✅ Feedback sent to RL Curator: {feedback_id}")
|
||||
else:
|
||||
print(f"⚠️ RL Curator responded with status {response.status_code}")
|
||||
except httpx.ConnectError:
|
||||
print(f"⚠️ Could not connect to RL Curator service (feedback {feedback_id})")
|
||||
except Exception as e:
|
||||
print(f"❌ Error sending feedback to RL Curator: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Background feedback task failed: {e}")
|
||||
@@ -47,6 +47,37 @@ async def get_project_tasks(project_id: str, current_user: Dict[str, Any] = Depe
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.put("/projects/{project_id}")
|
||||
async def update_project(project_id: str, project_data: Dict[str, Any], current_user: Dict[str, Any] = Depends(get_current_user_context)) -> Dict[str, Any]:
|
||||
"""Update a project configuration."""
|
||||
try:
|
||||
updated_project = project_service.update_project(project_id, project_data)
|
||||
if not updated_project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return updated_project
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/projects")
|
||||
async def create_project(project_data: Dict[str, Any], current_user: Dict[str, Any] = Depends(get_current_user_context)) -> Dict[str, Any]:
|
||||
"""Create a new project."""
|
||||
try:
|
||||
new_project = project_service.create_project(project_data)
|
||||
return new_project
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete("/projects/{project_id}")
|
||||
async def delete_project(project_id: str, current_user: Dict[str, Any] = Depends(get_current_user_context)) -> Dict[str, Any]:
|
||||
"""Delete a project."""
|
||||
try:
|
||||
result = project_service.delete_project(project_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return {"success": True, "message": "Project deleted successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# === Bzzz Integration Endpoints ===
|
||||
|
||||
@bzzz_router.get("/active-repos")
|
||||
|
||||
Reference in New Issue
Block a user