Copy CCLI source to backend for Docker builds
🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
1
backend/ccli_src/src/__init__.py
Normal file
1
backend/ccli_src/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# CCLI Source Package
|
||||
1
backend/ccli_src/src/agents/__init__.py
Normal file
1
backend/ccli_src/src/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# CLI Agents Package
|
||||
BIN
backend/ccli_src/src/agents/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
backend/ccli_src/src/agents/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
344
backend/ccli_src/src/agents/cli_agent_factory.py
Normal file
344
backend/ccli_src/src/agents/cli_agent_factory.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
CLI Agent Factory
|
||||
Creates and manages CLI-based agents with predefined configurations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agents.gemini_cli_agent import GeminiCliAgent, GeminiCliConfig
|
||||
|
||||
|
||||
class CliAgentType(Enum):
|
||||
"""Supported CLI agent types"""
|
||||
GEMINI = "gemini"
|
||||
|
||||
|
||||
class Specialization(Enum):
|
||||
"""Agent specializations"""
|
||||
GENERAL_AI = "general_ai"
|
||||
REASONING = "reasoning"
|
||||
CODE_ANALYSIS = "code_analysis"
|
||||
DOCUMENTATION = "documentation"
|
||||
TESTING = "testing"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CliAgentDefinition:
|
||||
"""Definition for a CLI agent instance"""
|
||||
agent_id: str
|
||||
agent_type: CliAgentType
|
||||
config: Dict[str, Any]
|
||||
specialization: Specialization
|
||||
description: str
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class CliAgentFactory:
|
||||
"""
|
||||
Factory for creating and managing CLI agents
|
||||
|
||||
Provides predefined configurations for known agent instances and
|
||||
supports dynamic agent creation with custom configurations.
|
||||
"""
|
||||
|
||||
# Predefined agent configurations based on verified environment testing
|
||||
PREDEFINED_AGENTS = {
|
||||
"walnut-gemini": CliAgentDefinition(
|
||||
agent_id="walnut-gemini",
|
||||
agent_type=CliAgentType.GEMINI,
|
||||
config={
|
||||
"host": "walnut",
|
||||
"node_version": "v22.14.0",
|
||||
"model": "gemini-2.5-pro",
|
||||
"max_concurrent": 2,
|
||||
"command_timeout": 60,
|
||||
"ssh_timeout": 5
|
||||
},
|
||||
specialization=Specialization.GENERAL_AI,
|
||||
description="Gemini CLI agent on WALNUT for general AI tasks",
|
||||
enabled=True
|
||||
),
|
||||
|
||||
"ironwood-gemini": CliAgentDefinition(
|
||||
agent_id="ironwood-gemini",
|
||||
agent_type=CliAgentType.GEMINI,
|
||||
config={
|
||||
"host": "ironwood",
|
||||
"node_version": "v22.17.0",
|
||||
"model": "gemini-2.5-pro",
|
||||
"max_concurrent": 2,
|
||||
"command_timeout": 60,
|
||||
"ssh_timeout": 5
|
||||
},
|
||||
specialization=Specialization.REASONING,
|
||||
description="Gemini CLI agent on IRONWOOD for reasoning tasks (faster)",
|
||||
enabled=True
|
||||
),
|
||||
|
||||
# Additional specialized configurations
|
||||
"walnut-gemini-code": CliAgentDefinition(
|
||||
agent_id="walnut-gemini-code",
|
||||
agent_type=CliAgentType.GEMINI,
|
||||
config={
|
||||
"host": "walnut",
|
||||
"node_version": "v22.14.0",
|
||||
"model": "gemini-2.5-pro",
|
||||
"max_concurrent": 1, # More conservative for code analysis
|
||||
"command_timeout": 90, # Longer timeout for complex code analysis
|
||||
"ssh_timeout": 5
|
||||
},
|
||||
specialization=Specialization.CODE_ANALYSIS,
|
||||
description="Gemini CLI agent specialized for code analysis tasks",
|
||||
enabled=False # Start disabled, enable when needed
|
||||
),
|
||||
|
||||
"ironwood-gemini-docs": CliAgentDefinition(
|
||||
agent_id="ironwood-gemini-docs",
|
||||
agent_type=CliAgentType.GEMINI,
|
||||
config={
|
||||
"host": "ironwood",
|
||||
"node_version": "v22.17.0",
|
||||
"model": "gemini-2.5-pro",
|
||||
"max_concurrent": 2,
|
||||
"command_timeout": 45,
|
||||
"ssh_timeout": 5
|
||||
},
|
||||
specialization=Specialization.DOCUMENTATION,
|
||||
description="Gemini CLI agent for documentation generation",
|
||||
enabled=False
|
||||
)
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.active_agents: Dict[str, GeminiCliAgent] = {}
|
||||
|
||||
@classmethod
|
||||
def get_predefined_agent_ids(cls) -> List[str]:
|
||||
"""Get list of all predefined agent IDs"""
|
||||
return list(cls.PREDEFINED_AGENTS.keys())
|
||||
|
||||
@classmethod
|
||||
def get_enabled_agent_ids(cls) -> List[str]:
|
||||
"""Get list of enabled predefined agent IDs"""
|
||||
return [
|
||||
agent_id for agent_id, definition in cls.PREDEFINED_AGENTS.items()
|
||||
if definition.enabled
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_agent_definition(cls, agent_id: str) -> Optional[CliAgentDefinition]:
|
||||
"""Get predefined agent definition by ID"""
|
||||
return cls.PREDEFINED_AGENTS.get(agent_id)
|
||||
|
||||
def create_agent(self, agent_id: str, custom_config: Optional[Dict[str, Any]] = None) -> GeminiCliAgent:
|
||||
"""
|
||||
Create a CLI agent instance
|
||||
|
||||
Args:
|
||||
agent_id: ID of predefined agent or custom ID
|
||||
custom_config: Optional custom configuration to override defaults
|
||||
|
||||
Returns:
|
||||
GeminiCliAgent instance
|
||||
|
||||
Raises:
|
||||
ValueError: If agent_id is unknown and no custom_config provided
|
||||
"""
|
||||
|
||||
# Check if agent already exists
|
||||
if agent_id in self.active_agents:
|
||||
self.logger.warning(f"Agent {agent_id} already exists, returning existing instance")
|
||||
return self.active_agents[agent_id]
|
||||
|
||||
# Get configuration
|
||||
if agent_id in self.PREDEFINED_AGENTS:
|
||||
definition = self.PREDEFINED_AGENTS[agent_id]
|
||||
|
||||
if not definition.enabled:
|
||||
self.logger.warning(f"Agent {agent_id} is disabled but being created anyway")
|
||||
|
||||
config_dict = definition.config.copy()
|
||||
specialization = definition.specialization.value
|
||||
|
||||
# Apply custom overrides
|
||||
if custom_config:
|
||||
config_dict.update(custom_config)
|
||||
|
||||
elif custom_config:
|
||||
# Custom agent configuration
|
||||
config_dict = custom_config
|
||||
specialization = custom_config.get("specialization", "general_ai")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown agent ID '{agent_id}' and no custom configuration provided")
|
||||
|
||||
# Determine agent type and create appropriate agent
|
||||
agent_type = config_dict.get("agent_type", "gemini")
|
||||
|
||||
if agent_type == "gemini" or agent_type == CliAgentType.GEMINI:
|
||||
agent = self._create_gemini_agent(agent_id, config_dict, specialization)
|
||||
else:
|
||||
raise ValueError(f"Unsupported agent type: {agent_type}")
|
||||
|
||||
# Store in active agents
|
||||
self.active_agents[agent_id] = agent
|
||||
|
||||
self.logger.info(f"Created CLI agent: {agent_id} ({specialization})")
|
||||
return agent
|
||||
|
||||
def _create_gemini_agent(self, agent_id: str, config_dict: Dict[str, Any], specialization: str) -> GeminiCliAgent:
|
||||
"""Create a Gemini CLI agent with the given configuration"""
|
||||
|
||||
# Create GeminiCliConfig from dictionary
|
||||
config = GeminiCliConfig(
|
||||
host=config_dict["host"],
|
||||
node_version=config_dict["node_version"],
|
||||
model=config_dict.get("model", "gemini-2.5-pro"),
|
||||
max_concurrent=config_dict.get("max_concurrent", 2),
|
||||
command_timeout=config_dict.get("command_timeout", 60),
|
||||
ssh_timeout=config_dict.get("ssh_timeout", 5),
|
||||
node_path=config_dict.get("node_path"),
|
||||
gemini_path=config_dict.get("gemini_path")
|
||||
)
|
||||
|
||||
return GeminiCliAgent(config, specialization)
|
||||
|
||||
def get_agent(self, agent_id: str) -> Optional[GeminiCliAgent]:
|
||||
"""Get an existing agent instance"""
|
||||
return self.active_agents.get(agent_id)
|
||||
|
||||
def remove_agent(self, agent_id: str) -> bool:
|
||||
"""Remove an agent instance"""
|
||||
if agent_id in self.active_agents:
|
||||
agent = self.active_agents.pop(agent_id)
|
||||
# Note: Cleanup should be called by the caller if needed
|
||||
self.logger.info(f"Removed CLI agent: {agent_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_active_agents(self) -> Dict[str, GeminiCliAgent]:
|
||||
"""Get all active agent instances"""
|
||||
return self.active_agents.copy()
|
||||
|
||||
def get_agent_info(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get information about an agent"""
|
||||
|
||||
# Check active agents
|
||||
if agent_id in self.active_agents:
|
||||
agent = self.active_agents[agent_id]
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"status": "active",
|
||||
"host": agent.config.host,
|
||||
"model": agent.config.model,
|
||||
"specialization": agent.specialization,
|
||||
"active_tasks": len(agent.active_tasks),
|
||||
"max_concurrent": agent.config.max_concurrent,
|
||||
"statistics": agent.get_statistics()
|
||||
}
|
||||
|
||||
# Check predefined but not active
|
||||
if agent_id in self.PREDEFINED_AGENTS:
|
||||
definition = self.PREDEFINED_AGENTS[agent_id]
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"status": "available" if definition.enabled else "disabled",
|
||||
"agent_type": definition.agent_type.value,
|
||||
"specialization": definition.specialization.value,
|
||||
"description": definition.description,
|
||||
"config": definition.config
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def list_all_agents(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""List all agents (predefined and active)"""
|
||||
all_agents = {}
|
||||
|
||||
# Add predefined agents
|
||||
for agent_id in self.PREDEFINED_AGENTS:
|
||||
all_agents[agent_id] = self.get_agent_info(agent_id)
|
||||
|
||||
# Add any custom active agents not in predefined list
|
||||
for agent_id in self.active_agents:
|
||||
if agent_id not in all_agents:
|
||||
all_agents[agent_id] = self.get_agent_info(agent_id)
|
||||
|
||||
return all_agents
|
||||
|
||||
async def health_check_all(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Perform health checks on all active agents"""
|
||||
health_results = {}
|
||||
|
||||
for agent_id, agent in self.active_agents.items():
|
||||
try:
|
||||
health_results[agent_id] = await agent.health_check()
|
||||
except Exception as e:
|
||||
health_results[agent_id] = {
|
||||
"agent_id": agent_id,
|
||||
"error": str(e),
|
||||
"healthy": False
|
||||
}
|
||||
|
||||
return health_results
|
||||
|
||||
async def cleanup_all(self):
|
||||
"""Clean up all active agents"""
|
||||
for agent_id, agent in list(self.active_agents.items()):
|
||||
try:
|
||||
await agent.cleanup()
|
||||
self.logger.info(f"Cleaned up agent: {agent_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning up agent {agent_id}: {e}")
|
||||
|
||||
self.active_agents.clear()
|
||||
|
||||
@classmethod
|
||||
def create_custom_agent_config(cls, host: str, node_version: str,
|
||||
specialization: str = "general_ai",
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper to create custom agent configuration
|
||||
|
||||
Args:
|
||||
host: Target host for SSH connection
|
||||
node_version: Node.js version (e.g., "v22.14.0")
|
||||
specialization: Agent specialization
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
Configuration dictionary for create_agent()
|
||||
"""
|
||||
config = {
|
||||
"host": host,
|
||||
"node_version": node_version,
|
||||
"specialization": specialization,
|
||||
"agent_type": "gemini",
|
||||
"model": "gemini-2.5-pro",
|
||||
"max_concurrent": 2,
|
||||
"command_timeout": 60,
|
||||
"ssh_timeout": 5
|
||||
}
|
||||
|
||||
config.update(kwargs)
|
||||
return config
|
||||
|
||||
|
||||
# Module-level convenience functions
|
||||
_default_factory = None
|
||||
|
||||
def get_default_factory() -> CliAgentFactory:
|
||||
"""Get the default CLI agent factory instance"""
|
||||
global _default_factory
|
||||
if _default_factory is None:
|
||||
_default_factory = CliAgentFactory()
|
||||
return _default_factory
|
||||
|
||||
def create_agent(agent_id: str, custom_config: Optional[Dict[str, Any]] = None) -> GeminiCliAgent:
|
||||
"""Convenience function to create an agent using the default factory"""
|
||||
factory = get_default_factory()
|
||||
return factory.create_agent(agent_id, custom_config)
|
||||
369
backend/ccli_src/src/agents/gemini_cli_agent.py
Normal file
369
backend/ccli_src/src/agents/gemini_cli_agent.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
Gemini CLI Agent Adapter
|
||||
Provides a standardized interface for executing tasks on Gemini CLI via SSH.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import hashlib
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Any, Optional, List
|
||||
from enum import Enum
|
||||
|
||||
from executors.ssh_executor import SSHExecutor, SSHConfig, SSHResult
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Task execution status"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeminiCliConfig:
|
||||
"""Configuration for Gemini CLI agent"""
|
||||
host: str
|
||||
node_version: str
|
||||
model: str = "gemini-2.5-pro"
|
||||
max_concurrent: int = 2
|
||||
command_timeout: int = 60
|
||||
ssh_timeout: int = 5
|
||||
node_path: Optional[str] = None
|
||||
gemini_path: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Auto-generate paths if not provided"""
|
||||
if self.node_path is None:
|
||||
self.node_path = f"/home/tony/.nvm/versions/node/{self.node_version}/bin/node"
|
||||
if self.gemini_path is None:
|
||||
self.gemini_path = f"/home/tony/.nvm/versions/node/{self.node_version}/bin/gemini"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskRequest:
|
||||
"""Represents a task to be executed"""
|
||||
prompt: str
|
||||
model: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
priority: int = 3
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Generate task ID if not provided"""
|
||||
if self.task_id is None:
|
||||
# Generate a unique task ID based on prompt and timestamp
|
||||
content = f"{self.prompt}_{time.time()}"
|
||||
self.task_id = hashlib.md5(content.encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""Result of a task execution"""
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
response: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
execution_time: float = 0.0
|
||||
model: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
result = asdict(self)
|
||||
result['status'] = self.status.value
|
||||
return result
|
||||
|
||||
|
||||
class GeminiCliAgent:
|
||||
"""
|
||||
Adapter for Google Gemini CLI execution via SSH
|
||||
|
||||
Provides a consistent interface for executing AI tasks on remote Gemini CLI installations
|
||||
while handling SSH connections, environment setup, error recovery, and concurrent execution.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GeminiCliConfig, specialization: str = "general_ai"):
|
||||
self.config = config
|
||||
self.specialization = specialization
|
||||
self.agent_id = f"{config.host}-gemini"
|
||||
|
||||
# SSH configuration
|
||||
self.ssh_config = SSHConfig(
|
||||
host=config.host,
|
||||
connect_timeout=config.ssh_timeout,
|
||||
command_timeout=config.command_timeout
|
||||
)
|
||||
|
||||
# SSH executor with connection pooling
|
||||
self.ssh_executor = SSHExecutor(pool_size=3, persist_timeout=120)
|
||||
|
||||
# Task management
|
||||
self.active_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.task_history: List[TaskResult] = []
|
||||
self.max_history = 100
|
||||
|
||||
# Logging
|
||||
self.logger = logging.getLogger(f"gemini_cli.{config.host}")
|
||||
|
||||
# Performance tracking
|
||||
self.stats = {
|
||||
"total_tasks": 0,
|
||||
"successful_tasks": 0,
|
||||
"failed_tasks": 0,
|
||||
"total_execution_time": 0.0,
|
||||
"average_execution_time": 0.0
|
||||
}
|
||||
|
||||
async def execute_task(self, request: TaskRequest) -> TaskResult:
|
||||
"""
|
||||
Execute a task on the Gemini CLI
|
||||
|
||||
Args:
|
||||
request: TaskRequest containing prompt and configuration
|
||||
|
||||
Returns:
|
||||
TaskResult with execution status and response
|
||||
"""
|
||||
|
||||
# Check concurrent task limit
|
||||
if len(self.active_tasks) >= self.config.max_concurrent:
|
||||
return TaskResult(
|
||||
task_id=request.task_id,
|
||||
status=TaskStatus.FAILED,
|
||||
error=f"Agent at maximum concurrent tasks ({self.config.max_concurrent})",
|
||||
agent_id=self.agent_id
|
||||
)
|
||||
|
||||
# Start task execution
|
||||
task = asyncio.create_task(self._execute_task_impl(request))
|
||||
self.active_tasks[request.task_id] = task
|
||||
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
# Clean up task from active list
|
||||
self.active_tasks.pop(request.task_id, None)
|
||||
|
||||
async def _execute_task_impl(self, request: TaskRequest) -> TaskResult:
|
||||
"""Internal implementation of task execution"""
|
||||
start_time = time.time()
|
||||
model = request.model or self.config.model
|
||||
|
||||
try:
|
||||
self.logger.info(f"Starting task {request.task_id} with model {model}")
|
||||
|
||||
# Build the CLI command
|
||||
command = self._build_cli_command(request.prompt, model)
|
||||
|
||||
# Execute via SSH
|
||||
ssh_result = await self.ssh_executor.execute(self.ssh_config, command)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Process result
|
||||
if ssh_result.returncode == 0:
|
||||
result = TaskResult(
|
||||
task_id=request.task_id,
|
||||
status=TaskStatus.COMPLETED,
|
||||
response=self._clean_response(ssh_result.stdout),
|
||||
execution_time=execution_time,
|
||||
model=model,
|
||||
agent_id=self.agent_id,
|
||||
metadata={
|
||||
"ssh_duration": ssh_result.duration,
|
||||
"command": command,
|
||||
"stderr": ssh_result.stderr
|
||||
}
|
||||
)
|
||||
self.stats["successful_tasks"] += 1
|
||||
else:
|
||||
result = TaskResult(
|
||||
task_id=request.task_id,
|
||||
status=TaskStatus.FAILED,
|
||||
error=f"CLI execution failed: {ssh_result.stderr}",
|
||||
execution_time=execution_time,
|
||||
model=model,
|
||||
agent_id=self.agent_id,
|
||||
metadata={
|
||||
"returncode": ssh_result.returncode,
|
||||
"command": command,
|
||||
"stdout": ssh_result.stdout,
|
||||
"stderr": ssh_result.stderr
|
||||
}
|
||||
)
|
||||
self.stats["failed_tasks"] += 1
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
self.logger.error(f"Task {request.task_id} failed: {e}")
|
||||
|
||||
result = TaskResult(
|
||||
task_id=request.task_id,
|
||||
status=TaskStatus.FAILED,
|
||||
error=str(e),
|
||||
execution_time=execution_time,
|
||||
model=model,
|
||||
agent_id=self.agent_id
|
||||
)
|
||||
self.stats["failed_tasks"] += 1
|
||||
|
||||
# Update statistics
|
||||
self.stats["total_tasks"] += 1
|
||||
self.stats["total_execution_time"] += execution_time
|
||||
self.stats["average_execution_time"] = (
|
||||
self.stats["total_execution_time"] / self.stats["total_tasks"]
|
||||
)
|
||||
|
||||
# Add to history (with size limit)
|
||||
self.task_history.append(result)
|
||||
if len(self.task_history) > self.max_history:
|
||||
self.task_history.pop(0)
|
||||
|
||||
self.logger.info(f"Task {request.task_id} completed with status {result.status.value}")
|
||||
return result
|
||||
|
||||
def _build_cli_command(self, prompt: str, model: str) -> str:
|
||||
"""Build the complete CLI command for execution"""
|
||||
|
||||
# Environment setup
|
||||
env_setup = f"source ~/.nvm/nvm.sh && nvm use {self.config.node_version}"
|
||||
|
||||
# Escape the prompt for shell safety
|
||||
escaped_prompt = prompt.replace("'", "'\\''")
|
||||
|
||||
# Build gemini command
|
||||
gemini_cmd = f"echo '{escaped_prompt}' | {self.config.gemini_path} --model {model}"
|
||||
|
||||
# Complete command
|
||||
full_command = f"{env_setup} && {gemini_cmd}"
|
||||
|
||||
return full_command
|
||||
|
||||
def _clean_response(self, raw_output: str) -> str:
|
||||
"""Clean up the raw CLI output"""
|
||||
lines = raw_output.strip().split('\n')
|
||||
|
||||
# Remove NVM output lines
|
||||
cleaned_lines = []
|
||||
for line in lines:
|
||||
if not (line.startswith('Now using node') or
|
||||
line.startswith('MCP STDERR') or
|
||||
line.strip() == ''):
|
||||
cleaned_lines.append(line)
|
||||
|
||||
return '\n'.join(cleaned_lines).strip()
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform a health check on the agent"""
|
||||
try:
|
||||
# Test SSH connection
|
||||
ssh_healthy = await self.ssh_executor.test_connection(self.ssh_config)
|
||||
|
||||
# Test Gemini CLI with a simple prompt
|
||||
if ssh_healthy:
|
||||
test_request = TaskRequest(
|
||||
prompt="Say 'health check ok'",
|
||||
task_id="health_check"
|
||||
)
|
||||
result = await self.execute_task(test_request)
|
||||
cli_healthy = result.status == TaskStatus.COMPLETED
|
||||
response_time = result.execution_time
|
||||
else:
|
||||
cli_healthy = False
|
||||
response_time = None
|
||||
|
||||
# Get connection stats
|
||||
connection_stats = await self.ssh_executor.get_connection_stats()
|
||||
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"host": self.config.host,
|
||||
"ssh_healthy": ssh_healthy,
|
||||
"cli_healthy": cli_healthy,
|
||||
"response_time": response_time,
|
||||
"active_tasks": len(self.active_tasks),
|
||||
"max_concurrent": self.config.max_concurrent,
|
||||
"total_tasks": self.stats["total_tasks"],
|
||||
"success_rate": (
|
||||
self.stats["successful_tasks"] / max(self.stats["total_tasks"], 1)
|
||||
),
|
||||
"average_execution_time": self.stats["average_execution_time"],
|
||||
"connection_stats": connection_stats,
|
||||
"model": self.config.model,
|
||||
"specialization": self.specialization
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check failed: {e}")
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"host": self.config.host,
|
||||
"ssh_healthy": False,
|
||||
"cli_healthy": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get_task_status(self, task_id: str) -> Optional[TaskResult]:
|
||||
"""Get the status of a specific task"""
|
||||
# Check active tasks
|
||||
if task_id in self.active_tasks:
|
||||
task = self.active_tasks[task_id]
|
||||
if task.done():
|
||||
return task.result()
|
||||
else:
|
||||
return TaskResult(
|
||||
task_id=task_id,
|
||||
status=TaskStatus.RUNNING,
|
||||
agent_id=self.agent_id
|
||||
)
|
||||
|
||||
# Check history
|
||||
for result in reversed(self.task_history):
|
||||
if result.task_id == task_id:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
async def cancel_task(self, task_id: str) -> bool:
|
||||
"""Cancel a running task"""
|
||||
if task_id in self.active_tasks:
|
||||
task = self.active_tasks[task_id]
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get agent performance statistics"""
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"host": self.config.host,
|
||||
"specialization": self.specialization,
|
||||
"model": self.config.model,
|
||||
"stats": self.stats.copy(),
|
||||
"active_tasks": len(self.active_tasks),
|
||||
"history_length": len(self.task_history)
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
# Cancel any active tasks
|
||||
for task_id, task in list(self.active_tasks.items()):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to complete
|
||||
if self.active_tasks:
|
||||
await asyncio.gather(*self.active_tasks.values(), return_exceptions=True)
|
||||
|
||||
# Close SSH connections
|
||||
await self.ssh_executor.cleanup()
|
||||
|
||||
self.logger.info(f"Agent {self.agent_id} cleaned up successfully")
|
||||
1
backend/ccli_src/src/executors/__init__.py
Normal file
1
backend/ccli_src/src/executors/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Executors Package
|
||||
Binary file not shown.
Binary file not shown.
148
backend/ccli_src/src/executors/simple_ssh_executor.py
Normal file
148
backend/ccli_src/src/executors/simple_ssh_executor.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Simple SSH Executor for CCLI
|
||||
Uses subprocess for SSH execution without external dependencies.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import time
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSHResult:
|
||||
"""Result of an SSH command execution"""
|
||||
stdout: str
|
||||
stderr: str
|
||||
returncode: int
|
||||
duration: float
|
||||
host: str
|
||||
command: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSHConfig:
|
||||
"""SSH connection configuration"""
|
||||
host: str
|
||||
username: str = "tony"
|
||||
connect_timeout: int = 5
|
||||
command_timeout: int = 30
|
||||
max_retries: int = 2
|
||||
ssh_options: Optional[Dict[str, str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ssh_options is None:
|
||||
self.ssh_options = {
|
||||
"BatchMode": "yes",
|
||||
"ConnectTimeout": str(self.connect_timeout),
|
||||
"StrictHostKeyChecking": "no"
|
||||
}
|
||||
|
||||
|
||||
class SimpleSSHExecutor:
|
||||
"""Simple SSH command executor using subprocess"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def execute(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
|
||||
"""Execute a command via SSH with retries and error handling"""
|
||||
|
||||
for attempt in range(config.max_retries + 1):
|
||||
try:
|
||||
return await self._execute_once(config, command, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"SSH execution attempt {attempt + 1} failed for {config.host}: {e}")
|
||||
|
||||
if attempt < config.max_retries:
|
||||
await asyncio.sleep(1) # Brief delay before retry
|
||||
else:
|
||||
# Final attempt failed
|
||||
raise Exception(f"SSH execution failed after {config.max_retries + 1} attempts: {e}")
|
||||
|
||||
async def _execute_once(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
|
||||
"""Execute command once via SSH"""
|
||||
start_time = time.time()
|
||||
|
||||
# Build SSH command
|
||||
ssh_cmd = self._build_ssh_command(config, command)
|
||||
|
||||
try:
|
||||
# Execute command with timeout
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*ssh_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=config.command_timeout
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
return SSHResult(
|
||||
stdout=stdout.decode('utf-8'),
|
||||
stderr=stderr.decode('utf-8'),
|
||||
returncode=process.returncode,
|
||||
duration=duration,
|
||||
host=config.host,
|
||||
command=command
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration = time.time() - start_time
|
||||
raise Exception(f"SSH command timeout after {config.command_timeout}s on {config.host}")
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
self.logger.error(f"SSH execution error on {config.host}: {e}")
|
||||
raise
|
||||
|
||||
def _build_ssh_command(self, config: SSHConfig, command: str) -> list:
|
||||
"""Build SSH command array"""
|
||||
ssh_cmd = ["ssh"]
|
||||
|
||||
# Add SSH options
|
||||
for option, value in config.ssh_options.items():
|
||||
ssh_cmd.extend(["-o", f"{option}={value}"])
|
||||
|
||||
# Add destination
|
||||
if config.username:
|
||||
destination = f"{config.username}@{config.host}"
|
||||
else:
|
||||
destination = config.host
|
||||
|
||||
ssh_cmd.append(destination)
|
||||
ssh_cmd.append(command)
|
||||
|
||||
return ssh_cmd
|
||||
|
||||
async def test_connection(self, config: SSHConfig) -> bool:
|
||||
"""Test if SSH connection is working"""
|
||||
try:
|
||||
result = await self.execute(config, "echo 'connection_test'")
|
||||
return result.returncode == 0 and "connection_test" in result.stdout
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection test failed for {config.host}: {e}")
|
||||
return False
|
||||
|
||||
async def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about current connections (simplified for subprocess)"""
|
||||
return {
|
||||
"total_connections": 0, # subprocess doesn't maintain persistent connections
|
||||
"connection_type": "subprocess"
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources (no-op for subprocess)"""
|
||||
pass
|
||||
|
||||
|
||||
# Alias for compatibility
|
||||
SSHExecutor = SimpleSSHExecutor
|
||||
221
backend/ccli_src/src/executors/ssh_executor.py
Normal file
221
backend/ccli_src/src/executors/ssh_executor.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
SSH Executor for CCLI
|
||||
Handles SSH connections, command execution, and connection pooling for CLI agents.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import asyncssh
|
||||
import time
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSHResult:
|
||||
"""Result of an SSH command execution"""
|
||||
stdout: str
|
||||
stderr: str
|
||||
returncode: int
|
||||
duration: float
|
||||
host: str
|
||||
command: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSHConfig:
|
||||
"""SSH connection configuration"""
|
||||
host: str
|
||||
username: str = "tony"
|
||||
connect_timeout: int = 5
|
||||
command_timeout: int = 30
|
||||
max_retries: int = 2
|
||||
known_hosts: Optional[str] = None
|
||||
|
||||
|
||||
class SSHConnectionPool:
|
||||
"""Manages SSH connection pooling for efficiency"""
|
||||
|
||||
def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
|
||||
self.pool_size = pool_size
|
||||
self.persist_timeout = persist_timeout
|
||||
self.connections: Dict[str, Dict[str, Any]] = {}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_connection(self, config: SSHConfig) -> asyncssh.SSHClientConnection:
|
||||
"""Get a pooled SSH connection, creating if needed"""
|
||||
host_key = f"{config.username}@{config.host}"
|
||||
|
||||
# Check if we have a valid connection
|
||||
if host_key in self.connections:
|
||||
conn_info = self.connections[host_key]
|
||||
connection = conn_info['connection']
|
||||
|
||||
# Check if connection is still alive and not expired
|
||||
if (not connection.is_closed() and
|
||||
time.time() - conn_info['created'] < self.persist_timeout):
|
||||
self.logger.debug(f"Reusing connection to {host_key}")
|
||||
return connection
|
||||
else:
|
||||
# Connection expired or closed, remove it
|
||||
self.logger.debug(f"Connection to {host_key} expired, creating new one")
|
||||
await self._close_connection(host_key)
|
||||
|
||||
# Create new connection
|
||||
self.logger.debug(f"Creating new SSH connection to {host_key}")
|
||||
connection = await asyncssh.connect(
|
||||
config.host,
|
||||
username=config.username,
|
||||
connect_timeout=config.connect_timeout,
|
||||
known_hosts=config.known_hosts
|
||||
)
|
||||
|
||||
self.connections[host_key] = {
|
||||
'connection': connection,
|
||||
'created': time.time(),
|
||||
'uses': 0
|
||||
}
|
||||
|
||||
return connection
|
||||
|
||||
async def _close_connection(self, host_key: str):
|
||||
"""Close and remove a connection from the pool"""
|
||||
if host_key in self.connections:
|
||||
try:
|
||||
conn_info = self.connections[host_key]
|
||||
if not conn_info['connection'].is_closed():
|
||||
conn_info['connection'].close()
|
||||
await conn_info['connection'].wait_closed()
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing connection to {host_key}: {e}")
|
||||
finally:
|
||||
del self.connections[host_key]
|
||||
|
||||
async def close_all(self):
|
||||
"""Close all pooled connections"""
|
||||
for host_key in list(self.connections.keys()):
|
||||
await self._close_connection(host_key)
|
||||
|
||||
|
||||
class SSHExecutor:
|
||||
"""Main SSH command executor with connection pooling and error handling"""
|
||||
|
||||
def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
|
||||
self.pool = SSHConnectionPool(pool_size, persist_timeout)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def execute(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
|
||||
"""Execute a command via SSH with retries and error handling"""
|
||||
|
||||
for attempt in range(config.max_retries + 1):
|
||||
try:
|
||||
return await self._execute_once(config, command, **kwargs)
|
||||
|
||||
except (asyncssh.Error, asyncio.TimeoutError, OSError) as e:
|
||||
self.logger.warning(f"SSH execution attempt {attempt + 1} failed for {config.host}: {e}")
|
||||
|
||||
if attempt < config.max_retries:
|
||||
# Close any bad connections and retry
|
||||
host_key = f"{config.username}@{config.host}"
|
||||
await self.pool._close_connection(host_key)
|
||||
await asyncio.sleep(1) # Brief delay before retry
|
||||
else:
|
||||
# Final attempt failed
|
||||
raise Exception(f"SSH execution failed after {config.max_retries + 1} attempts: {e}")
|
||||
|
||||
async def _execute_once(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
|
||||
"""Execute command once via SSH"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
connection = await self.pool.get_connection(config)
|
||||
|
||||
# Execute command with timeout
|
||||
result = await asyncio.wait_for(
|
||||
connection.run(command, check=False, **kwargs),
|
||||
timeout=config.command_timeout
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Update connection usage stats
|
||||
host_key = f"{config.username}@{config.host}"
|
||||
if host_key in self.pool.connections:
|
||||
self.pool.connections[host_key]['uses'] += 1
|
||||
|
||||
return SSHResult(
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
returncode=result.exit_status,
|
||||
duration=duration,
|
||||
host=config.host,
|
||||
command=command
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration = time.time() - start_time
|
||||
raise Exception(f"SSH command timeout after {config.command_timeout}s on {config.host}")
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
self.logger.error(f"SSH execution error on {config.host}: {e}")
|
||||
raise
|
||||
|
||||
async def test_connection(self, config: SSHConfig) -> bool:
|
||||
"""Test if SSH connection is working"""
|
||||
try:
|
||||
result = await self.execute(config, "echo 'connection_test'")
|
||||
return result.returncode == 0 and "connection_test" in result.stdout
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection test failed for {config.host}: {e}")
|
||||
return False
|
||||
|
||||
async def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about current connections"""
|
||||
stats = {
|
||||
"total_connections": len(self.pool.connections),
|
||||
"connections": {}
|
||||
}
|
||||
|
||||
for host_key, conn_info in self.pool.connections.items():
|
||||
stats["connections"][host_key] = {
|
||||
"created": conn_info["created"],
|
||||
"age_seconds": time.time() - conn_info["created"],
|
||||
"uses": conn_info["uses"],
|
||||
"is_closed": conn_info["connection"].is_closed()
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
async def cleanup(self):
|
||||
"""Close all connections and cleanup resources"""
|
||||
await self.pool.close_all()
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection_context(self, config: SSHConfig):
|
||||
"""Context manager for SSH connections"""
|
||||
try:
|
||||
connection = await self.pool.get_connection(config)
|
||||
yield connection
|
||||
except Exception as e:
|
||||
self.logger.error(f"SSH connection context error: {e}")
|
||||
raise
|
||||
# Connection stays in pool for reuse
|
||||
|
||||
|
||||
# Module-level convenience functions
|
||||
_default_executor = None
|
||||
|
||||
def get_default_executor() -> SSHExecutor:
|
||||
"""Get the default SSH executor instance"""
|
||||
global _default_executor
|
||||
if _default_executor is None:
|
||||
_default_executor = SSHExecutor()
|
||||
return _default_executor
|
||||
|
||||
async def execute_ssh_command(host: str, command: str, **kwargs) -> SSHResult:
|
||||
"""Convenience function for simple SSH command execution"""
|
||||
config = SSHConfig(host=host)
|
||||
executor = get_default_executor()
|
||||
return await executor.execute(config, command, **kwargs)
|
||||
380
backend/ccli_src/src/tests/test_gemini_cli_agent.py
Normal file
380
backend/ccli_src/src/tests/test_gemini_cli_agent.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Unit tests for GeminiCliAgent
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from dataclasses import dataclass
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from agents.gemini_cli_agent import (
|
||||
GeminiCliAgent, GeminiCliConfig, TaskRequest, TaskResult, TaskStatus
|
||||
)
|
||||
from executors.ssh_executor import SSHResult
|
||||
|
||||
|
||||
class TestGeminiCliAgent:
|
||||
|
||||
@pytest.fixture
|
||||
def agent_config(self):
|
||||
return GeminiCliConfig(
|
||||
host="test-host",
|
||||
node_version="v22.14.0",
|
||||
model="gemini-2.5-pro",
|
||||
max_concurrent=2,
|
||||
command_timeout=30
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self, agent_config):
|
||||
return GeminiCliAgent(agent_config, "test_specialty")
|
||||
|
||||
@pytest.fixture
|
||||
def task_request(self):
|
||||
return TaskRequest(
|
||||
prompt="What is 2+2?",
|
||||
task_id="test-task-123"
|
||||
)
|
||||
|
||||
def test_agent_initialization(self, agent_config):
|
||||
"""Test agent initialization with proper configuration"""
|
||||
agent = GeminiCliAgent(agent_config, "general_ai")
|
||||
|
||||
assert agent.config.host == "test-host"
|
||||
assert agent.config.node_version == "v22.14.0"
|
||||
assert agent.specialization == "general_ai"
|
||||
assert agent.agent_id == "test-host-gemini"
|
||||
assert len(agent.active_tasks) == 0
|
||||
assert agent.stats["total_tasks"] == 0
|
||||
|
||||
def test_config_auto_paths(self):
|
||||
"""Test automatic path generation in config"""
|
||||
config = GeminiCliConfig(
|
||||
host="walnut",
|
||||
node_version="v22.14.0"
|
||||
)
|
||||
|
||||
expected_node_path = "/home/tony/.nvm/versions/node/v22.14.0/bin/node"
|
||||
expected_gemini_path = "/home/tony/.nvm/versions/node/v22.14.0/bin/gemini"
|
||||
|
||||
assert config.node_path == expected_node_path
|
||||
assert config.gemini_path == expected_gemini_path
|
||||
|
||||
def test_build_cli_command(self, agent):
|
||||
"""Test CLI command building"""
|
||||
prompt = "What is Python?"
|
||||
model = "gemini-2.5-pro"
|
||||
|
||||
command = agent._build_cli_command(prompt, model)
|
||||
|
||||
assert "source ~/.nvm/nvm.sh" in command
|
||||
assert "nvm use v22.14.0" in command
|
||||
assert "gemini --model gemini-2.5-pro" in command
|
||||
assert "What is Python?" in command
|
||||
|
||||
def test_build_cli_command_escaping(self, agent):
|
||||
"""Test CLI command with special characters"""
|
||||
prompt = "What's the meaning of 'life'?"
|
||||
model = "gemini-2.5-pro"
|
||||
|
||||
command = agent._build_cli_command(prompt, model)
|
||||
|
||||
# Should properly escape single quotes
|
||||
assert "What\\'s the meaning of \\'life\\'?" in command
|
||||
|
||||
def test_clean_response(self, agent):
|
||||
"""Test response cleaning"""
|
||||
raw_output = """Now using node v22.14.0 (npm v11.3.0)
|
||||
MCP STDERR (hive): Warning message
|
||||
|
||||
This is the actual response
|
||||
from Gemini CLI
|
||||
|
||||
"""
|
||||
|
||||
cleaned = agent._clean_response(raw_output)
|
||||
expected = "This is the actual response\nfrom Gemini CLI"
|
||||
|
||||
assert cleaned == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_success(self, agent, task_request, mocker):
|
||||
"""Test successful task execution"""
|
||||
# Mock SSH executor
|
||||
mock_ssh_result = SSHResult(
|
||||
stdout="Now using node v22.14.0\n4\n",
|
||||
stderr="",
|
||||
returncode=0,
|
||||
duration=1.5,
|
||||
host="test-host",
|
||||
command="test-command"
|
||||
)
|
||||
|
||||
mock_execute = AsyncMock(return_value=mock_ssh_result)
|
||||
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
||||
|
||||
result = await agent.execute_task(task_request)
|
||||
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
assert result.task_id == "test-task-123"
|
||||
assert result.response == "4"
|
||||
assert result.execution_time > 0
|
||||
assert result.model == "gemini-2.5-pro"
|
||||
assert result.agent_id == "test-host-gemini"
|
||||
|
||||
# Check statistics update
|
||||
assert agent.stats["successful_tasks"] == 1
|
||||
assert agent.stats["total_tasks"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_failure(self, agent, task_request, mocker):
|
||||
"""Test task execution failure handling"""
|
||||
mock_ssh_result = SSHResult(
|
||||
stdout="",
|
||||
stderr="Command failed: invalid model",
|
||||
returncode=1,
|
||||
duration=0.5,
|
||||
host="test-host",
|
||||
command="test-command"
|
||||
)
|
||||
|
||||
mock_execute = AsyncMock(return_value=mock_ssh_result)
|
||||
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
||||
|
||||
result = await agent.execute_task(task_request)
|
||||
|
||||
assert result.status == TaskStatus.FAILED
|
||||
assert "CLI execution failed" in result.error
|
||||
assert result.execution_time > 0
|
||||
|
||||
# Check statistics update
|
||||
assert agent.stats["failed_tasks"] == 1
|
||||
assert agent.stats["total_tasks"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_exception(self, agent, task_request, mocker):
|
||||
"""Test task execution with exception"""
|
||||
mock_execute = AsyncMock(side_effect=Exception("SSH connection failed"))
|
||||
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
||||
|
||||
result = await agent.execute_task(task_request)
|
||||
|
||||
assert result.status == TaskStatus.FAILED
|
||||
assert "SSH connection failed" in result.error
|
||||
assert result.execution_time > 0
|
||||
|
||||
# Check statistics update
|
||||
assert agent.stats["failed_tasks"] == 1
|
||||
assert agent.stats["total_tasks"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_task_limit(self, agent, mocker):
|
||||
"""Test concurrent task execution limits"""
|
||||
# Mock a slow SSH execution
|
||||
slow_ssh_result = SSHResult(
|
||||
stdout="result",
|
||||
stderr="",
|
||||
returncode=0,
|
||||
duration=2.0,
|
||||
host="test-host",
|
||||
command="test-command"
|
||||
)
|
||||
|
||||
async def slow_execute(*args, **kwargs):
|
||||
await asyncio.sleep(0.1) # Simulate slow execution
|
||||
return slow_ssh_result
|
||||
|
||||
mock_execute = AsyncMock(side_effect=slow_execute)
|
||||
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
||||
|
||||
# Start maximum concurrent tasks
|
||||
task1 = TaskRequest(prompt="Task 1", task_id="task-1")
|
||||
task2 = TaskRequest(prompt="Task 2", task_id="task-2")
|
||||
task3 = TaskRequest(prompt="Task 3", task_id="task-3")
|
||||
|
||||
# Start first two tasks (should succeed)
|
||||
result1_coro = agent.execute_task(task1)
|
||||
result2_coro = agent.execute_task(task2)
|
||||
|
||||
# Give tasks time to start
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Third task should fail due to limit
|
||||
result3 = await agent.execute_task(task3)
|
||||
assert result3.status == TaskStatus.FAILED
|
||||
assert "maximum concurrent tasks" in result3.error
|
||||
|
||||
# Wait for first two to complete
|
||||
result1 = await result1_coro
|
||||
result2 = await result2_coro
|
||||
|
||||
assert result1.status == TaskStatus.COMPLETED
|
||||
assert result2.status == TaskStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(self, agent, mocker):
|
||||
"""Test successful health check"""
|
||||
# Mock SSH connection test
|
||||
mock_test_connection = AsyncMock(return_value=True)
|
||||
mocker.patch.object(agent.ssh_executor, 'test_connection', mock_test_connection)
|
||||
|
||||
# Mock successful CLI execution
|
||||
mock_ssh_result = SSHResult(
|
||||
stdout="health check ok\n",
|
||||
stderr="",
|
||||
returncode=0,
|
||||
duration=1.0,
|
||||
host="test-host",
|
||||
command="test-command"
|
||||
)
|
||||
mock_execute = AsyncMock(return_value=mock_ssh_result)
|
||||
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
||||
|
||||
# Mock connection stats
|
||||
mock_get_stats = AsyncMock(return_value={"total_connections": 1})
|
||||
mocker.patch.object(agent.ssh_executor, 'get_connection_stats', mock_get_stats)
|
||||
|
||||
health = await agent.health_check()
|
||||
|
||||
assert health["agent_id"] == "test-host-gemini"
|
||||
assert health["ssh_healthy"] is True
|
||||
assert health["cli_healthy"] is True
|
||||
assert health["response_time"] > 0
|
||||
assert health["active_tasks"] == 0
|
||||
assert health["max_concurrent"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_failure(self, agent, mocker):
|
||||
"""Test health check with failures"""
|
||||
# Mock SSH connection failure
|
||||
mock_test_connection = AsyncMock(return_value=False)
|
||||
mocker.patch.object(agent.ssh_executor, 'test_connection', mock_test_connection)
|
||||
|
||||
health = await agent.health_check()
|
||||
|
||||
assert health["ssh_healthy"] is False
|
||||
assert health["cli_healthy"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_status_tracking(self, agent, mocker):
|
||||
"""Test task status tracking"""
|
||||
# Mock SSH execution
|
||||
mock_ssh_result = SSHResult(
|
||||
stdout="result\n",
|
||||
stderr="",
|
||||
returncode=0,
|
||||
duration=1.0,
|
||||
host="test-host",
|
||||
command="test-command"
|
||||
)
|
||||
mock_execute = AsyncMock(return_value=mock_ssh_result)
|
||||
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
||||
|
||||
task_request = TaskRequest(prompt="Test", task_id="status-test")
|
||||
|
||||
# Execute task
|
||||
result = await agent.execute_task(task_request)
|
||||
|
||||
# Check task in history
|
||||
status = await agent.get_task_status("status-test")
|
||||
assert status is not None
|
||||
assert status.status == TaskStatus.COMPLETED
|
||||
assert status.task_id == "status-test"
|
||||
|
||||
# Check non-existent task
|
||||
status = await agent.get_task_status("non-existent")
|
||||
assert status is None
|
||||
|
||||
def test_statistics(self, agent):
|
||||
"""Test statistics tracking"""
|
||||
stats = agent.get_statistics()
|
||||
|
||||
assert stats["agent_id"] == "test-host-gemini"
|
||||
assert stats["host"] == "test-host"
|
||||
assert stats["specialization"] == "test_specialty"
|
||||
assert stats["model"] == "gemini-2.5-pro"
|
||||
assert stats["stats"]["total_tasks"] == 0
|
||||
assert stats["active_tasks"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cancellation(self, agent, mocker):
|
||||
"""Test task cancellation"""
|
||||
# Mock a long-running SSH execution
|
||||
async def long_execute(*args, **kwargs):
|
||||
await asyncio.sleep(10) # Long execution
|
||||
return SSHResult("", "", 0, 10.0, "test-host", "cmd")
|
||||
|
||||
mock_execute = AsyncMock(side_effect=long_execute)
|
||||
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
||||
|
||||
task_request = TaskRequest(prompt="Long task", task_id="cancel-test")
|
||||
|
||||
# Start task
|
||||
task_coro = agent.execute_task(task_request)
|
||||
|
||||
# Let it start
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Cancel it
|
||||
cancelled = await agent.cancel_task("cancel-test")
|
||||
assert cancelled is True
|
||||
|
||||
# The task should be cancelled
|
||||
try:
|
||||
await task_coro
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup(self, agent, mocker):
|
||||
"""Test agent cleanup"""
|
||||
# Mock SSH executor cleanup
|
||||
mock_cleanup = AsyncMock()
|
||||
mocker.patch.object(agent.ssh_executor, 'cleanup', mock_cleanup)
|
||||
|
||||
await agent.cleanup()
|
||||
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
|
||||
class TestTaskRequest:
|
||||
|
||||
def test_task_request_auto_id(self):
|
||||
"""Test automatic task ID generation"""
|
||||
request = TaskRequest(prompt="Test prompt")
|
||||
|
||||
assert request.task_id is not None
|
||||
assert len(request.task_id) == 12 # MD5 hash truncated to 12 chars
|
||||
|
||||
def test_task_request_custom_id(self):
|
||||
"""Test custom task ID"""
|
||||
request = TaskRequest(prompt="Test", task_id="custom-123")
|
||||
|
||||
assert request.task_id == "custom-123"
|
||||
|
||||
|
||||
class TestTaskResult:
|
||||
|
||||
def test_task_result_to_dict(self):
|
||||
"""Test TaskResult serialization"""
|
||||
result = TaskResult(
|
||||
task_id="test-123",
|
||||
status=TaskStatus.COMPLETED,
|
||||
response="Test response",
|
||||
execution_time=1.5,
|
||||
model="gemini-2.5-pro",
|
||||
agent_id="test-agent"
|
||||
)
|
||||
|
||||
result_dict = result.to_dict()
|
||||
|
||||
assert result_dict["task_id"] == "test-123"
|
||||
assert result_dict["status"] == "completed"
|
||||
assert result_dict["response"] == "Test response"
|
||||
assert result_dict["execution_time"] == 1.5
|
||||
assert result_dict["model"] == "gemini-2.5-pro"
|
||||
assert result_dict["agent_id"] == "test-agent"
|
||||
Reference in New Issue
Block a user