 8b32d54e79
			
		
	
	8b32d54e79
	
	
	
		
			
			🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
		
			
				
	
	
		
			380 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			380 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Unit tests for GeminiCliAgent
 | |
| """
 | |
| 
 | |
| import pytest
 | |
| import asyncio
 | |
| from unittest.mock import Mock, AsyncMock, patch
 | |
| from dataclasses import dataclass
 | |
| 
 | |
| import sys
 | |
| import os
 | |
| sys.path.append(os.path.dirname(os.path.dirname(__file__)))
 | |
| 
 | |
| from agents.gemini_cli_agent import (
 | |
|     GeminiCliAgent, GeminiCliConfig, TaskRequest, TaskResult, TaskStatus
 | |
| )
 | |
| from executors.ssh_executor import SSHResult
 | |
| 
 | |
| 
 | |
| class TestGeminiCliAgent:
 | |
|     
 | |
|     @pytest.fixture
 | |
|     def agent_config(self):
 | |
|         return GeminiCliConfig(
 | |
|             host="test-host",
 | |
|             node_version="v22.14.0",
 | |
|             model="gemini-2.5-pro",
 | |
|             max_concurrent=2,
 | |
|             command_timeout=30
 | |
|         )
 | |
|     
 | |
|     @pytest.fixture
 | |
|     def agent(self, agent_config):
 | |
|         return GeminiCliAgent(agent_config, "test_specialty")
 | |
|     
 | |
|     @pytest.fixture
 | |
|     def task_request(self):
 | |
|         return TaskRequest(
 | |
|             prompt="What is 2+2?",
 | |
|             task_id="test-task-123"
 | |
|         )
 | |
|     
 | |
|     def test_agent_initialization(self, agent_config):
 | |
|         """Test agent initialization with proper configuration"""
 | |
|         agent = GeminiCliAgent(agent_config, "general_ai")
 | |
|         
 | |
|         assert agent.config.host == "test-host"
 | |
|         assert agent.config.node_version == "v22.14.0"
 | |
|         assert agent.specialization == "general_ai"
 | |
|         assert agent.agent_id == "test-host-gemini"
 | |
|         assert len(agent.active_tasks) == 0
 | |
|         assert agent.stats["total_tasks"] == 0
 | |
|     
 | |
|     def test_config_auto_paths(self):
 | |
|         """Test automatic path generation in config"""
 | |
|         config = GeminiCliConfig(
 | |
|             host="walnut",
 | |
|             node_version="v22.14.0"
 | |
|         )
 | |
|         
 | |
|         expected_node_path = "/home/tony/.nvm/versions/node/v22.14.0/bin/node"
 | |
|         expected_gemini_path = "/home/tony/.nvm/versions/node/v22.14.0/bin/gemini"
 | |
|         
 | |
|         assert config.node_path == expected_node_path
 | |
|         assert config.gemini_path == expected_gemini_path
 | |
|     
 | |
|     def test_build_cli_command(self, agent):
 | |
|         """Test CLI command building"""
 | |
|         prompt = "What is Python?"
 | |
|         model = "gemini-2.5-pro"
 | |
|         
 | |
|         command = agent._build_cli_command(prompt, model)
 | |
|         
 | |
|         assert "source ~/.nvm/nvm.sh" in command
 | |
|         assert "nvm use v22.14.0" in command
 | |
|         assert "gemini --model gemini-2.5-pro" in command
 | |
|         assert "What is Python?" in command
 | |
|     
 | |
|     def test_build_cli_command_escaping(self, agent):
 | |
|         """Test CLI command with special characters"""
 | |
|         prompt = "What's the meaning of 'life'?"
 | |
|         model = "gemini-2.5-pro"
 | |
|         
 | |
|         command = agent._build_cli_command(prompt, model)
 | |
|         
 | |
|         # Should properly escape single quotes
 | |
|         assert "What\\'s the meaning of \\'life\\'?" in command
 | |
|     
 | |
|     def test_clean_response(self, agent):
 | |
|         """Test response cleaning"""
 | |
|         raw_output = """Now using node v22.14.0 (npm v11.3.0)
 | |
