package ai import ( "fmt" "os" "strings" "time" "gopkg.in/yaml.v3" ) // ModelConfig represents the complete model configuration loaded from YAML type ModelConfig struct { Providers map[string]ProviderConfig `yaml:"providers" json:"providers"` DefaultProvider string `yaml:"default_provider" json:"default_provider"` FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"` Roles map[string]RoleConfig `yaml:"roles" json:"roles"` Environments map[string]EnvConfig `yaml:"environments" json:"environments"` ModelPreferences map[string]TaskPreference `yaml:"model_preferences" json:"model_preferences"` } // EnvConfig represents environment-specific configuration overrides type EnvConfig struct { DefaultProvider string `yaml:"default_provider" json:"default_provider"` FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"` } // TaskPreference represents preferred models for specific task types type TaskPreference struct { PreferredModels []string `yaml:"preferred_models" json:"preferred_models"` MinContextTokens int `yaml:"min_context_tokens" json:"min_context_tokens"` } // ConfigLoader loads and manages AI provider configurations type ConfigLoader struct { configPath string environment string } // NewConfigLoader creates a new configuration loader func NewConfigLoader(configPath, environment string) *ConfigLoader { return &ConfigLoader{ configPath: configPath, environment: environment, } } // LoadConfig loads the complete configuration from the YAML file func (c *ConfigLoader) LoadConfig() (*ModelConfig, error) { data, err := os.ReadFile(c.configPath) if err != nil { return nil, fmt.Errorf("failed to read config file %s: %w", c.configPath, err) } // Expand environment variables in the config configData := c.expandEnvVars(string(data)) var config ModelConfig if err := yaml.Unmarshal([]byte(configData), &config); err != nil { return nil, fmt.Errorf("failed to parse config file %s: %w", c.configPath, err) } // Apply environment-specific overrides if c.environment != "" { c.applyEnvironmentOverrides(&config) } // Validate the configuration if err := c.validateConfig(&config); err != nil { return nil, fmt.Errorf("invalid configuration: %w", err) } return &config, nil } // LoadProviderFactory creates a provider factory from the configuration func (c *ConfigLoader) LoadProviderFactory() (*ProviderFactory, error) { config, err := c.LoadConfig() if err != nil { return nil, err } factory := NewProviderFactory() // Register all providers for name, providerConfig := range config.Providers { if err := factory.RegisterProvider(name, providerConfig); err != nil { // Log warning but continue with other providers fmt.Printf("Warning: Failed to register provider %s: %v\n", name, err) continue } } // Set up role mapping roleMapping := RoleModelMapping{ DefaultProvider: config.DefaultProvider, FallbackProvider: config.FallbackProvider, Roles: config.Roles, } factory.SetRoleMapping(roleMapping) return factory, nil } // expandEnvVars expands environment variables in the configuration func (c *ConfigLoader) expandEnvVars(config string) string { // Replace ${VAR} and $VAR patterns with environment variable values expanded := config // Handle ${VAR} pattern for { start := strings.Index(expanded, "${") if start == -1 { break } end := strings.Index(expanded[start:], "}") if end == -1 { break } end += start varName := expanded[start+2 : end] varValue := os.Getenv(varName) expanded = expanded[:start] + varValue + expanded[end+1:] } return expanded } // applyEnvironmentOverrides applies environment-specific configuration overrides func (c *ConfigLoader) applyEnvironmentOverrides(config *ModelConfig) { envConfig, exists := config.Environments[c.environment] if !exists { return } // Override default and fallback providers if specified if envConfig.DefaultProvider != "" { config.DefaultProvider = envConfig.DefaultProvider } if envConfig.FallbackProvider != "" { config.FallbackProvider = envConfig.FallbackProvider } } // validateConfig validates the loaded configuration func (c *ConfigLoader) validateConfig(config *ModelConfig) error { // Check that default provider exists if config.DefaultProvider != "" { if _, exists := config.Providers[config.DefaultProvider]; !exists { return fmt.Errorf("default_provider '%s' not found in providers", config.DefaultProvider) } } // Check that fallback provider exists if config.FallbackProvider != "" { if _, exists := config.Providers[config.FallbackProvider]; !exists { return fmt.Errorf("fallback_provider '%s' not found in providers", config.FallbackProvider) } } // Validate each provider configuration for name, providerConfig := range config.Providers { if err := c.validateProviderConfig(name, providerConfig); err != nil { return fmt.Errorf("invalid provider config '%s': %w", name, err) } } // Validate role configurations for roleName, roleConfig := range config.Roles { if err := c.validateRoleConfig(roleName, roleConfig, config.Providers); err != nil { return fmt.Errorf("invalid role config '%s': %w", roleName, err) } } return nil } // validateProviderConfig validates a single provider configuration func (c *ConfigLoader) validateProviderConfig(name string, config ProviderConfig) error { // Check required fields if config.Type == "" { return fmt.Errorf("type is required") } // Validate provider type validTypes := []string{"ollama", "openai", "resetdata"} typeValid := false for _, validType := range validTypes { if config.Type == validType { typeValid = true break } } if !typeValid { return fmt.Errorf("invalid provider type '%s', must be one of: %s", config.Type, strings.Join(validTypes, ", ")) } // Check endpoint for all types if config.Endpoint == "" { return fmt.Errorf("endpoint is required") } // Check API key for providers that require it if (config.Type == "openai" || config.Type == "resetdata") && config.APIKey == "" { return fmt.Errorf("api_key is required for %s provider", config.Type) } // Check default model if config.DefaultModel == "" { return fmt.Errorf("default_model is required") } // Validate timeout if config.Timeout == 0 { config.Timeout = 300 * time.Second // Set default } // Validate temperature range if config.Temperature < 0 || config.Temperature > 2.0 { return fmt.Errorf("temperature must be between 0 and 2.0") } // Validate max tokens if config.MaxTokens <= 0 { config.MaxTokens = 4096 // Set default } return nil } // validateRoleConfig validates a role configuration func (c *ConfigLoader) validateRoleConfig(roleName string, config RoleConfig, providers map[string]ProviderConfig) error { // Check that provider exists if config.Provider != "" { if _, exists := providers[config.Provider]; !exists { return fmt.Errorf("provider '%s' not found", config.Provider) } } // Check fallback provider exists if specified if config.FallbackProvider != "" { if _, exists := providers[config.FallbackProvider]; !exists { return fmt.Errorf("fallback_provider '%s' not found", config.FallbackProvider) } } // Validate temperature range if config.Temperature < 0 || config.Temperature > 2.0 { return fmt.Errorf("temperature must be between 0 and 2.0") } // Validate max tokens if config.MaxTokens < 0 { return fmt.Errorf("max_tokens cannot be negative") } return nil } // GetProviderForTaskType returns the best provider for a specific task type func (c *ConfigLoader) GetProviderForTaskType(config *ModelConfig, factory *ProviderFactory, taskType string) (ModelProvider, ProviderConfig, error) { // Check if we have preferences for this task type if preference, exists := config.ModelPreferences[taskType]; exists { // Try each preferred model in order for _, modelName := range preference.PreferredModels { for providerName, provider := range factory.providers { capabilities := provider.GetCapabilities() for _, supportedModel := range capabilities.SupportedModels { if supportedModel == modelName && factory.isProviderHealthy(providerName) { providerConfig := factory.configs[providerName] providerConfig.DefaultModel = modelName // Ensure minimum context if specified if preference.MinContextTokens > providerConfig.MaxTokens { providerConfig.MaxTokens = preference.MinContextTokens } return provider, providerConfig, nil } } } } } // Fall back to default provider selection if config.DefaultProvider != "" { provider, err := factory.GetProvider(config.DefaultProvider) if err != nil { return nil, ProviderConfig{}, err } return provider, factory.configs[config.DefaultProvider], nil } return nil, ProviderConfig{}, NewProviderError(ErrProviderNotFound, "no suitable provider found for task type "+taskType) } // DefaultConfigPath returns the default path for the model configuration file func DefaultConfigPath() string { // Try environment variable first if path := os.Getenv("CHORUS_MODEL_CONFIG"); path != "" { return path } // Try relative to current working directory if _, err := os.Stat("configs/models.yaml"); err == nil { return "configs/models.yaml" } // Try relative to executable if _, err := os.Stat("./configs/models.yaml"); err == nil { return "./configs/models.yaml" } // Default fallback return "configs/models.yaml" } // GetEnvironment returns the current environment (from env var or default) func GetEnvironment() string { if env := os.Getenv("CHORUS_ENVIRONMENT"); env != "" { return env } if env := os.Getenv("NODE_ENV"); env != "" { return env } return "development" // default }