🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
221 lines
8.0 KiB
Python
221 lines
8.0 KiB
Python
"""
|
|
SSH Executor for CCLI
|
|
Handles SSH connections, command execution, and connection pooling for CLI agents.
|
|
"""
|
|
|
|
import asyncio
|
|
import asyncssh
|
|
import time
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Dict, Any
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
|
@dataclass
|
|
class SSHResult:
|
|
"""Result of an SSH command execution"""
|
|
stdout: str
|
|
stderr: str
|
|
returncode: int
|
|
duration: float
|
|
host: str
|
|
command: str
|
|
|
|
|
|
@dataclass
|
|
class SSHConfig:
|
|
"""SSH connection configuration"""
|
|
host: str
|
|
username: str = "tony"
|
|
connect_timeout: int = 5
|
|
command_timeout: int = 30
|
|
max_retries: int = 2
|
|
known_hosts: Optional[str] = None
|
|
|
|
|
|
class SSHConnectionPool:
|
|
"""Manages SSH connection pooling for efficiency"""
|
|
|
|
def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
|
|
self.pool_size = pool_size
|
|
self.persist_timeout = persist_timeout
|
|
self.connections: Dict[str, Dict[str, Any]] = {}
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
async def get_connection(self, config: SSHConfig) -> asyncssh.SSHClientConnection:
|
|
"""Get a pooled SSH connection, creating if needed"""
|
|
host_key = f"{config.username}@{config.host}"
|
|
|
|
# Check if we have a valid connection
|
|
if host_key in self.connections:
|
|
conn_info = self.connections[host_key]
|
|
connection = conn_info['connection']
|
|
|
|
# Check if connection is still alive and not expired
|
|
if (not connection.is_closed() and
|
|
time.time() - conn_info['created'] < self.persist_timeout):
|
|
self.logger.debug(f"Reusing connection to {host_key}")
|
|
return connection
|
|
else:
|
|
# Connection expired or closed, remove it
|
|
self.logger.debug(f"Connection to {host_key} expired, creating new one")
|
|
await self._close_connection(host_key)
|
|
|
|
# Create new connection
|
|
self.logger.debug(f"Creating new SSH connection to {host_key}")
|
|
connection = await asyncssh.connect(
|
|
config.host,
|
|
username=config.username,
|
|
connect_timeout=config.connect_timeout,
|
|
known_hosts=config.known_hosts
|
|
)
|
|
|
|
self.connections[host_key] = {
|
|
'connection': connection,
|
|
'created': time.time(),
|
|
'uses': 0
|
|
}
|
|
|
|
return connection
|
|
|
|
async def _close_connection(self, host_key: str):
|
|
"""Close and remove a connection from the pool"""
|
|
if host_key in self.connections:
|
|
try:
|
|
conn_info = self.connections[host_key]
|
|
if not conn_info['connection'].is_closed():
|
|
conn_info['connection'].close()
|
|
await conn_info['connection'].wait_closed()
|
|
except Exception as e:
|
|
self.logger.warning(f"Error closing connection to {host_key}: {e}")
|
|
finally:
|
|
del self.connections[host_key]
|
|
|
|
async def close_all(self):
|
|
"""Close all pooled connections"""
|
|
for host_key in list(self.connections.keys()):
|
|
await self._close_connection(host_key)
|
|
|
|
|
|
class SSHExecutor:
|
|
"""Main SSH command executor with connection pooling and error handling"""
|
|
|
|
def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
|
|
self.pool = SSHConnectionPool(pool_size, persist_timeout)
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
async def execute(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
|
|
"""Execute a command via SSH with retries and error handling"""
|
|
|
|
for attempt in range(config.max_retries + 1):
|
|
try:
|
|
return await self._execute_once(config, command, **kwargs)
|
|
|
|
except (asyncssh.Error, asyncio.TimeoutError, OSError) as e:
|
|
self.logger.warning(f"SSH execution attempt {attempt + 1} failed for {config.host}: {e}")
|
|
|
|
if attempt < config.max_retries:
|
|
# Close any bad connections and retry
|
|
host_key = f"{config.username}@{config.host}"
|
|
await self.pool._close_connection(host_key)
|
|
await asyncio.sleep(1) # Brief delay before retry
|
|
else:
|
|
# Final attempt failed
|
|
raise Exception(f"SSH execution failed after {config.max_retries + 1} attempts: {e}")
|
|
|
|
async def _execute_once(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
|
|
"""Execute command once via SSH"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
connection = await self.pool.get_connection(config)
|
|
|
|
# Execute command with timeout
|
|
result = await asyncio.wait_for(
|
|
connection.run(command, check=False, **kwargs),
|
|
timeout=config.command_timeout
|
|
)
|
|
|
|
duration = time.time() - start_time
|
|
|
|
# Update connection usage stats
|
|
host_key = f"{config.username}@{config.host}"
|
|
if host_key in self.pool.connections:
|
|
self.pool.connections[host_key]['uses'] += 1
|
|
|
|
return SSHResult(
|
|
stdout=result.stdout,
|
|
stderr=result.stderr,
|
|
returncode=result.exit_status,
|
|
duration=duration,
|
|
host=config.host,
|
|
command=command
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
duration = time.time() - start_time
|
|
raise Exception(f"SSH command timeout after {config.command_timeout}s on {config.host}")
|
|
|
|
except Exception as e:
|
|
duration = time.time() - start_time
|
|
self.logger.error(f"SSH execution error on {config.host}: {e}")
|
|
raise
|
|
|
|
async def test_connection(self, config: SSHConfig) -> bool:
|
|
"""Test if SSH connection is working"""
|
|
try:
|
|
result = await self.execute(config, "echo 'connection_test'")
|
|
return result.returncode == 0 and "connection_test" in result.stdout
|
|
except Exception as e:
|
|
self.logger.error(f"Connection test failed for {config.host}: {e}")
|
|
return False
|
|
|
|
async def get_connection_stats(self) -> Dict[str, Any]:
|
|
"""Get statistics about current connections"""
|
|
stats = {
|
|
"total_connections": len(self.pool.connections),
|
|
"connections": {}
|
|
}
|
|
|
|
for host_key, conn_info in self.pool.connections.items():
|
|
stats["connections"][host_key] = {
|
|
"created": conn_info["created"],
|
|
"age_seconds": time.time() - conn_info["created"],
|
|
"uses": conn_info["uses"],
|
|
"is_closed": conn_info["connection"].is_closed()
|
|
}
|
|
|
|
return stats
|
|
|
|
async def cleanup(self):
|
|
"""Close all connections and cleanup resources"""
|
|
await self.pool.close_all()
|
|
|
|
@asynccontextmanager
|
|
async def connection_context(self, config: SSHConfig):
|
|
"""Context manager for SSH connections"""
|
|
try:
|
|
connection = await self.pool.get_connection(config)
|
|
yield connection
|
|
except Exception as e:
|
|
self.logger.error(f"SSH connection context error: {e}")
|
|
raise
|
|
# Connection stays in pool for reuse
|
|
|
|
|
|
# Module-level convenience functions
|
|
_default_executor = None
|
|
|
|
def get_default_executor() -> SSHExecutor:
|
|
"""Get the default SSH executor instance"""
|
|
global _default_executor
|
|
if _default_executor is None:
|
|
_default_executor = SSHExecutor()
|
|
return _default_executor
|
|
|
|
async def execute_ssh_command(host: str, command: str, **kwargs) -> SSHResult:
|
|
"""Convenience function for simple SSH command execution"""
|
|
config = SSHConfig(host=host)
|
|
executor = get_default_executor()
|
|
return await executor.execute(config, command, **kwargs) |