diff --git a/cmd/test-llm/main.go b/cmd/test-llm/main.go new file mode 100644 index 0000000..127b977 --- /dev/null +++ b/cmd/test-llm/main.go @@ -0,0 +1,101 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "time" + + "github.com/chorus-services/whoosh/internal/composer" +) + +func main() { + log.Println("๐Ÿงช Testing WHOOSH LLM Integration") + + // Create a test configuration with LLM features enabled + config := composer.DefaultComposerConfig() + config.FeatureFlags.EnableLLMClassification = true + config.FeatureFlags.EnableLLMSkillAnalysis = true + config.FeatureFlags.EnableAnalysisLogging = true + config.FeatureFlags.EnableFailsafeFallback = true + + // Create service without database for this test + service := composer.NewService(nil, config) + + // Test input - simulating WHOOSH-LLM-002 task + testInput := &composer.TaskAnalysisInput{ + Title: "WHOOSH-LLM-002: Implement LLM Integration for Team Composition Engine", + Description: "Implement LLM-powered task classification and skill requirement analysis using Ollama API. Replace stubbed functions with real AI-powered analysis.", + Requirements: []string{ + "Connect to Ollama API endpoints", + "Implement task classification with LLM", + "Implement skill requirement analysis", + "Add error handling and fallback to heuristics", + "Support feature flags for LLM vs heuristic execution", + }, + Repository: "https://gitea.chorus.services/tony/WHOOSH", + Priority: composer.PriorityHigh, + TechStack: []string{"Go", "Docker", "Ollama", "PostgreSQL", "HTTP API"}, + } + + ctx := context.Background() + + log.Println("๐Ÿ“Š Testing LLM Task Classification...") + startTime := time.Now() + + // Test task classification + classification, err := testTaskClassification(ctx, service, testInput) + if err != nil { + log.Fatalf("โŒ Task classification failed: %v", err) + } + + classificationDuration := time.Since(startTime) + log.Printf("โœ… Task Classification completed in %v", classificationDuration) + printClassification(classification) + + log.Println("\n๐Ÿ” Testing LLM Skill Analysis...") + startTime = time.Now() + + // Test skill analysis + skillRequirements, err := testSkillAnalysis(ctx, service, testInput, classification) + if err != nil { + log.Fatalf("โŒ Skill analysis failed: %v", err) + } + + skillDuration := time.Since(startTime) + log.Printf("โœ… Skill Analysis completed in %v", skillDuration) + printSkillRequirements(skillRequirements) + + totalTime := classificationDuration + skillDuration + log.Printf("\n๐Ÿ Total LLM processing time: %v", totalTime) + + if totalTime > 5*time.Second { + log.Printf("โš ๏ธ Warning: Total time (%v) exceeds 5s requirement", totalTime) + } else { + log.Printf("โœ… Performance requirement met (< 5s)") + } + + log.Println("\n๐ŸŽ‰ LLM Integration test completed successfully!") +} + +func testTaskClassification(ctx context.Context, service *composer.Service, input *composer.TaskAnalysisInput) (*composer.TaskClassification, error) { + // Use reflection to access private method for testing + // In a real test, we'd create public test methods + return service.DetermineTaskType(input.Title, input.Description), nil +} + +func testSkillAnalysis(ctx context.Context, service *composer.Service, input *composer.TaskAnalysisInput, classification *composer.TaskClassification) (*composer.SkillRequirements, error) { + // Test the skill analysis using the public test method + return service.AnalyzeSkillRequirementsLocal(input, classification) +} + +func printClassification(classification *composer.TaskClassification) { + data, _ := json.MarshalIndent(classification, " ", " ") + fmt.Printf(" Classification Result:\n %s\n", string(data)) +} + +func printSkillRequirements(requirements *composer.SkillRequirements) { + data, _ := json.MarshalIndent(requirements, " ", " ") + fmt.Printf(" Skill Requirements:\n %s\n", string(data)) +} \ No newline at end of file diff --git a/go.mod b/go.mod index b57ef22..16bac74 100644 --- a/go.mod +++ b/go.mod @@ -58,4 +58,4 @@ require ( gotest.tools/v3 v3.5.2 // indirect ) -replace github.com/chorus-services/backbeat => ./BACKBEAT-prototype +replace github.com/chorus-services/backbeat => ../BACKBEAT/backbeat/prototype diff --git a/internal/orchestrator/assignment_broker.go b/internal/orchestrator/assignment_broker.go new file mode 100644 index 0000000..df5e984 --- /dev/null +++ b/internal/orchestrator/assignment_broker.go @@ -0,0 +1,501 @@ +package orchestrator + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "net/http" + "strconv" + "sync" + "time" + + "github.com/gorilla/mux" + "github.com/rs/zerolog/log" + "go.opentelemetry.io/otel/attribute" + + "github.com/chorus-services/whoosh/internal/tracing" +) + +// AssignmentBroker manages per-replica assignments for CHORUS instances +type AssignmentBroker struct { + mu sync.RWMutex + assignments map[string]*Assignment + templates map[string]*AssignmentTemplate + bootstrap *BootstrapPoolManager +} + +// Assignment represents a configuration assignment for a CHORUS replica +type Assignment struct { + ID string `json:"id"` + TaskSlot string `json:"task_slot,omitempty"` + TaskID string `json:"task_id,omitempty"` + ClusterID string `json:"cluster_id"` + Role string `json:"role"` + Model string `json:"model"` + PromptUCXL string `json:"prompt_ucxl,omitempty"` + Specialization string `json:"specialization"` + Capabilities []string `json:"capabilities"` + Environment map[string]string `json:"environment,omitempty"` + BootstrapPeers []string `json:"bootstrap_peers"` + JoinStaggerMS int `json:"join_stagger_ms"` + DialsPerSecond int `json:"dials_per_second"` + MaxConcurrentDHT int `json:"max_concurrent_dht"` + ConfigEpoch int64 `json:"config_epoch"` + AssignedAt time.Time `json:"assigned_at"` + ExpiresAt time.Time `json:"expires_at,omitempty"` +} + +// AssignmentTemplate defines a template for creating assignments +type AssignmentTemplate struct { + Name string `json:"name"` + Role string `json:"role"` + Model string `json:"model"` + PromptUCXL string `json:"prompt_ucxl,omitempty"` + Specialization string `json:"specialization"` + Capabilities []string `json:"capabilities"` + Environment map[string]string `json:"environment,omitempty"` + + // Scaling configuration + DialsPerSecond int `json:"dials_per_second"` + MaxConcurrentDHT int `json:"max_concurrent_dht"` + BootstrapPeerCount int `json:"bootstrap_peer_count"` // How many bootstrap peers to assign + MaxStaggerMS int `json:"max_stagger_ms"` // Maximum stagger delay +} + +// AssignmentRequest represents a request for assignment +type AssignmentRequest struct { + TaskSlot string `json:"task_slot,omitempty"` + TaskID string `json:"task_id,omitempty"` + ClusterID string `json:"cluster_id"` + Template string `json:"template,omitempty"` // Template name to use + Role string `json:"role,omitempty"` // Override role + Model string `json:"model,omitempty"` // Override model +} + +// AssignmentStats represents statistics about assignments +type AssignmentStats struct { + TotalAssignments int `json:"total_assignments"` + AssignmentsByRole map[string]int `json:"assignments_by_role"` + AssignmentsByModel map[string]int `json:"assignments_by_model"` + ActiveAssignments int `json:"active_assignments"` + ExpiredAssignments int `json:"expired_assignments"` + TemplateCount int `json:"template_count"` + AvgStaggerMS float64 `json:"avg_stagger_ms"` +} + +// NewAssignmentBroker creates a new assignment broker +func NewAssignmentBroker(bootstrapManager *BootstrapPoolManager) *AssignmentBroker { + broker := &AssignmentBroker{ + assignments: make(map[string]*Assignment), + templates: make(map[string]*AssignmentTemplate), + bootstrap: bootstrapManager, + } + + // Initialize default templates + broker.initializeDefaultTemplates() + + return broker +} + +// initializeDefaultTemplates sets up default assignment templates +func (ab *AssignmentBroker) initializeDefaultTemplates() { + defaultTemplates := []*AssignmentTemplate{ + { + Name: "general-developer", + Role: "developer", + Model: "meta/llama-3.1-8b-instruct", + Specialization: "general_developer", + Capabilities: []string{"general_development", "task_coordination"}, + DialsPerSecond: 5, + MaxConcurrentDHT: 16, + BootstrapPeerCount: 3, + MaxStaggerMS: 20000, + }, + { + Name: "code-reviewer", + Role: "reviewer", + Model: "meta/llama-3.1-70b-instruct", + Specialization: "code_reviewer", + Capabilities: []string{"code_review", "quality_assurance"}, + DialsPerSecond: 3, + MaxConcurrentDHT: 8, + BootstrapPeerCount: 2, + MaxStaggerMS: 15000, + }, + { + Name: "task-coordinator", + Role: "coordinator", + Model: "meta/llama-3.1-8b-instruct", + Specialization: "task_coordinator", + Capabilities: []string{"task_coordination", "planning"}, + DialsPerSecond: 8, + MaxConcurrentDHT: 24, + BootstrapPeerCount: 4, + MaxStaggerMS: 10000, + }, + { + Name: "admin", + Role: "admin", + Model: "meta/llama-3.1-70b-instruct", + Specialization: "system_admin", + Capabilities: []string{"administration", "leadership", "slurp_operations"}, + DialsPerSecond: 10, + MaxConcurrentDHT: 32, + BootstrapPeerCount: 5, + MaxStaggerMS: 5000, + }, + } + + for _, template := range defaultTemplates { + ab.templates[template.Name] = template + } + + log.Info().Int("template_count", len(defaultTemplates)).Msg("Initialized default assignment templates") +} + +// RegisterRoutes registers HTTP routes for the assignment broker +func (ab *AssignmentBroker) RegisterRoutes(router *mux.Router) { + router.HandleFunc("/assign", ab.handleAssignRequest).Methods("GET") + router.HandleFunc("/assignments", ab.handleListAssignments).Methods("GET") + router.HandleFunc("/assignments/{id}", ab.handleGetAssignment).Methods("GET") + router.HandleFunc("/assignments/{id}", ab.handleDeleteAssignment).Methods("DELETE") + router.HandleFunc("/templates", ab.handleListTemplates).Methods("GET") + router.HandleFunc("/templates", ab.handleCreateTemplate).Methods("POST") + router.HandleFunc("/templates/{name}", ab.handleGetTemplate).Methods("GET") + router.HandleFunc("/assignments/stats", ab.handleGetStats).Methods("GET") +} + +// handleAssignRequest handles requests for new assignments +func (ab *AssignmentBroker) handleAssignRequest(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "assignment_broker.assign_request") + defer span.End() + + // Parse query parameters + req := AssignmentRequest{ + TaskSlot: r.URL.Query().Get("slot"), + TaskID: r.URL.Query().Get("task"), + ClusterID: r.URL.Query().Get("cluster"), + Template: r.URL.Query().Get("template"), + Role: r.URL.Query().Get("role"), + Model: r.URL.Query().Get("model"), + } + + // Default cluster ID if not provided + if req.ClusterID == "" { + req.ClusterID = "default" + } + + // Default template if not provided + if req.Template == "" { + req.Template = "general-developer" + } + + span.SetAttributes( + attribute.String("assignment.cluster_id", req.ClusterID), + attribute.String("assignment.template", req.Template), + attribute.String("assignment.task_slot", req.TaskSlot), + attribute.String("assignment.task_id", req.TaskID), + ) + + // Create assignment + assignment, err := ab.CreateAssignment(ctx, req) + if err != nil { + log.Error().Err(err).Msg("Failed to create assignment") + http.Error(w, fmt.Sprintf("Failed to create assignment: %v", err), http.StatusInternalServerError) + return + } + + log.Info(). + Str("assignment_id", assignment.ID). + Str("role", assignment.Role). + Str("model", assignment.Model). + Str("cluster_id", assignment.ClusterID). + Msg("Created assignment") + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(assignment) +} + +// handleListAssignments returns all active assignments +func (ab *AssignmentBroker) handleListAssignments(w http.ResponseWriter, r *http.Request) { + ab.mu.RLock() + defer ab.mu.RUnlock() + + assignments := make([]*Assignment, 0, len(ab.assignments)) + for _, assignment := range ab.assignments { + // Only return non-expired assignments + if assignment.ExpiresAt.IsZero() || time.Now().Before(assignment.ExpiresAt) { + assignments = append(assignments, assignment) + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(assignments) +} + +// handleGetAssignment returns a specific assignment by ID +func (ab *AssignmentBroker) handleGetAssignment(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + assignmentID := vars["id"] + + ab.mu.RLock() + assignment, exists := ab.assignments[assignmentID] + ab.mu.RUnlock() + + if !exists { + http.Error(w, "Assignment not found", http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(assignment) +} + +// handleDeleteAssignment deletes an assignment +func (ab *AssignmentBroker) handleDeleteAssignment(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + assignmentID := vars["id"] + + ab.mu.Lock() + defer ab.mu.Unlock() + + if _, exists := ab.assignments[assignmentID]; !exists { + http.Error(w, "Assignment not found", http.StatusNotFound) + return + } + + delete(ab.assignments, assignmentID) + log.Info().Str("assignment_id", assignmentID).Msg("Deleted assignment") + + w.WriteHeader(http.StatusNoContent) +} + +// handleListTemplates returns all available templates +func (ab *AssignmentBroker) handleListTemplates(w http.ResponseWriter, r *http.Request) { + ab.mu.RLock() + defer ab.mu.RUnlock() + + templates := make([]*AssignmentTemplate, 0, len(ab.templates)) + for _, template := range ab.templates { + templates = append(templates, template) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(templates) +} + +// handleCreateTemplate creates a new assignment template +func (ab *AssignmentBroker) handleCreateTemplate(w http.ResponseWriter, r *http.Request) { + var template AssignmentTemplate + if err := json.NewDecoder(r.Body).Decode(&template); err != nil { + http.Error(w, "Invalid template data", http.StatusBadRequest) + return + } + + if template.Name == "" { + http.Error(w, "Template name is required", http.StatusBadRequest) + return + } + + ab.mu.Lock() + ab.templates[template.Name] = &template + ab.mu.Unlock() + + log.Info().Str("template_name", template.Name).Msg("Created assignment template") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(&template) +} + +// handleGetTemplate returns a specific template +func (ab *AssignmentBroker) handleGetTemplate(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + templateName := vars["name"] + + ab.mu.RLock() + template, exists := ab.templates[templateName] + ab.mu.RUnlock() + + if !exists { + http.Error(w, "Template not found", http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(template) +} + +// handleGetStats returns assignment statistics +func (ab *AssignmentBroker) handleGetStats(w http.ResponseWriter, r *http.Request) { + stats := ab.GetStats() + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(stats) +} + +// CreateAssignment creates a new assignment from a request +func (ab *AssignmentBroker) CreateAssignment(ctx context.Context, req AssignmentRequest) (*Assignment, error) { + ab.mu.Lock() + defer ab.mu.Unlock() + + // Get template + template, exists := ab.templates[req.Template] + if !exists { + return nil, fmt.Errorf("template '%s' not found", req.Template) + } + + // Generate assignment ID + assignmentID := ab.generateAssignmentID(req) + + // Get bootstrap peer subset + var bootstrapPeers []string + if ab.bootstrap != nil { + subset := ab.bootstrap.GetSubset(template.BootstrapPeerCount) + for _, peer := range subset.Peers { + bootstrapPeers = append(bootstrapPeers, fmt.Sprintf("%s/p2p/%s", peer.Addrs[0], peer.ID)) + } + } + + // Generate stagger delay + staggerMS := 0 + if template.MaxStaggerMS > 0 { + staggerMS = rand.Intn(template.MaxStaggerMS) + } + + // Create assignment + assignment := &Assignment{ + ID: assignmentID, + TaskSlot: req.TaskSlot, + TaskID: req.TaskID, + ClusterID: req.ClusterID, + Role: template.Role, + Model: template.Model, + PromptUCXL: template.PromptUCXL, + Specialization: template.Specialization, + Capabilities: template.Capabilities, + Environment: make(map[string]string), + BootstrapPeers: bootstrapPeers, + JoinStaggerMS: staggerMS, + DialsPerSecond: template.DialsPerSecond, + MaxConcurrentDHT: template.MaxConcurrentDHT, + ConfigEpoch: time.Now().Unix(), + AssignedAt: time.Now(), + ExpiresAt: time.Now().Add(24 * time.Hour), // 24 hour default expiry + } + + // Apply request overrides + if req.Role != "" { + assignment.Role = req.Role + } + if req.Model != "" { + assignment.Model = req.Model + } + + // Copy environment from template + for key, value := range template.Environment { + assignment.Environment[key] = value + } + + // Add assignment-specific environment + assignment.Environment["ASSIGNMENT_ID"] = assignmentID + assignment.Environment["CONFIG_EPOCH"] = strconv.FormatInt(assignment.ConfigEpoch, 10) + assignment.Environment["DISABLE_MDNS"] = "true" + assignment.Environment["DIALS_PER_SEC"] = strconv.Itoa(assignment.DialsPerSecond) + assignment.Environment["MAX_CONCURRENT_DHT"] = strconv.Itoa(assignment.MaxConcurrentDHT) + assignment.Environment["JOIN_STAGGER_MS"] = strconv.Itoa(assignment.JoinStaggerMS) + + // Store assignment + ab.assignments[assignmentID] = assignment + + return assignment, nil +} + +// generateAssignmentID generates a unique assignment ID +func (ab *AssignmentBroker) generateAssignmentID(req AssignmentRequest) string { + timestamp := time.Now().Unix() + + if req.TaskSlot != "" && req.TaskID != "" { + return fmt.Sprintf("assign-%s-%s-%d", req.TaskSlot, req.TaskID, timestamp) + } + + if req.TaskSlot != "" { + return fmt.Sprintf("assign-%s-%d", req.TaskSlot, timestamp) + } + + return fmt.Sprintf("assign-%s-%d", req.ClusterID, timestamp) +} + +// GetStats returns assignment statistics +func (ab *AssignmentBroker) GetStats() *AssignmentStats { + ab.mu.RLock() + defer ab.mu.RUnlock() + + stats := &AssignmentStats{ + TotalAssignments: len(ab.assignments), + AssignmentsByRole: make(map[string]int), + AssignmentsByModel: make(map[string]int), + TemplateCount: len(ab.templates), + } + + var totalStagger int + activeCount := 0 + expiredCount := 0 + now := time.Now() + + for _, assignment := range ab.assignments { + // Count by role + stats.AssignmentsByRole[assignment.Role]++ + + // Count by model + stats.AssignmentsByModel[assignment.Model]++ + + // Track stagger for average + totalStagger += assignment.JoinStaggerMS + + // Count active vs expired + if assignment.ExpiresAt.IsZero() || now.Before(assignment.ExpiresAt) { + activeCount++ + } else { + expiredCount++ + } + } + + stats.ActiveAssignments = activeCount + stats.ExpiredAssignments = expiredCount + + if len(ab.assignments) > 0 { + stats.AvgStaggerMS = float64(totalStagger) / float64(len(ab.assignments)) + } + + return stats +} + +// CleanupExpiredAssignments removes expired assignments +func (ab *AssignmentBroker) CleanupExpiredAssignments() { + ab.mu.Lock() + defer ab.mu.Unlock() + + now := time.Now() + expiredCount := 0 + + for id, assignment := range ab.assignments { + if !assignment.ExpiresAt.IsZero() && now.After(assignment.ExpiresAt) { + delete(ab.assignments, id) + expiredCount++ + } + } + + if expiredCount > 0 { + log.Info().Int("expired_count", expiredCount).Msg("Cleaned up expired assignments") + } +} + +// GetAssignment returns an assignment by ID +func (ab *AssignmentBroker) GetAssignment(id string) (*Assignment, bool) { + ab.mu.RLock() + defer ab.mu.RUnlock() + + assignment, exists := ab.assignments[id] + return assignment, exists +} \ No newline at end of file diff --git a/internal/orchestrator/bootstrap_pool.go b/internal/orchestrator/bootstrap_pool.go new file mode 100644 index 0000000..188a458 --- /dev/null +++ b/internal/orchestrator/bootstrap_pool.go @@ -0,0 +1,444 @@ +package orchestrator + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "net/http" + "sync" + "time" + + "github.com/rs/zerolog/log" + "go.opentelemetry.io/otel/attribute" + + "github.com/chorus-services/whoosh/internal/tracing" +) + +// BootstrapPoolManager manages the pool of bootstrap peers for CHORUS instances +type BootstrapPoolManager struct { + mu sync.RWMutex + peers []BootstrapPeer + chorusNodes map[string]CHORUSNodeInfo + updateInterval time.Duration + healthCheckTimeout time.Duration + httpClient *http.Client +} + +// BootstrapPeer represents a bootstrap peer in the pool +type BootstrapPeer struct { + ID string `json:"id"` // Peer ID + Addresses []string `json:"addresses"` // Multiaddresses + Priority int `json:"priority"` // Priority (higher = more likely to be selected) + Healthy bool `json:"healthy"` // Health status + LastSeen time.Time `json:"last_seen"` // Last seen timestamp + NodeInfo CHORUSNodeInfo `json:"node_info,omitempty"` // Associated CHORUS node info +} + +// CHORUSNodeInfo represents information about a CHORUS node +type CHORUSNodeInfo struct { + AgentID string `json:"agent_id"` + Role string `json:"role"` + Specialization string `json:"specialization"` + Capabilities []string `json:"capabilities"` + LastHeartbeat time.Time `json:"last_heartbeat"` + Healthy bool `json:"healthy"` + IsBootstrap bool `json:"is_bootstrap"` +} + +// BootstrapSubset represents a subset of peers assigned to a replica +type BootstrapSubset struct { + Peers []BootstrapPeer `json:"peers"` + AssignedAt time.Time `json:"assigned_at"` + RequestedBy string `json:"requested_by,omitempty"` +} + +// BootstrapPoolConfig represents configuration for the bootstrap pool +type BootstrapPoolConfig struct { + MinPoolSize int `json:"min_pool_size"` // Minimum peers to maintain + MaxPoolSize int `json:"max_pool_size"` // Maximum peers in pool + HealthCheckInterval time.Duration `json:"health_check_interval"` // How often to check peer health + StaleThreshold time.Duration `json:"stale_threshold"` // When to consider a peer stale + PreferredRoles []string `json:"preferred_roles"` // Preferred roles for bootstrap peers +} + +// BootstrapPoolStats represents statistics about the bootstrap pool +type BootstrapPoolStats struct { + TotalPeers int `json:"total_peers"` + HealthyPeers int `json:"healthy_peers"` + UnhealthyPeers int `json:"unhealthy_peers"` + StalePeers int `json:"stale_peers"` + PeersByRole map[string]int `json:"peers_by_role"` + LastUpdated time.Time `json:"last_updated"` + AvgLatency float64 `json:"avg_latency_ms"` +} + +// NewBootstrapPoolManager creates a new bootstrap pool manager +func NewBootstrapPoolManager(config BootstrapPoolConfig) *BootstrapPoolManager { + if config.MinPoolSize == 0 { + config.MinPoolSize = 5 + } + if config.MaxPoolSize == 0 { + config.MaxPoolSize = 30 + } + if config.HealthCheckInterval == 0 { + config.HealthCheckInterval = 2 * time.Minute + } + if config.StaleThreshold == 0 { + config.StaleThreshold = 10 * time.Minute + } + + return &BootstrapPoolManager{ + peers: make([]BootstrapPeer, 0), + chorusNodes: make(map[string]CHORUSNodeInfo), + updateInterval: config.HealthCheckInterval, + healthCheckTimeout: 10 * time.Second, + httpClient: &http.Client{Timeout: 10 * time.Second}, + } +} + +// Start begins the bootstrap pool management process +func (bpm *BootstrapPoolManager) Start(ctx context.Context) { + log.Info().Msg("Starting bootstrap pool manager") + + // Start periodic health checks + ticker := time.NewTicker(bpm.updateInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info().Msg("Bootstrap pool manager stopping") + return + case <-ticker.C: + if err := bpm.updatePeerHealth(ctx); err != nil { + log.Error().Err(err).Msg("Failed to update peer health") + } + } + } +} + +// AddPeer adds a new peer to the bootstrap pool +func (bpm *BootstrapPoolManager) AddPeer(peer BootstrapPeer) { + bpm.mu.Lock() + defer bpm.mu.Unlock() + + // Check if peer already exists + for i, existingPeer := range bpm.peers { + if existingPeer.ID == peer.ID { + // Update existing peer + bpm.peers[i] = peer + log.Debug().Str("peer_id", peer.ID).Msg("Updated existing bootstrap peer") + return + } + } + + // Add new peer + peer.LastSeen = time.Now() + bpm.peers = append(bpm.peers, peer) + log.Info().Str("peer_id", peer.ID).Msg("Added new bootstrap peer") +} + +// RemovePeer removes a peer from the bootstrap pool +func (bpm *BootstrapPoolManager) RemovePeer(peerID string) { + bpm.mu.Lock() + defer bpm.mu.Unlock() + + for i, peer := range bpm.peers { + if peer.ID == peerID { + // Remove peer by swapping with last element + bpm.peers[i] = bpm.peers[len(bpm.peers)-1] + bpm.peers = bpm.peers[:len(bpm.peers)-1] + log.Info().Str("peer_id", peerID).Msg("Removed bootstrap peer") + return + } + } +} + +// GetSubset returns a subset of healthy bootstrap peers +func (bpm *BootstrapPoolManager) GetSubset(count int) BootstrapSubset { + bpm.mu.RLock() + defer bpm.mu.RUnlock() + + // Filter healthy peers + var healthyPeers []BootstrapPeer + for _, peer := range bpm.peers { + if peer.Healthy && time.Since(peer.LastSeen) < 10*time.Minute { + healthyPeers = append(healthyPeers, peer) + } + } + + if len(healthyPeers) == 0 { + log.Warn().Msg("No healthy bootstrap peers available") + return BootstrapSubset{ + Peers: []BootstrapPeer{}, + AssignedAt: time.Now(), + } + } + + // Ensure count doesn't exceed available peers + if count > len(healthyPeers) { + count = len(healthyPeers) + } + + // Select peers with weighted random selection based on priority + selectedPeers := bpm.selectWeightedRandomPeers(healthyPeers, count) + + return BootstrapSubset{ + Peers: selectedPeers, + AssignedAt: time.Now(), + } +} + +// selectWeightedRandomPeers selects peers using weighted random selection +func (bpm *BootstrapPoolManager) selectWeightedRandomPeers(peers []BootstrapPeer, count int) []BootstrapPeer { + if count >= len(peers) { + return peers + } + + // Calculate total weight + totalWeight := 0 + for _, peer := range peers { + weight := peer.Priority + if weight <= 0 { + weight = 1 // Minimum weight + } + totalWeight += weight + } + + selected := make([]BootstrapPeer, 0, count) + usedIndices := make(map[int]bool) + + for len(selected) < count { + // Random selection with weight + randWeight := rand.Intn(totalWeight) + currentWeight := 0 + + for i, peer := range peers { + if usedIndices[i] { + continue + } + + weight := peer.Priority + if weight <= 0 { + weight = 1 + } + currentWeight += weight + + if randWeight < currentWeight { + selected = append(selected, peer) + usedIndices[i] = true + break + } + } + + // Prevent infinite loop if we can't find more unique peers + if len(selected) == len(peers)-len(usedIndices) { + break + } + } + + return selected +} + +// DiscoverPeersFromCHORUS discovers bootstrap peers from existing CHORUS nodes +func (bpm *BootstrapPoolManager) DiscoverPeersFromCHORUS(ctx context.Context, chorusEndpoints []string) error { + ctx, span := tracing.Tracer.Start(ctx, "bootstrap_pool.discover_peers") + defer span.End() + + discoveredCount := 0 + + for _, endpoint := range chorusEndpoints { + if err := bpm.discoverFromEndpoint(ctx, endpoint); err != nil { + log.Warn().Str("endpoint", endpoint).Err(err).Msg("Failed to discover peers from CHORUS endpoint") + continue + } + discoveredCount++ + } + + span.SetAttributes( + attribute.Int("discovery.endpoints_checked", len(chorusEndpoints)), + attribute.Int("discovery.successful_discoveries", discoveredCount), + ) + + log.Info(). + Int("endpoints_checked", len(chorusEndpoints)). + Int("successful_discoveries", discoveredCount). + Msg("Completed peer discovery from CHORUS nodes") + + return nil +} + +// discoverFromEndpoint discovers peers from a single CHORUS endpoint +func (bpm *BootstrapPoolManager) discoverFromEndpoint(ctx context.Context, endpoint string) error { + url := fmt.Sprintf("%s/api/v1/peers", endpoint) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return fmt.Errorf("failed to create discovery request: %w", err) + } + + resp, err := bpm.httpClient.Do(req) + if err != nil { + return fmt.Errorf("discovery request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("discovery request returned status %d", resp.StatusCode) + } + + var peerInfo struct { + Peers []BootstrapPeer `json:"peers"` + NodeInfo CHORUSNodeInfo `json:"node_info"` + } + + if err := json.NewDecoder(resp.Body).Decode(&peerInfo); err != nil { + return fmt.Errorf("failed to decode peer discovery response: %w", err) + } + + // Add discovered peers to pool + for _, peer := range peerInfo.Peers { + peer.NodeInfo = peerInfo.NodeInfo + peer.Healthy = true + peer.LastSeen = time.Now() + + // Set priority based on role + if bpm.isPreferredRole(peer.NodeInfo.Role) { + peer.Priority = 100 + } else { + peer.Priority = 50 + } + + bpm.AddPeer(peer) + } + + return nil +} + +// isPreferredRole checks if a role is preferred for bootstrap peers +func (bpm *BootstrapPoolManager) isPreferredRole(role string) bool { + preferredRoles := []string{"admin", "coordinator", "stable"} + for _, preferred := range preferredRoles { + if role == preferred { + return true + } + } + return false +} + +// updatePeerHealth updates the health status of all peers +func (bpm *BootstrapPoolManager) updatePeerHealth(ctx context.Context) error { + bpm.mu.Lock() + defer bpm.mu.Unlock() + + ctx, span := tracing.Tracer.Start(ctx, "bootstrap_pool.update_health") + defer span.End() + + healthyCount := 0 + checkedCount := 0 + + for i := range bpm.peers { + peer := &bpm.peers[i] + + // Check if peer is stale + if time.Since(peer.LastSeen) > 10*time.Minute { + peer.Healthy = false + continue + } + + // Health check via ping (if addresses are available) + if len(peer.Addresses) > 0 { + if bpm.pingPeer(ctx, peer) { + peer.Healthy = true + peer.LastSeen = time.Now() + healthyCount++ + } else { + peer.Healthy = false + } + checkedCount++ + } + } + + span.SetAttributes( + attribute.Int("health_check.checked_count", checkedCount), + attribute.Int("health_check.healthy_count", healthyCount), + attribute.Int("health_check.total_peers", len(bpm.peers)), + ) + + log.Debug(). + Int("checked", checkedCount). + Int("healthy", healthyCount). + Int("total", len(bpm.peers)). + Msg("Updated bootstrap peer health") + + return nil +} + +// pingPeer performs a simple connectivity check to a peer +func (bpm *BootstrapPoolManager) pingPeer(ctx context.Context, peer *BootstrapPeer) bool { + // For now, just return true if the peer was seen recently + // In a real implementation, this would do a libp2p ping or HTTP health check + return time.Since(peer.LastSeen) < 5*time.Minute +} + +// GetStats returns statistics about the bootstrap pool +func (bpm *BootstrapPoolManager) GetStats() BootstrapPoolStats { + bpm.mu.RLock() + defer bpm.mu.RUnlock() + + stats := BootstrapPoolStats{ + TotalPeers: len(bpm.peers), + PeersByRole: make(map[string]int), + LastUpdated: time.Now(), + } + + staleCutoff := time.Now().Add(-10 * time.Minute) + + for _, peer := range bpm.peers { + // Count by health status + if peer.Healthy { + stats.HealthyPeers++ + } else { + stats.UnhealthyPeers++ + } + + // Count stale peers + if peer.LastSeen.Before(staleCutoff) { + stats.StalePeers++ + } + + // Count by role + role := peer.NodeInfo.Role + if role == "" { + role = "unknown" + } + stats.PeersByRole[role]++ + } + + return stats +} + +// GetHealthyPeerCount returns the number of healthy peers +func (bpm *BootstrapPoolManager) GetHealthyPeerCount() int { + bpm.mu.RLock() + defer bpm.mu.RUnlock() + + count := 0 + for _, peer := range bpm.peers { + if peer.Healthy && time.Since(peer.LastSeen) < 10*time.Minute { + count++ + } + } + return count +} + +// GetAllPeers returns all peers in the pool (for debugging) +func (bpm *BootstrapPoolManager) GetAllPeers() []BootstrapPeer { + bpm.mu.RLock() + defer bpm.mu.RUnlock() + + peers := make([]BootstrapPeer, len(bpm.peers)) + copy(peers, bpm.peers) + return peers +} \ No newline at end of file diff --git a/internal/orchestrator/health_gates.go b/internal/orchestrator/health_gates.go new file mode 100644 index 0000000..3750be4 --- /dev/null +++ b/internal/orchestrator/health_gates.go @@ -0,0 +1,408 @@ +package orchestrator + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/rs/zerolog/log" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/chorus-services/whoosh/internal/tracing" +) + +// HealthGates manages health checks that gate scaling operations +type HealthGates struct { + kachingURL string + backbeatURL string + chorusURL string + httpClient *http.Client + thresholds HealthThresholds +} + +// HealthThresholds defines the health criteria for allowing scaling +type HealthThresholds struct { + KachingMaxLatencyMS int `json:"kaching_max_latency_ms"` // Maximum acceptable KACHING latency + KachingMinRateRemaining int `json:"kaching_min_rate_remaining"` // Minimum rate limit remaining + BackbeatMaxLagSeconds int `json:"backbeat_max_lag_seconds"` // Maximum subject lag in seconds + BootstrapMinHealthyPeers int `json:"bootstrap_min_healthy_peers"` // Minimum healthy bootstrap peers + JoinSuccessRateThreshold float64 `json:"join_success_rate_threshold"` // Minimum join success rate (0.0-1.0) +} + +// HealthStatus represents the current health status across all gates +type HealthStatus struct { + Healthy bool `json:"healthy"` + Timestamp time.Time `json:"timestamp"` + Gates map[string]GateStatus `json:"gates"` + OverallReason string `json:"overall_reason,omitempty"` +} + +// GateStatus represents the status of an individual health gate +type GateStatus struct { + Name string `json:"name"` + Healthy bool `json:"healthy"` + Reason string `json:"reason,omitempty"` + Metrics map[string]interface{} `json:"metrics,omitempty"` + LastChecked time.Time `json:"last_checked"` +} + +// KachingHealth represents KACHING health metrics +type KachingHealth struct { + Healthy bool `json:"healthy"` + LatencyP95MS float64 `json:"latency_p95_ms"` + QueueDepth int `json:"queue_depth"` + RateLimitRemaining int `json:"rate_limit_remaining"` + ActiveLeases int `json:"active_leases"` + ClusterCapacity int `json:"cluster_capacity"` +} + +// BackbeatHealth represents BACKBEAT health metrics +type BackbeatHealth struct { + Healthy bool `json:"healthy"` + SubjectLags map[string]int `json:"subject_lags"` + MaxLagSeconds int `json:"max_lag_seconds"` + ConsumerHealth map[string]bool `json:"consumer_health"` +} + +// BootstrapHealth represents bootstrap peer pool health +type BootstrapHealth struct { + Healthy bool `json:"healthy"` + TotalPeers int `json:"total_peers"` + HealthyPeers int `json:"healthy_peers"` + ReachablePeers int `json:"reachable_peers"` +} + +// ScalingMetrics represents recent scaling operation metrics +type ScalingMetrics struct { + LastWaveSize int `json:"last_wave_size"` + LastWaveStarted time.Time `json:"last_wave_started"` + LastWaveCompleted time.Time `json:"last_wave_completed"` + JoinSuccessRate float64 `json:"join_success_rate"` + SuccessfulJoins int `json:"successful_joins"` + FailedJoins int `json:"failed_joins"` +} + +// NewHealthGates creates a new health gates manager +func NewHealthGates(kachingURL, backbeatURL, chorusURL string) *HealthGates { + return &HealthGates{ + kachingURL: kachingURL, + backbeatURL: backbeatURL, + chorusURL: chorusURL, + httpClient: &http.Client{Timeout: 10 * time.Second}, + thresholds: HealthThresholds{ + KachingMaxLatencyMS: 500, // 500ms max latency + KachingMinRateRemaining: 20, // At least 20 requests remaining + BackbeatMaxLagSeconds: 30, // Max 30 seconds lag + BootstrapMinHealthyPeers: 3, // At least 3 healthy bootstrap peers + JoinSuccessRateThreshold: 0.8, // 80% join success rate + }, + } +} + +// SetThresholds updates the health thresholds +func (hg *HealthGates) SetThresholds(thresholds HealthThresholds) { + hg.thresholds = thresholds +} + +// CheckHealth checks all health gates and returns overall status +func (hg *HealthGates) CheckHealth(ctx context.Context, recentMetrics *ScalingMetrics) (*HealthStatus, error) { + ctx, span := tracing.Tracer.Start(ctx, "health_gates.check_health") + defer span.End() + + status := &HealthStatus{ + Timestamp: time.Now(), + Gates: make(map[string]GateStatus), + Healthy: true, + } + + var failReasons []string + + // Check KACHING health + if kachingStatus, err := hg.checkKachingHealth(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to check KACHING health") + status.Gates["kaching"] = GateStatus{ + Name: "kaching", + Healthy: false, + Reason: fmt.Sprintf("Health check failed: %v", err), + LastChecked: time.Now(), + } + status.Healthy = false + failReasons = append(failReasons, "KACHING unreachable") + } else { + status.Gates["kaching"] = *kachingStatus + if !kachingStatus.Healthy { + status.Healthy = false + failReasons = append(failReasons, kachingStatus.Reason) + } + } + + // Check BACKBEAT health + if backbeatStatus, err := hg.checkBackbeatHealth(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to check BACKBEAT health") + status.Gates["backbeat"] = GateStatus{ + Name: "backbeat", + Healthy: false, + Reason: fmt.Sprintf("Health check failed: %v", err), + LastChecked: time.Now(), + } + status.Healthy = false + failReasons = append(failReasons, "BACKBEAT unreachable") + } else { + status.Gates["backbeat"] = *backbeatStatus + if !backbeatStatus.Healthy { + status.Healthy = false + failReasons = append(failReasons, backbeatStatus.Reason) + } + } + + // Check bootstrap peer health + if bootstrapStatus, err := hg.checkBootstrapHealth(ctx); err != nil { + log.Warn().Err(err).Msg("Failed to check bootstrap health") + status.Gates["bootstrap"] = GateStatus{ + Name: "bootstrap", + Healthy: false, + Reason: fmt.Sprintf("Health check failed: %v", err), + LastChecked: time.Now(), + } + status.Healthy = false + failReasons = append(failReasons, "Bootstrap peers unreachable") + } else { + status.Gates["bootstrap"] = *bootstrapStatus + if !bootstrapStatus.Healthy { + status.Healthy = false + failReasons = append(failReasons, bootstrapStatus.Reason) + } + } + + // Check recent scaling metrics if provided + if recentMetrics != nil { + if metricsStatus := hg.checkScalingMetrics(recentMetrics); !metricsStatus.Healthy { + status.Gates["scaling_metrics"] = *metricsStatus + status.Healthy = false + failReasons = append(failReasons, metricsStatus.Reason) + } else { + status.Gates["scaling_metrics"] = *metricsStatus + } + } + + // Set overall reason if unhealthy + if !status.Healthy && len(failReasons) > 0 { + status.OverallReason = fmt.Sprintf("Health gates failed: %v", failReasons) + } + + // Add tracing attributes + span.SetAttributes( + attribute.Bool("health.overall_healthy", status.Healthy), + attribute.Int("health.gate_count", len(status.Gates)), + ) + + if !status.Healthy { + span.SetAttributes(attribute.String("health.fail_reason", status.OverallReason)) + } + + return status, nil +} + +// checkKachingHealth checks KACHING health and rate limits +func (hg *HealthGates) checkKachingHealth(ctx context.Context) (*GateStatus, error) { + url := fmt.Sprintf("%s/health/burst", hg.kachingURL) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create KACHING health request: %w", err) + } + + resp, err := hg.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("KACHING health request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("KACHING health check returned status %d", resp.StatusCode) + } + + var health KachingHealth + if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { + return nil, fmt.Errorf("failed to decode KACHING health response: %w", err) + } + + status := &GateStatus{ + Name: "kaching", + LastChecked: time.Now(), + Metrics: map[string]interface{}{ + "latency_p95_ms": health.LatencyP95MS, + "queue_depth": health.QueueDepth, + "rate_limit_remaining": health.RateLimitRemaining, + "active_leases": health.ActiveLeases, + "cluster_capacity": health.ClusterCapacity, + }, + } + + // Check latency threshold + if health.LatencyP95MS > float64(hg.thresholds.KachingMaxLatencyMS) { + status.Healthy = false + status.Reason = fmt.Sprintf("KACHING latency too high: %.1fms > %dms", + health.LatencyP95MS, hg.thresholds.KachingMaxLatencyMS) + return status, nil + } + + // Check rate limit threshold + if health.RateLimitRemaining < hg.thresholds.KachingMinRateRemaining { + status.Healthy = false + status.Reason = fmt.Sprintf("KACHING rate limit too low: %d < %d remaining", + health.RateLimitRemaining, hg.thresholds.KachingMinRateRemaining) + return status, nil + } + + // Check overall KACHING health + if !health.Healthy { + status.Healthy = false + status.Reason = "KACHING reports unhealthy status" + return status, nil + } + + status.Healthy = true + return status, nil +} + +// checkBackbeatHealth checks BACKBEAT subject lag and consumer health +func (hg *HealthGates) checkBackbeatHealth(ctx context.Context) (*GateStatus, error) { + url := fmt.Sprintf("%s/metrics", hg.backbeatURL) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create BACKBEAT health request: %w", err) + } + + resp, err := hg.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("BACKBEAT health request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("BACKBEAT health check returned status %d", resp.StatusCode) + } + + var health BackbeatHealth + if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { + return nil, fmt.Errorf("failed to decode BACKBEAT health response: %w", err) + } + + status := &GateStatus{ + Name: "backbeat", + LastChecked: time.Now(), + Metrics: map[string]interface{}{ + "subject_lags": health.SubjectLags, + "max_lag_seconds": health.MaxLagSeconds, + "consumer_health": health.ConsumerHealth, + }, + } + + // Check subject lag threshold + if health.MaxLagSeconds > hg.thresholds.BackbeatMaxLagSeconds { + status.Healthy = false + status.Reason = fmt.Sprintf("BACKBEAT lag too high: %ds > %ds", + health.MaxLagSeconds, hg.thresholds.BackbeatMaxLagSeconds) + return status, nil + } + + // Check overall BACKBEAT health + if !health.Healthy { + status.Healthy = false + status.Reason = "BACKBEAT reports unhealthy status" + return status, nil + } + + status.Healthy = true + return status, nil +} + +// checkBootstrapHealth checks bootstrap peer pool health +func (hg *HealthGates) checkBootstrapHealth(ctx context.Context) (*GateStatus, error) { + url := fmt.Sprintf("%s/peers", hg.chorusURL) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create bootstrap health request: %w", err) + } + + resp, err := hg.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("bootstrap health request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bootstrap health check returned status %d", resp.StatusCode) + } + + var health BootstrapHealth + if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { + return nil, fmt.Errorf("failed to decode bootstrap health response: %w", err) + } + + status := &GateStatus{ + Name: "bootstrap", + LastChecked: time.Now(), + Metrics: map[string]interface{}{ + "total_peers": health.TotalPeers, + "healthy_peers": health.HealthyPeers, + "reachable_peers": health.ReachablePeers, + }, + } + + // Check minimum healthy peers threshold + if health.HealthyPeers < hg.thresholds.BootstrapMinHealthyPeers { + status.Healthy = false + status.Reason = fmt.Sprintf("Not enough healthy bootstrap peers: %d < %d", + health.HealthyPeers, hg.thresholds.BootstrapMinHealthyPeers) + return status, nil + } + + status.Healthy = true + return status, nil +} + +// checkScalingMetrics checks recent scaling success rate +func (hg *HealthGates) checkScalingMetrics(metrics *ScalingMetrics) *GateStatus { + status := &GateStatus{ + Name: "scaling_metrics", + LastChecked: time.Now(), + Metrics: map[string]interface{}{ + "join_success_rate": metrics.JoinSuccessRate, + "successful_joins": metrics.SuccessfulJoins, + "failed_joins": metrics.FailedJoins, + "last_wave_size": metrics.LastWaveSize, + }, + } + + // Check join success rate threshold + if metrics.JoinSuccessRate < hg.thresholds.JoinSuccessRateThreshold { + status.Healthy = false + status.Reason = fmt.Sprintf("Join success rate too low: %.1f%% < %.1f%%", + metrics.JoinSuccessRate*100, hg.thresholds.JoinSuccessRateThreshold*100) + return status + } + + status.Healthy = true + return status +} + +// GetThresholds returns the current health thresholds +func (hg *HealthGates) GetThresholds() HealthThresholds { + return hg.thresholds +} + +// IsHealthy performs a quick health check and returns boolean result +func (hg *HealthGates) IsHealthy(ctx context.Context, recentMetrics *ScalingMetrics) bool { + status, err := hg.CheckHealth(ctx, recentMetrics) + if err != nil { + return false + } + return status.Healthy +} \ No newline at end of file diff --git a/internal/orchestrator/scaling_api.go b/internal/orchestrator/scaling_api.go new file mode 100644 index 0000000..a1590a4 --- /dev/null +++ b/internal/orchestrator/scaling_api.go @@ -0,0 +1,513 @@ +package orchestrator + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/gorilla/mux" + "github.com/rs/zerolog/log" + "go.opentelemetry.io/otel/attribute" + + "github.com/chorus-services/whoosh/internal/tracing" +) + +// ScalingAPI provides HTTP endpoints for scaling operations +type ScalingAPI struct { + controller *ScalingController + metrics *ScalingMetricsCollector +} + +// ScaleRequest represents a scaling request +type ScaleRequest struct { + ServiceName string `json:"service_name"` + TargetReplicas int `json:"target_replicas"` + WaveSize int `json:"wave_size,omitempty"` + Template string `json:"template,omitempty"` + Environment map[string]string `json:"environment,omitempty"` + ForceScale bool `json:"force_scale,omitempty"` +} + +// ScaleResponse represents a scaling response +type ScaleResponse struct { + WaveID string `json:"wave_id"` + ServiceName string `json:"service_name"` + TargetReplicas int `json:"target_replicas"` + CurrentReplicas int `json:"current_replicas"` + Status string `json:"status"` + StartedAt time.Time `json:"started_at"` + Message string `json:"message,omitempty"` +} + +// HealthResponse represents health check response +type HealthResponse struct { + Healthy bool `json:"healthy"` + Timestamp time.Time `json:"timestamp"` + Gates map[string]GateStatus `json:"gates"` + OverallReason string `json:"overall_reason,omitempty"` +} + +// NewScalingAPI creates a new scaling API instance +func NewScalingAPI(controller *ScalingController, metrics *ScalingMetricsCollector) *ScalingAPI { + return &ScalingAPI{ + controller: controller, + metrics: metrics, + } +} + +// RegisterRoutes registers HTTP routes for the scaling API +func (api *ScalingAPI) RegisterRoutes(router *mux.Router) { + // Scaling operations + router.HandleFunc("/api/v1/scale", api.ScaleService).Methods("POST") + router.HandleFunc("/api/v1/scale/status", api.GetScalingStatus).Methods("GET") + router.HandleFunc("/api/v1/scale/stop", api.StopScaling).Methods("POST") + + // Health gates + router.HandleFunc("/api/v1/health/gates", api.GetHealthGates).Methods("GET") + router.HandleFunc("/api/v1/health/thresholds", api.GetHealthThresholds).Methods("GET") + router.HandleFunc("/api/v1/health/thresholds", api.UpdateHealthThresholds).Methods("PUT") + + // Metrics and monitoring + router.HandleFunc("/api/v1/metrics/scaling", api.GetScalingMetrics).Methods("GET") + router.HandleFunc("/api/v1/metrics/operations", api.GetRecentOperations).Methods("GET") + router.HandleFunc("/api/v1/metrics/export", api.ExportMetrics).Methods("GET") + + // Service management + router.HandleFunc("/api/v1/services/{serviceName}/status", api.GetServiceStatus).Methods("GET") + router.HandleFunc("/api/v1/services/{serviceName}/replicas", api.GetServiceReplicas).Methods("GET") + + // Assignment management + router.HandleFunc("/api/v1/assignments/templates", api.GetAssignmentTemplates).Methods("GET") + router.HandleFunc("/api/v1/assignments", api.CreateAssignment).Methods("POST") + + // Bootstrap peer management + router.HandleFunc("/api/v1/bootstrap/peers", api.GetBootstrapPeers).Methods("GET") + router.HandleFunc("/api/v1/bootstrap/stats", api.GetBootstrapStats).Methods("GET") +} + +// ScaleService handles scaling requests +func (api *ScalingAPI) ScaleService(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "scaling_api.scale_service") + defer span.End() + + var req ScaleRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.writeError(w, http.StatusBadRequest, "Invalid request body", err) + return + } + + // Validate request + if req.ServiceName == "" { + api.writeError(w, http.StatusBadRequest, "Service name is required", nil) + return + } + if req.TargetReplicas < 0 { + api.writeError(w, http.StatusBadRequest, "Target replicas must be non-negative", nil) + return + } + + span.SetAttributes( + attribute.String("request.service_name", req.ServiceName), + attribute.Int("request.target_replicas", req.TargetReplicas), + attribute.Bool("request.force_scale", req.ForceScale), + ) + + // Get current replica count + currentReplicas, err := api.controller.swarmManager.GetServiceReplicas(ctx, req.ServiceName) + if err != nil { + api.writeError(w, http.StatusNotFound, "Service not found", err) + return + } + + // Check if scaling is needed + if currentReplicas == req.TargetReplicas && !req.ForceScale { + response := ScaleResponse{ + ServiceName: req.ServiceName, + TargetReplicas: req.TargetReplicas, + CurrentReplicas: currentReplicas, + Status: "no_action_needed", + StartedAt: time.Now(), + Message: "Service already at target replica count", + } + api.writeJSON(w, http.StatusOK, response) + return + } + + // Determine scaling direction and wave size + var waveSize int + if req.WaveSize > 0 { + waveSize = req.WaveSize + } else { + // Default wave size based on scaling direction + if req.TargetReplicas > currentReplicas { + waveSize = 3 // Scale up in smaller waves + } else { + waveSize = 5 // Scale down in larger waves + } + } + + // Start scaling operation + waveID, err := api.controller.StartScaling(ctx, req.ServiceName, req.TargetReplicas, waveSize, req.Template) + if err != nil { + api.writeError(w, http.StatusInternalServerError, "Failed to start scaling", err) + return + } + + response := ScaleResponse{ + WaveID: waveID, + ServiceName: req.ServiceName, + TargetReplicas: req.TargetReplicas, + CurrentReplicas: currentReplicas, + Status: "scaling_started", + StartedAt: time.Now(), + Message: fmt.Sprintf("Started scaling %s from %d to %d replicas", req.ServiceName, currentReplicas, req.TargetReplicas), + } + + log.Info(). + Str("wave_id", waveID). + Str("service_name", req.ServiceName). + Int("current_replicas", currentReplicas). + Int("target_replicas", req.TargetReplicas). + Int("wave_size", waveSize). + Msg("Started scaling operation via API") + + api.writeJSON(w, http.StatusAccepted, response) +} + +// GetScalingStatus returns the current scaling status +func (api *ScalingAPI) GetScalingStatus(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_scaling_status") + defer span.End() + + currentWave := api.metrics.GetCurrentWave() + if currentWave == nil { + api.writeJSON(w, http.StatusOK, map[string]interface{}{ + "status": "idle", + "message": "No scaling operation in progress", + }) + return + } + + // Calculate progress + progress := float64(currentWave.CurrentReplicas) / float64(currentWave.TargetReplicas) * 100 + if progress > 100 { + progress = 100 + } + + response := map[string]interface{}{ + "status": "scaling", + "wave_id": currentWave.WaveID, + "service_name": currentWave.ServiceName, + "started_at": currentWave.StartedAt, + "target_replicas": currentWave.TargetReplicas, + "current_replicas": currentWave.CurrentReplicas, + "progress_percent": progress, + "join_attempts": len(currentWave.JoinAttempts), + "health_checks": len(currentWave.HealthChecks), + "backoff_level": currentWave.BackoffLevel, + "duration": time.Since(currentWave.StartedAt).String(), + } + + api.writeJSON(w, http.StatusOK, response) +} + +// StopScaling stops the current scaling operation +func (api *ScalingAPI) StopScaling(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "scaling_api.stop_scaling") + defer span.End() + + currentWave := api.metrics.GetCurrentWave() + if currentWave == nil { + api.writeError(w, http.StatusBadRequest, "No scaling operation in progress", nil) + return + } + + // Stop the scaling operation + api.controller.StopScaling(ctx) + + response := map[string]interface{}{ + "status": "stopped", + "wave_id": currentWave.WaveID, + "message": "Scaling operation stopped", + "stopped_at": time.Now(), + } + + log.Info(). + Str("wave_id", currentWave.WaveID). + Str("service_name", currentWave.ServiceName). + Msg("Stopped scaling operation via API") + + api.writeJSON(w, http.StatusOK, response) +} + +// GetHealthGates returns the current health gate status +func (api *ScalingAPI) GetHealthGates(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_health_gates") + defer span.End() + + status, err := api.controller.healthGates.CheckHealth(ctx, nil) + if err != nil { + api.writeError(w, http.StatusInternalServerError, "Failed to check health gates", err) + return + } + + response := HealthResponse{ + Healthy: status.Healthy, + Timestamp: status.Timestamp, + Gates: status.Gates, + OverallReason: status.OverallReason, + } + + api.writeJSON(w, http.StatusOK, response) +} + +// GetHealthThresholds returns the current health thresholds +func (api *ScalingAPI) GetHealthThresholds(w http.ResponseWriter, r *http.Request) { + _, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_health_thresholds") + defer span.End() + + thresholds := api.controller.healthGates.GetThresholds() + api.writeJSON(w, http.StatusOK, thresholds) +} + +// UpdateHealthThresholds updates the health thresholds +func (api *ScalingAPI) UpdateHealthThresholds(w http.ResponseWriter, r *http.Request) { + _, span := tracing.Tracer.Start(r.Context(), "scaling_api.update_health_thresholds") + defer span.End() + + var thresholds HealthThresholds + if err := json.NewDecoder(r.Body).Decode(&thresholds); err != nil { + api.writeError(w, http.StatusBadRequest, "Invalid request body", err) + return + } + + api.controller.healthGates.SetThresholds(thresholds) + + log.Info(). + Interface("thresholds", thresholds). + Msg("Updated health thresholds via API") + + api.writeJSON(w, http.StatusOK, map[string]string{ + "status": "updated", + "message": "Health thresholds updated successfully", + }) +} + +// GetScalingMetrics returns scaling metrics for a time window +func (api *ScalingAPI) GetScalingMetrics(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_scaling_metrics") + defer span.End() + + // Parse query parameters for time window + windowStart, windowEnd := api.parseTimeWindow(r) + + report := api.metrics.GenerateReport(ctx, windowStart, windowEnd) + api.writeJSON(w, http.StatusOK, report) +} + +// GetRecentOperations returns recent scaling operations +func (api *ScalingAPI) GetRecentOperations(w http.ResponseWriter, r *http.Request) { + _, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_recent_operations") + defer span.End() + + // Parse limit parameter + limit := 50 // Default limit + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + operations := api.metrics.GetRecentOperations(limit) + api.writeJSON(w, http.StatusOK, map[string]interface{}{ + "operations": operations, + "count": len(operations), + }) +} + +// ExportMetrics exports all metrics data +func (api *ScalingAPI) ExportMetrics(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "scaling_api.export_metrics") + defer span.End() + + data, err := api.metrics.ExportMetrics(ctx) + if err != nil { + api.writeError(w, http.StatusInternalServerError, "Failed to export metrics", err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=scaling-metrics-%s.json", + time.Now().Format("2006-01-02-15-04-05"))) + w.Write(data) +} + +// GetServiceStatus returns detailed status for a specific service +func (api *ScalingAPI) GetServiceStatus(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_service_status") + defer span.End() + + vars := mux.Vars(r) + serviceName := vars["serviceName"] + + status, err := api.controller.swarmManager.GetServiceStatus(ctx, serviceName) + if err != nil { + api.writeError(w, http.StatusNotFound, "Service not found", err) + return + } + + span.SetAttributes(attribute.String("service.name", serviceName)) + api.writeJSON(w, http.StatusOK, status) +} + +// GetServiceReplicas returns the current replica count for a service +func (api *ScalingAPI) GetServiceReplicas(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_service_replicas") + defer span.End() + + vars := mux.Vars(r) + serviceName := vars["serviceName"] + + replicas, err := api.controller.swarmManager.GetServiceReplicas(ctx, serviceName) + if err != nil { + api.writeError(w, http.StatusNotFound, "Service not found", err) + return + } + + runningReplicas, err := api.controller.swarmManager.GetRunningReplicas(ctx, serviceName) + if err != nil { + log.Warn().Err(err).Str("service_name", serviceName).Msg("Failed to get running replica count") + runningReplicas = 0 + } + + response := map[string]interface{}{ + "service_name": serviceName, + "desired_replicas": replicas, + "running_replicas": runningReplicas, + "timestamp": time.Now(), + } + + span.SetAttributes( + attribute.String("service.name", serviceName), + attribute.Int("service.desired_replicas", replicas), + attribute.Int("service.running_replicas", runningReplicas), + ) + + api.writeJSON(w, http.StatusOK, response) +} + +// GetAssignmentTemplates returns available assignment templates +func (api *ScalingAPI) GetAssignmentTemplates(w http.ResponseWriter, r *http.Request) { + _, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_assignment_templates") + defer span.End() + + templates := api.controller.assignmentBroker.GetAvailableTemplates() + api.writeJSON(w, http.StatusOK, map[string]interface{}{ + "templates": templates, + "count": len(templates), + }) +} + +// CreateAssignment creates a new assignment +func (api *ScalingAPI) CreateAssignment(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.Tracer.Start(r.Context(), "scaling_api.create_assignment") + defer span.End() + + var req AssignmentRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.writeError(w, http.StatusBadRequest, "Invalid request body", err) + return + } + + assignment, err := api.controller.assignmentBroker.CreateAssignment(ctx, req) + if err != nil { + api.writeError(w, http.StatusBadRequest, "Failed to create assignment", err) + return + } + + span.SetAttributes( + attribute.String("assignment.id", assignment.ID), + attribute.String("assignment.template", req.Template), + ) + + api.writeJSON(w, http.StatusCreated, assignment) +} + +// GetBootstrapPeers returns available bootstrap peers +func (api *ScalingAPI) GetBootstrapPeers(w http.ResponseWriter, r *http.Request) { + _, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_bootstrap_peers") + defer span.End() + + peers := api.controller.bootstrapManager.GetAllPeers() + api.writeJSON(w, http.StatusOK, map[string]interface{}{ + "peers": peers, + "count": len(peers), + }) +} + +// GetBootstrapStats returns bootstrap pool statistics +func (api *ScalingAPI) GetBootstrapStats(w http.ResponseWriter, r *http.Request) { + _, span := tracing.Tracer.Start(r.Context(), "scaling_api.get_bootstrap_stats") + defer span.End() + + stats := api.controller.bootstrapManager.GetStats() + api.writeJSON(w, http.StatusOK, stats) +} + +// Helper functions + +// parseTimeWindow parses start and end time parameters from request +func (api *ScalingAPI) parseTimeWindow(r *http.Request) (time.Time, time.Time) { + now := time.Now() + + // Default to last 24 hours + windowEnd := now + windowStart := now.Add(-24 * time.Hour) + + // Parse custom window if provided + if startStr := r.URL.Query().Get("start"); startStr != "" { + if start, err := time.Parse(time.RFC3339, startStr); err == nil { + windowStart = start + } + } + + if endStr := r.URL.Query().Get("end"); endStr != "" { + if end, err := time.Parse(time.RFC3339, endStr); err == nil { + windowEnd = end + } + } + + // Parse duration if provided (overrides start) + if durationStr := r.URL.Query().Get("duration"); durationStr != "" { + if duration, err := time.ParseDuration(durationStr); err == nil { + windowStart = windowEnd.Add(-duration) + } + } + + return windowStart, windowEnd +} + +// writeJSON writes a JSON response +func (api *ScalingAPI) writeJSON(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(data) +} + +// writeError writes an error response +func (api *ScalingAPI) writeError(w http.ResponseWriter, status int, message string, err error) { + response := map[string]interface{}{ + "error": message, + "timestamp": time.Now(), + } + + if err != nil { + response["details"] = err.Error() + log.Error().Err(err).Str("error_message", message).Msg("API error") + } + + api.writeJSON(w, status, response) +} \ No newline at end of file diff --git a/internal/orchestrator/scaling_controller.go b/internal/orchestrator/scaling_controller.go new file mode 100644 index 0000000..d69925f --- /dev/null +++ b/internal/orchestrator/scaling_controller.go @@ -0,0 +1,640 @@ +package orchestrator + +import ( + "context" + "fmt" + "math" + "sync" + "time" + + "github.com/rs/zerolog/log" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/chorus-services/whoosh/internal/tracing" +) + +// ScalingController manages wave-based scaling operations for CHORUS services +type ScalingController struct { + mu sync.RWMutex + swarmManager *SwarmManager + healthGates *HealthGates + assignmentBroker *AssignmentBroker + bootstrapManager *BootstrapPoolManager + metricsCollector *ScalingMetricsCollector + + // Scaling configuration + config ScalingConfig + + // Current scaling state + currentOperations map[string]*ScalingOperation + scalingActive bool + stopChan chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// ScalingConfig defines configuration for scaling operations +type ScalingConfig struct { + MinWaveSize int `json:"min_wave_size"` // Minimum replicas per wave + MaxWaveSize int `json:"max_wave_size"` // Maximum replicas per wave + WaveInterval time.Duration `json:"wave_interval"` // Time between waves + MaxConcurrentOps int `json:"max_concurrent_ops"` // Maximum concurrent scaling operations + + // Backoff configuration + InitialBackoff time.Duration `json:"initial_backoff"` // Initial backoff delay + MaxBackoff time.Duration `json:"max_backoff"` // Maximum backoff delay + BackoffMultiplier float64 `json:"backoff_multiplier"` // Backoff multiplier + JitterPercentage float64 `json:"jitter_percentage"` // Jitter percentage (0.0-1.0) + + // Health gate configuration + HealthCheckTimeout time.Duration `json:"health_check_timeout"` // Timeout for health checks + MinJoinSuccessRate float64 `json:"min_join_success_rate"` // Minimum join success rate + SuccessRateWindow int `json:"success_rate_window"` // Window size for success rate calculation +} + +// ScalingOperation represents an ongoing scaling operation +type ScalingOperation struct { + ID string `json:"id"` + ServiceName string `json:"service_name"` + CurrentReplicas int `json:"current_replicas"` + TargetReplicas int `json:"target_replicas"` + + // Wave state + CurrentWave int `json:"current_wave"` + WavesCompleted int `json:"waves_completed"` + WaveSize int `json:"wave_size"` + + // Timing + StartedAt time.Time `json:"started_at"` + LastWaveAt time.Time `json:"last_wave_at,omitempty"` + EstimatedCompletion time.Time `json:"estimated_completion,omitempty"` + + // Backoff state + ConsecutiveFailures int `json:"consecutive_failures"` + NextWaveAt time.Time `json:"next_wave_at,omitempty"` + BackoffDelay time.Duration `json:"backoff_delay"` + + // Status + Status ScalingStatus `json:"status"` + LastError string `json:"last_error,omitempty"` + + // Configuration + Template string `json:"template"` + ScalingParams map[string]interface{} `json:"scaling_params,omitempty"` +} + +// ScalingStatus represents the status of a scaling operation +type ScalingStatus string + +const ( + ScalingStatusPending ScalingStatus = "pending" + ScalingStatusRunning ScalingStatus = "running" + ScalingStatusWaiting ScalingStatus = "waiting" // Waiting for health gates + ScalingStatusBackoff ScalingStatus = "backoff" // In backoff period + ScalingStatusCompleted ScalingStatus = "completed" + ScalingStatusFailed ScalingStatus = "failed" + ScalingStatusCancelled ScalingStatus = "cancelled" +) + +// ScalingRequest represents a request to scale a service +type ScalingRequest struct { + ServiceName string `json:"service_name"` + TargetReplicas int `json:"target_replicas"` + Template string `json:"template,omitempty"` + ScalingParams map[string]interface{} `json:"scaling_params,omitempty"` + Force bool `json:"force,omitempty"` // Skip health gates +} + +// WaveResult represents the result of a scaling wave +type WaveResult struct { + WaveNumber int `json:"wave_number"` + RequestedCount int `json:"requested_count"` + SuccessfulJoins int `json:"successful_joins"` + FailedJoins int `json:"failed_joins"` + Duration time.Duration `json:"duration"` + CompletedAt time.Time `json:"completed_at"` +} + +// NewScalingController creates a new scaling controller +func NewScalingController( + swarmManager *SwarmManager, + healthGates *HealthGates, + assignmentBroker *AssignmentBroker, + bootstrapManager *BootstrapPoolManager, + metricsCollector *ScalingMetricsCollector, +) *ScalingController { + ctx, cancel := context.WithCancel(context.Background()) + + return &ScalingController{ + swarmManager: swarmManager, + healthGates: healthGates, + assignmentBroker: assignmentBroker, + bootstrapManager: bootstrapManager, + metricsCollector: metricsCollector, + config: ScalingConfig{ + MinWaveSize: 3, + MaxWaveSize: 8, + WaveInterval: 30 * time.Second, + MaxConcurrentOps: 3, + InitialBackoff: 30 * time.Second, + MaxBackoff: 2 * time.Minute, + BackoffMultiplier: 1.5, + JitterPercentage: 0.2, + HealthCheckTimeout: 10 * time.Second, + MinJoinSuccessRate: 0.8, + SuccessRateWindow: 10, + }, + currentOperations: make(map[string]*ScalingOperation), + stopChan: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + } +} + +// StartScaling initiates a scaling operation and returns the wave ID +func (sc *ScalingController) StartScaling(ctx context.Context, serviceName string, targetReplicas, waveSize int, template string) (string, error) { + request := ScalingRequest{ + ServiceName: serviceName, + TargetReplicas: targetReplicas, + Template: template, + } + + operation, err := sc.startScalingOperation(ctx, request) + if err != nil { + return "", err + } + + return operation.ID, nil +} + +// startScalingOperation initiates a scaling operation +func (sc *ScalingController) startScalingOperation(ctx context.Context, request ScalingRequest) (*ScalingOperation, error) { + ctx, span := tracing.Tracer.Start(ctx, "scaling_controller.start_scaling") + defer span.End() + + sc.mu.Lock() + defer sc.mu.Unlock() + + // Check if there's already an operation for this service + if existingOp, exists := sc.currentOperations[request.ServiceName]; exists { + if existingOp.Status == ScalingStatusRunning || existingOp.Status == ScalingStatusWaiting { + return nil, fmt.Errorf("scaling operation already in progress for service %s", request.ServiceName) + } + } + + // Check concurrent operation limit + runningOps := 0 + for _, op := range sc.currentOperations { + if op.Status == ScalingStatusRunning || op.Status == ScalingStatusWaiting { + runningOps++ + } + } + + if runningOps >= sc.config.MaxConcurrentOps { + return nil, fmt.Errorf("maximum concurrent scaling operations (%d) reached", sc.config.MaxConcurrentOps) + } + + // Get current replica count + currentReplicas, err := sc.swarmManager.GetServiceReplicas(ctx, request.ServiceName) + if err != nil { + return nil, fmt.Errorf("failed to get current replica count: %w", err) + } + + // Calculate wave size + waveSize := sc.calculateWaveSize(currentReplicas, request.TargetReplicas) + + // Create scaling operation + operation := &ScalingOperation{ + ID: fmt.Sprintf("scale-%s-%d", request.ServiceName, time.Now().Unix()), + ServiceName: request.ServiceName, + CurrentReplicas: currentReplicas, + TargetReplicas: request.TargetReplicas, + CurrentWave: 1, + WaveSize: waveSize, + StartedAt: time.Now(), + Status: ScalingStatusPending, + Template: request.Template, + ScalingParams: request.ScalingParams, + BackoffDelay: sc.config.InitialBackoff, + } + + // Store operation + sc.currentOperations[request.ServiceName] = operation + + // Start metrics tracking + if sc.metricsCollector != nil { + sc.metricsCollector.StartWave(ctx, operation.ID, operation.ServiceName, operation.TargetReplicas) + } + + // Start scaling process in background + go sc.executeScaling(context.Background(), operation, request.Force) + + span.SetAttributes( + attribute.String("scaling.service_name", request.ServiceName), + attribute.Int("scaling.current_replicas", currentReplicas), + attribute.Int("scaling.target_replicas", request.TargetReplicas), + attribute.Int("scaling.wave_size", waveSize), + attribute.String("scaling.operation_id", operation.ID), + ) + + log.Info(). + Str("operation_id", operation.ID). + Str("service_name", request.ServiceName). + Int("current_replicas", currentReplicas). + Int("target_replicas", request.TargetReplicas). + Int("wave_size", waveSize). + Msg("Started scaling operation") + + return operation, nil +} + +// executeScaling executes the scaling operation with wave-based approach +func (sc *ScalingController) executeScaling(ctx context.Context, operation *ScalingOperation, force bool) { + ctx, span := tracing.Tracer.Start(ctx, "scaling_controller.execute_scaling") + defer span.End() + + defer func() { + sc.mu.Lock() + // Keep completed operations for a while for monitoring + if operation.Status == ScalingStatusCompleted || operation.Status == ScalingStatusFailed { + // Clean up after 1 hour + go func() { + time.Sleep(1 * time.Hour) + sc.mu.Lock() + delete(sc.currentOperations, operation.ServiceName) + sc.mu.Unlock() + }() + } + sc.mu.Unlock() + }() + + operation.Status = ScalingStatusRunning + + for operation.CurrentReplicas < operation.TargetReplicas { + // Check if we should wait for backoff + if !operation.NextWaveAt.IsZero() && time.Now().Before(operation.NextWaveAt) { + operation.Status = ScalingStatusBackoff + waitTime := time.Until(operation.NextWaveAt) + log.Info(). + Str("operation_id", operation.ID). + Dur("wait_time", waitTime). + Msg("Waiting for backoff period") + + select { + case <-ctx.Done(): + operation.Status = ScalingStatusCancelled + return + case <-time.After(waitTime): + // Continue after backoff + } + } + + operation.Status = ScalingStatusRunning + + // Check health gates (unless forced) + if !force { + if err := sc.waitForHealthGates(ctx, operation); err != nil { + operation.LastError = err.Error() + operation.ConsecutiveFailures++ + sc.applyBackoff(operation) + continue + } + } + + // Execute scaling wave + waveResult, err := sc.executeWave(ctx, operation) + if err != nil { + log.Error(). + Str("operation_id", operation.ID). + Err(err). + Msg("Scaling wave failed") + + operation.LastError = err.Error() + operation.ConsecutiveFailures++ + sc.applyBackoff(operation) + continue + } + + // Update operation state + operation.CurrentReplicas += waveResult.SuccessfulJoins + operation.WavesCompleted++ + operation.LastWaveAt = time.Now() + operation.ConsecutiveFailures = 0 // Reset on success + operation.NextWaveAt = time.Time{} // Clear backoff + + // Update scaling metrics + sc.updateScalingMetrics(operation.ServiceName, waveResult) + + log.Info(). + Str("operation_id", operation.ID). + Int("wave", operation.CurrentWave). + Int("successful_joins", waveResult.SuccessfulJoins). + Int("failed_joins", waveResult.FailedJoins). + Int("current_replicas", operation.CurrentReplicas). + Int("target_replicas", operation.TargetReplicas). + Msg("Scaling wave completed") + + // Move to next wave + operation.CurrentWave++ + + // Wait between waves + if operation.CurrentReplicas < operation.TargetReplicas { + select { + case <-ctx.Done(): + operation.Status = ScalingStatusCancelled + return + case <-time.After(sc.config.WaveInterval): + // Continue to next wave + } + } + } + + // Scaling completed successfully + operation.Status = ScalingStatusCompleted + operation.EstimatedCompletion = time.Now() + + log.Info(). + Str("operation_id", operation.ID). + Str("service_name", operation.ServiceName). + Int("final_replicas", operation.CurrentReplicas). + Int("waves_completed", operation.WavesCompleted). + Dur("total_duration", time.Since(operation.StartedAt)). + Msg("Scaling operation completed successfully") +} + +// waitForHealthGates waits for health gates to be satisfied +func (sc *ScalingController) waitForHealthGates(ctx context.Context, operation *ScalingOperation) error { + operation.Status = ScalingStatusWaiting + + ctx, cancel := context.WithTimeout(ctx, sc.config.HealthCheckTimeout) + defer cancel() + + // Get recent scaling metrics for this service + var recentMetrics *ScalingMetrics + if metrics, exists := sc.scalingMetrics[operation.ServiceName]; exists { + recentMetrics = metrics + } + + healthStatus, err := sc.healthGates.CheckHealth(ctx, recentMetrics) + if err != nil { + return fmt.Errorf("health gate check failed: %w", err) + } + + if !healthStatus.Healthy { + return fmt.Errorf("health gates not satisfied: %s", healthStatus.OverallReason) + } + + return nil +} + +// executeWave executes a single scaling wave +func (sc *ScalingController) executeWave(ctx context.Context, operation *ScalingOperation) (*WaveResult, error) { + startTime := time.Now() + + // Calculate how many replicas to add in this wave + remaining := operation.TargetReplicas - operation.CurrentReplicas + waveSize := operation.WaveSize + if remaining < waveSize { + waveSize = remaining + } + + // Create assignments for new replicas + var assignments []*Assignment + for i := 0; i < waveSize; i++ { + assignReq := AssignmentRequest{ + ClusterID: "production", // TODO: Make configurable + Template: operation.Template, + } + + assignment, err := sc.assignmentBroker.CreateAssignment(ctx, assignReq) + if err != nil { + return nil, fmt.Errorf("failed to create assignment: %w", err) + } + + assignments = append(assignments, assignment) + } + + // Deploy new replicas + newReplicaCount := operation.CurrentReplicas + waveSize + err := sc.swarmManager.ScaleService(ctx, operation.ServiceName, newReplicaCount) + if err != nil { + return nil, fmt.Errorf("failed to scale service: %w", err) + } + + // Wait for replicas to come online and join successfully + successfulJoins, failedJoins := sc.waitForReplicaJoins(ctx, operation.ServiceName, waveSize) + + result := &WaveResult{ + WaveNumber: operation.CurrentWave, + RequestedCount: waveSize, + SuccessfulJoins: successfulJoins, + FailedJoins: failedJoins, + Duration: time.Since(startTime), + CompletedAt: time.Now(), + } + + return result, nil +} + +// waitForReplicaJoins waits for new replicas to join the cluster +func (sc *ScalingController) waitForReplicaJoins(ctx context.Context, serviceName string, expectedJoins int) (successful, failed int) { + // Wait up to 2 minutes for replicas to join + ctx, cancel := context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + startTime := time.Now() + + for { + select { + case <-ctx.Done(): + // Timeout reached, return current counts + return successful, expectedJoins - successful + case <-ticker.C: + // Check service status + running, err := sc.swarmManager.GetRunningReplicas(ctx, serviceName) + if err != nil { + log.Warn().Err(err).Msg("Failed to get running replicas") + continue + } + + // For now, assume all running replicas are successful joins + // In a real implementation, this would check P2P network membership + if running >= expectedJoins { + successful = expectedJoins + failed = 0 + return + } + + // If we've been waiting too long with no progress, consider some failed + if time.Since(startTime) > 90*time.Second { + successful = running + failed = expectedJoins - running + return + } + } + } +} + +// calculateWaveSize calculates the appropriate wave size for scaling +func (sc *ScalingController) calculateWaveSize(current, target int) int { + totalNodes := 10 // TODO: Get actual node count from swarm + + // Wave size formula: min(max(3, floor(total_nodes/10)), 8) + waveSize := int(math.Max(3, math.Floor(float64(totalNodes)/10))) + if waveSize > sc.config.MaxWaveSize { + waveSize = sc.config.MaxWaveSize + } + + // Don't exceed remaining replicas needed + remaining := target - current + if waveSize > remaining { + waveSize = remaining + } + + return waveSize +} + +// applyBackoff applies exponential backoff to the operation +func (sc *ScalingController) applyBackoff(operation *ScalingOperation) { + // Calculate backoff delay with exponential increase + backoff := time.Duration(float64(operation.BackoffDelay) * math.Pow(sc.config.BackoffMultiplier, float64(operation.ConsecutiveFailures-1))) + + // Cap at maximum backoff + if backoff > sc.config.MaxBackoff { + backoff = sc.config.MaxBackoff + } + + // Add jitter + jitter := time.Duration(float64(backoff) * sc.config.JitterPercentage * (rand.Float64() - 0.5)) + backoff += jitter + + operation.BackoffDelay = backoff + operation.NextWaveAt = time.Now().Add(backoff) + + log.Warn(). + Str("operation_id", operation.ID). + Int("consecutive_failures", operation.ConsecutiveFailures). + Dur("backoff_delay", backoff). + Time("next_wave_at", operation.NextWaveAt). + Msg("Applied exponential backoff") +} + +// updateScalingMetrics updates scaling metrics for success rate tracking +func (sc *ScalingController) updateScalingMetrics(serviceName string, result *WaveResult) { + sc.mu.Lock() + defer sc.mu.Unlock() + + metrics, exists := sc.scalingMetrics[serviceName] + if !exists { + metrics = &ScalingMetrics{ + LastWaveSize: result.RequestedCount, + LastWaveStarted: result.CompletedAt.Add(-result.Duration), + LastWaveCompleted: result.CompletedAt, + } + sc.scalingMetrics[serviceName] = metrics + } + + // Update metrics + metrics.LastWaveSize = result.RequestedCount + metrics.LastWaveCompleted = result.CompletedAt + metrics.SuccessfulJoins += result.SuccessfulJoins + metrics.FailedJoins += result.FailedJoins + + // Calculate success rate + total := metrics.SuccessfulJoins + metrics.FailedJoins + if total > 0 { + metrics.JoinSuccessRate = float64(metrics.SuccessfulJoins) / float64(total) + } +} + +// GetOperation returns a scaling operation by service name +func (sc *ScalingController) GetOperation(serviceName string) (*ScalingOperation, bool) { + sc.mu.RLock() + defer sc.mu.RUnlock() + + op, exists := sc.currentOperations[serviceName] + return op, exists +} + +// GetAllOperations returns all current scaling operations +func (sc *ScalingController) GetAllOperations() map[string]*ScalingOperation { + sc.mu.RLock() + defer sc.mu.RUnlock() + + operations := make(map[string]*ScalingOperation) + for k, v := range sc.currentOperations { + operations[k] = v + } + return operations +} + +// CancelOperation cancels a scaling operation +func (sc *ScalingController) CancelOperation(serviceName string) error { + sc.mu.Lock() + defer sc.mu.Unlock() + + operation, exists := sc.currentOperations[serviceName] + if !exists { + return fmt.Errorf("no scaling operation found for service %s", serviceName) + } + + if operation.Status == ScalingStatusCompleted || operation.Status == ScalingStatusFailed { + return fmt.Errorf("scaling operation already finished") + } + + operation.Status = ScalingStatusCancelled + log.Info().Str("operation_id", operation.ID).Msg("Scaling operation cancelled") + + // Complete metrics tracking + if sc.metricsCollector != nil { + currentReplicas, _ := sc.swarmManager.GetServiceReplicas(context.Background(), serviceName) + sc.metricsCollector.CompleteWave(context.Background(), false, currentReplicas, "Operation cancelled", operation.ConsecutiveFailures) + } + + return nil +} + +// StopScaling stops all active scaling operations +func (sc *ScalingController) StopScaling(ctx context.Context) { + ctx, span := tracing.Tracer.Start(ctx, "scaling_controller.stop_scaling") + defer span.End() + + sc.mu.Lock() + defer sc.mu.Unlock() + + cancelledCount := 0 + for serviceName, operation := range sc.currentOperations { + if operation.Status == ScalingStatusRunning || operation.Status == ScalingStatusWaiting || operation.Status == ScalingStatusBackoff { + operation.Status = ScalingStatusCancelled + cancelledCount++ + + // Complete metrics tracking for cancelled operations + if sc.metricsCollector != nil { + currentReplicas, _ := sc.swarmManager.GetServiceReplicas(ctx, serviceName) + sc.metricsCollector.CompleteWave(ctx, false, currentReplicas, "Scaling stopped", operation.ConsecutiveFailures) + } + + log.Info().Str("operation_id", operation.ID).Str("service_name", serviceName).Msg("Scaling operation stopped") + } + } + + // Signal stop to running operations + select { + case sc.stopChan <- struct{}{}: + default: + } + + span.SetAttributes(attribute.Int("stopped_operations", cancelledCount)) + log.Info().Int("cancelled_operations", cancelledCount).Msg("Stopped all scaling operations") +} + +// Close shuts down the scaling controller +func (sc *ScalingController) Close() error { + sc.cancel() + sc.StopScaling(sc.ctx) + return nil +} \ No newline at end of file diff --git a/internal/orchestrator/scaling_metrics.go b/internal/orchestrator/scaling_metrics.go new file mode 100644 index 0000000..2747a0e --- /dev/null +++ b/internal/orchestrator/scaling_metrics.go @@ -0,0 +1,454 @@ +package orchestrator + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/rs/zerolog/log" + "go.opentelemetry.io/otel/attribute" + + "github.com/chorus-services/whoosh/internal/tracing" +) + +// ScalingMetricsCollector collects and manages scaling operation metrics +type ScalingMetricsCollector struct { + mu sync.RWMutex + operations []ScalingOperation + maxHistory int + currentWave *WaveMetrics +} + +// ScalingOperation represents a completed scaling operation +type ScalingOperation struct { + ID string `json:"id"` + ServiceName string `json:"service_name"` + WaveNumber int `json:"wave_number"` + StartedAt time.Time `json:"started_at"` + CompletedAt time.Time `json:"completed_at"` + Duration time.Duration `json:"duration"` + TargetReplicas int `json:"target_replicas"` + AchievedReplicas int `json:"achieved_replicas"` + Success bool `json:"success"` + FailureReason string `json:"failure_reason,omitempty"` + JoinAttempts []JoinAttempt `json:"join_attempts"` + HealthGateResults map[string]bool `json:"health_gate_results"` + BackoffLevel int `json:"backoff_level"` +} + +// JoinAttempt represents an individual replica join attempt +type JoinAttempt struct { + ReplicaID string `json:"replica_id"` + AttemptedAt time.Time `json:"attempted_at"` + CompletedAt time.Time `json:"completed_at,omitempty"` + Duration time.Duration `json:"duration"` + Success bool `json:"success"` + FailureReason string `json:"failure_reason,omitempty"` + BootstrapPeers []string `json:"bootstrap_peers"` +} + +// WaveMetrics tracks metrics for the currently executing wave +type WaveMetrics struct { + WaveID string `json:"wave_id"` + ServiceName string `json:"service_name"` + StartedAt time.Time `json:"started_at"` + TargetReplicas int `json:"target_replicas"` + CurrentReplicas int `json:"current_replicas"` + JoinAttempts []JoinAttempt `json:"join_attempts"` + HealthChecks []HealthCheckResult `json:"health_checks"` + BackoffLevel int `json:"backoff_level"` +} + +// HealthCheckResult represents a health gate check result +type HealthCheckResult struct { + Timestamp time.Time `json:"timestamp"` + GateName string `json:"gate_name"` + Healthy bool `json:"healthy"` + Reason string `json:"reason,omitempty"` + Metrics map[string]interface{} `json:"metrics,omitempty"` + CheckDuration time.Duration `json:"check_duration"` +} + +// ScalingMetricsReport provides aggregated metrics for reporting +type ScalingMetricsReport struct { + WindowStart time.Time `json:"window_start"` + WindowEnd time.Time `json:"window_end"` + TotalOperations int `json:"total_operations"` + SuccessfulOps int `json:"successful_operations"` + FailedOps int `json:"failed_operations"` + SuccessRate float64 `json:"success_rate"` + AverageWaveTime time.Duration `json:"average_wave_time"` + AverageJoinTime time.Duration `json:"average_join_time"` + BackoffEvents int `json:"backoff_events"` + HealthGateFailures map[string]int `json:"health_gate_failures"` + ServiceMetrics map[string]ServiceMetrics `json:"service_metrics"` + CurrentWave *WaveMetrics `json:"current_wave,omitempty"` +} + +// ServiceMetrics provides per-service scaling metrics +type ServiceMetrics struct { + ServiceName string `json:"service_name"` + TotalWaves int `json:"total_waves"` + SuccessfulWaves int `json:"successful_waves"` + AverageWaveTime time.Duration `json:"average_wave_time"` + LastScaled time.Time `json:"last_scaled"` + CurrentReplicas int `json:"current_replicas"` +} + +// NewScalingMetricsCollector creates a new metrics collector +func NewScalingMetricsCollector(maxHistory int) *ScalingMetricsCollector { + if maxHistory == 0 { + maxHistory = 1000 // Default to keeping 1000 operations + } + + return &ScalingMetricsCollector{ + operations: make([]ScalingOperation, 0), + maxHistory: maxHistory, + } +} + +// StartWave begins tracking a new scaling wave +func (smc *ScalingMetricsCollector) StartWave(ctx context.Context, waveID, serviceName string, targetReplicas int) { + ctx, span := tracing.Tracer.Start(ctx, "scaling_metrics.start_wave") + defer span.End() + + smc.mu.Lock() + defer smc.mu.Unlock() + + smc.currentWave = &WaveMetrics{ + WaveID: waveID, + ServiceName: serviceName, + StartedAt: time.Now(), + TargetReplicas: targetReplicas, + JoinAttempts: make([]JoinAttempt, 0), + HealthChecks: make([]HealthCheckResult, 0), + } + + span.SetAttributes( + attribute.String("wave.id", waveID), + attribute.String("wave.service", serviceName), + attribute.Int("wave.target_replicas", targetReplicas), + ) + + log.Info(). + Str("wave_id", waveID). + Str("service_name", serviceName). + Int("target_replicas", targetReplicas). + Msg("Started tracking scaling wave") +} + +// RecordJoinAttempt records a replica join attempt +func (smc *ScalingMetricsCollector) RecordJoinAttempt(replicaID string, bootstrapPeers []string, success bool, duration time.Duration, failureReason string) { + smc.mu.Lock() + defer smc.mu.Unlock() + + if smc.currentWave == nil { + log.Warn().Str("replica_id", replicaID).Msg("No active wave to record join attempt") + return + } + + attempt := JoinAttempt{ + ReplicaID: replicaID, + AttemptedAt: time.Now().Add(-duration), + CompletedAt: time.Now(), + Duration: duration, + Success: success, + FailureReason: failureReason, + BootstrapPeers: bootstrapPeers, + } + + smc.currentWave.JoinAttempts = append(smc.currentWave.JoinAttempts, attempt) + + log.Debug(). + Str("wave_id", smc.currentWave.WaveID). + Str("replica_id", replicaID). + Bool("success", success). + Dur("duration", duration). + Msg("Recorded join attempt") +} + +// RecordHealthCheck records a health gate check result +func (smc *ScalingMetricsCollector) RecordHealthCheck(gateName string, healthy bool, reason string, metrics map[string]interface{}, duration time.Duration) { + smc.mu.Lock() + defer smc.mu.Unlock() + + if smc.currentWave == nil { + log.Warn().Str("gate_name", gateName).Msg("No active wave to record health check") + return + } + + result := HealthCheckResult{ + Timestamp: time.Now(), + GateName: gateName, + Healthy: healthy, + Reason: reason, + Metrics: metrics, + CheckDuration: duration, + } + + smc.currentWave.HealthChecks = append(smc.currentWave.HealthChecks, result) + + log.Debug(). + Str("wave_id", smc.currentWave.WaveID). + Str("gate_name", gateName). + Bool("healthy", healthy). + Dur("duration", duration). + Msg("Recorded health check") +} + +// CompleteWave finishes tracking the current wave and archives it +func (smc *ScalingMetricsCollector) CompleteWave(ctx context.Context, success bool, achievedReplicas int, failureReason string, backoffLevel int) { + ctx, span := tracing.Tracer.Start(ctx, "scaling_metrics.complete_wave") + defer span.End() + + smc.mu.Lock() + defer smc.mu.Unlock() + + if smc.currentWave == nil { + log.Warn().Msg("No active wave to complete") + return + } + + now := time.Now() + operation := ScalingOperation{ + ID: smc.currentWave.WaveID, + ServiceName: smc.currentWave.ServiceName, + WaveNumber: len(smc.operations) + 1, + StartedAt: smc.currentWave.StartedAt, + CompletedAt: now, + Duration: now.Sub(smc.currentWave.StartedAt), + TargetReplicas: smc.currentWave.TargetReplicas, + AchievedReplicas: achievedReplicas, + Success: success, + FailureReason: failureReason, + JoinAttempts: smc.currentWave.JoinAttempts, + HealthGateResults: smc.extractHealthGateResults(), + BackoffLevel: backoffLevel, + } + + // Add to operations history + smc.operations = append(smc.operations, operation) + + // Trim history if needed + if len(smc.operations) > smc.maxHistory { + smc.operations = smc.operations[len(smc.operations)-smc.maxHistory:] + } + + span.SetAttributes( + attribute.String("wave.id", operation.ID), + attribute.String("wave.service", operation.ServiceName), + attribute.Bool("wave.success", success), + attribute.Int("wave.achieved_replicas", achievedReplicas), + attribute.Int("wave.backoff_level", backoffLevel), + attribute.String("wave.duration", operation.Duration.String()), + ) + + log.Info(). + Str("wave_id", operation.ID). + Str("service_name", operation.ServiceName). + Bool("success", success). + Int("achieved_replicas", achievedReplicas). + Dur("duration", operation.Duration). + Msg("Completed scaling wave") + + // Clear current wave + smc.currentWave = nil +} + +// extractHealthGateResults extracts the final health gate results from checks +func (smc *ScalingMetricsCollector) extractHealthGateResults() map[string]bool { + results := make(map[string]bool) + + // Get the latest result for each gate + for _, check := range smc.currentWave.HealthChecks { + results[check.GateName] = check.Healthy + } + + return results +} + +// GenerateReport generates a metrics report for the specified time window +func (smc *ScalingMetricsCollector) GenerateReport(ctx context.Context, windowStart, windowEnd time.Time) *ScalingMetricsReport { + ctx, span := tracing.Tracer.Start(ctx, "scaling_metrics.generate_report") + defer span.End() + + smc.mu.RLock() + defer smc.mu.RUnlock() + + report := &ScalingMetricsReport{ + WindowStart: windowStart, + WindowEnd: windowEnd, + HealthGateFailures: make(map[string]int), + ServiceMetrics: make(map[string]ServiceMetrics), + CurrentWave: smc.currentWave, + } + + // Filter operations within window + var windowOps []ScalingOperation + for _, op := range smc.operations { + if op.StartedAt.After(windowStart) && op.StartedAt.Before(windowEnd) { + windowOps = append(windowOps, op) + } + } + + report.TotalOperations = len(windowOps) + + if len(windowOps) == 0 { + return report + } + + // Calculate aggregated metrics + var totalDuration time.Duration + var totalJoinDuration time.Duration + var totalJoinAttempts int + serviceStats := make(map[string]*ServiceMetrics) + + for _, op := range windowOps { + // Overall stats + if op.Success { + report.SuccessfulOps++ + } else { + report.FailedOps++ + } + + totalDuration += op.Duration + + // Backoff tracking + if op.BackoffLevel > 0 { + report.BackoffEvents++ + } + + // Health gate failures + for gate, healthy := range op.HealthGateResults { + if !healthy { + report.HealthGateFailures[gate]++ + } + } + + // Join attempt metrics + for _, attempt := range op.JoinAttempts { + totalJoinDuration += attempt.Duration + totalJoinAttempts++ + } + + // Service-specific metrics + if _, exists := serviceStats[op.ServiceName]; !exists { + serviceStats[op.ServiceName] = &ServiceMetrics{ + ServiceName: op.ServiceName, + } + } + + svc := serviceStats[op.ServiceName] + svc.TotalWaves++ + if op.Success { + svc.SuccessfulWaves++ + } + if op.CompletedAt.After(svc.LastScaled) { + svc.LastScaled = op.CompletedAt + svc.CurrentReplicas = op.AchievedReplicas + } + } + + // Calculate rates and averages + report.SuccessRate = float64(report.SuccessfulOps) / float64(report.TotalOperations) + report.AverageWaveTime = totalDuration / time.Duration(len(windowOps)) + + if totalJoinAttempts > 0 { + report.AverageJoinTime = totalJoinDuration / time.Duration(totalJoinAttempts) + } + + // Finalize service metrics + for serviceName, stats := range serviceStats { + if stats.TotalWaves > 0 { + // Calculate average wave time for this service + var serviceDuration time.Duration + serviceWaves := 0 + for _, op := range windowOps { + if op.ServiceName == serviceName { + serviceDuration += op.Duration + serviceWaves++ + } + } + stats.AverageWaveTime = serviceDuration / time.Duration(serviceWaves) + } + report.ServiceMetrics[serviceName] = *stats + } + + span.SetAttributes( + attribute.Int("report.total_operations", report.TotalOperations), + attribute.Int("report.successful_operations", report.SuccessfulOps), + attribute.Float64("report.success_rate", report.SuccessRate), + attribute.String("report.window_duration", windowEnd.Sub(windowStart).String()), + ) + + return report +} + +// GetCurrentWave returns the currently active wave metrics +func (smc *ScalingMetricsCollector) GetCurrentWave() *WaveMetrics { + smc.mu.RLock() + defer smc.mu.RUnlock() + + if smc.currentWave == nil { + return nil + } + + // Return a copy to avoid concurrent access issues + wave := *smc.currentWave + wave.JoinAttempts = make([]JoinAttempt, len(smc.currentWave.JoinAttempts)) + copy(wave.JoinAttempts, smc.currentWave.JoinAttempts) + wave.HealthChecks = make([]HealthCheckResult, len(smc.currentWave.HealthChecks)) + copy(wave.HealthChecks, smc.currentWave.HealthChecks) + + return &wave +} + +// GetRecentOperations returns the most recent scaling operations +func (smc *ScalingMetricsCollector) GetRecentOperations(limit int) []ScalingOperation { + smc.mu.RLock() + defer smc.mu.RUnlock() + + if limit <= 0 || limit > len(smc.operations) { + limit = len(smc.operations) + } + + // Return most recent operations + start := len(smc.operations) - limit + operations := make([]ScalingOperation, limit) + copy(operations, smc.operations[start:]) + + return operations +} + +// ExportMetrics exports metrics in JSON format +func (smc *ScalingMetricsCollector) ExportMetrics(ctx context.Context) ([]byte, error) { + ctx, span := tracing.Tracer.Start(ctx, "scaling_metrics.export") + defer span.End() + + smc.mu.RLock() + defer smc.mu.RUnlock() + + export := struct { + Operations []ScalingOperation `json:"operations"` + CurrentWave *WaveMetrics `json:"current_wave,omitempty"` + ExportedAt time.Time `json:"exported_at"` + }{ + Operations: smc.operations, + CurrentWave: smc.currentWave, + ExportedAt: time.Now(), + } + + data, err := json.MarshalIndent(export, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal metrics: %w", err) + } + + span.SetAttributes( + attribute.Int("export.operation_count", len(smc.operations)), + attribute.Bool("export.has_current_wave", smc.currentWave != nil), + ) + + return data, nil +} \ No newline at end of file diff --git a/internal/orchestrator/swarm_manager.go b/internal/orchestrator/swarm_manager.go index c813a60..f8d47a0 100644 --- a/internal/orchestrator/swarm_manager.go +++ b/internal/orchestrator/swarm_manager.go @@ -77,6 +77,236 @@ func (sm *SwarmManager) Close() error { return sm.client.Close() } +// ScaleService scales a Docker Swarm service to the specified replica count +func (sm *SwarmManager) ScaleService(ctx context.Context, serviceName string, replicas int) error { + ctx, span := tracing.Tracer.Start(ctx, "swarm_manager.scale_service") + defer span.End() + + // Get the service + service, _, err := sm.client.ServiceInspectWithRaw(ctx, serviceName, types.ServiceInspectOptions{}) + if err != nil { + return fmt.Errorf("failed to inspect service %s: %w", serviceName, err) + } + + // Update replica count + serviceSpec := service.Spec + if serviceSpec.Mode.Replicated == nil { + return fmt.Errorf("service %s is not in replicated mode", serviceName) + } + + currentReplicas := *serviceSpec.Mode.Replicated.Replicas + serviceSpec.Mode.Replicated.Replicas = uint64Ptr(uint64(replicas)) + + // Update the service + updateResponse, err := sm.client.ServiceUpdate( + ctx, + service.ID, + service.Version, + serviceSpec, + types.ServiceUpdateOptions{}, + ) + if err != nil { + return fmt.Errorf("failed to update service %s: %w", serviceName, err) + } + + span.SetAttributes( + attribute.String("service.name", serviceName), + attribute.String("service.id", service.ID), + attribute.Int("scaling.current_replicas", int(currentReplicas)), + attribute.Int("scaling.target_replicas", replicas), + ) + + log.Info(). + Str("service_name", serviceName). + Str("service_id", service.ID). + Uint64("current_replicas", currentReplicas). + Int("target_replicas", replicas). + Str("update_id", updateResponse.ID). + Msg("Scaled service") + + return nil +} + +// GetServiceReplicas returns the current replica count for a service +func (sm *SwarmManager) GetServiceReplicas(ctx context.Context, serviceName string) (int, error) { + service, _, err := sm.client.ServiceInspectWithRaw(ctx, serviceName, types.ServiceInspectOptions{}) + if err != nil { + return 0, fmt.Errorf("failed to inspect service %s: %w", serviceName, err) + } + + if service.Spec.Mode.Replicated == nil { + return 0, fmt.Errorf("service %s is not in replicated mode", serviceName) + } + + return int(*service.Spec.Mode.Replicated.Replicas), nil +} + +// GetRunningReplicas returns the number of currently running replicas for a service +func (sm *SwarmManager) GetRunningReplicas(ctx context.Context, serviceName string) (int, error) { + // Get service to get its ID + service, _, err := sm.client.ServiceInspectWithRaw(ctx, serviceName, types.ServiceInspectOptions{}) + if err != nil { + return 0, fmt.Errorf("failed to inspect service %s: %w", serviceName, err) + } + + // List tasks for this service + taskFilters := filters.NewArgs() + taskFilters.Add("service", service.ID) + + tasks, err := sm.client.TaskList(ctx, types.TaskListOptions{ + Filters: taskFilters, + }) + if err != nil { + return 0, fmt.Errorf("failed to list tasks for service %s: %w", serviceName, err) + } + + // Count running tasks + runningCount := 0 + for _, task := range tasks { + if task.Status.State == swarm.TaskStateRunning { + runningCount++ + } + } + + return runningCount, nil +} + +// GetServiceStatus returns detailed status information for a service +func (sm *SwarmManager) GetServiceStatus(ctx context.Context, serviceName string) (*ServiceStatus, error) { + service, _, err := sm.client.ServiceInspectWithRaw(ctx, serviceName, types.ServiceInspectOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to inspect service %s: %w", serviceName, err) + } + + // Get tasks for detailed status + taskFilters := filters.NewArgs() + taskFilters.Add("service", service.ID) + + tasks, err := sm.client.TaskList(ctx, types.TaskListOptions{ + Filters: taskFilters, + }) + if err != nil { + return nil, fmt.Errorf("failed to list tasks for service %s: %w", serviceName, err) + } + + status := &ServiceStatus{ + ServiceID: service.ID, + ServiceName: serviceName, + Image: service.Spec.TaskTemplate.ContainerSpec.Image, + CreatedAt: service.CreatedAt, + UpdatedAt: service.UpdatedAt, + Tasks: make([]TaskStatus, 0, len(tasks)), + } + + if service.Spec.Mode.Replicated != nil { + status.DesiredReplicas = int(*service.Spec.Mode.Replicated.Replicas) + } + + // Process tasks + runningCount := 0 + for _, task := range tasks { + taskStatus := TaskStatus{ + TaskID: task.ID, + NodeID: task.NodeID, + State: string(task.Status.State), + Message: task.Status.Message, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, + } + + if task.Status.Timestamp != nil { + taskStatus.StatusTimestamp = *task.Status.Timestamp + } + + status.Tasks = append(status.Tasks, taskStatus) + + if task.Status.State == swarm.TaskStateRunning { + runningCount++ + } + } + + status.RunningReplicas = runningCount + + return status, nil +} + +// CreateCHORUSService creates a new CHORUS service with the specified configuration +func (sm *SwarmManager) CreateCHORUSService(ctx context.Context, config *CHORUSServiceConfig) (*swarm.Service, error) { + ctx, span := tracing.Tracer.Start(ctx, "swarm_manager.create_chorus_service") + defer span.End() + + // Build service specification + serviceSpec := swarm.ServiceSpec{ + Annotations: swarm.Annotations{ + Name: config.ServiceName, + Labels: config.Labels, + }, + TaskTemplate: swarm.TaskSpec{ + ContainerSpec: &swarm.ContainerSpec{ + Image: config.Image, + Env: buildEnvironmentList(config.Environment), + }, + Resources: &swarm.ResourceRequirements{ + Limits: &swarm.Resources{ + NanoCPUs: config.Resources.CPULimit, + MemoryBytes: config.Resources.MemoryLimit, + }, + Reservations: &swarm.Resources{ + NanoCPUs: config.Resources.CPURequest, + MemoryBytes: config.Resources.MemoryRequest, + }, + }, + Placement: &swarm.Placement{ + Constraints: config.Placement.Constraints, + }, + }, + Mode: swarm.ServiceMode{ + Replicated: &swarm.ReplicatedService{ + Replicas: uint64Ptr(uint64(config.InitialReplicas)), + }, + }, + Networks: buildNetworkAttachments(config.Networks), + UpdateConfig: &swarm.UpdateConfig{ + Parallelism: 1, + Delay: 15 * time.Second, + Order: swarm.UpdateOrderStartFirst, + }, + } + + // Add volumes if specified + if len(config.Volumes) > 0 { + serviceSpec.TaskTemplate.ContainerSpec.Mounts = buildMounts(config.Volumes) + } + + // Create the service + response, err := sm.client.ServiceCreate(ctx, serviceSpec, types.ServiceCreateOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to create service %s: %w", config.ServiceName, err) + } + + // Get the created service + service, _, err := sm.client.ServiceInspectWithRaw(ctx, response.ID, types.ServiceInspectOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to inspect created service: %w", err) + } + + span.SetAttributes( + attribute.String("service.name", config.ServiceName), + attribute.String("service.id", response.ID), + attribute.Int("service.initial_replicas", config.InitialReplicas), + attribute.String("service.image", config.Image), + ) + + log.Info(). + Str("service_name", config.ServiceName). + Str("service_id", response.ID). + Int("initial_replicas", config.InitialReplicas). + Str("image", config.Image). + Msg("Created CHORUS service") + + return &service, nil +} + // AgentDeploymentConfig defines configuration for deploying an agent type AgentDeploymentConfig struct { TeamID string `json:"team_id"` @@ -487,94 +717,42 @@ func (sm *SwarmManager) GetServiceLogs(serviceID string, lines int) (string, err return string(logs), nil } -// ScaleService scales a service to the specified number of replicas -func (sm *SwarmManager) ScaleService(serviceID string, replicas uint64) error { - log.Info(). - Str("service_id", serviceID). - Uint64("replicas", replicas). - Msg("๐Ÿ“ˆ Scaling agent service") - - // Get current service spec - service, _, err := sm.client.ServiceInspectWithRaw(sm.ctx, serviceID, types.ServiceInspectOptions{}) - if err != nil { - return fmt.Errorf("failed to inspect service: %w", err) - } - - // Update replicas - service.Spec.Mode.Replicated.Replicas = &replicas - - // Update the service - _, err = sm.client.ServiceUpdate(sm.ctx, serviceID, service.Version, service.Spec, types.ServiceUpdateOptions{}) - if err != nil { - return fmt.Errorf("failed to scale service: %w", err) - } - - log.Info(). - Str("service_id", serviceID). - Uint64("replicas", replicas). - Msg("โœ… Service scaled successfully") - - return nil -} -// GetServiceStatus returns the current status of a service -func (sm *SwarmManager) GetServiceStatus(serviceID string) (*ServiceStatus, error) { - service, _, err := sm.client.ServiceInspectWithRaw(sm.ctx, serviceID, types.ServiceInspectOptions{}) - if err != nil { - return nil, fmt.Errorf("failed to inspect service: %w", err) - } - - // Get task status - tasks, err := sm.client.TaskList(sm.ctx, types.TaskListOptions{ - Filters: filters.NewArgs(filters.Arg("service", serviceID)), - }) - if err != nil { - return nil, fmt.Errorf("failed to list tasks: %w", err) - } - - status := &ServiceStatus{ - ServiceID: serviceID, - ServiceName: service.Spec.Name, - Image: service.Spec.TaskTemplate.ContainerSpec.Image, - Replicas: 0, - RunningTasks: 0, - FailedTasks: 0, - TaskStates: make(map[string]int), - CreatedAt: service.CreatedAt, - UpdatedAt: service.UpdatedAt, - } - - if service.Spec.Mode.Replicated != nil && service.Spec.Mode.Replicated.Replicas != nil { - status.Replicas = *service.Spec.Mode.Replicated.Replicas - } - - // Count task states - for _, task := range tasks { - state := string(task.Status.State) - status.TaskStates[state]++ - - switch task.Status.State { - case swarm.TaskStateRunning: - status.RunningTasks++ - case swarm.TaskStateFailed: - status.FailedTasks++ - } - } - - return status, nil -} -// ServiceStatus represents the current status of a service +// ServiceStatus represents the current status of a service with detailed task information type ServiceStatus struct { - ServiceID string `json:"service_id"` - ServiceName string `json:"service_name"` - Image string `json:"image"` - Replicas uint64 `json:"replicas"` - RunningTasks uint64 `json:"running_tasks"` - FailedTasks uint64 `json:"failed_tasks"` - TaskStates map[string]int `json:"task_states"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ServiceID string `json:"service_id"` + ServiceName string `json:"service_name"` + Image string `json:"image"` + DesiredReplicas int `json:"desired_replicas"` + RunningReplicas int `json:"running_replicas"` + Tasks []TaskStatus `json:"tasks"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// TaskStatus represents the status of an individual task +type TaskStatus struct { + TaskID string `json:"task_id"` + NodeID string `json:"node_id"` + State string `json:"state"` + Message string `json:"message"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + StatusTimestamp time.Time `json:"status_timestamp"` +} + +// CHORUSServiceConfig represents configuration for creating a CHORUS service +type CHORUSServiceConfig struct { + ServiceName string `json:"service_name"` + Image string `json:"image"` + InitialReplicas int `json:"initial_replicas"` + Environment map[string]string `json:"environment"` + Labels map[string]string `json:"labels"` + Networks []string `json:"networks"` + Volumes []VolumeMount `json:"volumes"` + Resources ResourceLimits `json:"resources"` + Placement PlacementConfig `json:"placement"` } // CleanupFailedServices removes failed services @@ -611,6 +789,61 @@ func (sm *SwarmManager) CleanupFailedServices() error { } } } - + return nil +} + +// Helper functions for SwarmManager + +// uint64Ptr returns a pointer to a uint64 value +func uint64Ptr(v uint64) *uint64 { + return &v +} + +// buildEnvironmentList converts a map to a slice of environment variables +func buildEnvironmentList(env map[string]string) []string { + var envList []string + for key, value := range env { + envList = append(envList, fmt.Sprintf("%s=%s", key, value)) + } + return envList +} + +// buildNetworkAttachments converts network names to attachment configs +func buildNetworkAttachments(networks []string) []swarm.NetworkAttachmentConfig { + if len(networks) == 0 { + networks = []string{"chorus_default"} + } + + var attachments []swarm.NetworkAttachmentConfig + for _, network := range networks { + attachments = append(attachments, swarm.NetworkAttachmentConfig{ + Target: network, + }) + } + return attachments +} + +// buildMounts converts volume mounts to Docker mount specs +func buildMounts(volumes []VolumeMount) []mount.Mount { + var mounts []mount.Mount + + for _, vol := range volumes { + mountType := mount.TypeBind + switch vol.Type { + case "volume": + mountType = mount.TypeVolume + case "tmpfs": + mountType = mount.TypeTmpfs + } + + mounts = append(mounts, mount.Mount{ + Type: mountType, + Source: vol.Source, + Target: vol.Target, + ReadOnly: vol.ReadOnly, + }) + } + + return mounts } \ No newline at end of file