package ai import ( "context" "fmt" "strings" "time" "github.com/sashabaranov/go-openai" ) // OpenAIProvider implements ModelProvider for OpenAI API type OpenAIProvider struct { config ProviderConfig client *openai.Client } // NewOpenAIProvider creates a new OpenAI provider instance func NewOpenAIProvider(config ProviderConfig) *OpenAIProvider { client := openai.NewClient(config.APIKey) // Use custom endpoint if specified if config.Endpoint != "" && config.Endpoint != "https://api.openai.com/v1" { clientConfig := openai.DefaultConfig(config.APIKey) clientConfig.BaseURL = config.Endpoint client = openai.NewClientWithConfig(clientConfig) } return &OpenAIProvider{ config: config, client: client, } } // ExecuteTask implements the ModelProvider interface for OpenAI func (p *OpenAIProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) { startTime := time.Now() // Build messages for the chat completion messages, err := p.buildChatMessages(request) if err != nil { return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to build messages: %v", err)) } // Prepare the chat completion request chatReq := openai.ChatCompletionRequest{ Model: p.selectModel(request.ModelName), Messages: messages, Temperature: p.getTemperature(request.Temperature), MaxTokens: p.getMaxTokens(request.MaxTokens), Stream: false, } // Add tools if enabled and supported if p.config.EnableTools && request.EnableTools { chatReq.Tools = p.getToolDefinitions(request) chatReq.ToolChoice = "auto" } // Execute the chat completion resp, err := p.client.CreateChatCompletion(ctx, chatReq) if err != nil { return nil, p.handleOpenAIError(err) } endTime := time.Now() // Process the response if len(resp.Choices) == 0 { return nil, NewProviderError(ErrTaskExecutionFailed, "no response choices returned from OpenAI") } choice := resp.Choices[0] responseText := choice.Message.Content // Process tool calls if present var actions []TaskAction var artifacts []Artifact if len(choice.Message.ToolCalls) > 0 { toolActions, toolArtifacts := p.processToolCalls(choice.Message.ToolCalls, request) actions = append(actions, toolActions...) artifacts = append(artifacts, toolArtifacts...) } // Parse response for additional actions responseActions, responseArtifacts := p.parseResponseForActions(responseText, request) actions = append(actions, responseActions...) artifacts = append(artifacts, responseArtifacts...) return &TaskResponse{ Success: true, TaskID: request.TaskID, AgentID: request.AgentID, ModelUsed: resp.Model, Provider: "openai", Response: responseText, Actions: actions, Artifacts: artifacts, StartTime: startTime, EndTime: endTime, Duration: endTime.Sub(startTime), TokensUsed: TokenUsage{ PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens, }, }, nil } // GetCapabilities returns OpenAI provider capabilities func (p *OpenAIProvider) GetCapabilities() ProviderCapabilities { return ProviderCapabilities{ SupportsMCP: p.config.EnableMCP, SupportsTools: true, // OpenAI supports function calling SupportsStreaming: true, SupportsFunctions: true, MaxTokens: p.getModelMaxTokens(p.config.DefaultModel), SupportedModels: p.getSupportedModels(), SupportsImages: p.modelSupportsImages(p.config.DefaultModel), SupportsFiles: true, } } // ValidateConfig validates the OpenAI provider configuration func (p *OpenAIProvider) ValidateConfig() error { if p.config.APIKey == "" { return NewProviderError(ErrAPIKeyRequired, "API key is required for OpenAI provider") } if p.config.DefaultModel == "" { return NewProviderError(ErrInvalidConfiguration, "default_model is required for OpenAI provider") } // Test the API connection with a minimal request ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := p.testConnection(ctx); err != nil { return NewProviderError(ErrProviderUnavailable, fmt.Sprintf("failed to connect to OpenAI: %v", err)) } return nil } // GetProviderInfo returns information about the OpenAI provider func (p *OpenAIProvider) GetProviderInfo() ProviderInfo { endpoint := p.config.Endpoint if endpoint == "" { endpoint = "https://api.openai.com/v1" } return ProviderInfo{ Name: "OpenAI", Type: "openai", Version: "1.0.0", Endpoint: endpoint, DefaultModel: p.config.DefaultModel, RequiresAPIKey: true, RateLimit: 10000, // Approximate RPM for paid accounts } } // buildChatMessages constructs messages for the OpenAI chat completion func (p *OpenAIProvider) buildChatMessages(request *TaskRequest) ([]openai.ChatCompletionMessage, error) { var messages []openai.ChatCompletionMessage // System message systemPrompt := p.getSystemPrompt(request) if systemPrompt != "" { messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleSystem, Content: systemPrompt, }) } // User message with task details userPrompt, err := p.buildTaskPrompt(request) if err != nil { return nil, err } messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: userPrompt, }) return messages, nil } // buildTaskPrompt constructs a comprehensive prompt for task execution func (p *OpenAIProvider) buildTaskPrompt(request *TaskRequest) (string, error) { var prompt strings.Builder prompt.WriteString(fmt.Sprintf("You are working as a %s agent on the following task:\n\n", request.AgentRole)) prompt.WriteString(fmt.Sprintf("**Repository:** %s\n", request.Repository)) prompt.WriteString(fmt.Sprintf("**Task:** %s\n", request.TaskTitle)) prompt.WriteString(fmt.Sprintf("**Description:**\n%s\n\n", request.TaskDescription)) if len(request.TaskLabels) > 0 { prompt.WriteString(fmt.Sprintf("**Labels:** %s\n", strings.Join(request.TaskLabels, ", "))) } prompt.WriteString(fmt.Sprintf("**Priority:** %d/10 | **Complexity:** %d/10\n\n", request.Priority, request.Complexity)) if request.WorkingDirectory != "" { prompt.WriteString(fmt.Sprintf("**Working Directory:** %s\n", request.WorkingDirectory)) } if len(request.RepositoryFiles) > 0 { prompt.WriteString("**Relevant Files:**\n") for _, file := range request.RepositoryFiles { prompt.WriteString(fmt.Sprintf("- %s\n", file)) } prompt.WriteString("\n") } // Add role-specific guidance prompt.WriteString(p.getRoleSpecificGuidance(request.AgentRole)) prompt.WriteString("\nAnalyze this task and provide a detailed implementation plan. ") if request.EnableTools { prompt.WriteString("Use the available tools to make concrete changes or gather information as needed. ") } prompt.WriteString("Be specific about what needs to be done and how to accomplish it.") return prompt.String(), nil } // getRoleSpecificGuidance returns guidance specific to the agent role func (p *OpenAIProvider) getRoleSpecificGuidance(role string) string { switch strings.ToLower(role) { case "developer": return `**Developer Guidelines:** - Write clean, maintainable, and well-documented code - Follow language-specific best practices and conventions - Implement proper error handling and validation - Create or update tests to cover your changes - Consider performance and security implications` case "reviewer": return `**Code Review Guidelines:** - Analyze code quality, readability, and maintainability - Check for bugs, security vulnerabilities, and performance issues - Verify test coverage and quality - Ensure documentation is accurate and complete - Suggest improvements and alternatives` case "architect": return `**Architecture Guidelines:** - Design scalable and maintainable system architecture - Make informed technology and framework decisions - Define clear interfaces and API contracts - Consider security, performance, and scalability requirements - Document architectural decisions and rationale` case "tester": return `**Testing Guidelines:** - Create comprehensive test plans and test cases - Implement unit, integration, and end-to-end tests - Identify edge cases and potential failure scenarios - Set up test automation and continuous integration - Validate functionality against requirements` default: return `**General Guidelines:** - Understand requirements thoroughly before implementation - Follow software development best practices - Provide clear documentation and explanations - Consider maintainability and future extensibility` } } // getToolDefinitions returns tool definitions for OpenAI function calling func (p *OpenAIProvider) getToolDefinitions(request *TaskRequest) []openai.Tool { var tools []openai.Tool // File operations tool tools = append(tools, openai.Tool{ Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{ Name: "file_operation", Description: "Create, read, update, or delete files in the repository", Parameters: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "operation": map[string]interface{}{ "type": "string", "enum": []string{"create", "read", "update", "delete"}, "description": "The file operation to perform", }, "path": map[string]interface{}{ "type": "string", "description": "The file path relative to the repository root", }, "content": map[string]interface{}{ "type": "string", "description": "The file content (for create/update operations)", }, }, "required": []string{"operation", "path"}, }, }, }) // Command execution tool tools = append(tools, openai.Tool{ Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{ Name: "execute_command", Description: "Execute shell commands in the repository working directory", Parameters: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "command": map[string]interface{}{ "type": "string", "description": "The shell command to execute", }, "working_dir": map[string]interface{}{ "type": "string", "description": "Working directory for command execution (optional)", }, }, "required": []string{"command"}, }, }, }) return tools } // processToolCalls handles OpenAI function calls func (p *OpenAIProvider) processToolCalls(toolCalls []openai.ToolCall, request *TaskRequest) ([]TaskAction, []Artifact) { var actions []TaskAction var artifacts []Artifact for _, toolCall := range toolCalls { action := TaskAction{ Type: "function_call", Target: toolCall.Function.Name, Content: toolCall.Function.Arguments, Timestamp: time.Now(), Metadata: map[string]interface{}{ "tool_call_id": toolCall.ID, "function": toolCall.Function.Name, }, } // In a real implementation, you would actually execute these tool calls // For now, just mark them as successful action.Result = fmt.Sprintf("Function call %s processed", toolCall.Function.Name) action.Success = true actions = append(actions, action) } return actions, artifacts } // selectModel chooses the appropriate OpenAI model func (p *OpenAIProvider) selectModel(requestedModel string) string { if requestedModel != "" { return requestedModel } return p.config.DefaultModel } // getTemperature returns the temperature setting func (p *OpenAIProvider) getTemperature(requestTemp float32) float32 { if requestTemp > 0 { return requestTemp } if p.config.Temperature > 0 { return p.config.Temperature } return 0.7 // Default temperature } // getMaxTokens returns the max tokens setting func (p *OpenAIProvider) getMaxTokens(requestTokens int) int { if requestTokens > 0 { return requestTokens } if p.config.MaxTokens > 0 { return p.config.MaxTokens } return 4096 // Default max tokens } // getSystemPrompt constructs the system prompt func (p *OpenAIProvider) getSystemPrompt(request *TaskRequest) string { if request.SystemPrompt != "" { return request.SystemPrompt } return fmt.Sprintf(`You are an expert AI assistant specializing in software development. You are currently operating as a %s agent in the CHORUS autonomous development system. Your capabilities: - Code analysis, implementation, and optimization - Software architecture and design patterns - Testing strategies and implementation - Documentation and technical writing - DevOps and deployment practices Always provide thorough, actionable responses with specific implementation details. When using tools, explain your reasoning and the expected outcomes.`, request.AgentRole) } // getModelMaxTokens returns the maximum tokens for a specific model func (p *OpenAIProvider) getModelMaxTokens(model string) int { switch model { case "gpt-4o", "gpt-4o-2024-05-13": return 128000 case "gpt-4-turbo", "gpt-4-turbo-2024-04-09": return 128000 case "gpt-4", "gpt-4-0613": return 8192 case "gpt-3.5-turbo", "gpt-3.5-turbo-0125": return 16385 default: return 4096 // Conservative default } } // modelSupportsImages checks if a model supports image inputs func (p *OpenAIProvider) modelSupportsImages(model string) bool { visionModels := []string{"gpt-4o", "gpt-4o-2024-05-13", "gpt-4-turbo", "gpt-4-vision-preview"} for _, visionModel := range visionModels { if strings.Contains(model, visionModel) { return true } } return false } // getSupportedModels returns a list of supported OpenAI models func (p *OpenAIProvider) getSupportedModels() []string { return []string{ "gpt-4o", "gpt-4o-2024-05-13", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4", "gpt-4-0613", "gpt-3.5-turbo", "gpt-3.5-turbo-0125", } } // testConnection tests the OpenAI API connection func (p *OpenAIProvider) testConnection(ctx context.Context) error { // Simple test request to verify API key and connection _, err := p.client.ListModels(ctx) return err } // handleOpenAIError converts OpenAI errors to provider errors func (p *OpenAIProvider) handleOpenAIError(err error) *ProviderError { errStr := err.Error() if strings.Contains(errStr, "rate limit") { return &ProviderError{ Code: "RATE_LIMIT_EXCEEDED", Message: "OpenAI API rate limit exceeded", Details: errStr, Retryable: true, } } if strings.Contains(errStr, "quota") { return &ProviderError{ Code: "QUOTA_EXCEEDED", Message: "OpenAI API quota exceeded", Details: errStr, Retryable: false, } } if strings.Contains(errStr, "invalid_api_key") { return &ProviderError{ Code: "INVALID_API_KEY", Message: "Invalid OpenAI API key", Details: errStr, Retryable: false, } } return &ProviderError{ Code: "API_ERROR", Message: "OpenAI API error", Details: errStr, Retryable: true, } } // parseResponseForActions extracts actions from the response text func (p *OpenAIProvider) parseResponseForActions(response string, request *TaskRequest) ([]TaskAction, []Artifact) { var actions []TaskAction var artifacts []Artifact // Create a basic task analysis action action := TaskAction{ Type: "task_analysis", Target: request.TaskTitle, Content: response, Result: "Task analyzed by OpenAI model", Success: true, Timestamp: time.Now(), Metadata: map[string]interface{}{ "agent_role": request.AgentRole, "repository": request.Repository, "model": p.config.DefaultModel, }, } actions = append(actions, action) return actions, artifacts }