 8b32d54e79
			
		
	
	8b32d54e79
	
	
	
		
			
			🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
		
			
				
	
	
		
			369 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			369 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Gemini CLI Agent Adapter
 | |
| Provides a standardized interface for executing tasks on Gemini CLI via SSH.
 | |
| """
 | |
| 
 | |
| import asyncio
 | |
| import json
 | |
| import time
 | |
| import logging
 | |
| import hashlib
 | |
| from dataclasses import dataclass, asdict
 | |
| from typing import Dict, Any, Optional, List
 | |
| from enum import Enum
 | |
| 
 | |
| from executors.ssh_executor import SSHExecutor, SSHConfig, SSHResult
 | |
| 
 | |
| 
 | |
| class TaskStatus(Enum):
 | |
|     """Task execution status"""
 | |
|     PENDING = "pending"
 | |
|     RUNNING = "running"
 | |
|     COMPLETED = "completed"
 | |
|     FAILED = "failed"
 | |
|     TIMEOUT = "timeout"
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class GeminiCliConfig:
 | |
|     """Configuration for Gemini CLI agent"""
 | |
|     host: str
 | |
|     node_version: str
 | |
|     model: str = "gemini-2.5-pro"
 | |
|     max_concurrent: int = 2
 | |
|     command_timeout: int = 60
 | |
|     ssh_timeout: int = 5
 | |
|     node_path: Optional[str] = None
 | |
|     gemini_path: Optional[str] = None
 | |
|     
 | |
|     def __post_init__(self):
 | |
|         """Auto-generate paths if not provided"""
 | |
|         if self.node_path is None:
 | |
|             self.node_path = f"/home/tony/.nvm/versions/node/{self.node_version}/bin/node"
 | |
|         if self.gemini_path is None:
 | |
|             self.gemini_path = f"/home/tony/.nvm/versions/node/{self.node_version}/bin/gemini"
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class TaskRequest:
 | |
|     """Represents a task to be executed"""
 | |
|     prompt: str
 | |
|     model: Optional[str] = None
 | |
|     task_id: Optional[str] = None
 | |
|     priority: int = 3
 | |
|     metadata: Optional[Dict[str, Any]] = None
 | |
|     
 | |
|     def __post_init__(self):
 | |
|         """Generate task ID if not provided"""
 | |
|         if self.task_id is None:
 | |
|             # Generate a unique task ID based on prompt and timestamp
 | |
|             content = f"{self.prompt}_{time.time()}"
 | |
|             self.task_id = hashlib.md5(content.encode()).hexdigest()[:12]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class TaskResult:
 | |
|     """Result of a task execution"""
 | |
|     task_id: str
 | |
|     status: TaskStatus
 | |
|     response: Optional[str] = None
 | |
|     error: Optional[str] = None
 | |
|     execution_time: float = 0.0
 | |
|     model: Optional[str] = None
 | |
|     agent_id: Optional[str] = None
 | |
|     metadata: Optional[Dict[str, Any]] = None
 | |
|     
 | |
|     def to_dict(self) -> Dict[str, Any]:
 | |
|         """Convert to dictionary for JSON serialization"""
 | |
|         result = asdict(self)
 | |
|         result['status'] = self.status.value
 | |
|         return result
 | |
| 
 | |
| 
 | |
| class GeminiCliAgent:
 | |
|     """
 | |
|     Adapter for Google Gemini CLI execution via SSH
 | |
|     
 | |
|     Provides a consistent interface for executing AI tasks on remote Gemini CLI installations
 | |
|     while handling SSH connections, environment setup, error recovery, and concurrent execution.
 | |
