🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
380 lines
13 KiB
Python
380 lines
13 KiB
Python
"""
|
|
Unit tests for GeminiCliAgent
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
from dataclasses import dataclass
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|
|
|
from agents.gemini_cli_agent import (
|
|
GeminiCliAgent, GeminiCliConfig, TaskRequest, TaskResult, TaskStatus
|
|
)
|
|
from executors.ssh_executor import SSHResult
|
|
|
|
|
|
class TestGeminiCliAgent:
|
|
|
|
@pytest.fixture
|
|
def agent_config(self):
|
|
return GeminiCliConfig(
|
|
host="test-host",
|
|
node_version="v22.14.0",
|
|
model="gemini-2.5-pro",
|
|
max_concurrent=2,
|
|
command_timeout=30
|
|
)
|
|
|
|
@pytest.fixture
|
|
def agent(self, agent_config):
|
|
return GeminiCliAgent(agent_config, "test_specialty")
|
|
|
|
@pytest.fixture
|
|
def task_request(self):
|
|
return TaskRequest(
|
|
prompt="What is 2+2?",
|
|
task_id="test-task-123"
|
|
)
|
|
|
|
def test_agent_initialization(self, agent_config):
|
|
"""Test agent initialization with proper configuration"""
|
|
agent = GeminiCliAgent(agent_config, "general_ai")
|
|
|
|
assert agent.config.host == "test-host"
|
|
assert agent.config.node_version == "v22.14.0"
|
|
assert agent.specialization == "general_ai"
|
|
assert agent.agent_id == "test-host-gemini"
|
|
assert len(agent.active_tasks) == 0
|
|
assert agent.stats["total_tasks"] == 0
|
|
|
|
def test_config_auto_paths(self):
|
|
"""Test automatic path generation in config"""
|
|
config = GeminiCliConfig(
|
|
host="walnut",
|
|
node_version="v22.14.0"
|
|
)
|
|
|
|
expected_node_path = "/home/tony/.nvm/versions/node/v22.14.0/bin/node"
|
|
expected_gemini_path = "/home/tony/.nvm/versions/node/v22.14.0/bin/gemini"
|
|
|
|
assert config.node_path == expected_node_path
|
|
assert config.gemini_path == expected_gemini_path
|
|
|
|
def test_build_cli_command(self, agent):
|
|
"""Test CLI command building"""
|
|
prompt = "What is Python?"
|
|
model = "gemini-2.5-pro"
|
|
|
|
command = agent._build_cli_command(prompt, model)
|
|
|
|
assert "source ~/.nvm/nvm.sh" in command
|
|
assert "nvm use v22.14.0" in command
|
|
assert "gemini --model gemini-2.5-pro" in command
|
|
assert "What is Python?" in command
|
|
|
|
def test_build_cli_command_escaping(self, agent):
|
|
"""Test CLI command with special characters"""
|
|
prompt = "What's the meaning of 'life'?"
|
|
model = "gemini-2.5-pro"
|
|
|
|
command = agent._build_cli_command(prompt, model)
|
|
|
|
# Should properly escape single quotes
|
|
assert "What\\'s the meaning of \\'life\\'?" in command
|
|
|
|
def test_clean_response(self, agent):
|
|
"""Test response cleaning"""
|
|
raw_output = """Now using node v22.14.0 (npm v11.3.0)
|
|
MCP STDERR (hive): Warning message
|
|
|
|
This is the actual response
|
|
from Gemini CLI
|
|
|
|
"""
|
|
|
|
cleaned = agent._clean_response(raw_output)
|
|
expected = "This is the actual response\nfrom Gemini CLI"
|
|
|
|
assert cleaned == expected
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_task_success(self, agent, task_request, mocker):
|
|
"""Test successful task execution"""
|
|
# Mock SSH executor
|
|
mock_ssh_result = SSHResult(
|
|
stdout="Now using node v22.14.0\n4\n",
|
|
stderr="",
|
|
returncode=0,
|
|
duration=1.5,
|
|
host="test-host",
|
|
command="test-command"
|
|
)
|
|
|
|
mock_execute = AsyncMock(return_value=mock_ssh_result)
|
|
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
|
|
|
result = await agent.execute_task(task_request)
|
|
|
|
assert result.status == TaskStatus.COMPLETED
|
|
assert result.task_id == "test-task-123"
|
|
assert result.response == "4"
|
|
assert result.execution_time > 0
|
|
assert result.model == "gemini-2.5-pro"
|
|
assert result.agent_id == "test-host-gemini"
|
|
|
|
# Check statistics update
|
|
assert agent.stats["successful_tasks"] == 1
|
|
assert agent.stats["total_tasks"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_task_failure(self, agent, task_request, mocker):
|
|
"""Test task execution failure handling"""
|
|
mock_ssh_result = SSHResult(
|
|
stdout="",
|
|
stderr="Command failed: invalid model",
|
|
returncode=1,
|
|
duration=0.5,
|
|
host="test-host",
|
|
command="test-command"
|
|
)
|
|
|
|
mock_execute = AsyncMock(return_value=mock_ssh_result)
|
|
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
|
|
|
result = await agent.execute_task(task_request)
|
|
|
|
assert result.status == TaskStatus.FAILED
|
|
assert "CLI execution failed" in result.error
|
|
assert result.execution_time > 0
|
|
|
|
# Check statistics update
|
|
assert agent.stats["failed_tasks"] == 1
|
|
assert agent.stats["total_tasks"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_task_exception(self, agent, task_request, mocker):
|
|
"""Test task execution with exception"""
|
|
mock_execute = AsyncMock(side_effect=Exception("SSH connection failed"))
|
|
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
|
|
|
result = await agent.execute_task(task_request)
|
|
|
|
assert result.status == TaskStatus.FAILED
|
|
assert "SSH connection failed" in result.error
|
|
assert result.execution_time > 0
|
|
|
|
# Check statistics update
|
|
assert agent.stats["failed_tasks"] == 1
|
|
assert agent.stats["total_tasks"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_task_limit(self, agent, mocker):
|
|
"""Test concurrent task execution limits"""
|
|
# Mock a slow SSH execution
|
|
slow_ssh_result = SSHResult(
|
|
stdout="result",
|
|
stderr="",
|
|
returncode=0,
|
|
duration=2.0,
|
|
host="test-host",
|
|
command="test-command"
|
|
)
|
|
|
|
async def slow_execute(*args, **kwargs):
|
|
await asyncio.sleep(0.1) # Simulate slow execution
|
|
return slow_ssh_result
|
|
|
|
mock_execute = AsyncMock(side_effect=slow_execute)
|
|
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
|
|
|
# Start maximum concurrent tasks
|
|
task1 = TaskRequest(prompt="Task 1", task_id="task-1")
|
|
task2 = TaskRequest(prompt="Task 2", task_id="task-2")
|
|
task3 = TaskRequest(prompt="Task 3", task_id="task-3")
|
|
|
|
# Start first two tasks (should succeed)
|
|
result1_coro = agent.execute_task(task1)
|
|
result2_coro = agent.execute_task(task2)
|
|
|
|
# Give tasks time to start
|
|
await asyncio.sleep(0.01)
|
|
|
|
# Third task should fail due to limit
|
|
result3 = await agent.execute_task(task3)
|
|
assert result3.status == TaskStatus.FAILED
|
|
assert "maximum concurrent tasks" in result3.error
|
|
|
|
# Wait for first two to complete
|
|
result1 = await result1_coro
|
|
result2 = await result2_coro
|
|
|
|
assert result1.status == TaskStatus.COMPLETED
|
|
assert result2.status == TaskStatus.COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_success(self, agent, mocker):
|
|
"""Test successful health check"""
|
|
# Mock SSH connection test
|
|
mock_test_connection = AsyncMock(return_value=True)
|
|
mocker.patch.object(agent.ssh_executor, 'test_connection', mock_test_connection)
|
|
|
|
# Mock successful CLI execution
|
|
mock_ssh_result = SSHResult(
|
|
stdout="health check ok\n",
|
|
stderr="",
|
|
returncode=0,
|
|
duration=1.0,
|
|
host="test-host",
|
|
command="test-command"
|
|
)
|
|
mock_execute = AsyncMock(return_value=mock_ssh_result)
|
|
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
|
|
|
# Mock connection stats
|
|
mock_get_stats = AsyncMock(return_value={"total_connections": 1})
|
|
mocker.patch.object(agent.ssh_executor, 'get_connection_stats', mock_get_stats)
|
|
|
|
health = await agent.health_check()
|
|
|
|
assert health["agent_id"] == "test-host-gemini"
|
|
assert health["ssh_healthy"] is True
|
|
assert health["cli_healthy"] is True
|
|
assert health["response_time"] > 0
|
|
assert health["active_tasks"] == 0
|
|
assert health["max_concurrent"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_failure(self, agent, mocker):
|
|
"""Test health check with failures"""
|
|
# Mock SSH connection failure
|
|
mock_test_connection = AsyncMock(return_value=False)
|
|
mocker.patch.object(agent.ssh_executor, 'test_connection', mock_test_connection)
|
|
|
|
health = await agent.health_check()
|
|
|
|
assert health["ssh_healthy"] is False
|
|
assert health["cli_healthy"] is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_task_status_tracking(self, agent, mocker):
|
|
"""Test task status tracking"""
|
|
# Mock SSH execution
|
|
mock_ssh_result = SSHResult(
|
|
stdout="result\n",
|
|
stderr="",
|
|
returncode=0,
|
|
duration=1.0,
|
|
host="test-host",
|
|
command="test-command"
|
|
)
|
|
mock_execute = AsyncMock(return_value=mock_ssh_result)
|
|
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
|
|
|
task_request = TaskRequest(prompt="Test", task_id="status-test")
|
|
|
|
# Execute task
|
|
result = await agent.execute_task(task_request)
|
|
|
|
# Check task in history
|
|
status = await agent.get_task_status("status-test")
|
|
assert status is not None
|
|
assert status.status == TaskStatus.COMPLETED
|
|
assert status.task_id == "status-test"
|
|
|
|
# Check non-existent task
|
|
status = await agent.get_task_status("non-existent")
|
|
assert status is None
|
|
|
|
def test_statistics(self, agent):
|
|
"""Test statistics tracking"""
|
|
stats = agent.get_statistics()
|
|
|
|
assert stats["agent_id"] == "test-host-gemini"
|
|
assert stats["host"] == "test-host"
|
|
assert stats["specialization"] == "test_specialty"
|
|
assert stats["model"] == "gemini-2.5-pro"
|
|
assert stats["stats"]["total_tasks"] == 0
|
|
assert stats["active_tasks"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_task_cancellation(self, agent, mocker):
|
|
"""Test task cancellation"""
|
|
# Mock a long-running SSH execution
|
|
async def long_execute(*args, **kwargs):
|
|
await asyncio.sleep(10) # Long execution
|
|
return SSHResult("", "", 0, 10.0, "test-host", "cmd")
|
|
|
|
mock_execute = AsyncMock(side_effect=long_execute)
|
|
mocker.patch.object(agent.ssh_executor, 'execute', mock_execute)
|
|
|
|
task_request = TaskRequest(prompt="Long task", task_id="cancel-test")
|
|
|
|
# Start task
|
|
task_coro = agent.execute_task(task_request)
|
|
|
|
# Let it start
|
|
await asyncio.sleep(0.01)
|
|
|
|
# Cancel it
|
|
cancelled = await agent.cancel_task("cancel-test")
|
|
assert cancelled is True
|
|
|
|
# The task should be cancelled
|
|
try:
|
|
await task_coro
|
|
except asyncio.CancelledError:
|
|
pass # Expected
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup(self, agent, mocker):
|
|
"""Test agent cleanup"""
|
|
# Mock SSH executor cleanup
|
|
mock_cleanup = AsyncMock()
|
|
mocker.patch.object(agent.ssh_executor, 'cleanup', mock_cleanup)
|
|
|
|
await agent.cleanup()
|
|
|
|
mock_cleanup.assert_called_once()
|
|
|
|
|
|
class TestTaskRequest:
|
|
|
|
def test_task_request_auto_id(self):
|
|
"""Test automatic task ID generation"""
|
|
request = TaskRequest(prompt="Test prompt")
|
|
|
|
assert request.task_id is not None
|
|
assert len(request.task_id) == 12 # MD5 hash truncated to 12 chars
|
|
|
|
def test_task_request_custom_id(self):
|
|
"""Test custom task ID"""
|
|
request = TaskRequest(prompt="Test", task_id="custom-123")
|
|
|
|
assert request.task_id == "custom-123"
|
|
|
|
|
|
class TestTaskResult:
|
|
|
|
def test_task_result_to_dict(self):
|
|
"""Test TaskResult serialization"""
|
|
result = TaskResult(
|
|
task_id="test-123",
|
|
status=TaskStatus.COMPLETED,
|
|
response="Test response",
|
|
execution_time=1.5,
|
|
model="gemini-2.5-pro",
|
|
agent_id="test-agent"
|
|
)
|
|
|
|
result_dict = result.to_dict()
|
|
|
|
assert result_dict["task_id"] == "test-123"
|
|
assert result_dict["status"] == "completed"
|
|
assert result_dict["response"] == "Test response"
|
|
assert result_dict["execution_time"] == 1.5
|
|
assert result_dict["model"] == "gemini-2.5-pro"
|
|
assert result_dict["agent_id"] == "test-agent" |