Copy CCLI source to backend for Docker builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
anthonyrawlins
2025-07-10 12:46:52 +10:00
parent baa48bfcd4
commit 8b32d54e79
26 changed files with 2930 additions and 0 deletions

View File

@@ -0,0 +1 @@
# CCLI Source Package

View File

@@ -0,0 +1 @@
# CLI Agents Package

View File

@@ -0,0 +1,344 @@
"""
CLI Agent Factory
Creates and manages CLI-based agents with predefined configurations.
"""
import logging
from typing import Dict, List, Optional, Any
from enum import Enum
from dataclasses import dataclass
from agents.gemini_cli_agent import GeminiCliAgent, GeminiCliConfig
class CliAgentType(Enum):
"""Supported CLI agent types"""
GEMINI = "gemini"
class Specialization(Enum):
"""Agent specializations"""
GENERAL_AI = "general_ai"
REASONING = "reasoning"
CODE_ANALYSIS = "code_analysis"
DOCUMENTATION = "documentation"
TESTING = "testing"
@dataclass
class CliAgentDefinition:
"""Definition for a CLI agent instance"""
agent_id: str
agent_type: CliAgentType
config: Dict[str, Any]
specialization: Specialization
description: str
enabled: bool = True
class CliAgentFactory:
"""
Factory for creating and managing CLI agents
Provides predefined configurations for known agent instances and
supports dynamic agent creation with custom configurations.
"""
# Predefined agent configurations based on verified environment testing
PREDEFINED_AGENTS = {
"walnut-gemini": CliAgentDefinition(
agent_id="walnut-gemini",
agent_type=CliAgentType.GEMINI,
config={
"host": "walnut",
"node_version": "v22.14.0",
"model": "gemini-2.5-pro",
"max_concurrent": 2,
"command_timeout": 60,
"ssh_timeout": 5
},
specialization=Specialization.GENERAL_AI,
description="Gemini CLI agent on WALNUT for general AI tasks",
enabled=True
),
"ironwood-gemini": CliAgentDefinition(
agent_id="ironwood-gemini",
agent_type=CliAgentType.GEMINI,
config={
"host": "ironwood",
"node_version": "v22.17.0",
"model": "gemini-2.5-pro",
"max_concurrent": 2,
"command_timeout": 60,
"ssh_timeout": 5
},
specialization=Specialization.REASONING,
description="Gemini CLI agent on IRONWOOD for reasoning tasks (faster)",
enabled=True
),
# Additional specialized configurations
"walnut-gemini-code": CliAgentDefinition(
agent_id="walnut-gemini-code",
agent_type=CliAgentType.GEMINI,
config={
"host": "walnut",
"node_version": "v22.14.0",
"model": "gemini-2.5-pro",
"max_concurrent": 1, # More conservative for code analysis
"command_timeout": 90, # Longer timeout for complex code analysis
"ssh_timeout": 5
},
specialization=Specialization.CODE_ANALYSIS,
description="Gemini CLI agent specialized for code analysis tasks",
enabled=False # Start disabled, enable when needed
),
"ironwood-gemini-docs": CliAgentDefinition(
agent_id="ironwood-gemini-docs",
agent_type=CliAgentType.GEMINI,
config={
"host": "ironwood",
"node_version": "v22.17.0",
"model": "gemini-2.5-pro",
"max_concurrent": 2,
"command_timeout": 45,
"ssh_timeout": 5
},
specialization=Specialization.DOCUMENTATION,
description="Gemini CLI agent for documentation generation",
enabled=False
)
}
def __init__(self):
self.logger = logging.getLogger(__name__)
self.active_agents: Dict[str, GeminiCliAgent] = {}
@classmethod
def get_predefined_agent_ids(cls) -> List[str]:
"""Get list of all predefined agent IDs"""
return list(cls.PREDEFINED_AGENTS.keys())
@classmethod
def get_enabled_agent_ids(cls) -> List[str]:
"""Get list of enabled predefined agent IDs"""
return [
agent_id for agent_id, definition in cls.PREDEFINED_AGENTS.items()
if definition.enabled
]
@classmethod
def get_agent_definition(cls, agent_id: str) -> Optional[CliAgentDefinition]:
"""Get predefined agent definition by ID"""
return cls.PREDEFINED_AGENTS.get(agent_id)
def create_agent(self, agent_id: str, custom_config: Optional[Dict[str, Any]] = None) -> GeminiCliAgent:
"""
Create a CLI agent instance
Args:
agent_id: ID of predefined agent or custom ID
custom_config: Optional custom configuration to override defaults
Returns:
GeminiCliAgent instance
Raises:
ValueError: If agent_id is unknown and no custom_config provided
"""
# Check if agent already exists
if agent_id in self.active_agents:
self.logger.warning(f"Agent {agent_id} already exists, returning existing instance")
return self.active_agents[agent_id]
# Get configuration
if agent_id in self.PREDEFINED_AGENTS:
definition = self.PREDEFINED_AGENTS[agent_id]
if not definition.enabled:
self.logger.warning(f"Agent {agent_id} is disabled but being created anyway")
config_dict = definition.config.copy()
specialization = definition.specialization.value
# Apply custom overrides
if custom_config:
config_dict.update(custom_config)
elif custom_config:
# Custom agent configuration
config_dict = custom_config
specialization = custom_config.get("specialization", "general_ai")
else:
raise ValueError(f"Unknown agent ID '{agent_id}' and no custom configuration provided")
# Determine agent type and create appropriate agent
agent_type = config_dict.get("agent_type", "gemini")
if agent_type == "gemini" or agent_type == CliAgentType.GEMINI:
agent = self._create_gemini_agent(agent_id, config_dict, specialization)
else:
raise ValueError(f"Unsupported agent type: {agent_type}")
# Store in active agents
self.active_agents[agent_id] = agent
self.logger.info(f"Created CLI agent: {agent_id} ({specialization})")
return agent
def _create_gemini_agent(self, agent_id: str, config_dict: Dict[str, Any], specialization: str) -> GeminiCliAgent:
"""Create a Gemini CLI agent with the given configuration"""
# Create GeminiCliConfig from dictionary
config = GeminiCliConfig(
host=config_dict["host"],
node_version=config_dict["node_version"],
model=config_dict.get("model", "gemini-2.5-pro"),
max_concurrent=config_dict.get("max_concurrent", 2),
command_timeout=config_dict.get("command_timeout", 60),
ssh_timeout=config_dict.get("ssh_timeout", 5),
node_path=config_dict.get("node_path"),
gemini_path=config_dict.get("gemini_path")
)
return GeminiCliAgent(config, specialization)
def get_agent(self, agent_id: str) -> Optional[GeminiCliAgent]:
"""Get an existing agent instance"""
return self.active_agents.get(agent_id)
def remove_agent(self, agent_id: str) -> bool:
"""Remove an agent instance"""
if agent_id in self.active_agents:
agent = self.active_agents.pop(agent_id)
# Note: Cleanup should be called by the caller if needed
self.logger.info(f"Removed CLI agent: {agent_id}")
return True
return False
def get_active_agents(self) -> Dict[str, GeminiCliAgent]:
"""Get all active agent instances"""
return self.active_agents.copy()
def get_agent_info(self, agent_id: str) -> Optional[Dict[str, Any]]:
"""Get information about an agent"""
# Check active agents
if agent_id in self.active_agents:
agent = self.active_agents[agent_id]
return {
"agent_id": agent_id,
"status": "active",
"host": agent.config.host,
"model": agent.config.model,
"specialization": agent.specialization,
"active_tasks": len(agent.active_tasks),
"max_concurrent": agent.config.max_concurrent,
"statistics": agent.get_statistics()
}
# Check predefined but not active
if agent_id in self.PREDEFINED_AGENTS:
definition = self.PREDEFINED_AGENTS[agent_id]
return {
"agent_id": agent_id,
"status": "available" if definition.enabled else "disabled",
"agent_type": definition.agent_type.value,
"specialization": definition.specialization.value,
"description": definition.description,
"config": definition.config
}
return None
def list_all_agents(self) -> Dict[str, Dict[str, Any]]:
"""List all agents (predefined and active)"""
all_agents = {}
# Add predefined agents
for agent_id in self.PREDEFINED_AGENTS:
all_agents[agent_id] = self.get_agent_info(agent_id)
# Add any custom active agents not in predefined list
for agent_id in self.active_agents:
if agent_id not in all_agents:
all_agents[agent_id] = self.get_agent_info(agent_id)
return all_agents
async def health_check_all(self) -> Dict[str, Dict[str, Any]]:
"""Perform health checks on all active agents"""
health_results = {}
for agent_id, agent in self.active_agents.items():
try:
health_results[agent_id] = await agent.health_check()
except Exception as e:
health_results[agent_id] = {
"agent_id": agent_id,
"error": str(e),
"healthy": False
}
return health_results
async def cleanup_all(self):
"""Clean up all active agents"""
for agent_id, agent in list(self.active_agents.items()):
try:
await agent.cleanup()
self.logger.info(f"Cleaned up agent: {agent_id}")
except Exception as e:
self.logger.error(f"Error cleaning up agent {agent_id}: {e}")
self.active_agents.clear()
@classmethod
def create_custom_agent_config(cls, host: str, node_version: str,
specialization: str = "general_ai",
**kwargs) -> Dict[str, Any]:
"""
Helper to create custom agent configuration
Args:
host: Target host for SSH connection
node_version: Node.js version (e.g., "v22.14.0")
specialization: Agent specialization
**kwargs: Additional configuration options
Returns:
Configuration dictionary for create_agent()
"""
config = {
"host": host,
"node_version": node_version,
"specialization": specialization,
"agent_type": "gemini",
"model": "gemini-2.5-pro",
"max_concurrent": 2,
"command_timeout": 60,
"ssh_timeout": 5
}
config.update(kwargs)
return config
# Module-level convenience functions
_default_factory = None
def get_default_factory() -> CliAgentFactory:
"""Get the default CLI agent factory instance"""
global _default_factory
if _default_factory is None:
_default_factory = CliAgentFactory()
return _default_factory
def create_agent(agent_id: str, custom_config: Optional[Dict[str, Any]] = None) -> GeminiCliAgent:
"""Convenience function to create an agent using the default factory"""
factory = get_default_factory()
return factory.create_agent(agent_id, custom_config)

