Files
hive/src/executors/ssh_executor.py
anthonyrawlins 6933a6ccb1 Add CCLI (CLI agent integration) complete implementation
- Complete Gemini CLI agent adapter with SSH execution
- CLI agent factory with connection pooling
- SSH executor with AsyncSSH for remote CLI execution
- Backend integration with CLI agent manager
- MCP server updates with CLI agent tools
- Frontend UI updates for mixed agent types
- Database migrations for CLI agent support
- Docker deployment with CLI source integration
- Comprehensive documentation and testing

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-10 12:45:43 +10:00

221 lines
8.0 KiB
Python

"""
SSH Executor for CCLI
Handles SSH connections, command execution, and connection pooling for CLI agents.
"""
import asyncio
import asyncssh
import time
import logging
from dataclasses import dataclass
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
@dataclass
class SSHResult:
"""Result of an SSH command execution"""
stdout: str
stderr: str
returncode: int
duration: float
host: str
command: str
@dataclass
class SSHConfig:
"""SSH connection configuration"""
host: str
username: str = "tony"
connect_timeout: int = 5
command_timeout: int = 30
max_retries: int = 2
known_hosts: Optional[str] = None
class SSHConnectionPool:
"""Manages SSH connection pooling for efficiency"""
def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
self.pool_size = pool_size
self.persist_timeout = persist_timeout
self.connections: Dict[str, Dict[str, Any]] = {}
self.logger = logging.getLogger(__name__)
async def get_connection(self, config: SSHConfig) -> asyncssh.SSHClientConnection:
"""Get a pooled SSH connection, creating if needed"""
host_key = f"{config.username}@{config.host}"
# Check if we have a valid connection
if host_key in self.connections:
conn_info = self.connections[host_key]
connection = conn_info['connection']
# Check if connection is still alive and not expired
if (not connection.is_closed() and
time.time() - conn_info['created'] < self.persist_timeout):
self.logger.debug(f"Reusing connection to {host_key}")
return connection
else:
# Connection expired or closed, remove it
self.logger.debug(f"Connection to {host_key} expired, creating new one")
await self._close_connection(host_key)
# Create new connection
self.logger.debug(f"Creating new SSH connection to {host_key}")
connection = await asyncssh.connect(
config.host,
username=config.username,
connect_timeout=config.connect_timeout,
known_hosts=config.known_hosts
)
self.connections[host_key] = {
'connection': connection,
'created': time.time(),
'uses': 0
}
return connection
async def _close_connection(self, host_key: str):
"""Close and remove a connection from the pool"""
if host_key in self.connections:
try:
conn_info = self.connections[host_key]
if not conn_info['connection'].is_closed():
conn_info['connection'].close()
await conn_info['connection'].wait_closed()
except Exception as e:
self.logger.warning(f"Error closing connection to {host_key}: {e}")
finally:
del self.connections[host_key]
async def close_all(self):
"""Close all pooled connections"""
for host_key in list(self.connections.keys()):
await self._close_connection(host_key)
class SSHExecutor:
"""Main SSH command executor with connection pooling and error handling"""
def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
self.pool = SSHConnectionPool(pool_size, persist_timeout)
self.logger = logging.getLogger(__name__)
async def execute(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
"""Execute a command via SSH with retries and error handling"""
for attempt in range(config.max_retries + 1):
try:
return await self._execute_once(config, command, **kwargs)
except (asyncssh.Error, asyncio.TimeoutError, OSError) as e:
self.logger.warning(f"SSH execution attempt {attempt + 1} failed for {config.host}: {e}")
if attempt < config.max_retries:
# Close any bad connections and retry
host_key = f"{config.username}@{config.host}"
await self.pool._close_connection(host_key)
await asyncio.sleep(1) # Brief delay before retry
else:
# Final attempt failed
raise Exception(f"SSH execution failed after {config.max_retries + 1} attempts: {e}")
async def _execute_once(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
"""Execute command once via SSH"""
start_time = time.time()
try:
connection = await self.pool.get_connection(config)
# Execute command with timeout
result = await asyncio.wait_for(
connection.run(command, check=False, **kwargs),
timeout=config.command_timeout
)
duration = time.time() - start_time
# Update connection usage stats
host_key = f"{config.username}@{config.host}"
if host_key in self.pool.connections:
self.pool.connections[host_key]['uses'] += 1
return SSHResult(
stdout=result.stdout,
stderr=result.stderr,
returncode=result.exit_status,
duration=duration,
host=config.host,
command=command
)
except asyncio.TimeoutError:
duration = time.time() - start_time
raise Exception(f"SSH command timeout after {config.command_timeout}s on {config.host}")
except Exception as e:
duration = time.time() - start_time
self.logger.error(f"SSH execution error on {config.host}: {e}")
raise
async def test_connection(self, config: SSHConfig) -> bool:
"""Test if SSH connection is working"""
try:
result = await self.execute(config, "echo 'connection_test'")
return result.returncode == 0 and "connection_test" in result.stdout
except Exception as e:
self.logger.error(f"Connection test failed for {config.host}: {e}")
return False
async def get_connection_stats(self) -> Dict[str, Any]:
"""Get statistics about current connections"""
stats = {
"total_connections": len(self.pool.connections),
"connections": {}
}
for host_key, conn_info in self.pool.connections.items():
stats["connections"][host_key] = {
"created": conn_info["created"],
"age_seconds": time.time() - conn_info["created"],
"uses": conn_info["uses"],
"is_closed": conn_info["connection"].is_closed()
}
return stats
async def cleanup(self):
"""Close all connections and cleanup resources"""
await self.pool.close_all()
@asynccontextmanager
async def connection_context(self, config: SSHConfig):
"""Context manager for SSH connections"""
try:
connection = await self.pool.get_connection(config)
yield connection
except Exception as e:
self.logger.error(f"SSH connection context error: {e}")
raise
# Connection stays in pool for reuse
# Module-level convenience functions
_default_executor = None
def get_default_executor() -> SSHExecutor:
"""Get the default SSH executor instance"""
global _default_executor
if _default_executor is None:
_default_executor = SSHExecutor()
return _default_executor
async def execute_ssh_command(host: str, command: str, **kwargs) -> SSHResult:
"""Convenience function for simple SSH command execution"""
config = SSHConfig(host=host)
executor = get_default_executor()
return await executor.execute(config, command, **kwargs)