package reasoning import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" "chorus/pkg/mcp" ) const ( defaultTimeout = 60 * time.Second ) var ( availableModels []string modelWebhookURL string defaultModel string ollamaEndpoint string = "http://localhost:11434" // Default fallback aiProvider string = "resetdata" // Default provider resetdataConfig ResetDataConfig defaultSystemPrompt string lightragClient *mcp.LightRAGClient // Optional LightRAG client for context enrichment ) // AIProvider represents the AI service provider type AIProvider string const ( ProviderOllama AIProvider = "ollama" ProviderResetData AIProvider = "resetdata" ) // ResetDataConfig holds resetdata API configuration type ResetDataConfig struct { BaseURL string APIKey string Model string Timeout time.Duration } // OllamaRequest represents the request payload for the Ollama API. type OllamaRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` Stream bool `json:"stream"` } // OllamaResponse represents a single streamed response object from the Ollama API. type OllamaResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` Response string `json:"response"` Done bool `json:"done"` } // OpenAIMessage represents a message in the OpenAI API format type OpenAIMessage struct { Role string `json:"role"` Content string `json:"content"` } // OpenAIRequest represents the request payload for OpenAI-compatible APIs type OpenAIRequest struct { Model string `json:"model"` Messages []OpenAIMessage `json:"messages"` Temperature float64 `json:"temperature"` TopP float64 `json:"top_p"` MaxTokens int `json:"max_tokens"` Stream bool `json:"stream"` } // OpenAIChoice represents a choice in the OpenAI response type OpenAIChoice struct { Message struct { Content string `json:"content"` } `json:"message"` } // OpenAIResponse represents the response from OpenAI-compatible APIs type OpenAIResponse struct { Choices []OpenAIChoice `json:"choices"` } // GenerateResponse queries the configured AI provider with a given prompt and model, // and returns the complete generated response as a single string. func GenerateResponse(ctx context.Context, model, prompt string) (string, error) { // Set up a timeout for the request ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() switch AIProvider(aiProvider) { case ProviderResetData: return generateResetDataResponse(ctx, model, prompt) case ProviderOllama: return generateOllamaResponse(ctx, model, prompt) default: // Default to ResetData if unknown provider return generateResetDataResponse(ctx, model, prompt) } } // generateResetDataResponse queries the ResetData API func generateResetDataResponse(ctx context.Context, model, prompt string) (string, error) { if resetdataConfig.APIKey == "" { return "", fmt.Errorf("resetdata API key not configured") } // Use the configured model if provided, otherwise use the one passed in modelToUse := model if resetdataConfig.Model != "" { modelToUse = resetdataConfig.Model } // Ensure the model has the correct format for ResetData if !strings.Contains(modelToUse, ":") && !strings.Contains(modelToUse, "/") { modelToUse = resetdataConfig.Model // Fallback to configured model } // Create the request payload in OpenAI format requestPayload := OpenAIRequest{ Model: modelToUse, Messages: []OpenAIMessage{ {Role: "system", Content: defaultSystemPromptOrFallback()}, {Role: "user", Content: prompt}, }, Temperature: 0.2, TopP: 0.7, MaxTokens: 1024, Stream: false, } payloadBytes, err := json.Marshal(requestPayload) if err != nil { return "", fmt.Errorf("failed to marshal resetdata request: %w", err) } // Create the HTTP request apiURL := resetdataConfig.BaseURL + "/chat/completions" req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(payloadBytes)) if err != nil { return "", fmt.Errorf("failed to create http request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+resetdataConfig.APIKey) // Execute the request resp, err := http.DefaultClient.Do(req) if err != nil { return "", fmt.Errorf("failed to execute http request to resetdata: %w", err) } defer resp.Body.Close() // Check for non-200 status codes if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("resetdata api returned non-200 status: %d - %s", resp.StatusCode, string(bodyBytes)) } // Decode the JSON response var openaiResp OpenAIResponse if err := json.NewDecoder(resp.Body).Decode(&openaiResp); err != nil { return "", fmt.Errorf("failed to decode resetdata response: %w", err) } if len(openaiResp.Choices) == 0 { return "", fmt.Errorf("no choices in resetdata response") } return openaiResp.Choices[0].Message.Content, nil } // generateOllamaResponse queries the Ollama API (legacy support) func generateOllamaResponse(ctx context.Context, model, prompt string) (string, error) { // Create the request payload requestPayload := OllamaRequest{ Model: model, Prompt: prompt, Stream: false, // We will handle the full response at once for simplicity } payloadBytes, err := json.Marshal(requestPayload) if err != nil { return "", fmt.Errorf("failed to marshal ollama request: %w", err) } // Create the HTTP request apiURL := ollamaEndpoint + "/api/generate" req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(payloadBytes)) if err != nil { return "", fmt.Errorf("failed to create http request: %w", err) } req.Header.Set("Content-Type", "application/json") // Execute the request resp, err := http.DefaultClient.Do(req) if err != nil { return "", fmt.Errorf("failed to execute http request to ollama: %w", err) } defer resp.Body.Close() // Check for non-200 status codes if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("ollama api returned non-200 status: %d - %s", resp.StatusCode, string(bodyBytes)) } // Decode the JSON response var ollamaResp OllamaResponse if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil { return "", fmt.Errorf("failed to decode ollama response: %w", err) } return ollamaResp.Response, nil } // SetModelConfig configures the available models and webhook URL for smart model selection func SetModelConfig(models []string, webhookURL, defaultReasoningModel string) { availableModels = models modelWebhookURL = webhookURL defaultModel = defaultReasoningModel } // SetAIProvider configures which AI provider to use func SetAIProvider(provider string) { aiProvider = provider } // SetResetDataConfig configures the ResetData API settings func SetResetDataConfig(config ResetDataConfig) { resetdataConfig = config } // SetOllamaEndpoint configures the Ollama API endpoint func SetOllamaEndpoint(endpoint string) { ollamaEndpoint = endpoint } // SetDefaultSystemPrompt configures the default system message used when building prompts. func SetDefaultSystemPrompt(systemPrompt string) { defaultSystemPrompt = systemPrompt } // SetLightRAGClient configures the optional LightRAG client for context enrichment func SetLightRAGClient(client *mcp.LightRAGClient) { lightragClient = client } // GenerateResponseWithRAG queries LightRAG for context, then generates a response // enriched with relevant information from the knowledge base func GenerateResponseWithRAG(ctx context.Context, model, prompt string, queryMode mcp.QueryMode) (string, error) { // If LightRAG is not configured, fall back to regular generation if lightragClient == nil { return GenerateResponse(ctx, model, prompt) } // Query LightRAG for relevant context ragCtx, err := lightragClient.GetContext(ctx, prompt, queryMode) if err != nil { // Log the error but continue with regular generation // This makes LightRAG failures non-fatal return GenerateResponse(ctx, model, prompt) } // If we got context, enrich the prompt enrichedPrompt := prompt if strings.TrimSpace(ragCtx) != "" { enrichedPrompt = fmt.Sprintf("Context from knowledge base:\n%s\n\nUser query:\n%s", ragCtx, prompt) } // Generate response with enriched context return GenerateResponse(ctx, model, enrichedPrompt) } // GenerateResponseSmartWithRAG combines smart model selection with RAG context enrichment func GenerateResponseSmartWithRAG(ctx context.Context, prompt string, queryMode mcp.QueryMode) (string, error) { selectedModel := selectBestModel(availableModels, prompt) return GenerateResponseWithRAG(ctx, selectedModel, prompt, queryMode) } // selectBestModel calls the model selection webhook to choose the best model for a prompt func selectBestModel(availableModels []string, prompt string) string { if modelWebhookURL == "" || len(availableModels) == 0 { // Fallback to first available model if len(availableModels) > 0 { return availableModels[0] } return defaultModel // Last resort fallback } requestPayload := map[string]interface{}{ "models": availableModels, "prompt": prompt, } payloadBytes, err := json.Marshal(requestPayload) if err != nil { // Fallback on error return availableModels[0] } resp, err := http.Post(modelWebhookURL, "application/json", bytes.NewBuffer(payloadBytes)) if err != nil { // Fallback on error return availableModels[0] } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { // Fallback on error return availableModels[0] } var response struct { Model string `json:"model"` } if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { // Fallback on error return availableModels[0] } // Validate that the returned model is in our available list for _, model := range availableModels { if model == response.Model { return response.Model } } // Fallback if webhook returned invalid model return availableModels[0] } // GenerateResponseSmart automatically selects the best model for the prompt func GenerateResponseSmart(ctx context.Context, prompt string) (string, error) { selectedModel := selectBestModel(availableModels, prompt) return GenerateResponse(ctx, selectedModel, prompt) } func defaultSystemPromptOrFallback() string { if strings.TrimSpace(defaultSystemPrompt) != "" { return defaultSystemPrompt } return "You are a helpful assistant." }