View File

@@ -0,0 +1,369 @@
"""
Gemini CLI Agent Adapter
Provides a standardized interface for executing tasks on Gemini CLI via SSH.
"""
import asyncio
import json
import time
import logging
import hashlib
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, List
from enum import Enum
from executors.ssh_executor import SSHExecutor, SSHConfig, SSHResult
class TaskStatus(Enum):
"""Task execution status"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TIMEOUT = "timeout"
@dataclass
class GeminiCliConfig:
"""Configuration for Gemini CLI agent"""
host: str
node_version: str
model: str = "gemini-2.5-pro"
max_concurrent: int = 2
command_timeout: int = 60
ssh_timeout: int = 5
node_path: Optional[str] = None
gemini_path: Optional[str] = None
def __post_init__(self):
"""Auto-generate paths if not provided"""
if self.node_path is None:
self.node_path = f"/home/tony/.nvm/versions/node/{self.node_version}/bin/node"
if self.gemini_path is None:
self.gemini_path = f"/home/tony/.nvm/versions/node/{self.node_version}/bin/gemini"
@dataclass
class TaskRequest:
"""Represents a task to be executed"""
prompt: str
model: Optional[str] = None
task_id: Optional[str] = None
priority: int = 3
metadata: Optional[Dict[str, Any]] = None
def __post_init__(self):
"""Generate task ID if not provided"""
if self.task_id is None:
# Generate a unique task ID based on prompt and timestamp
content = f"{self.prompt}_{time.time()}"
self.task_id = hashlib.md5(content.encode()).hexdigest()[:12]
@dataclass
class TaskResult:
"""Result of a task execution"""
task_id: str
status: TaskStatus
response: Optional[str] = None
error: Optional[str] = None
execution_time: float = 0.0
model: Optional[str] = None
agent_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization"""
result = asdict(self)
result['status'] = self.status.value
return result
class GeminiCliAgent:
"""
Adapter for Google Gemini CLI execution via SSH
Provides a consistent interface for executing AI tasks on remote Gemini CLI installations
while handling SSH connections, environment setup, error recovery, and concurrent execution.
"""
def __init__(self, config: GeminiCliConfig, specialization: str = "general_ai"):
self.config = config
self.specialization = specialization
self.agent_id = f"{config.host}-gemini"
# SSH configuration
self.ssh_config = SSHConfig(
host=config.host,
connect_timeout=config.ssh_timeout,
command_timeout=config.command_timeout
)
# SSH executor with connection pooling
self.ssh_executor = SSHExecutor(pool_size=3, persist_timeout=120)
# Task management
self.active_tasks: Dict[str, asyncio.Task] = {}
self.task_history: List[TaskResult] = []
self.max_history = 100
# Logging
self.logger = logging.getLogger(f"gemini_cli.{config.host}")
# Performance tracking
self.stats = {
"total_tasks": 0,
"successful_tasks": 0,
"failed_tasks": 0,
"total_execution_time": 0.0,
"average_execution_time": 0.0
}
async def execute_task(self, request: TaskRequest) -> TaskResult:
"""
Execute a task on the Gemini CLI
Args:
request: TaskRequest containing prompt and configuration
Returns:
TaskResult with execution status and response
"""
# Check concurrent task limit
if len(self.active_tasks) >= self.config.max_concurrent:
return TaskResult(
task_id=request.task_id,
status=TaskStatus.FAILED,
error=f"Agent at maximum concurrent tasks ({self.config.max_concurrent})",
agent_id=self.agent_id
)
# Start task execution
task = asyncio.create_task(self._execute_task_impl(request))
self.active_tasks[request.task_id] = task
try:
result = await task
return result
finally:
# Clean up task from active list
self.active_tasks.pop(request.task_id, None)
async def _execute_task_impl(self, request: TaskRequest) -> TaskResult:
"""Internal implementation of task execution"""
start_time = time.time()
model = request.model or self.config.model
try:
self.logger.info(f"Starting task {request.task_id} with model {model}")
# Build the CLI command
command = self._build_cli_command(request.prompt, model)
# Execute via SSH
ssh_result = await self.ssh_executor.execute(self.ssh_config, command)
execution_time = time.time() - start_time
# Process result
if ssh_result.returncode == 0:
result = TaskResult(
task_id=request.task_id,
status=TaskStatus.COMPLETED,
response=self._clean_response(ssh_result.stdout),
execution_time=execution_time,
model=model,
agent_id=self.agent_id,
metadata={
"ssh_duration": ssh_result.duration,
"command": command,
"stderr": ssh_result.stderr
}
)
self.stats["successful_tasks"] += 1
else:
result = TaskResult(
task_id=request.task_id,
status=TaskStatus.FAILED,
error=f"CLI execution failed: {ssh_result.stderr}",
execution_time=execution_time,
model=model,
agent_id=self.agent_id,
metadata={
"returncode": ssh_result.returncode,
"command": command,
"stdout": ssh_result.stdout,
"stderr": ssh_result.stderr
}
)
self.stats["failed_tasks"] += 1
except Exception as e:
execution_time = time.time() - start_time
self.logger.error(f"Task {request.task_id} failed: {e}")
result = TaskResult(
task_id=request.task_id,
status=TaskStatus.FAILED,
error=str(e),
execution_time=execution_time,
model=model,
agent_id=self.agent_id
)
self.stats["failed_tasks"] += 1
# Update statistics
self.stats["total_tasks"] += 1
self.stats["total_execution_time"] += execution_time
self.stats["average_execution_time"] = (
self.stats["total_execution_time"] / self.stats["total_tasks"]
)
# Add to history (with size limit)
self.task_history.append(result)
if len(self.task_history) > self.max_history:
self.task_history.pop(0)
self.logger.info(f"Task {request.task_id} completed with status {result.status.value}")
return result
def _build_cli_command(self, prompt: str, model: str) -> str:
"""Build the complete CLI command for execution"""
# Environment setup
env_setup = f"source ~/.nvm/nvm.sh && nvm use {self.config.node_version}"
# Escape the prompt for shell safety
escaped_prompt = prompt.replace("'", "'\\''")
# Build gemini command
gemini_cmd = f"echo '{escaped_prompt}' | {self.config.gemini_path} --model {model}"
# Complete command
full_command = f"{env_setup} && {gemini_cmd}"
return full_command
def _clean_response(self, raw_output: str) -> str:
"""Clean up the raw CLI output"""
lines = raw_output.strip().split('\n')
# Remove NVM output lines
cleaned_lines = []
for line in lines:
if not (line.startswith('Now using node') or
line.startswith('MCP STDERR') or
line.strip() == ''):
cleaned_lines.append(line)
return '\n'.join(cleaned_lines).strip()
async def health_check(self) -> Dict[str, Any]:
"""Perform a health check on the agent"""
try:
# Test SSH connection
ssh_healthy = await self.ssh_executor.test_connection(self.ssh_config)
# Test Gemini CLI with a simple prompt
if ssh_healthy:
test_request = TaskRequest(
prompt="Say 'health check ok'",
task_id="health_check"
)
result = await self.execute_task(test_request)
cli_healthy = result.status == TaskStatus.COMPLETED
response_time = result.execution_time
else:
cli_healthy = False
response_time = None
# Get connection stats
connection_stats = await self.ssh_executor.get_connection_stats()
return {
"agent_id": self.agent_id,
"host": self.config.host,
"ssh_healthy": ssh_healthy,
"cli_healthy": cli_healthy,
"response_time": response_time,
"active_tasks": len(self.active_tasks),
"max_concurrent": self.config.max_concurrent,
"total_tasks": self.stats["total_tasks"],
"success_rate": (
self.stats["successful_tasks"] / max(self.stats["total_tasks"], 1)
),
"average_execution_time": self.stats["average_execution_time"],
"connection_stats": connection_stats,
"model": self.config.model,
"specialization": self.specialization
}
except Exception as e:
self.logger.error(f"Health check failed: {e}")
return {
"agent_id": self.agent_id,
"host": self.config.host,
"ssh_healthy": False,
"cli_healthy": False,
"error": str(e)
}
async def get_task_status(self, task_id: str) -> Optional[TaskResult]:
"""Get the status of a specific task"""
# Check active tasks
if task_id in self.active_tasks:
task = self.active_tasks[task_id]
if task.done():
return task.result()
else:
return TaskResult(
task_id=task_id,
status=TaskStatus.RUNNING,
agent_id=self.agent_id
)
# Check history
for result in reversed(self.task_history):
if result.task_id == task_id:
return result
return None
async def cancel_task(self, task_id: str) -> bool:
"""Cancel a running task"""
if task_id in self.active_tasks:
task = self.active_tasks[task_id]
if not task.done():
task.cancel()
return True
return False
def get_statistics(self) -> Dict[str, Any]:
"""Get agent performance statistics"""
return {
"agent_id": self.agent_id,
"host": self.config.host,
"specialization": self.specialization,
"model": self.config.model,
"stats": self.stats.copy(),
"active_tasks": len(self.active_tasks),
"history_length": len(self.task_history)
}
async def cleanup(self):
"""Clean up resources"""
# Cancel any active tasks
for task_id, task in list(self.active_tasks.items()):
if not task.done():
task.cancel()
# Wait for tasks to complete
if self.active_tasks:
await asyncio.gather(*self.active_tasks.values(), return_exceptions=True)
# Close SSH connections
await self.ssh_executor.cleanup()
self.logger.info(f"Agent {self.agent_id} cleaned up successfully")

