package ai import ( "context" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewProviderFactory(t *testing.T) { factory := NewProviderFactory() assert.NotNil(t, factory) assert.Empty(t, factory.configs) assert.Empty(t, factory.providers) assert.Empty(t, factory.healthChecks) assert.Empty(t, factory.lastHealthCheck) } func TestProviderFactoryRegisterProvider(t *testing.T) { factory := NewProviderFactory() // Create a valid mock provider config (since validation will be called) config := ProviderConfig{ Type: "mock", Endpoint: "mock://localhost", DefaultModel: "test-model", Temperature: 0.7, MaxTokens: 4096, Timeout: 300 * time.Second, } // Override CreateProvider to return our mock originalCreate := factory.CreateProvider factory.CreateProvider = func(config ProviderConfig) (ModelProvider, error) { return NewMockProvider("test-provider"), nil } defer func() { factory.CreateProvider = originalCreate }() err := factory.RegisterProvider("test", config) require.NoError(t, err) // Verify provider was registered assert.Len(t, factory.providers, 1) assert.Contains(t, factory.providers, "test") assert.True(t, factory.healthChecks["test"]) } func TestProviderFactoryRegisterProviderValidationFailure(t *testing.T) { factory := NewProviderFactory() // Create a mock provider that will fail validation config := ProviderConfig{ Type: "mock", Endpoint: "mock://localhost", DefaultModel: "test-model", } // Override CreateProvider to return a failing mock factory.CreateProvider = func(config ProviderConfig) (ModelProvider, error) { mock := NewMockProvider("failing-provider") mock.shouldFail = true // This will make ValidateConfig fail return mock, nil } err := factory.RegisterProvider("failing", config) require.Error(t, err) assert.Contains(t, err.Error(), "invalid configuration") // Verify provider was not registered assert.Empty(t, factory.providers) } func TestProviderFactoryGetProvider(t *testing.T) { factory := NewProviderFactory() mockProvider := NewMockProvider("test-provider") // Manually add provider and mark as healthy factory.providers["test"] = mockProvider factory.healthChecks["test"] = true factory.lastHealthCheck["test"] = time.Now() provider, err := factory.GetProvider("test") require.NoError(t, err) assert.Equal(t, mockProvider, provider) } func TestProviderFactoryGetProviderNotFound(t *testing.T) { factory := NewProviderFactory() _, err := factory.GetProvider("nonexistent") require.Error(t, err) assert.IsType(t, &ProviderError{}, err) providerErr := err.(*ProviderError) assert.Equal(t, "PROVIDER_NOT_FOUND", providerErr.Code) } func TestProviderFactoryGetProviderUnhealthy(t *testing.T) { factory := NewProviderFactory() mockProvider := NewMockProvider("test-provider") // Add provider but mark as unhealthy factory.providers["test"] = mockProvider factory.healthChecks["test"] = false factory.lastHealthCheck["test"] = time.Now() _, err := factory.GetProvider("test") require.Error(t, err) assert.IsType(t, &ProviderError{}, err) providerErr := err.(*ProviderError) assert.Equal(t, "PROVIDER_UNAVAILABLE", providerErr.Code) } func TestProviderFactorySetRoleMapping(t *testing.T) { factory := NewProviderFactory() mapping := RoleModelMapping{ DefaultProvider: "test", FallbackProvider: "backup", Roles: map[string]RoleConfig{ "developer": { Provider: "test", Model: "dev-model", }, }, } factory.SetRoleMapping(mapping) assert.Equal(t, mapping, factory.roleMapping) } func TestProviderFactoryGetProviderForRole(t *testing.T) { factory := NewProviderFactory() // Set up providers devProvider := NewMockProvider("dev-provider") backupProvider := NewMockProvider("backup-provider") factory.providers["dev"] = devProvider factory.providers["backup"] = backupProvider factory.healthChecks["dev"] = true factory.healthChecks["backup"] = true factory.lastHealthCheck["dev"] = time.Now() factory.lastHealthCheck["backup"] = time.Now() factory.configs["dev"] = ProviderConfig{ Type: "mock", DefaultModel: "dev-model", Temperature: 0.7, } factory.configs["backup"] = ProviderConfig{ Type: "mock", DefaultModel: "backup-model", Temperature: 0.8, } // Set up role mapping mapping := RoleModelMapping{ DefaultProvider: "backup", FallbackProvider: "backup", Roles: map[string]RoleConfig{ "developer": { Provider: "dev", Model: "custom-dev-model", Temperature: 0.3, }, }, } factory.SetRoleMapping(mapping) provider, config, err := factory.GetProviderForRole("developer") require.NoError(t, err) assert.Equal(t, devProvider, provider) assert.Equal(t, "custom-dev-model", config.DefaultModel) assert.Equal(t, float32(0.3), config.Temperature) } func TestProviderFactoryGetProviderForRoleWithFallback(t *testing.T) { factory := NewProviderFactory() // Set up only backup provider (primary is missing) backupProvider := NewMockProvider("backup-provider") factory.providers["backup"] = backupProvider factory.healthChecks["backup"] = true factory.lastHealthCheck["backup"] = time.Now() factory.configs["backup"] = ProviderConfig{Type: "mock", DefaultModel: "backup-model"} // Set up role mapping with primary provider that doesn't exist mapping := RoleModelMapping{ DefaultProvider: "backup", FallbackProvider: "backup", Roles: map[string]RoleConfig{ "developer": { Provider: "nonexistent", FallbackProvider: "backup", }, }, } factory.SetRoleMapping(mapping) provider, config, err := factory.GetProviderForRole("developer") require.NoError(t, err) assert.Equal(t, backupProvider, provider) assert.Equal(t, "backup-model", config.DefaultModel) } func TestProviderFactoryGetProviderForRoleNotFound(t *testing.T) { factory := NewProviderFactory() // No providers registered and no default mapping := RoleModelMapping{ Roles: make(map[string]RoleConfig), } factory.SetRoleMapping(mapping) _, _, err := factory.GetProviderForRole("nonexistent") require.Error(t, err) assert.IsType(t, &ProviderError{}, err) } func TestProviderFactoryGetProviderForTask(t *testing.T) { factory := NewProviderFactory() // Set up a provider that supports a specific model mockProvider := NewMockProvider("test-provider") mockProvider.capabilities.SupportedModels = []string{"specific-model", "another-model"} factory.providers["test"] = mockProvider factory.healthChecks["test"] = true factory.lastHealthCheck["test"] = time.Now() factory.configs["test"] = ProviderConfig{Type: "mock", DefaultModel: "default-model"} request := &TaskRequest{ TaskID: "test-123", AgentRole: "developer", ModelName: "specific-model", // Request specific model } provider, config, err := factory.GetProviderForTask(request) require.NoError(t, err) assert.Equal(t, mockProvider, provider) assert.Equal(t, "specific-model", config.DefaultModel) // Should override default } func TestProviderFactoryGetProviderForTaskModelNotSupported(t *testing.T) { factory := NewProviderFactory() mockProvider := NewMockProvider("test-provider") mockProvider.capabilities.SupportedModels = []string{"model-1", "model-2"} factory.providers["test"] = mockProvider factory.healthChecks["test"] = true factory.lastHealthCheck["test"] = time.Now() request := &TaskRequest{ TaskID: "test-123", AgentRole: "developer", ModelName: "unsupported-model", } _, _, err := factory.GetProviderForTask(request) require.Error(t, err) assert.IsType(t, &ProviderError{}, err) providerErr := err.(*ProviderError) assert.Equal(t, "MODEL_NOT_SUPPORTED", providerErr.Code) } func TestProviderFactoryListProviders(t *testing.T) { factory := NewProviderFactory() // Add some mock providers factory.providers["provider1"] = NewMockProvider("provider1") factory.providers["provider2"] = NewMockProvider("provider2") factory.providers["provider3"] = NewMockProvider("provider3") providers := factory.ListProviders() assert.Len(t, providers, 3) assert.Contains(t, providers, "provider1") assert.Contains(t, providers, "provider2") assert.Contains(t, providers, "provider3") } func TestProviderFactoryListHealthyProviders(t *testing.T) { factory := NewProviderFactory() // Add providers with different health states factory.providers["healthy1"] = NewMockProvider("healthy1") factory.providers["healthy2"] = NewMockProvider("healthy2") factory.providers["unhealthy"] = NewMockProvider("unhealthy") factory.healthChecks["healthy1"] = true factory.healthChecks["healthy2"] = true factory.healthChecks["unhealthy"] = false factory.lastHealthCheck["healthy1"] = time.Now() factory.lastHealthCheck["healthy2"] = time.Now() factory.lastHealthCheck["unhealthy"] = time.Now() healthyProviders := factory.ListHealthyProviders() assert.Len(t, healthyProviders, 2) assert.Contains(t, healthyProviders, "healthy1") assert.Contains(t, healthyProviders, "healthy2") assert.NotContains(t, healthyProviders, "unhealthy") } func TestProviderFactoryGetProviderInfo(t *testing.T) { factory := NewProviderFactory() mock1 := NewMockProvider("mock1") mock2 := NewMockProvider("mock2") factory.providers["provider1"] = mock1 factory.providers["provider2"] = mock2 info := factory.GetProviderInfo() assert.Len(t, info, 2) assert.Contains(t, info, "provider1") assert.Contains(t, info, "provider2") // Verify that the name is overridden with the registered name assert.Equal(t, "provider1", info["provider1"].Name) assert.Equal(t, "provider2", info["provider2"].Name) } func TestProviderFactoryHealthCheck(t *testing.T) { factory := NewProviderFactory() // Add a healthy and an unhealthy provider healthyProvider := NewMockProvider("healthy") unhealthyProvider := NewMockProvider("unhealthy") unhealthyProvider.shouldFail = true factory.providers["healthy"] = healthyProvider factory.providers["unhealthy"] = unhealthyProvider ctx := context.Background() results := factory.HealthCheck(ctx) assert.Len(t, results, 2) assert.NoError(t, results["healthy"]) assert.Error(t, results["unhealthy"]) // Verify health states were updated assert.True(t, factory.healthChecks["healthy"]) assert.False(t, factory.healthChecks["unhealthy"]) } func TestProviderFactoryGetHealthStatus(t *testing.T) { factory := NewProviderFactory() mockProvider := NewMockProvider("test") factory.providers["test"] = mockProvider now := time.Now() factory.healthChecks["test"] = true factory.lastHealthCheck["test"] = now status := factory.GetHealthStatus() assert.Len(t, status, 1) assert.Contains(t, status, "test") testStatus := status["test"] assert.Equal(t, "test", testStatus.Name) assert.True(t, testStatus.Healthy) assert.Equal(t, now, testStatus.LastCheck) } func TestProviderFactoryIsProviderHealthy(t *testing.T) { factory := NewProviderFactory() // Test healthy provider factory.healthChecks["healthy"] = true factory.lastHealthCheck["healthy"] = time.Now() assert.True(t, factory.isProviderHealthy("healthy")) // Test unhealthy provider factory.healthChecks["unhealthy"] = false factory.lastHealthCheck["unhealthy"] = time.Now() assert.False(t, factory.isProviderHealthy("unhealthy")) // Test provider with old health check (should be considered unhealthy) factory.healthChecks["stale"] = true factory.lastHealthCheck["stale"] = time.Now().Add(-15 * time.Minute) assert.False(t, factory.isProviderHealthy("stale")) // Test non-existent provider assert.False(t, factory.isProviderHealthy("nonexistent")) } func TestProviderFactoryMergeRoleConfig(t *testing.T) { factory := NewProviderFactory() baseConfig := ProviderConfig{ Type: "test", DefaultModel: "base-model", Temperature: 0.7, MaxTokens: 4096, EnableTools: false, EnableMCP: false, MCPServers: []string{"base-server"}, } roleConfig := RoleConfig{ Model: "role-model", Temperature: 0.3, MaxTokens: 8192, EnableTools: true, EnableMCP: true, MCPServers: []string{"role-server"}, } merged := factory.mergeRoleConfig(baseConfig, roleConfig) assert.Equal(t, "role-model", merged.DefaultModel) assert.Equal(t, float32(0.3), merged.Temperature) assert.Equal(t, 8192, merged.MaxTokens) assert.True(t, merged.EnableTools) assert.True(t, merged.EnableMCP) assert.Len(t, merged.MCPServers, 2) // Should be merged assert.Contains(t, merged.MCPServers, "base-server") assert.Contains(t, merged.MCPServers, "role-server") } func TestDefaultProviderFactory(t *testing.T) { factory := DefaultProviderFactory() // Should have at least the default ollama provider providers := factory.ListProviders() assert.Contains(t, providers, "ollama") // Should have role mappings configured assert.NotEmpty(t, factory.roleMapping.Roles) assert.Contains(t, factory.roleMapping.Roles, "developer") assert.Contains(t, factory.roleMapping.Roles, "reviewer") // Test getting provider for developer role _, config, err := factory.GetProviderForRole("developer") require.NoError(t, err) assert.Equal(t, "codellama:13b", config.DefaultModel) assert.Equal(t, float32(0.3), config.Temperature) } func TestProviderFactoryCreateProvider(t *testing.T) { factory := NewProviderFactory() tests := []struct { name string config ProviderConfig expectErr bool }{ { name: "ollama provider", config: ProviderConfig{ Type: "ollama", Endpoint: "http://localhost:11434", DefaultModel: "llama2", }, expectErr: false, }, { name: "openai provider", config: ProviderConfig{ Type: "openai", Endpoint: "https://api.openai.com/v1", APIKey: "test-key", DefaultModel: "gpt-4", }, expectErr: false, }, { name: "resetdata provider", config: ProviderConfig{ Type: "resetdata", Endpoint: "https://api.resetdata.ai", APIKey: "test-key", DefaultModel: "llama2", }, expectErr: false, }, { name: "unknown provider", config: ProviderConfig{ Type: "unknown", }, expectErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { provider, err := factory.CreateProvider(tt.config) if tt.expectErr { assert.Error(t, err) assert.Nil(t, provider) } else { assert.NoError(t, err) assert.NotNil(t, provider) } }) } }