package execution import ( "context" "testing" "time" "chorus/pkg/ai" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) // MockProvider implements ai.ModelProvider for testing type MockProvider struct { mock.Mock } func (m *MockProvider) ExecuteTask(ctx context.Context, request *ai.TaskRequest) (*ai.TaskResponse, error) { args := m.Called(ctx, request) return args.Get(0).(*ai.TaskResponse), args.Error(1) } func (m *MockProvider) GetCapabilities() ai.ProviderCapabilities { args := m.Called() return args.Get(0).(ai.ProviderCapabilities) } func (m *MockProvider) ValidateConfig() error { args := m.Called() return args.Error(0) } func (m *MockProvider) GetProviderInfo() ai.ProviderInfo { args := m.Called() return args.Get(0).(ai.ProviderInfo) } // MockProviderFactory for testing type MockProviderFactory struct { mock.Mock provider ai.ModelProvider config ai.ProviderConfig } func (m *MockProviderFactory) GetProviderForRole(role string) (ai.ModelProvider, ai.ProviderConfig, error) { args := m.Called(role) return args.Get(0).(ai.ModelProvider), args.Get(1).(ai.ProviderConfig), args.Error(2) } func (m *MockProviderFactory) GetProvider(name string) (ai.ModelProvider, error) { args := m.Called(name) return args.Get(0).(ai.ModelProvider), args.Error(1) } func (m *MockProviderFactory) ListProviders() []string { args := m.Called() return args.Get(0).([]string) } func (m *MockProviderFactory) GetHealthStatus() map[string]bool { args := m.Called() return args.Get(0).(map[string]bool) } func TestNewTaskExecutionEngine(t *testing.T) { engine := NewTaskExecutionEngine() assert.NotNil(t, engine) assert.NotNil(t, engine.metrics) assert.NotNil(t, engine.activeTasks) assert.NotNil(t, engine.logger) } func TestTaskExecutionEngine_Initialize(t *testing.T) { engine := NewTaskExecutionEngine() tests := []struct { name string config *EngineConfig expectError bool }{ { name: "nil config", config: nil, expectError: true, }, { name: "missing AI factory", config: &EngineConfig{ DefaultTimeout: 1 * time.Minute, }, expectError: true, }, { name: "valid config", config: &EngineConfig{ AIProviderFactory: &MockProviderFactory{}, DefaultTimeout: 1 * time.Minute, }, expectError: false, }, { name: "config with defaults", config: &EngineConfig{ AIProviderFactory: &MockProviderFactory{}, }, expectError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := engine.Initialize(context.Background(), tt.config) if tt.expectError { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tt.config, engine.config) // Check defaults are set if tt.config.DefaultTimeout == 0 { assert.Equal(t, 5*time.Minute, engine.config.DefaultTimeout) } if tt.config.MaxConcurrentTasks == 0 { assert.Equal(t, 10, engine.config.MaxConcurrentTasks) } } }) } } func TestTaskExecutionEngine_ExecuteTask_SimpleResponse(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } engine := NewTaskExecutionEngine() // Setup mock AI provider mockProvider := &MockProvider{} mockFactory := &MockProviderFactory{} // Configure mock responses mockProvider.On("ExecuteTask", mock.Anything, mock.Anything).Return( &ai.TaskResponse{ TaskID: "test-123", Content: "Task completed successfully", Success: true, Actions: []ai.ActionResult{}, Metadata: map[string]interface{}{}, }, nil) mockFactory.On("GetProviderForRole", "general").Return( mockProvider, ai.ProviderConfig{ Provider: "mock", Model: "test-model", }, nil) config := &EngineConfig{ AIProviderFactory: mockFactory, DefaultTimeout: 30 * time.Second, EnableMetrics: true, } err := engine.Initialize(context.Background(), config) require.NoError(t, err) // Execute simple task (no sandbox commands) request := &TaskExecutionRequest{ ID: "test-123", Type: "analysis", Description: "Analyze the given data", Context: map[string]interface{}{"data": "sample data"}, } ctx := context.Background() result, err := engine.ExecuteTask(ctx, request) require.NoError(t, err) assert.True(t, result.Success) assert.Equal(t, "test-123", result.TaskID) assert.Contains(t, result.Output, "Task completed successfully") assert.NotNil(t, result.Metrics) assert.False(t, result.Metrics.StartTime.IsZero()) assert.False(t, result.Metrics.EndTime.IsZero()) assert.Greater(t, result.Metrics.Duration, time.Duration(0)) // Verify mocks were called mockProvider.AssertCalled(t, "ExecuteTask", mock.Anything, mock.Anything) mockFactory.AssertCalled(t, "GetProviderForRole", "general") } func TestTaskExecutionEngine_ExecuteTask_WithCommands(t *testing.T) { if testing.Short() { t.Skip("Skipping Docker integration test in short mode") } engine := NewTaskExecutionEngine() // Setup mock AI provider with commands mockProvider := &MockProvider{} mockFactory := &MockProviderFactory{} // Configure mock to return commands mockProvider.On("ExecuteTask", mock.Anything, mock.Anything).Return( &ai.TaskResponse{ TaskID: "test-456", Content: "Executing commands", Success: true, Actions: []ai.ActionResult{ { Type: "command", Content: map[string]interface{}{ "command": "echo 'Hello World'", }, }, { Type: "file", Content: map[string]interface{}{ "name": "test.txt", "content": "Test file content", }, }, }, Metadata: map[string]interface{}{}, }, nil) mockFactory.On("GetProviderForRole", "developer").Return( mockProvider, ai.ProviderConfig{ Provider: "mock", Model: "test-model", }, nil) config := &EngineConfig{ AIProviderFactory: mockFactory, DefaultTimeout: 1 * time.Minute, SandboxDefaults: &SandboxConfig{ Type: "docker", Image: "alpine:latest", Resources: ResourceLimits{ MemoryLimit: 256 * 1024 * 1024, CPULimit: 0.5, }, Security: SecurityPolicy{ NoNewPrivileges: true, AllowNetworking: false, }, }, } err := engine.Initialize(context.Background(), config) require.NoError(t, err) // Execute task with commands request := &TaskExecutionRequest{ ID: "test-456", Type: "code_generation", Description: "Generate a simple script", Timeout: 2 * time.Minute, } ctx := context.Background() result, err := engine.ExecuteTask(ctx, request) if err != nil { // If Docker is not available, skip this test t.Skipf("Docker not available for sandbox testing: %v", err) } require.NoError(t, err) assert.True(t, result.Success) assert.Equal(t, "test-456", result.TaskID) assert.NotEmpty(t, result.Output) assert.GreaterOrEqual(t, len(result.Artifacts), 1) // At least the file artifact assert.Equal(t, 1, result.Metrics.CommandsExecuted) assert.Greater(t, result.Metrics.SandboxTime, time.Duration(0)) // Check artifacts var foundTestFile bool for _, artifact := range result.Artifacts { if artifact.Name == "test.txt" { foundTestFile = true assert.Equal(t, "file", artifact.Type) assert.Equal(t, "Test file content", string(artifact.Content)) } } assert.True(t, foundTestFile, "Expected test.txt artifact not found") } func TestTaskExecutionEngine_DetermineRoleFromTask(t *testing.T) { engine := NewTaskExecutionEngine() tests := []struct { name string request *TaskExecutionRequest expectedRole string }{ { name: "code task", request: &TaskExecutionRequest{ Type: "code_generation", Description: "Write a function to sort array", }, expectedRole: "developer", }, { name: "analysis task", request: &TaskExecutionRequest{ Type: "analysis", Description: "Analyze the performance metrics", }, expectedRole: "analyst", }, { name: "test task", request: &TaskExecutionRequest{ Type: "testing", Description: "Write tests for the function", }, expectedRole: "tester", }, { name: "program task by description", request: &TaskExecutionRequest{ Type: "general", Description: "Create a program that processes data", }, expectedRole: "developer", }, { name: "review task by description", request: &TaskExecutionRequest{ Type: "general", Description: "Review the code quality", }, expectedRole: "analyst", }, { name: "general task", request: &TaskExecutionRequest{ Type: "documentation", Description: "Write user documentation", }, expectedRole: "general", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { role := engine.determineRoleFromTask(tt.request) assert.Equal(t, tt.expectedRole, role) }) } } func TestTaskExecutionEngine_ParseAIResponse(t *testing.T) { engine := NewTaskExecutionEngine() tests := []struct { name string response *ai.TaskResponse expectedCommands int expectedArtifacts int }{ { name: "response with commands and files", response: &ai.TaskResponse{ Actions: []ai.ActionResult{ { Type: "command", Content: map[string]interface{}{ "command": "ls -la", }, }, { Type: "command", Content: map[string]interface{}{ "command": "echo 'test'", }, }, { Type: "file", Content: map[string]interface{}{ "name": "script.sh", "content": "#!/bin/bash\necho 'Hello'", }, }, }, }, expectedCommands: 2, expectedArtifacts: 1, }, { name: "response with no actions", response: &ai.TaskResponse{ Actions: []ai.ActionResult{}, }, expectedCommands: 0, expectedArtifacts: 0, }, { name: "response with unknown action types", response: &ai.TaskResponse{ Actions: []ai.ActionResult{ { Type: "unknown", Content: map[string]interface{}{ "data": "some data", }, }, }, }, expectedCommands: 0, expectedArtifacts: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { commands, artifacts, err := engine.parseAIResponse(tt.response) require.NoError(t, err) assert.Len(t, commands, tt.expectedCommands) assert.Len(t, artifacts, tt.expectedArtifacts) // Validate artifact content if present for _, artifact := range artifacts { assert.NotEmpty(t, artifact.Name) assert.NotEmpty(t, artifact.Type) assert.Greater(t, artifact.Size, int64(0)) assert.False(t, artifact.CreatedAt.IsZero()) } }) } } func TestTaskExecutionEngine_CreateSandboxConfig(t *testing.T) { engine := NewTaskExecutionEngine() // Initialize with default config config := &EngineConfig{ AIProviderFactory: &MockProviderFactory{}, SandboxDefaults: &SandboxConfig{ Image: "ubuntu:20.04", Resources: ResourceLimits{ MemoryLimit: 1024 * 1024 * 1024, CPULimit: 2.0, }, Security: SecurityPolicy{ NoNewPrivileges: true, }, }, } engine.Initialize(context.Background(), config) tests := []struct { name string request *TaskExecutionRequest validate func(t *testing.T, config *SandboxConfig) }{ { name: "basic request uses defaults", request: &TaskExecutionRequest{ ID: "test", Type: "general", Description: "test task", }, validate: func(t *testing.T, config *SandboxConfig) { assert.Equal(t, "ubuntu:20.04", config.Image) assert.Equal(t, int64(1024*1024*1024), config.Resources.MemoryLimit) assert.Equal(t, 2.0, config.Resources.CPULimit) assert.True(t, config.Security.NoNewPrivileges) }, }, { name: "request with custom requirements", request: &TaskExecutionRequest{ ID: "test", Type: "custom", Description: "custom task", Requirements: &TaskRequirements{ SandboxType: "container", EnvironmentVars: map[string]string{ "ENV_VAR": "test_value", }, ResourceLimits: &ResourceLimits{ MemoryLimit: 512 * 1024 * 1024, CPULimit: 1.0, }, SecurityPolicy: &SecurityPolicy{ ReadOnlyRoot: true, }, }, }, validate: func(t *testing.T, config *SandboxConfig) { assert.Equal(t, "container", config.Type) assert.Equal(t, "test_value", config.Environment["ENV_VAR"]) assert.Equal(t, int64(512*1024*1024), config.Resources.MemoryLimit) assert.Equal(t, 1.0, config.Resources.CPULimit) assert.True(t, config.Security.ReadOnlyRoot) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sandboxConfig := engine.createSandboxConfig(tt.request) tt.validate(t, sandboxConfig) }) } } func TestTaskExecutionEngine_GetMetrics(t *testing.T) { engine := NewTaskExecutionEngine() metrics := engine.GetMetrics() assert.NotNil(t, metrics) assert.Equal(t, int64(0), metrics.TasksExecuted) assert.Equal(t, int64(0), metrics.TasksSuccessful) assert.Equal(t, int64(0), metrics.TasksFailed) } func TestTaskExecutionEngine_Shutdown(t *testing.T) { engine := NewTaskExecutionEngine() // Initialize engine config := &EngineConfig{ AIProviderFactory: &MockProviderFactory{}, } err := engine.Initialize(context.Background(), config) require.NoError(t, err) // Add a mock active task ctx, cancel := context.WithCancel(context.Background()) engine.activeTasks["test-task"] = cancel // Shutdown should cancel active tasks err = engine.Shutdown() assert.NoError(t, err) // Verify task was cleaned up select { case <-ctx.Done(): // Expected - task was canceled default: t.Error("Expected task context to be canceled") } } // Benchmark tests func BenchmarkTaskExecutionEngine_ExecuteSimpleTask(b *testing.B) { engine := NewTaskExecutionEngine() // Setup mock AI provider mockProvider := &MockProvider{} mockFactory := &MockProviderFactory{} mockProvider.On("ExecuteTask", mock.Anything, mock.Anything).Return( &ai.TaskResponse{ TaskID: "bench", Content: "Benchmark task completed", Success: true, Actions: []ai.ActionResult{}, }, nil) mockFactory.On("GetProviderForRole", mock.Anything).Return( mockProvider, ai.ProviderConfig{Provider: "mock", Model: "test"}, nil) config := &EngineConfig{ AIProviderFactory: mockFactory, DefaultTimeout: 30 * time.Second, } engine.Initialize(context.Background(), config) request := &TaskExecutionRequest{ ID: "bench", Type: "benchmark", Description: "Benchmark task", } b.ResetTimer() for i := 0; i < b.N; i++ { _, err := engine.ExecuteTask(context.Background(), request) if err != nil { b.Fatalf("Task execution failed: %v", err) } } }