View File

@@ -0,0 +1 @@
# Executors Package

View File

@@ -0,0 +1,148 @@
"""
Simple SSH Executor for CCLI
Uses subprocess for SSH execution without external dependencies.
"""
import asyncio
import subprocess
import time
import logging
from dataclasses import dataclass
from typing import Optional, Dict, Any
@dataclass
class SSHResult:
"""Result of an SSH command execution"""
stdout: str
stderr: str
returncode: int
duration: float
host: str
command: str
@dataclass
class SSHConfig:
"""SSH connection configuration"""
host: str
username: str = "tony"
connect_timeout: int = 5
command_timeout: int = 30
max_retries: int = 2
ssh_options: Optional[Dict[str, str]] = None
def __post_init__(self):
if self.ssh_options is None:
self.ssh_options = {
"BatchMode": "yes",
"ConnectTimeout": str(self.connect_timeout),
"StrictHostKeyChecking": "no"
}
class SimpleSSHExecutor:
"""Simple SSH command executor using subprocess"""
def __init__(self):
self.logger = logging.getLogger(__name__)
async def execute(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
"""Execute a command via SSH with retries and error handling"""
for attempt in range(config.max_retries + 1):
try:
return await self._execute_once(config, command, **kwargs)
except Exception as e:
self.logger.warning(f"SSH execution attempt {attempt + 1} failed for {config.host}: {e}")
if attempt < config.max_retries:
await asyncio.sleep(1) # Brief delay before retry
else:
# Final attempt failed
raise Exception(f"SSH execution failed after {config.max_retries + 1} attempts: {e}")
async def _execute_once(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
"""Execute command once via SSH"""
start_time = time.time()
# Build SSH command
ssh_cmd = self._build_ssh_command(config, command)
try:
# Execute command with timeout
process = await asyncio.create_subprocess_exec(
*ssh_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
**kwargs
)
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout=config.command_timeout
)
duration = time.time() - start_time
return SSHResult(
stdout=stdout.decode('utf-8'),
stderr=stderr.decode('utf-8'),
returncode=process.returncode,
duration=duration,
host=config.host,
command=command
)
except asyncio.TimeoutError:
duration = time.time() - start_time
raise Exception(f"SSH command timeout after {config.command_timeout}s on {config.host}")
except Exception as e:
duration = time.time() - start_time
self.logger.error(f"SSH execution error on {config.host}: {e}")
raise
def _build_ssh_command(self, config: SSHConfig, command: str) -> list:
"""Build SSH command array"""
ssh_cmd = ["ssh"]
# Add SSH options
for option, value in config.ssh_options.items():
ssh_cmd.extend(["-o", f"{option}={value}"])
# Add destination
if config.username:
destination = f"{config.username}@{config.host}"
else:
destination = config.host
ssh_cmd.append(destination)
ssh_cmd.append(command)
return ssh_cmd
async def test_connection(self, config: SSHConfig) -> bool:
"""Test if SSH connection is working"""
try:
result = await self.execute(config, "echo 'connection_test'")
return result.returncode == 0 and "connection_test" in result.stdout
except Exception as e:
self.logger.error(f"Connection test failed for {config.host}: {e}")
return False
async def get_connection_stats(self) -> Dict[str, Any]:
"""Get statistics about current connections (simplified for subprocess)"""
return {
"total_connections": 0, # subprocess doesn't maintain persistent connections
"connection_type": "subprocess"
}
async def cleanup(self):
"""Cleanup resources (no-op for subprocess)"""
pass
# Alias for compatibility
SSHExecutor = SimpleSSHExecutor

View File

@@ -0,0 +1,221 @@
"""
SSH Executor for CCLI
Handles SSH connections, command execution, and connection pooling for CLI agents.
"""
import asyncio
import asyncssh
import time
import logging
from dataclasses import dataclass
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
@dataclass
class SSHResult:
"""Result of an SSH command execution"""
stdout: str
stderr: str
returncode: int
duration: float
host: str
command: str
@dataclass
class SSHConfig:
"""SSH connection configuration"""
host: str
username: str = "tony"
connect_timeout: int = 5
command_timeout: int = 30
max_retries: int = 2
known_hosts: Optional[str] = None
class SSHConnectionPool:
"""Manages SSH connection pooling for efficiency"""
def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
self.pool_size = pool_size
self.persist_timeout = persist_timeout
self.connections: Dict[str, Dict[str, Any]] = {}
self.logger = logging.getLogger(__name__)
async def get_connection(self, config: SSHConfig) -> asyncssh.SSHClientConnection:
"""Get a pooled SSH connection, creating if needed"""
host_key = f"{config.username}@{config.host}"
# Check if we have a valid connection
if host_key in self.connections:
conn_info = self.connections[host_key]
connection = conn_info['connection']
# Check if connection is still alive and not expired
if (not connection.is_closed() and
time.time() - conn_info['created'] < self.persist_timeout):
self.logger.debug(f"Reusing connection to {host_key}")
return connection
else:
# Connection expired or closed, remove it
self.logger.debug(f"Connection to {host_key} expired, creating new one")
await self._close_connection(host_key)
# Create new connection
self.logger.debug(f"Creating new SSH connection to {host_key}")
connection = await asyncssh.connect(
config.host,
username=config.username,
connect_timeout=config.connect_timeout,
known_hosts=config.known_hosts
)
self.connections[host_key] = {
'connection': connection,
'created': time.time(),
'uses': 0
}
return connection
async def _close_connection(self, host_key: str):
"""Close and remove a connection from the pool"""
if host_key in self.connections:
try:
conn_info = self.connections[host_key]
if not conn_info['connection'].is_closed():
conn_info['connection'].close()
await conn_info['connection'].wait_closed()
except Exception as e:
self.logger.warning(f"Error closing connection to {host_key}: {e}")
finally:
del self.connections[host_key]
async def close_all(self):
"""Close all pooled connections"""
for host_key in list(self.connections.keys()):
await self._close_connection(host_key)
class SSHExecutor:
"""Main SSH command executor with connection pooling and error handling"""
def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
self.pool = SSHConnectionPool(pool_size, persist_timeout)
self.logger = logging.getLogger(__name__)
async def execute(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
"""Execute a command via SSH with retries and error handling"""
for attempt in range(config.max_retries + 1):
try:
return await self._execute_once(config, command, **kwargs)
except (asyncssh.Error, asyncio.TimeoutError, OSError) as e:
self.logger.warning(f"SSH execution attempt {attempt + 1} failed for {config.host}: {e}")
if attempt < config.max_retries:
# Close any bad connections and retry
host_key = f"{config.username}@{config.host}"
await self.pool._close_connection(host_key)
await asyncio.sleep(1) # Brief delay before retry
else:
# Final attempt failed
raise Exception(f"SSH execution failed after {config.max_retries + 1} attempts: {e}")
async def _execute_once(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
"""Execute command once via SSH"""
start_time = time.time()
try:
connection = await self.pool.get_connection(config)
# Execute command with timeout
result = await asyncio.wait_for(
connection.run(command, check=False, **kwargs),
timeout=config.command_timeout
)
duration = time.time() - start_time
# Update connection usage stats
host_key = f"{config.username}@{config.host}"
if host_key in self.pool.connections:
self.pool.connections[host_key]['uses'] += 1
return SSHResult(
stdout=result.stdout,
stderr=result.stderr,
returncode=result.exit_status,
duration=duration,
host=config.host,
command=command
)
except asyncio.TimeoutError:
duration = time.time() - start_time
raise Exception(f"SSH command timeout after {config.command_timeout}s on {config.host}")
except Exception as e:
duration = time.time() - start_time
self.logger.error(f"SSH execution error on {config.host}: {e}")
raise
async def test_connection(self, config: SSHConfig) -> bool:
"""Test if SSH connection is working"""
try:
result = await self.execute(config, "echo 'connection_test'")
return result.returncode == 0 and "connection_test" in result.stdout
except Exception as e:
self.logger.error(f"Connection test failed for {config.host}: {e}")
return False
async def get_connection_stats(self) -> Dict[str, Any]:
"""Get statistics about current connections"""
stats = {
"total_connections": len(self.pool.connections),
"connections": {}
}
for host_key, conn_info in self.pool.connections.items():
stats["connections"][host_key] = {
"created": conn_info["created"],
"age_seconds": time.time() - conn_info["created"],
"uses": conn_info["uses"],
"is_closed": conn_info["connection"].is_closed()
}
return stats
async def cleanup(self):
"""Close all connections and cleanup resources"""
await self.pool.close_all()
@asynccontextmanager
async def connection_context(self, config: SSHConfig):
"""Context manager for SSH connections"""
try:
connection = await self.pool.get_connection(config)
yield connection
except Exception as e:
self.logger.error(f"SSH connection context error: {e}")
raise
# Connection stays in pool for reuse
# Module-level convenience functions
_default_executor = None
def get_default_executor() -> SSHExecutor:
"""Get the default SSH executor instance"""
global _default_executor
if _default_executor is None:
_default_executor = SSHExecutor()
return _default_executor
async def execute_ssh_command(host: str, command: str, **kwargs) -> SSHResult:
"""Convenience function for simple SSH command execution"""
config = SSHConfig(host=host)
executor = get_default_executor()
return await executor.execute(config, command, **kwargs)

View File

@@ -0,0 +1,380 @@
"""
Unit tests for GeminiCliAgent
"""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, patch
from dataclasses import dataclass
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from agents.gemini_cli_agent import (
GeminiCliAgent, GeminiCliConfig, TaskRequest, TaskResult, TaskStatus
)
from executors.ssh_executor import SSHResult
class TestGeminiCliAgent:
@pytest.fixture
def agent_config(self):
return GeminiCliConfig(
host="test-host",
node_version="v22.14.0",
model="gemini-2.5-pro",
max_concurrent=2,
command_timeout=30
)
@pytest.fixture
def agent(self, agent_config):
return GeminiCliAgent(agent_config, "test_specialty")
@pytest.fixture
def task_request(self):
return TaskRequest(
prompt="What is 2+2?",
task_id="test-task-123"
)
def test_agent_initialization(self, agent_config):
"""Test agent initialization with proper configuration"""
agent = GeminiCliAgent(agent_config, "general_ai")
assert agent.config.host == "test-host"
assert agent.config.node_version == "v22.14.0"
assert agent.specialization == "general_ai"
assert agent.agent_id == "test-host-gemini"
assert len(agent.active_tasks) == 0
assert agent.stats["total_tasks"] == 0
def test_config_auto_paths(self):
"""Test automatic path generation in config"""
config = GeminiCliConfig(
host="walnut",
node_version="v22.14.0"
)
expected_node_path = "/home/tony/.nvm/versions/node/v22.14.0/bin/node"
expected_gemini_path = "/home/tony/.nvm/versions/node/v22.14.0/bin/gemini"
assert config.node_path == expected_node_path
assert config.gemini_path == expected_gemini_path
def test_build_cli_command(self, agent):
"""Test CLI command building"""
prompt = "What is Python?"
model = "gemini-2.5-pro"
command = agent._build_cli_command(prompt, model)
assert "source ~/.nvm/nvm.sh" in command
assert "nvm use v22.14.0" in command
assert "gemini --model gemini-2.5-pro" in command
assert "What is Python?" in command
def test_build_cli_command_escaping(self, agent):
"""Test CLI command with special characters"""
prompt = "What's the meaning of 'life'?"
model = "gemini-2.5-pro"
command = agent._build_cli_command(prompt, model)
# Should properly escape single quotes
assert "What\\'s the meaning of \\'life\\'?" in command
def test_clean_response(self, agent):
"""Test response cleaning"""
raw_output = """Now using node v22.14.0 (npm v11.3.0)
MCP STDERR (hive): Warning message
This is the actual response
from Gemini CLI
"""
cleaned = agent._clean_response(raw_output)
expected = "This is the actual response\nfrom Gemini CLI"
assert cleaned == expected
@pytest.mark.asyncio
async def test_execute_task_success(self, agent, task_request, mocker):
"""Test successful task execution"""
# Mock SSH executor
mock_ssh_result = SSHResult(
stdout="Now using node v22.14.0\n4\n",
stderr="",
returncode=0,
duration=1.5,
host="test-host",
command="test-command"
)
mock_execute = AsyncMock(return_value=mock_ssh_result)
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
result = await agent.execute_task(task_request)
assert result.status == TaskStatus.COMPLETED
assert result.task_id == "test-task-123"
assert result.response == "4"
assert result.execution_time > 0
assert result.model == "gemini-2.5-pro"
assert result.agent_id == "test-host-gemini"
# Check statistics update
assert agent.stats["successful_tasks"] == 1
assert agent.stats["total_tasks"] == 1
@pytest.mark.asyncio
async def test_execute_task_failure(self, agent, task_request, mocker):
"""Test task execution failure handling"""
mock_ssh_result = SSHResult(
stdout="",
stderr="Command failed: invalid model",
returncode=1,
duration=0.5,
host="test-host",
command="test-command"
)
mock_execute = AsyncMock(return_value=mock_ssh_result)
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
result = await agent.execute_task(task_request)
assert result.status == TaskStatus.FAILED
assert "CLI execution failed" in result.error
assert result.execution_time > 0
# Check statistics update
assert agent.stats["failed_tasks"] == 1
assert agent.stats["total_tasks"] == 1
@pytest.mark.asyncio
async def test_execute_task_exception(self, agent, task_request, mocker):
"""Test task execution with exception"""
mock_execute = AsyncMock(side_effect=Exception("SSH connection failed"))
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
result = await agent.execute_task(task_request)
assert result.status == TaskStatus.FAILED
assert "SSH connection failed" in result.error
assert result.execution_time > 0
# Check statistics update
assert agent.stats["failed_tasks"] == 1
assert agent.stats["total_tasks"] == 1
@pytest.mark.asyncio
async def test_concurrent_task_limit(self, agent, mocker):
"""Test concurrent task execution limits"""
# Mock a slow SSH execution
slow_ssh_result = SSHResult(
stdout="result",
stderr="",
returncode=0,
duration=2.0,
host="test-host",
command="test-command"
)
async def slow_execute(*args, **kwargs):
await asyncio.sleep(0.1) # Simulate slow execution
return slow_ssh_result
mock_execute = AsyncMock(side_effect=slow_execute)
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
# Start maximum concurrent tasks
task1 = TaskRequest(prompt="Task 1", task_id="task-1")
task2 = TaskRequest(prompt="Task 2", task_id="task-2")
task3 = TaskRequest(prompt="Task 3", task_id="task-3")
# Start first two tasks (should succeed)
result1_coro = agent.execute_task(task1)
result2_coro = agent.execute_task(task2)
# Give tasks time to start
await asyncio.sleep(0.01)
# Third task should fail due to limit
result3 = await agent.execute_task(task3)
assert result3.status == TaskStatus.FAILED
assert "maximum concurrent tasks" in result3.error
# Wait for first two to complete
result1 = await result1_coro
result2 = await result2_coro
assert result1.status == TaskStatus.COMPLETED
assert result2.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_health_check_success(self, agent, mocker):
"""Test successful health check"""
# Mock SSH connection test
mock_test_connection = AsyncMock(return_value=True)
mocker.patch.object(agent.ssh_executor, 'test_connection', mock_test_connection)
# Mock successful CLI execution
mock_ssh_result = SSHResult(
stdout="health check ok\n",
stderr="",
returncode=0,
duration=1.0,
host="test-host",
command="test-command"
)
mock_execute = AsyncMock(return_value=mock_ssh_result)
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
# Mock connection stats
mock_get_stats = AsyncMock(return_value={"total_connections": 1})
mocker.patch.object(agent.ssh_executor, 'get_connection_stats', mock_get_stats)
health = await agent.health_check()
assert health["agent_id"] == "test-host-gemini"
assert health["ssh_healthy"] is True
assert health["cli_healthy"] is True
assert health["response_time"] > 0
assert health["active_tasks"] == 0
assert health["max_concurrent"] == 2
@pytest.mark.asyncio
async def test_health_check_failure(self, agent, mocker):
"""Test health check with failures"""
# Mock SSH connection failure
mock_test_connection = AsyncMock(return_value=False)
mocker.patch.object(agent.ssh_executor, 'test_connection', mock_test_connection)
health = await agent.health_check()
assert health["ssh_healthy"] is False
assert health["cli_healthy"] is False
@pytest.mark.asyncio
async def test_task_status_tracking(self, agent, mocker):
"""Test task status tracking"""
# Mock SSH execution
mock_ssh_result = SSHResult(
stdout="result\n",
stderr="",
returncode=0,
duration=1.0,
host="test-host",
command="test-command"
)
mock_execute = AsyncMock(return_value=mock_ssh_result)
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
task_request = TaskRequest(prompt="Test", task_id="status-test")
# Execute task
result = await agent.execute_task(task_request)
# Check task in history
status = await agent.get_task_status("status-test")
assert status is not None
assert status.status == TaskStatus.COMPLETED
assert status.task_id == "status-test"
# Check non-existent task
status = await agent.get_task_status("non-existent")
assert status is None
def test_statistics(self, agent):
"""Test statistics tracking"""
stats = agent.get_statistics()
assert stats["agent_id"] == "test-host-gemini"
assert stats["host"] == "test-host"
assert stats["specialization"] == "test_specialty"
assert stats["model"] == "gemini-2.5-pro"
assert stats["stats"]["total_tasks"] == 0
assert stats["active_tasks"] == 0
@pytest.mark.asyncio
async def test_task_cancellation(self, agent, mocker):
"""Test task cancellation"""
# Mock a long-running SSH execution
async def long_execute(*args, **kwargs):
await asyncio.sleep(10) # Long execution
return SSHResult("", "", 0, 10.0, "test-host", "cmd")
mock_execute = AsyncMock(side_effect=long_execute)
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
task_request = TaskRequest(prompt="Long task", task_id="cancel-test")
# Start task
task_coro = agent.execute_task(task_request)
# Let it start
await asyncio.sleep(0.01)
# Cancel it
cancelled = await agent.cancel_task("cancel-test")
assert cancelled is True
# The task should be cancelled
try:
await task_coro
except asyncio.CancelledError:
pass # Expected
@pytest.mark.asyncio
async def test_cleanup(self, agent, mocker):
"""Test agent cleanup"""
# Mock SSH executor cleanup
mock_cleanup = AsyncMock()
mocker.patch.object(agent.ssh_executor, 'cleanup', mock_cleanup)
await agent.cleanup()
mock_cleanup.assert_called_once()
class TestTaskRequest:
def test_task_request_auto_id(self):
"""Test automatic task ID generation"""
request = TaskRequest(prompt="Test prompt")
assert request.task_id is not None
assert len(request.task_id) == 12 # MD5 hash truncated to 12 chars
def test_task_request_custom_id(self):
"""Test custom task ID"""
request = TaskRequest(prompt="Test", task_id="custom-123")
assert request.task_id == "custom-123"
class TestTaskResult:
def test_task_result_to_dict(self):
"""Test TaskResult serialization"""
result = TaskResult(
task_id="test-123",
status=TaskStatus.COMPLETED,
response="Test response",
execution_time=1.5,
model="gemini-2.5-pro",
agent_id="test-agent"
)
result_dict = result.to_dict()
assert result_dict["task_id"] == "test-123"
assert result_dict["status"] == "completed"
assert result_dict["response"] == "Test response"
assert result_dict["execution_time"] == 1.5
assert result_dict["model"] == "gemini-2.5-pro"
assert result_dict["agent_id"] == "test-agent"