 8b32d54e79
			
		
	
	8b32d54e79
	
	
	
		
			
			🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
		
			
				
	
	
		
			221 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			221 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| SSH Executor for CCLI
 | |
| Handles SSH connections, command execution, and connection pooling for CLI agents.
 | |
| """
 | |
| 
 | |
| import asyncio
 | |
| import asyncssh
 | |
| import time
 | |
| import logging
 | |
| from dataclasses import dataclass
 | |
| from typing import Optional, Dict, Any
 | |
| from contextlib import asynccontextmanager
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class SSHResult:
 | |
|     """Result of an SSH command execution"""
 | |
|     stdout: str
 | |
|     stderr: str
 | |
|     returncode: int
 | |
|     duration: float
 | |
|     host: str
 | |
|     command: str
 | |
| 
 | |
| 
 | |
| @dataclass 
 | |
| class SSHConfig:
 | |
|     """SSH connection configuration"""
 | |
|     host: str
 | |
|     username: str = "tony"
 | |
|     connect_timeout: int = 5
 | |
|     command_timeout: int = 30
 | |
|     max_retries: int = 2
 | |
|     known_hosts: Optional[str] = None
 | |
| 
 | |
| 
 | |
| class SSHConnectionPool:
 | |
|     """Manages SSH connection pooling for efficiency"""
 | |
|     
 | |
|     def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
 | |
|         self.pool_size = pool_size
 | |
|         self.persist_timeout = persist_timeout
 | |
|         self.connections: Dict[str, Dict[str, Any]] = {}
 | |
|         self.logger = logging.getLogger(__name__)
 | |
|     
 | |
|     async def get_connection(self, config: SSHConfig) -> asyncssh.SSHClientConnection:
 | |
|         """Get a pooled SSH connection, creating if needed"""
 | |
|         host_key = f"{config.username}@{config.host}"
 | |
|         
 | |
|         # Check if we have a valid connection
 | |
|         if host_key in self.connections:
 | |
|             conn_info = self.connections[host_key]
 | |
|             connection = conn_info['connection']
 | |
|             
 | |
|             # Check if connection is still alive and not expired
 | |
|             if (not connection.is_closed() and 
 | |
|                 time.time() - conn_info['created'] < self.persist_timeout):
 | |
|                 self.logger.debug(f"Reusing connection to {host_key}")
 | |
|                 return connection
 | |
|             else:
 | |
|                 # Connection expired or closed, remove it
 | |
|                 self.logger.debug(f"Connection to {host_key} expired, creating new one")
 | |
|                 await self._close_connection(host_key)
 | |
|         
 | |
|         # Create new connection
 | |
|         self.logger.debug(f"Creating new SSH connection to {host_key}")
 | |
|         connection = await asyncssh.connect(
 | |
|             config.host,
 | |
|             username=config.username,
 | |
|             connect_timeout=config.connect_timeout,
 | |
|             known_hosts=config.known_hosts
 | |
|         )
 | |
|         
 | |
|         self.connections[host_key] = {
 | |
|             'connection': connection,
 | |
|             'created': time.time(),
 | |
|             'uses': 0
 | |
|         }
 | |
|         
 | |
|         return connection
 | |
|     
 | |
|     async def _close_connection(self, host_key: str):
 | |
|         """Close and remove a connection from the pool"""
 | |
|         if host_key in self.connections:
 | |
|             try:
 | |
|                 conn_info = self.connections[host_key]
 | |
|                 if not conn_info['connection'].is_closed():
 | |
|                     conn_info['connection'].close()
 | |
|                     await conn_info['connection'].wait_closed()
 | |
|             except Exception as e:
 | |
|                 self.logger.warning(f"Error closing connection to {host_key}: {e}")
 | |
|             finally:
 | |
|                 del self.connections[host_key]
 | |
|     
 | |
|     async def close_all(self):
 | |
|         """Close all pooled connections"""
 | |
|         for host_key in list(self.connections.keys()):
 | |
|             await self._close_connection(host_key)
 | |
| 
 | |
| 
 | |
| class SSHExecutor:
 | |
|     """Main SSH command executor with connection pooling and error handling"""
 | |
|     
 | |
|     def __init__(self, pool_size: int = 3, persist_timeout: int = 60):
 | |
|         self.pool = SSHConnectionPool(pool_size, persist_timeout)
 | |
|         self.logger = logging.getLogger(__name__)
 | |
|     
 | |
|     async def execute(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
 | |
|         """Execute a command via SSH with retries and error handling"""
 | |
|         
 | |
|         for attempt in range(config.max_retries + 1):
 | |
|             try:
 | |
|                 return await self._execute_once(config, command, **kwargs)
 | |
|             
 | |
|             except (asyncssh.Error, asyncio.TimeoutError, OSError) as e:
 | |
|                 self.logger.warning(f"SSH execution attempt {attempt + 1} failed for {config.host}: {e}")
 | |
|                 
 | |
|                 if attempt < config.max_retries:
 | |
|                     # Close any bad connections and retry
 | |
|                     host_key = f"{config.username}@{config.host}"
 | |
|                     await self.pool._close_connection(host_key)
 | |
|                     await asyncio.sleep(1)  # Brief delay before retry
 | |
|                 else:
 | |
|                     # Final attempt failed
 | |
|                     raise Exception(f"SSH execution failed after {config.max_retries + 1} attempts: {e}")
 | |
|     
 | |
|     async def _execute_once(self, config: SSHConfig, command: str, **kwargs) -> SSHResult:
 | |
|         """Execute command once via SSH"""
 | |
|         start_time = time.time()
 | |
|         
 | |
|         try:
 | |
|             connection = await self.pool.get_connection(config)
 | |
|             
 | |
|             # Execute command with timeout
 | |
|             result = await asyncio.wait_for(
 | |
|                 connection.run(command, check=False, **kwargs),
 | |
|                 timeout=config.command_timeout
 | |
|             )
 | |
|             
 | |
|             duration = time.time() - start_time
 | |
|             
 | |
|             # Update connection usage stats
 | |
|             host_key = f"{config.username}@{config.host}"
 | |
|             if host_key in self.pool.connections:
 | |
|                 self.pool.connections[host_key]['uses'] += 1
 | |
|             
 | |
|             return SSHResult(
 | |
|                 stdout=result.stdout,
 | |
|                 stderr=result.stderr,
 | |
|                 returncode=result.exit_status,
 | |
|                 duration=duration,
 | |
|                 host=config.host,
 | |
|                 command=command
 | |
|             )
 | |
|             
 | |
|         except asyncio.TimeoutError:
 | |
|             duration = time.time() - start_time
 | |
|             raise Exception(f"SSH command timeout after {config.command_timeout}s on {config.host}")
 | |
|         
 | |
|         except Exception as e:
 | |
|             duration = time.time() - start_time
 | |
|             self.logger.error(f"SSH execution error on {config.host}: {e}")
 | |
|             raise
 | |
|     
 | |
|     async def test_connection(self, config: SSHConfig) -> bool:
 | |
|         """Test if SSH connection is working"""
 | |
|         try:
 | |
|             result = await self.execute(config, "echo 'connection_test'")
 | |
|             return result.returncode == 0 and "connection_test" in result.stdout
 | |
|         except Exception as e:
 | |
|             self.logger.error(f"Connection test failed for {config.host}: {e}")
 | |
|             return False
 | |
|     
 | |
|     async def get_connection_stats(self) -> Dict[str, Any]:
 | |
|         """Get statistics about current connections"""
 | |
|         stats = {
 | |
|             "total_connections": len(self.pool.connections),
 | |
|             "connections": {}
 | |
|         }
 | |
|         
 | |
|         for host_key, conn_info in self.pool.connections.items():
 | |
|             stats["connections"][host_key] = {
 | |
|                 "created": conn_info["created"],
 | |
|                 "age_seconds": time.time() - conn_info["created"],
 | |
|                 "uses": conn_info["uses"],
 | |
|                 "is_closed": conn_info["connection"].is_closed()
 | |
|             }
 | |
|         
 | |
|         return stats
 | |
|     
 | |
|     async def cleanup(self):
 | |
|         """Close all connections and cleanup resources"""
 | |
|         await self.pool.close_all()
 | |
|     
 | |
|     @asynccontextmanager
 | |
|     async def connection_context(self, config: SSHConfig):
 | |
|         """Context manager for SSH connections"""
 | |
|         try:
 | |
|             connection = await self.pool.get_connection(config)
 | |
|             yield connection
 | |
|         except Exception as e:
 | |
|             self.logger.error(f"SSH connection context error: {e}")
 | |
|             raise
 | |
|         # Connection stays in pool for reuse
 | |
| 
 | |
| 
 | |
| # Module-level convenience functions
 | |
| _default_executor = None
 | |
| 
 | |
| def get_default_executor() -> SSHExecutor:
 | |
|     """Get the default SSH executor instance"""
 | |
|     global _default_executor
 | |
|     if _default_executor is None:
 | |
|         _default_executor = SSHExecutor()
 | |
|     return _default_executor
 | |
| 
 | |
| async def execute_ssh_command(host: str, command: str, **kwargs) -> SSHResult:
 | |
|     """Convenience function for simple SSH command execution"""
 | |
|     config = SSHConfig(host=host)
 | |
|     executor = get_default_executor()
 | |
|     return await executor.execute(config, command, **kwargs) |