|     """
 | |
|     
 | |
|     def __init__(self, config: GeminiCliConfig, specialization: str = "general_ai"):
 | |
|         self.config = config
 | |
|         self.specialization = specialization
 | |
|         self.agent_id = f"{config.host}-gemini"
 | |
|         
 | |
|         # SSH configuration
 | |
|         self.ssh_config = SSHConfig(
 | |
|             host=config.host,
 | |
|             connect_timeout=config.ssh_timeout,
 | |
|             command_timeout=config.command_timeout
 | |
|         )
 | |
|         
 | |
|         # SSH executor with connection pooling
 | |
|         self.ssh_executor = SSHExecutor(pool_size=3, persist_timeout=120)
 | |
|         
 | |
|         # Task management
 | |
|         self.active_tasks: Dict[str, asyncio.Task] = {}
 | |
|         self.task_history: List[TaskResult] = []
 | |
|         self.max_history = 100
 | |
|         
 | |
|         # Logging
 | |
|         self.logger = logging.getLogger(f"gemini_cli.{config.host}")
 | |
|         
 | |
|         # Performance tracking
 | |
|         self.stats = {
 | |
|             "total_tasks": 0,
 | |
|             "successful_tasks": 0,
 | |
|             "failed_tasks": 0,
 | |
|             "total_execution_time": 0.0,
 | |
|             "average_execution_time": 0.0
 | |
|         }
 | |
|     
 | |
|     async def execute_task(self, request: TaskRequest) -> TaskResult:
 | |
|         """
 | |
|         Execute a task on the Gemini CLI
 | |
|         
 | |
|         Args:
 | |
|             request: TaskRequest containing prompt and configuration
 | |
|             
 | |
|         Returns:
 | |
|             TaskResult with execution status and response
 | |
