package ai import ( "context" "fmt" "time" ) // ProviderFactory creates and manages AI model providers type ProviderFactory struct { configs map[string]ProviderConfig // provider name -> config providers map[string]ModelProvider // provider name -> instance roleMapping RoleModelMapping // role-based model selection healthChecks map[string]bool // provider name -> health status lastHealthCheck map[string]time.Time // provider name -> last check time CreateProvider func(config ProviderConfig) (ModelProvider, error) // provider creation function } // NewProviderFactory creates a new provider factory func NewProviderFactory() *ProviderFactory { factory := &ProviderFactory{ configs: make(map[string]ProviderConfig), providers: make(map[string]ModelProvider), healthChecks: make(map[string]bool), lastHealthCheck: make(map[string]time.Time), } factory.CreateProvider = factory.defaultCreateProvider return factory } // RegisterProvider registers a provider configuration func (f *ProviderFactory) RegisterProvider(name string, config ProviderConfig) error { // Validate the configuration provider, err := f.CreateProvider(config) if err != nil { return fmt.Errorf("failed to create provider %s: %w", name, err) } if err := provider.ValidateConfig(); err != nil { return fmt.Errorf("invalid configuration for provider %s: %w", name, err) } f.configs[name] = config f.providers[name] = provider f.healthChecks[name] = true f.lastHealthCheck[name] = time.Now() return nil } // SetRoleMapping sets the role-to-model mapping configuration func (f *ProviderFactory) SetRoleMapping(mapping RoleModelMapping) { f.roleMapping = mapping } // GetProvider returns a provider by name func (f *ProviderFactory) GetProvider(name string) (ModelProvider, error) { provider, exists := f.providers[name] if !exists { return nil, NewProviderError(ErrProviderNotFound, fmt.Sprintf("provider %s not found", name)) } // Check health status if !f.isProviderHealthy(name) { return nil, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("provider %s is unhealthy", name)) } return provider, nil } // GetProviderForRole returns the best provider for a specific agent role func (f *ProviderFactory) GetProviderForRole(role string) (ModelProvider, ProviderConfig, error) { // Get role configuration roleConfig, exists := f.roleMapping.Roles[role] if !exists { // Fall back to default provider if f.roleMapping.DefaultProvider != "" { return f.getProviderWithFallback(f.roleMapping.DefaultProvider, f.roleMapping.FallbackProvider) } return nil, ProviderConfig{}, NewProviderError(ErrProviderNotFound, fmt.Sprintf("no provider configured for role %s", role)) } // Try primary provider first provider, config, err := f.getProviderWithFallback(roleConfig.Provider, roleConfig.FallbackProvider) if err != nil { // Try role fallback if roleConfig.FallbackProvider != "" { return f.getProviderWithFallback(roleConfig.FallbackProvider, f.roleMapping.FallbackProvider) } // Try global fallback if f.roleMapping.FallbackProvider != "" { return f.getProviderWithFallback(f.roleMapping.FallbackProvider, "") } return nil, ProviderConfig{}, err } // Merge role-specific configuration mergedConfig := f.mergeRoleConfig(config, roleConfig) return provider, mergedConfig, nil } // GetProviderForTask returns the best provider for a specific task func (f *ProviderFactory) GetProviderForTask(request *TaskRequest) (ModelProvider, ProviderConfig, error) { // Check if a specific model is requested if request.ModelName != "" { // Find provider that supports the requested model for name, provider := range f.providers { capabilities := provider.GetCapabilities() for _, supportedModel := range capabilities.SupportedModels { if supportedModel == request.ModelName { if f.isProviderHealthy(name) { config := f.configs[name] config.DefaultModel = request.ModelName // Override default model return provider, config, nil } } } } return nil, ProviderConfig{}, NewProviderError(ErrModelNotSupported, fmt.Sprintf("model %s not available", request.ModelName)) } // Use role-based selection return f.GetProviderForRole(request.AgentRole) } // ListProviders returns all registered provider names func (f *ProviderFactory) ListProviders() []string { var names []string for name := range f.providers { names = append(names, name) } return names } // ListHealthyProviders returns only healthy provider names func (f *ProviderFactory) ListHealthyProviders() []string { var names []string for name := range f.providers { if f.isProviderHealthy(name) { names = append(names, name) } } return names } // GetProviderInfo returns information about all registered providers func (f *ProviderFactory) GetProviderInfo() map[string]ProviderInfo { info := make(map[string]ProviderInfo) for name, provider := range f.providers { providerInfo := provider.GetProviderInfo() providerInfo.Name = name // Override with registered name info[name] = providerInfo } return info } // HealthCheck performs health checks on all providers func (f *ProviderFactory) HealthCheck(ctx context.Context) map[string]error { results := make(map[string]error) for name, provider := range f.providers { err := f.checkProviderHealth(ctx, name, provider) results[name] = err f.healthChecks[name] = (err == nil) f.lastHealthCheck[name] = time.Now() } return results } // GetHealthStatus returns the current health status of all providers func (f *ProviderFactory) GetHealthStatus() map[string]ProviderHealth { status := make(map[string]ProviderHealth) for name, provider := range f.providers { status[name] = ProviderHealth{ Name: name, Healthy: f.healthChecks[name], LastCheck: f.lastHealthCheck[name], ProviderInfo: provider.GetProviderInfo(), Capabilities: provider.GetCapabilities(), } } return status } // StartHealthCheckRoutine starts a background health check routine func (f *ProviderFactory) StartHealthCheckRoutine(ctx context.Context, interval time.Duration) { if interval == 0 { interval = 5 * time.Minute // Default to 5 minutes } ticker := time.NewTicker(interval) go func() { defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: healthCtx, cancel := context.WithTimeout(ctx, 30*time.Second) f.HealthCheck(healthCtx) cancel() } } }() } // defaultCreateProvider creates a provider instance based on configuration func (f *ProviderFactory) defaultCreateProvider(config ProviderConfig) (ModelProvider, error) { switch config.Type { case "ollama": return NewOllamaProvider(config), nil case "openai": return NewOpenAIProvider(config), nil case "resetdata": return NewResetDataProvider(config), nil default: return nil, NewProviderError(ErrProviderNotFound, fmt.Sprintf("unknown provider type: %s", config.Type)) } } // getProviderWithFallback attempts to get a provider with fallback support func (f *ProviderFactory) getProviderWithFallback(primaryName, fallbackName string) (ModelProvider, ProviderConfig, error) { // Try primary provider if primaryName != "" { if provider, exists := f.providers[primaryName]; exists && f.isProviderHealthy(primaryName) { return provider, f.configs[primaryName], nil } } // Try fallback provider if fallbackName != "" { if provider, exists := f.providers[fallbackName]; exists && f.isProviderHealthy(fallbackName) { return provider, f.configs[fallbackName], nil } } if primaryName != "" { return nil, ProviderConfig{}, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("provider %s and fallback %s are unavailable", primaryName, fallbackName)) } return nil, ProviderConfig{}, NewProviderError(ErrProviderNotFound, "no provider specified") } // mergeRoleConfig merges role-specific configuration with provider configuration func (f *ProviderFactory) mergeRoleConfig(baseConfig ProviderConfig, roleConfig RoleConfig) ProviderConfig { merged := baseConfig // Override model if specified in role config if roleConfig.Model != "" { merged.DefaultModel = roleConfig.Model } // Override temperature if specified if roleConfig.Temperature > 0 { merged.Temperature = roleConfig.Temperature } // Override max tokens if specified if roleConfig.MaxTokens > 0 { merged.MaxTokens = roleConfig.MaxTokens } // Override tool settings if roleConfig.EnableTools { merged.EnableTools = roleConfig.EnableTools } if roleConfig.EnableMCP { merged.EnableMCP = roleConfig.EnableMCP } // Merge MCP servers if len(roleConfig.MCPServers) > 0 { merged.MCPServers = append(merged.MCPServers, roleConfig.MCPServers...) } return merged } // isProviderHealthy checks if a provider is currently healthy func (f *ProviderFactory) isProviderHealthy(name string) bool { healthy, exists := f.healthChecks[name] if !exists { return false } // Check if health check is too old (consider unhealthy if >10 minutes old) lastCheck, exists := f.lastHealthCheck[name] if !exists || time.Since(lastCheck) > 10*time.Minute { return false } return healthy } // checkProviderHealth performs a health check on a specific provider func (f *ProviderFactory) checkProviderHealth(ctx context.Context, name string, provider ModelProvider) error { // Create a minimal health check request healthRequest := &TaskRequest{ TaskID: "health-check", AgentID: "health-checker", AgentRole: "system", Repository: "health-check", TaskTitle: "Health Check", TaskDescription: "Simple health check task", ModelName: "", // Use default MaxTokens: 50, // Minimal response EnableTools: false, } // Set a short timeout for health checks healthCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() _, err := provider.ExecuteTask(healthCtx, healthRequest) return err } // ProviderHealth represents the health status of a provider type ProviderHealth struct { Name string `json:"name"` Healthy bool `json:"healthy"` LastCheck time.Time `json:"last_check"` ProviderInfo ProviderInfo `json:"provider_info"` Capabilities ProviderCapabilities `json:"capabilities"` } // DefaultProviderFactory creates a factory with common provider configurations func DefaultProviderFactory() *ProviderFactory { factory := NewProviderFactory() // Register default Ollama provider ollamaConfig := ProviderConfig{ Type: "ollama", Endpoint: "http://localhost:11434", DefaultModel: "llama3.1:8b", Temperature: 0.7, MaxTokens: 4096, Timeout: 300 * time.Second, RetryAttempts: 3, RetryDelay: 2 * time.Second, EnableTools: true, EnableMCP: true, } factory.RegisterProvider("ollama", ollamaConfig) // Set default role mapping defaultMapping := RoleModelMapping{ DefaultProvider: "ollama", FallbackProvider: "ollama", Roles: map[string]RoleConfig{ "developer": { Provider: "ollama", Model: "codellama:13b", Temperature: 0.3, MaxTokens: 8192, EnableTools: true, EnableMCP: true, SystemPrompt: "You are an expert software developer focused on writing clean, maintainable, and well-tested code.", }, "reviewer": { Provider: "ollama", Model: "llama3.1:8b", Temperature: 0.2, MaxTokens: 6144, EnableTools: true, SystemPrompt: "You are a thorough code reviewer focused on quality, security, and best practices.", }, "architect": { Provider: "ollama", Model: "llama3.1:13b", Temperature: 0.5, MaxTokens: 8192, EnableTools: true, SystemPrompt: "You are a senior software architect focused on system design and technical decision making.", }, "tester": { Provider: "ollama", Model: "codellama:7b", Temperature: 0.3, MaxTokens: 6144, EnableTools: true, SystemPrompt: "You are a QA engineer focused on comprehensive testing and quality assurance.", }, }, } factory.SetRoleMapping(defaultMapping) return factory }