| MCP STDERR (hive): Warning message
 | |
| 
 | |
| This is the actual response
 | |
| from Gemini CLI
 | |
| 
 | |
| """
 | |
|         
 | |
|         cleaned = agent._clean_response(raw_output)
 | |
|         expected = "This is the actual response\nfrom Gemini CLI"
 | |
|         
 | |
|         assert cleaned == expected
 | |
|     
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_execute_task_success(self, agent, task_request, mocker):
 | |
|         """Test successful task execution"""
 | |
|         # Mock SSH executor
 | |
|         mock_ssh_result = SSHResult(
 | |
|             stdout="Now using node v22.14.0\n4\n",
 | |
|             stderr="",
 | |
|             returncode=0,
 | |
|             duration=1.5,
 | |
|             host="test-host",
 | |
|             command="test-command"
 | |
|         )
 | |
|         
 | |
|         mock_execute = AsyncMock(return_value=mock_ssh_result)
 | |
|         mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
 | |
|         
 | |
|         result = await agent.execute_task(task_request)
 | |
|         
 | |
|         assert result.status == TaskStatus.COMPLETED
 | |
|         assert result.task_id == "test-task-123"
 | |
|         assert result.response == "4"
 | |
|         assert result.execution_time > 0
 | |
|         assert result.model == "gemini-2.5-pro"
 | |
|         assert result.agent_id == "test-host-gemini"
 | |
|         
 | |
|         # Check statistics update
 | |
|         assert agent.stats["successful_tasks"] == 1
 | |
|         assert agent.stats["total_tasks"] == 1
 | |
|     
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_execute_task_failure(self, agent, task_request, mocker):
 | |
|         """Test task execution failure handling"""
 | |
|         mock_ssh_result = SSHResult(
 | |
|             stdout="",
 | |
|             stderr="Command failed: invalid model",
 | |
|             returncode=1,
 | |
|             duration=0.5,
 | |
|             host="test-host",
 | |
|             command="test-command"
 | |
|         )
 | |
|         
 | |
|         mock_execute = AsyncMock(return_value=mock_ssh_result)
 | |
|         mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
 | |
|         
 | |
|         result = await agent.execute_task(task_request)
 | |
|         
 | |
|         assert result.status == TaskStatus.FAILED
 | |
|         assert "CLI execution failed" in result.error
 | |
|         assert result.execution_time > 0
 | |
|         
 | |
|         # Check statistics update
 | |
|         assert agent.stats["failed_tasks"] == 1
 | |
|         assert agent.stats["total_tasks"] == 1
 | |
|     
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_execute_task_exception(self, agent, task_request, mocker):
 | |
|         """Test task execution with exception"""
 | |
|         mock_execute = AsyncMock(side_effect=Exception("SSH connection failed"))
 | |
|         mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
 | |
|         
 | |
|         result = await agent.execute_task(task_request)
 | |
|         
 | |
|         assert result.status == TaskStatus.FAILED
 | |
|         assert "SSH connection failed" in result.error
 | |
|         assert result.execution_time > 0
 | |
|         
 | |
|         # Check statistics update
 | |
|         assert agent.stats["failed_tasks"] == 1
 | |
|         assert agent.stats["total_tasks"] == 1
 | |
|     
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_concurrent_task_limit(self, agent, mocker):
 | |
|         """Test concurrent task execution limits"""
 | |
|         # Mock a slow SSH execution
 | |
|         slow_ssh_result = SSHResult(
 | |
|             stdout="result",
 | |
|             stderr="",
 | |
|             returncode=0,
 | |
|             duration=2.0,
 | |
|             host="test-host",
 | |
|             command="test-command"
 | |
|         )
 | |
|         
 | |
|         async def slow_execute(*args, **kwargs):
 | |
|             await asyncio.sleep(0.1)  # Simulate slow execution
 | |
|             return slow_ssh_result
 | |
|         
 | |
|         mock_execute = AsyncMock(side_effect=slow_execute)
 | |
|         mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
 | |
|         
 | |
|         # Start maximum concurrent tasks
 | |
|         task1 = TaskRequest(prompt="Task 1", task_id="task-1")
 | |
|         task2 = TaskRequest(prompt="Task 2", task_id="task-2")
 | |
|         task3 = TaskRequest(prompt="Task 3", task_id="task-3")
 | |
|         
 | |
|         # Start first two tasks (should succeed)
 | |
|         result1_coro = agent.execute_task(task1)
 | |
|         result2_coro = agent.execute_task(task2)
 | |
|         
 | |
|         # Give tasks time to start
 | |
|         await asyncio.sleep(0.01)
 | |
|         
 | |
|         # Third task should fail due to limit
 | |
|         result3 = await agent.execute_task(task3)
 | |
|         assert result3.status == TaskStatus.FAILED
 | |
|         assert "maximum concurrent tasks" in result3.error
 | |
|         
 | |
|         # Wait for first two to complete
 | |
|         result1 = await result1_coro
 | |
|         result2 = await result2_coro
 | |
|         
 | |
|         assert result1.status == TaskStatus.COMPLETED
 | |
|         assert result2.status == TaskStatus.COMPLETED
 | |
|     
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_health_check_success(self, agent, mocker):
 | |
|         """Test successful health check"""
 | |
|         # Mock SSH connection test
 | |
|         mock_test_connection = AsyncMock(return_value=True)
 | |
|         mocker.patch.object(agent.ssh_executor, 'test_connection', mock_test_connection)
 | |
|         
 | |
|         # Mock successful CLI execution
 | |
|         mock_ssh_result = SSHResult(
 | |
|             stdout="health check ok\n",
 | |
|             stderr="",
 | |
|             returncode=0,
 | |
|             duration=1.0,
 | |
|             host="test-host",
 | |
|             command="test-command"
 | |
|         )
 | |
|         mock_execute = AsyncMock(return_value=mock_ssh_result)
 | |
|         mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
 | |
|         
 | |
|         # Mock connection stats
 | |
|         mock_get_stats = AsyncMock(return_value={"total_connections": 1})
 | |
|         mocker.patch.object(agent.ssh_executor, 'get_connection_stats', mock_get_stats)
 | |
|         
 | |
|         health = await agent.health_check()
 | |
|         
 | |
|         assert health["agent_id"] == "test-host-gemini"
 | |
|         assert health["ssh_healthy"] is True
 | |
|         assert health["cli_healthy"] is True
 | |
|         assert health["response_time"] > 0
 | |
|         assert health["active_tasks"] == 0
 | |
|         assert health["max_concurrent"] == 2
 | |
|     
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_health_check_failure(self, agent, mocker):
 | |
|         """Test health check with failures"""
 | |
|         # Mock SSH connection failure
 | |
|         mock_test_connection = AsyncMock(return_value=False)
 | |
|         mocker.patch.object(agent.ssh_executor, 'test_connection', mock_test_connection)
 | |
|         
 | |
|         health = await agent.health_check()
 | |
|         
 | |
|         assert health["ssh_healthy"] is False
 | |
|         assert health["cli_healthy"] is False
 | |
|     
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_task_status_tracking(self, agent, mocker):
 | |
|         """Test task status tracking"""
 | |
|         # Mock SSH execution
 | |
|         mock_ssh_result = SSHResult(
 | |
|             stdout="result\n",
 | |
|             stderr="",
 | |
|             returncode=0,
 | |
|             duration=1.0,
 | |
|             host="test-host",
 | |
|             command="test-command"
 | |
|         )
 | |
|         mock_execute = AsyncMock(return_value=mock_ssh_result)
 | |
|         mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
 | |
|         
 | |
|         task_request = TaskRequest(prompt="Test", task_id="status-test")
 | |
|         
 | |
|         # Execute task
 | |
|         result = await agent.execute_task(task_request)
 | |
|         
 | |
|         # Check task in history
 | |
|         status = await agent.get_task_status("status-test")
 | |
|         assert status is not None
 | |
|         assert status.status == TaskStatus.COMPLETED
 | |
|         assert status.task_id == "status-test"
 | |
|         
 | |
|         # Check non-existent task
 | |
|         status = await agent.get_task_status("non-existent")
 | |
|         assert status is None
 | |
|     
 | |
|     def test_statistics(self, agent):
 | |
|         """Test statistics tracking"""
 | |
|         stats = agent.get_statistics()
 | |
|         
 | |
|         assert stats["agent_id"] == "test-host-gemini"
 | |
|         assert stats["host"] == "test-host"
 | |
|         assert stats["specialization"] == "test_specialty"
 | |
|         assert stats["model"] == "gemini-2.5-pro"
 | |
|         assert stats["stats"]["total_tasks"] == 0
 | |
|         assert stats["active_tasks"] == 0
 | |
|     
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_task_cancellation(self, agent, mocker):
 | |
|         """Test task cancellation"""
 | |
|         # Mock a long-running SSH execution
 | |
|         async def long_execute(*args, **kwargs):
 | |
|             await asyncio.sleep(10)  # Long execution
 | |
|             return SSHResult("", "", 0, 10.0, "test-host", "cmd")
 | |
|         
 | |
|         mock_execute = AsyncMock(side_effect=long_execute)
 | |
|         mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
 | |
|         
 | |
|         task_request = TaskRequest(prompt="Long task", task_id="cancel-test")
 | |
|         
 | |
|         # Start task
 | |
|         task_coro = agent.execute_task(task_request)
 | |
|         
 | |
|         # Let it start
 | |
|         await asyncio.sleep(0.01)
 | |
|         
 | |
|         # Cancel it
 | |
|         cancelled = await agent.cancel_task("cancel-test")
 | |
|         assert cancelled is True
 | |
|         
 | |
|         # The task should be cancelled
 | |
|         try:
 | |
|             await task_coro
 | |
|         except asyncio.CancelledError:
 | |
|             pass  # Expected
 | |
|     
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_cleanup(self, agent, mocker):
 | |
|         """Test agent cleanup"""
 | |
|         # Mock SSH executor cleanup
 | |
|         mock_cleanup = AsyncMock()
 | |
|         mocker.patch.object(agent.ssh_executor, 'cleanup', mock_cleanup)
 | |
|         
 | |
|         await agent.cleanup()
 | |
|         
 | |
|         mock_cleanup.assert_called_once()
 | |
| 
 | |
| 
 | |
| class TestTaskRequest:
 | |
|     
 | |
|     def test_task_request_auto_id(self):
 | |
|         """Test automatic task ID generation"""
 | |
|         request = TaskRequest(prompt="Test prompt")
 | |
|         
 | |
|         assert request.task_id is not None
 | |
|         assert len(request.task_id) == 12  # MD5 hash truncated to 12 chars
 | |
|     
 | |
|     def test_task_request_custom_id(self):
 | |
|         """Test custom task ID"""
 | |
|         request = TaskRequest(prompt="Test", task_id="custom-123")
 | |
|         
 | |
|         assert request.task_id == "custom-123"
 | |
| 
 | |
| 
 | |
| class TestTaskResult:
 | |
|     
 | |
|     def test_task_result_to_dict(self):
 | |
|         """Test TaskResult serialization"""
 | |
|         result = TaskResult(
 | |
|             task_id="test-123",
 | |
|             status=TaskStatus.COMPLETED,
 | |
|             response="Test response",
 | |
|             execution_time=1.5,
 | |
|             model="gemini-2.5-pro",
 | |
|             agent_id="test-agent"
 | |
|         )
 | |
|         
 | |
|         result_dict = result.to_dict()
 | |
|         
 | |
|         assert result_dict["task_id"] == "test-123"
 | |
|         assert result_dict["status"] == "completed"
 | |
|         assert result_dict["response"] == "Test response"
 | |
|         assert result_dict["execution_time"] == 1.5
 | |
|         assert result_dict["model"] == "gemini-2.5-pro"
 | |
|         assert result_dict["agent_id"] == "test-agent" |