|         """
 | |
|         
 | |
|         # Check concurrent task limit
 | |
|         if len(self.active_tasks) >= self.config.max_concurrent:
 | |
|             return TaskResult(
 | |
|                 task_id=request.task_id,
 | |
|                 status=TaskStatus.FAILED,
 | |
|                 error=f"Agent at maximum concurrent tasks ({self.config.max_concurrent})",
 | |
|                 agent_id=self.agent_id
 | |
|             )
 | |
|         
 | |
|         # Start task execution
 | |
|         task = asyncio.create_task(self._execute_task_impl(request))
 | |
|         self.active_tasks[request.task_id] = task
 | |
|         
 | |
|         try:
 | |
|             result = await task
 | |
|             return result
 | |
|         finally:
 | |
|             # Clean up task from active list
 | |
|             self.active_tasks.pop(request.task_id, None)
 | |
|     
 | |
|     async def _execute_task_impl(self, request: TaskRequest) -> TaskResult:
 | |
|         """Internal implementation of task execution"""
 | |
|         start_time = time.time()
 | |
|         model = request.model or self.config.model
 | |
|         
 | |
|         try:
 | |
|             self.logger.info(f"Starting task {request.task_id} with model {model}")
 | |
|             
 | |
|             # Build the CLI command
 | |
|             command = self._build_cli_command(request.prompt, model)
 | |
|             
 | |
|             # Execute via SSH
 | |
|             ssh_result = await self.ssh_executor.execute(self.ssh_config, command)
 | |
|             
 | |
|             execution_time = time.time() - start_time
 | |
|             
 | |
|             # Process result
 | |
|             if ssh_result.returncode == 0:
 | |
|                 result = TaskResult(
 | |
|                     task_id=request.task_id,
 | |
|                     status=TaskStatus.COMPLETED,
 | |
|                     response=self._clean_response(ssh_result.stdout),
 | |
|                     execution_time=execution_time,
 | |
|                     model=model,
 | |
|                     agent_id=self.agent_id,
 | |
|                     metadata={
 | |
|                         "ssh_duration": ssh_result.duration,
 | |
|                         "command": command,
 | |
|                         "stderr": ssh_result.stderr
 | |
|                     }
 | |
|                 )
 | |
|                 self.stats["successful_tasks"] += 1
 | |
|             else:
 | |
|                 result = TaskResult(
 | |
|                     task_id=request.task_id,
 | |
|                     status=TaskStatus.FAILED,
 | |
|                     error=f"CLI execution failed: {ssh_result.stderr}",
 | |
|                     execution_time=execution_time,
 | |
|                     model=model,
 | |
|                     agent_id=self.agent_id,
 | |
|                     metadata={
 | |
|                         "returncode": ssh_result.returncode,
 | |
|                         "command": command,
 | |
|                         "stdout": ssh_result.stdout,
 | |
|                         "stderr": ssh_result.stderr
 | |
|                     }
 | |
|                 )
 | |
|                 self.stats["failed_tasks"] += 1
 | |
|                 
 | |
|         except Exception as e:
 | |
|             execution_time = time.time() - start_time
 | |
|             self.logger.error(f"Task {request.task_id} failed: {e}")
 | |
|             
 | |
|             result = TaskResult(
 | |
|                 task_id=request.task_id,
 | |
|                 status=TaskStatus.FAILED,
 | |
|                 error=str(e),
 | |
|                 execution_time=execution_time,
 | |
|                 model=model,
 | |
|                 agent_id=self.agent_id
 | |
|             )
 | |
|             self.stats["failed_tasks"] += 1
 | |
|         
 | |
|         # Update statistics
 | |
|         self.stats["total_tasks"] += 1
 | |
|         self.stats["total_execution_time"] += execution_time
 | |
|         self.stats["average_execution_time"] = (
 | |
|             self.stats["total_execution_time"] / self.stats["total_tasks"]
 | |
|         )
 | |
|         
 | |
|         # Add to history (with size limit)
 | |
|         self.task_history.append(result)
 | |
|         if len(self.task_history) > self.max_history:
 | |
|             self.task_history.pop(0)
 | |
|         
 | |
|         self.logger.info(f"Task {request.task_id} completed with status {result.status.value}")
 | |
|         return result
 | |
|     
 | |
|     def _build_cli_command(self, prompt: str, model: str) -> str:
 | |
|         """Build the complete CLI command for execution"""
 | |
|         
 | |
|         # Environment setup
 | |
|         env_setup = f"source ~/.nvm/nvm.sh && nvm use {self.config.node_version}"
 | |
|         
 | |
|         # Escape the prompt for shell safety
 | |
|         escaped_prompt = prompt.replace("'", "'\\''")
 | |
|         
 | |
|         # Build gemini command
 | |
|         gemini_cmd = f"echo '{escaped_prompt}' | {self.config.gemini_path} --model {model}"
 | |
|         
 | |
|         # Complete command
 | |
|         full_command = f"{env_setup} && {gemini_cmd}"
 | |
|         
 | |
|         return full_command
 | |
|     
 | |
|     def _clean_response(self, raw_output: str) -> str:
 | |
|         """Clean up the raw CLI output"""
 | |
|         lines = raw_output.strip().split('\n')
 | |
|         
 | |
|         # Remove NVM output lines
 | |
|         cleaned_lines = []
 | |
|         for line in lines:
 | |
|             if not (line.startswith('Now using node') or 
 | |
|                    line.startswith('MCP STDERR') or
 | |
|                    line.strip() == ''):
 | |
|                 cleaned_lines.append(line)
 | |
|         
 | |
|         return '\n'.join(cleaned_lines).strip()
 | |
|     
 | |
|     async def health_check(self) -> Dict[str, Any]:
 | |
|         """Perform a health check on the agent"""
 | |
|         try:
 | |
|             # Test SSH connection
 | |
|             ssh_healthy = await self.ssh_executor.test_connection(self.ssh_config)
 | |
|             
 | |
|             # Test Gemini CLI with a simple prompt
 | |
|             if ssh_healthy:
 | |
|                 test_request = TaskRequest(
 | |
|                     prompt="Say 'health check ok'",
 | |
|                     task_id="health_check"
 | |
|                 )
 | |
|                 result = await self.execute_task(test_request)
 | |
|                 cli_healthy = result.status == TaskStatus.COMPLETED
 | |
|                 response_time = result.execution_time
 | |
|             else:
 | |
|                 cli_healthy = False
 | |
|                 response_time = None
 | |
|             
 | |
|             # Get connection stats
 | |
|             connection_stats = await self.ssh_executor.get_connection_stats()
 | |
|             
 | |
|             return {
 | |
|                 "agent_id": self.agent_id,
 | |
|                 "host": self.config.host,
 | |
|                 "ssh_healthy": ssh_healthy,
 | |
|                 "cli_healthy": cli_healthy,
 | |
|                 "response_time": response_time,
 | |
|                 "active_tasks": len(self.active_tasks),
 | |
|                 "max_concurrent": self.config.max_concurrent,
 | |
|                 "total_tasks": self.stats["total_tasks"],
 | |
|                 "success_rate": (
 | |
|                     self.stats["successful_tasks"] / max(self.stats["total_tasks"], 1)
 | |
|                 ),
 | |
|                 "average_execution_time": self.stats["average_execution_time"],
 | |
|                 "connection_stats": connection_stats,
 | |
|                 "model": self.config.model,
 | |
|                 "specialization": self.specialization
 | |
|             }
 | |
|             
 | |
|         except Exception as e:
 | |
|             self.logger.error(f"Health check failed: {e}")
 | |
|             return {
 | |
|                 "agent_id": self.agent_id,
 | |
|                 "host": self.config.host,
 | |
|                 "ssh_healthy": False,
 | |
|                 "cli_healthy": False,
 | |
|                 "error": str(e)
 | |
|             }
 | |
|     
 | |
|     async def get_task_status(self, task_id: str) -> Optional[TaskResult]:
 | |
|         """Get the status of a specific task"""
 | |
|         # Check active tasks
 | |
|         if task_id in self.active_tasks:
 | |
|             task = self.active_tasks[task_id]
 | |
|             if task.done():
 | |
|                 return task.result()
 | |
|             else:
 | |
|                 return TaskResult(
 | |
|                     task_id=task_id,
 | |
|                     status=TaskStatus.RUNNING,
 | |
|                     agent_id=self.agent_id
 | |
|                 )
 | |
|         
 | |
|         # Check history
 | |
|         for result in reversed(self.task_history):
 | |
|             if result.task_id == task_id:
 | |
|                 return result
 | |
|         
 | |
|         return None
 | |
|     
 | |
|     async def cancel_task(self, task_id: str) -> bool:
 | |
|         """Cancel a running task"""
 | |
|         if task_id in self.active_tasks:
 | |
|             task = self.active_tasks[task_id]
 | |
|             if not task.done():
 | |
|                 task.cancel()
 | |
|                 return True
 | |
|         return False
 | |
|     
 | |
|     def get_statistics(self) -> Dict[str, Any]:
 | |
|         """Get agent performance statistics"""
 | |
|         return {
 | |
|             "agent_id": self.agent_id,
 | |
|             "host": self.config.host,
 | |
|             "specialization": self.specialization,
 | |
|             "model": self.config.model,
 | |
|             "stats": self.stats.copy(),
 | |
|             "active_tasks": len(self.active_tasks),
 | |
|             "history_length": len(self.task_history)
 | |
|         }
 | |
|     
 | |
|     async def cleanup(self):
 | |
|         """Clean up resources"""
 | |
|         # Cancel any active tasks
 | |
|         for task_id, task in list(self.active_tasks.items()):
 | |
|             if not task.done():
 | |
|                 task.cancel()
 | |
|         
 | |
|         # Wait for tasks to complete
 | |
|         if self.active_tasks:
 | |
|             await asyncio.gather(*self.active_tasks.values(), return_exceptions=True)
 | |
|         
 | |
|         # Close SSH connections
 | |
|         await self.ssh_executor.cleanup()
 | |
|         
 | |
|         self.logger.info(f"Agent {self.agent_id} cleaned up successfully") |