Files
hive/backend/app/services/agent_service.py
anthonyrawlins c6d69695a8 Refactor UnifiedCoordinator to follow Single Responsibility Principle
- Create dedicated service classes for separated concerns:
  * AgentService: Agent management and health monitoring
  * WorkflowService: Workflow parsing and execution tracking
  * PerformanceService: Metrics and load balancing
  * BackgroundService: Background processes and cleanup
  * TaskService: Database persistence (already existed)

- Refactor UnifiedCoordinator into UnifiedCoordinatorRefactored
  * Clean separation of responsibilities
  * Improved maintainability and testability
  * Dependency injection pattern for services
  * Clear service boundaries and interfaces

- Maintain backward compatibility through re-exports
- Update main.py to use refactored coordinator

🚀 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-11 09:09:11 +10:00

300 lines
11 KiB
Python

"""
Agent Management Service
Handles agent registration, health monitoring, and connectivity management.
"""
import asyncio
import aiohttp
import time
import logging
from typing import Dict, List, Optional, Set, Any
from dataclasses import dataclass, field
from sqlalchemy.orm import Session
from enum import Enum
from ..models.agent import Agent as ORMAgent
from ..core.database import SessionLocal
from ..cli_agents.cli_agent_manager import get_cli_agent_manager
logger = logging.getLogger(__name__)
class AgentType(Enum):
"""Unified agent types supporting both original and distributed workflows"""
# Original agent types
KERNEL_DEV = "kernel_dev"
PYTORCH_DEV = "pytorch_dev"
PROFILER = "profiler"
DOCS_WRITER = "docs_writer"
TESTER = "tester"
CLI_GEMINI = "cli_gemini"
GENERAL_AI = "general_ai"
REASONING = "reasoning"
# Distributed workflow types
CODE_GENERATION = "code_generation"
CODE_REVIEW = "code_review"
TESTING = "testing"
COMPILATION = "compilation"
OPTIMIZATION = "optimization"
DOCUMENTATION = "documentation"
DEPLOYMENT = "deployment"
@dataclass
class Agent:
"""Unified agent representation supporting both Ollama and CLI agents"""
id: str
endpoint: str
model: str
specialty: AgentType
max_concurrent: int = 2
current_tasks: int = 0
agent_type: str = "ollama" # "ollama" or "cli"
cli_config: Optional[Dict[str, Any]] = None
# Enhanced fields for distributed workflows
gpu_type: str = "unknown"
capabilities: Set[str] = field(default_factory=set)
performance_history: List[float] = field(default_factory=list)
specializations: List[AgentType] = field(default_factory=list)
last_heartbeat: float = field(default_factory=time.time)
def __post_init__(self):
if self.specializations:
self.capabilities.update([spec.value for spec in self.specializations])
class AgentService:
"""Service for managing agents in the Hive cluster"""
def __init__(self):
self.agents: Dict[str, Agent] = {}
self.cli_agent_manager = None
self._initialized = False
async def initialize(self):
"""Initialize the agent service"""
if self._initialized:
return
try:
# Initialize CLI agent manager
self.cli_agent_manager = get_cli_agent_manager()
# Load agents from database
await self._load_database_agents()
# Initialize predefined cluster agents
self._initialize_cluster_agents()
# Test initial connectivity
await self._test_initial_connectivity()
self._initialized = True
logger.info("✅ Agent Service initialized successfully")
except Exception as e:
logger.error(f"❌ Failed to initialize agent service: {e}")
raise
def add_agent(self, agent: Agent):
"""Add an agent to the service"""
self.agents[agent.id] = agent
logger.info(f"✅ Added agent: {agent.id} ({agent.specialty.value})")
def get_agent(self, agent_id: str) -> Optional[Agent]:
"""Get agent by ID"""
return self.agents.get(agent_id)
def get_all_agents(self) -> Dict[str, Agent]:
"""Get all agents"""
return self.agents.copy()
def get_agents_by_specialty(self, specialty: AgentType) -> List[Agent]:
"""Get agents by specialty"""
return [
agent for agent in self.agents.values()
if agent.specialty == specialty or specialty in agent.specializations
]
def get_available_agents(self, specialty: Optional[AgentType] = None) -> List[Agent]:
"""Get available agents, optionally filtered by specialty"""
available = [
agent for agent in self.agents.values()
if agent.current_tasks < agent.max_concurrent
]
if specialty:
available = [
agent for agent in available
if agent.specialty == specialty or specialty in agent.specializations
]
return available
def get_optimal_agent(self, specialty: AgentType, load_balancer=None) -> Optional[Agent]:
"""Get the optimal agent for a task type"""
available_agents = [
agent for agent in self.agents.values()
if (agent.specialty == specialty or specialty in agent.specializations)
and agent.current_tasks < agent.max_concurrent
]
if not available_agents:
# Fallback to general AI agents
available_agents = [
agent for agent in self.agents.values()
if agent.specialty == AgentType.GENERAL_AI
and agent.current_tasks < agent.max_concurrent
]
if available_agents:
if load_balancer:
return min(available_agents, key=lambda a: load_balancer.get_weight(a.id))
else:
# Simple round-robin based on current tasks
return min(available_agents, key=lambda a: a.current_tasks)
return None
def increment_agent_tasks(self, agent_id: str):
"""Increment current task count for an agent"""
if agent_id in self.agents:
self.agents[agent_id].current_tasks += 1
def decrement_agent_tasks(self, agent_id: str):
"""Decrement current task count for an agent"""
if agent_id in self.agents:
self.agents[agent_id].current_tasks = max(0, self.agents[agent_id].current_tasks - 1)
def update_agent_heartbeat(self, agent_id: str):
"""Update agent heartbeat timestamp"""
if agent_id in self.agents:
self.agents[agent_id].last_heartbeat = time.time()
async def _load_database_agents(self):
"""Load agents from database"""
try:
db = SessionLocal()
orm_agents = db.query(ORMAgent).all()
for orm_agent in orm_agents:
specialty = AgentType(orm_agent.specialty) if orm_agent.specialty else AgentType.GENERAL_AI
agent = Agent(
id=orm_agent.id,
endpoint=orm_agent.endpoint,
model=orm_agent.model or "unknown",
specialty=specialty,
max_concurrent=orm_agent.max_concurrent,
current_tasks=orm_agent.current_tasks,
agent_type=orm_agent.agent_type,
cli_config=orm_agent.cli_config
)
self.add_agent(agent)
db.close()
logger.info(f"📊 Loaded {len(orm_agents)} agents from database")
except Exception as e:
logger.error(f"❌ Failed to load agents from database: {e}")
def _initialize_cluster_agents(self):
"""Initialize predefined cluster agents"""
cluster_agents = [
Agent(
id="walnut-codellama",
endpoint="http://walnut.local:11434",
model="codellama:34b",
specialty=AgentType.KERNEL_DEV
),
Agent(
id="oak-gemma",
endpoint="http://oak.local:11434",
model="gemma2:27b",
specialty=AgentType.PYTORCH_DEV
),
Agent(
id="ironwood-llama",
endpoint="http://ironwood.local:11434",
model="llama3.1:70b",
specialty=AgentType.GENERAL_AI
)
]
for agent in cluster_agents:
if agent.id not in self.agents:
self.add_agent(agent)
async def _test_initial_connectivity(self):
"""Test connectivity to all agents"""
logger.info("🔍 Testing agent connectivity...")
for agent in self.agents.values():
try:
if agent.agent_type == "cli":
# Test CLI agent
if self.cli_agent_manager:
await self.cli_agent_manager.test_agent(agent.id)
else:
# Test Ollama agent
async with aiohttp.ClientSession() as session:
async with session.get(
f"{agent.endpoint}/api/tags",
timeout=aiohttp.ClientTimeout(total=5)
) as response:
if response.status == 200:
logger.info(f"✅ Agent {agent.id} is responsive")
else:
logger.warning(f"⚠️ Agent {agent.id} returned HTTP {response.status}")
except Exception as e:
logger.warning(f"⚠️ Agent {agent.id} is not responsive: {e}")
async def check_agent_health(self, agent: Agent) -> bool:
"""Check individual agent health"""
try:
if agent.agent_type == "cli":
# CLI agent health check
if self.cli_agent_manager:
return await self.cli_agent_manager.test_agent(agent.id)
return False
else:
# Ollama agent health check
async with aiohttp.ClientSession() as session:
async with session.get(
f"{agent.endpoint}/api/tags",
timeout=aiohttp.ClientTimeout(total=10)
) as response:
return response.status == 200
except Exception as e:
logger.warning(f"⚠️ Agent {agent.id} health check error: {e}")
return False
async def health_monitor_cycle(self):
"""Single cycle of health monitoring for all agents"""
try:
for agent in self.agents.values():
is_healthy = await self.check_agent_health(agent)
if is_healthy:
agent.last_heartbeat = time.time()
else:
logger.warning(f"⚠️ Agent {agent.id} health check failed")
except Exception as e:
logger.error(f"❌ Health monitor cycle error: {e}")
def get_agent_status(self) -> Dict[str, Dict]:
"""Get status of all agents"""
agent_status = {}
for agent_id, agent in self.agents.items():
agent_status[agent_id] = {
"type": agent.agent_type,
"model": agent.model,
"specialty": agent.specialty.value,
"current_tasks": agent.current_tasks,
"max_concurrent": agent.max_concurrent,
"last_heartbeat": agent.last_heartbeat,
"utilization": agent.current_tasks / agent.max_concurrent if agent.max_concurrent > 0 else 0
}
return agent_status