package ai import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // MockProvider implements ModelProvider for testing type MockProvider struct { name string capabilities ProviderCapabilities shouldFail bool response *TaskResponse executeFunc func(ctx context.Context, request *TaskRequest) (*TaskResponse, error) } func NewMockProvider(name string) *MockProvider { return &MockProvider{ name: name, capabilities: ProviderCapabilities{ SupportsMCP: true, SupportsTools: true, SupportsStreaming: true, SupportsFunctions: false, MaxTokens: 4096, SupportedModels: []string{"test-model", "test-model-2"}, SupportsImages: false, SupportsFiles: true, }, response: &TaskResponse{ Success: true, Response: "Mock response", }, } } func (m *MockProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) { if m.executeFunc != nil { return m.executeFunc(ctx, request) } if m.shouldFail { return nil, NewProviderError(ErrTaskExecutionFailed, "mock execution failed") } response := *m.response // Copy the response response.TaskID = request.TaskID response.AgentID = request.AgentID response.Provider = m.name response.StartTime = time.Now() response.EndTime = time.Now().Add(100 * time.Millisecond) response.Duration = response.EndTime.Sub(response.StartTime) return &response, nil } func (m *MockProvider) GetCapabilities() ProviderCapabilities { return m.capabilities } func (m *MockProvider) ValidateConfig() error { if m.shouldFail { return NewProviderError(ErrInvalidConfiguration, "mock config validation failed") } return nil } func (m *MockProvider) GetProviderInfo() ProviderInfo { return ProviderInfo{ Name: m.name, Type: "mock", Version: "1.0.0", Endpoint: "mock://localhost", DefaultModel: "test-model", RequiresAPIKey: false, RateLimit: 0, } } func TestProviderError(t *testing.T) { tests := []struct { name string err *ProviderError expected string retryable bool }{ { name: "simple error", err: ErrProviderNotFound, expected: "Provider not found", retryable: false, }, { name: "error with details", err: NewProviderError(ErrRateLimitExceeded, "API rate limit of 1000/hour exceeded"), expected: "Rate limit exceeded: API rate limit of 1000/hour exceeded", retryable: false, }, { name: "retryable error", err: &ProviderError{ Code: "TEMPORARY_ERROR", Message: "Temporary failure", Retryable: true, }, expected: "Temporary failure", retryable: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.Equal(t, tt.expected, tt.err.Error()) assert.Equal(t, tt.retryable, tt.err.IsRetryable()) }) } } func TestTaskRequest(t *testing.T) { request := &TaskRequest{ TaskID: "test-task-123", AgentID: "agent-456", AgentRole: "developer", Repository: "test/repo", TaskTitle: "Test Task", TaskDescription: "A test task for unit testing", TaskLabels: []string{"bug", "urgent"}, Priority: 8, Complexity: 6, ModelName: "test-model", Temperature: 0.7, MaxTokens: 4096, EnableTools: true, } // Validate required fields assert.NotEmpty(t, request.TaskID) assert.NotEmpty(t, request.AgentID) assert.NotEmpty(t, request.AgentRole) assert.NotEmpty(t, request.Repository) assert.NotEmpty(t, request.TaskTitle) assert.Greater(t, request.Priority, 0) assert.Greater(t, request.Complexity, 0) } func TestTaskResponse(t *testing.T) { startTime := time.Now() endTime := startTime.Add(2 * time.Second) response := &TaskResponse{ Success: true, TaskID: "test-task-123", AgentID: "agent-456", ModelUsed: "test-model", Provider: "mock", Response: "Task completed successfully", Actions: []TaskAction{ { Type: "file_create", Target: "test.go", Content: "package main", Result: "File created", Success: true, Timestamp: time.Now(), }, }, Artifacts: []Artifact{ { Name: "test.go", Type: "file", Path: "./test.go", Content: "package main", Size: 12, CreatedAt: time.Now(), }, }, StartTime: startTime, EndTime: endTime, Duration: endTime.Sub(startTime), TokensUsed: TokenUsage{ PromptTokens: 50, CompletionTokens: 100, TotalTokens: 150, }, } // Validate response structure assert.True(t, response.Success) assert.NotEmpty(t, response.TaskID) assert.NotEmpty(t, response.Provider) assert.Len(t, response.Actions, 1) assert.Len(t, response.Artifacts, 1) assert.Equal(t, 2*time.Second, response.Duration) assert.Equal(t, 150, response.TokensUsed.TotalTokens) } func TestTaskAction(t *testing.T) { action := TaskAction{ Type: "file_edit", Target: "main.go", Content: "updated content", Result: "File updated successfully", Success: true, Timestamp: time.Now(), Metadata: map[string]interface{}{ "line_count": 42, "backup": true, }, } assert.Equal(t, "file_edit", action.Type) assert.True(t, action.Success) assert.NotNil(t, action.Metadata) assert.Equal(t, 42, action.Metadata["line_count"]) } func TestArtifact(t *testing.T) { artifact := Artifact{ Name: "output.log", Type: "log", Path: "/tmp/output.log", Content: "Log content here", Size: 16, CreatedAt: time.Now(), Checksum: "abc123", } assert.Equal(t, "output.log", artifact.Name) assert.Equal(t, "log", artifact.Type) assert.Equal(t, int64(16), artifact.Size) assert.NotEmpty(t, artifact.Checksum) } func TestProviderCapabilities(t *testing.T) { capabilities := ProviderCapabilities{ SupportsMCP: true, SupportsTools: true, SupportsStreaming: false, SupportsFunctions: true, MaxTokens: 8192, SupportedModels: []string{"gpt-4", "gpt-3.5-turbo"}, SupportsImages: true, SupportsFiles: true, } assert.True(t, capabilities.SupportsMCP) assert.True(t, capabilities.SupportsTools) assert.False(t, capabilities.SupportsStreaming) assert.Equal(t, 8192, capabilities.MaxTokens) assert.Len(t, capabilities.SupportedModels, 2) } func TestProviderInfo(t *testing.T) { info := ProviderInfo{ Name: "Test Provider", Type: "test", Version: "1.0.0", Endpoint: "https://api.test.com", DefaultModel: "test-model", RequiresAPIKey: true, RateLimit: 1000, } assert.Equal(t, "Test Provider", info.Name) assert.True(t, info.RequiresAPIKey) assert.Equal(t, 1000, info.RateLimit) } func TestProviderConfig(t *testing.T) { config := ProviderConfig{ Type: "test", Endpoint: "https://api.test.com", APIKey: "test-key", DefaultModel: "test-model", Temperature: 0.7, MaxTokens: 4096, Timeout: 300 * time.Second, RetryAttempts: 3, RetryDelay: 2 * time.Second, EnableTools: true, EnableMCP: true, } assert.Equal(t, "test", config.Type) assert.Equal(t, float32(0.7), config.Temperature) assert.Equal(t, 4096, config.MaxTokens) assert.Equal(t, 300*time.Second, config.Timeout) assert.True(t, config.EnableTools) } func TestRoleConfig(t *testing.T) { roleConfig := RoleConfig{ Provider: "openai", Model: "gpt-4", Temperature: 0.3, MaxTokens: 8192, SystemPrompt: "You are a helpful assistant", FallbackProvider: "ollama", FallbackModel: "llama2", EnableTools: true, EnableMCP: false, AllowedTools: []string{"file_ops", "code_analysis"}, MCPServers: []string{"file-server"}, } assert.Equal(t, "openai", roleConfig.Provider) assert.Equal(t, "gpt-4", roleConfig.Model) assert.Equal(t, float32(0.3), roleConfig.Temperature) assert.Len(t, roleConfig.AllowedTools, 2) assert.True(t, roleConfig.EnableTools) assert.False(t, roleConfig.EnableMCP) } func TestRoleModelMapping(t *testing.T) { mapping := RoleModelMapping{ DefaultProvider: "ollama", FallbackProvider: "openai", Roles: map[string]RoleConfig{ "developer": { Provider: "ollama", Model: "codellama", Temperature: 0.3, }, "reviewer": { Provider: "openai", Model: "gpt-4", Temperature: 0.2, }, }, } assert.Equal(t, "ollama", mapping.DefaultProvider) assert.Len(t, mapping.Roles, 2) devConfig, exists := mapping.Roles["developer"] require.True(t, exists) assert.Equal(t, "codellama", devConfig.Model) assert.Equal(t, float32(0.3), devConfig.Temperature) } func TestTokenUsage(t *testing.T) { usage := TokenUsage{ PromptTokens: 100, CompletionTokens: 200, TotalTokens: 300, } assert.Equal(t, 100, usage.PromptTokens) assert.Equal(t, 200, usage.CompletionTokens) assert.Equal(t, 300, usage.TotalTokens) assert.Equal(t, usage.PromptTokens+usage.CompletionTokens, usage.TotalTokens) } func TestMockProviderExecuteTask(t *testing.T) { provider := NewMockProvider("test-provider") request := &TaskRequest{ TaskID: "test-123", AgentID: "agent-456", AgentRole: "developer", Repository: "test/repo", TaskTitle: "Test Task", } ctx := context.Background() response, err := provider.ExecuteTask(ctx, request) require.NoError(t, err) assert.True(t, response.Success) assert.Equal(t, "test-123", response.TaskID) assert.Equal(t, "agent-456", response.AgentID) assert.Equal(t, "test-provider", response.Provider) assert.NotEmpty(t, response.Response) } func TestMockProviderFailure(t *testing.T) { provider := NewMockProvider("failing-provider") provider.shouldFail = true request := &TaskRequest{ TaskID: "test-123", AgentID: "agent-456", AgentRole: "developer", } ctx := context.Background() _, err := provider.ExecuteTask(ctx, request) require.Error(t, err) assert.IsType(t, &ProviderError{}, err) providerErr := err.(*ProviderError) assert.Equal(t, "TASK_EXECUTION_FAILED", providerErr.Code) } func TestMockProviderCustomExecuteFunc(t *testing.T) { provider := NewMockProvider("custom-provider") // Set custom execution function provider.executeFunc = func(ctx context.Context, request *TaskRequest) (*TaskResponse, error) { return &TaskResponse{ Success: true, TaskID: request.TaskID, Response: "Custom response: " + request.TaskTitle, Provider: "custom-provider", }, nil } request := &TaskRequest{ TaskID: "test-123", TaskTitle: "Custom Task", } ctx := context.Background() response, err := provider.ExecuteTask(ctx, request) require.NoError(t, err) assert.Equal(t, "Custom response: Custom Task", response.Response) } func TestMockProviderCapabilities(t *testing.T) { provider := NewMockProvider("test-provider") capabilities := provider.GetCapabilities() assert.True(t, capabilities.SupportsMCP) assert.True(t, capabilities.SupportsTools) assert.Equal(t, 4096, capabilities.MaxTokens) assert.Len(t, capabilities.SupportedModels, 2) assert.Contains(t, capabilities.SupportedModels, "test-model") } func TestMockProviderInfo(t *testing.T) { provider := NewMockProvider("test-provider") info := provider.GetProviderInfo() assert.Equal(t, "test-provider", info.Name) assert.Equal(t, "mock", info.Type) assert.Equal(t, "test-model", info.DefaultModel) assert.False(t, info.RequiresAPIKey) }