package orchestrator import ( "context" "fmt" "math" "sync" "time" "github.com/rs/zerolog/log" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "github.com/chorus-services/whoosh/internal/tracing" ) // ScalingController manages wave-based scaling operations for CHORUS services type ScalingController struct { mu sync.RWMutex swarmManager *SwarmManager healthGates *HealthGates assignmentBroker *AssignmentBroker bootstrapManager *BootstrapPoolManager metricsCollector *ScalingMetricsCollector // Scaling configuration config ScalingConfig // Current scaling state currentOperations map[string]*ScalingOperation scalingActive bool stopChan chan struct{} ctx context.Context cancel context.CancelFunc } // ScalingConfig defines configuration for scaling operations type ScalingConfig struct { MinWaveSize int `json:"min_wave_size"` // Minimum replicas per wave MaxWaveSize int `json:"max_wave_size"` // Maximum replicas per wave WaveInterval time.Duration `json:"wave_interval"` // Time between waves MaxConcurrentOps int `json:"max_concurrent_ops"` // Maximum concurrent scaling operations // Backoff configuration InitialBackoff time.Duration `json:"initial_backoff"` // Initial backoff delay MaxBackoff time.Duration `json:"max_backoff"` // Maximum backoff delay BackoffMultiplier float64 `json:"backoff_multiplier"` // Backoff multiplier JitterPercentage float64 `json:"jitter_percentage"` // Jitter percentage (0.0-1.0) // Health gate configuration HealthCheckTimeout time.Duration `json:"health_check_timeout"` // Timeout for health checks MinJoinSuccessRate float64 `json:"min_join_success_rate"` // Minimum join success rate SuccessRateWindow int `json:"success_rate_window"` // Window size for success rate calculation } // ScalingOperation represents an ongoing scaling operation type ScalingOperation struct { ID string `json:"id"` ServiceName string `json:"service_name"` CurrentReplicas int `json:"current_replicas"` TargetReplicas int `json:"target_replicas"` // Wave state CurrentWave int `json:"current_wave"` WavesCompleted int `json:"waves_completed"` WaveSize int `json:"wave_size"` // Timing StartedAt time.Time `json:"started_at"` LastWaveAt time.Time `json:"last_wave_at,omitempty"` EstimatedCompletion time.Time `json:"estimated_completion,omitempty"` // Backoff state ConsecutiveFailures int `json:"consecutive_failures"` NextWaveAt time.Time `json:"next_wave_at,omitempty"` BackoffDelay time.Duration `json:"backoff_delay"` // Status Status ScalingStatus `json:"status"` LastError string `json:"last_error,omitempty"` // Configuration Template string `json:"template"` ScalingParams map[string]interface{} `json:"scaling_params,omitempty"` } // ScalingStatus represents the status of a scaling operation type ScalingStatus string const ( ScalingStatusPending ScalingStatus = "pending" ScalingStatusRunning ScalingStatus = "running" ScalingStatusWaiting ScalingStatus = "waiting" // Waiting for health gates ScalingStatusBackoff ScalingStatus = "backoff" // In backoff period ScalingStatusCompleted ScalingStatus = "completed" ScalingStatusFailed ScalingStatus = "failed" ScalingStatusCancelled ScalingStatus = "cancelled" ) // ScalingRequest represents a request to scale a service type ScalingRequest struct { ServiceName string `json:"service_name"` TargetReplicas int `json:"target_replicas"` Template string `json:"template,omitempty"` ScalingParams map[string]interface{} `json:"scaling_params,omitempty"` Force bool `json:"force,omitempty"` // Skip health gates } // WaveResult represents the result of a scaling wave type WaveResult struct { WaveNumber int `json:"wave_number"` RequestedCount int `json:"requested_count"` SuccessfulJoins int `json:"successful_joins"` FailedJoins int `json:"failed_joins"` Duration time.Duration `json:"duration"` CompletedAt time.Time `json:"completed_at"` } // NewScalingController creates a new scaling controller func NewScalingController( swarmManager *SwarmManager, healthGates *HealthGates, assignmentBroker *AssignmentBroker, bootstrapManager *BootstrapPoolManager, metricsCollector *ScalingMetricsCollector, ) *ScalingController { ctx, cancel := context.WithCancel(context.Background()) return &ScalingController{ swarmManager: swarmManager, healthGates: healthGates, assignmentBroker: assignmentBroker, bootstrapManager: bootstrapManager, metricsCollector: metricsCollector, config: ScalingConfig{ MinWaveSize: 3, MaxWaveSize: 8, WaveInterval: 30 * time.Second, MaxConcurrentOps: 3, InitialBackoff: 30 * time.Second, MaxBackoff: 2 * time.Minute, BackoffMultiplier: 1.5, JitterPercentage: 0.2, HealthCheckTimeout: 10 * time.Second, MinJoinSuccessRate: 0.8, SuccessRateWindow: 10, }, currentOperations: make(map[string]*ScalingOperation), stopChan: make(chan struct{}, 1), ctx: ctx, cancel: cancel, } } // StartScaling initiates a scaling operation and returns the wave ID func (sc *ScalingController) StartScaling(ctx context.Context, serviceName string, targetReplicas, waveSize int, template string) (string, error) { request := ScalingRequest{ ServiceName: serviceName, TargetReplicas: targetReplicas, Template: template, } operation, err := sc.startScalingOperation(ctx, request) if err != nil { return "", err } return operation.ID, nil } // startScalingOperation initiates a scaling operation func (sc *ScalingController) startScalingOperation(ctx context.Context, request ScalingRequest) (*ScalingOperation, error) { ctx, span := tracing.Tracer.Start(ctx, "scaling_controller.start_scaling") defer span.End() sc.mu.Lock() defer sc.mu.Unlock() // Check if there's already an operation for this service if existingOp, exists := sc.currentOperations[request.ServiceName]; exists { if existingOp.Status == ScalingStatusRunning || existingOp.Status == ScalingStatusWaiting { return nil, fmt.Errorf("scaling operation already in progress for service %s", request.ServiceName) } } // Check concurrent operation limit runningOps := 0 for _, op := range sc.currentOperations { if op.Status == ScalingStatusRunning || op.Status == ScalingStatusWaiting { runningOps++ } } if runningOps >= sc.config.MaxConcurrentOps { return nil, fmt.Errorf("maximum concurrent scaling operations (%d) reached", sc.config.MaxConcurrentOps) } // Get current replica count currentReplicas, err := sc.swarmManager.GetServiceReplicas(ctx, request.ServiceName) if err != nil { return nil, fmt.Errorf("failed to get current replica count: %w", err) } // Calculate wave size waveSize := sc.calculateWaveSize(currentReplicas, request.TargetReplicas) // Create scaling operation operation := &ScalingOperation{ ID: fmt.Sprintf("scale-%s-%d", request.ServiceName, time.Now().Unix()), ServiceName: request.ServiceName, CurrentReplicas: currentReplicas, TargetReplicas: request.TargetReplicas, CurrentWave: 1, WaveSize: waveSize, StartedAt: time.Now(), Status: ScalingStatusPending, Template: request.Template, ScalingParams: request.ScalingParams, BackoffDelay: sc.config.InitialBackoff, } // Store operation sc.currentOperations[request.ServiceName] = operation // Start metrics tracking if sc.metricsCollector != nil { sc.metricsCollector.StartWave(ctx, operation.ID, operation.ServiceName, operation.TargetReplicas) } // Start scaling process in background go sc.executeScaling(context.Background(), operation, request.Force) span.SetAttributes( attribute.String("scaling.service_name", request.ServiceName), attribute.Int("scaling.current_replicas", currentReplicas), attribute.Int("scaling.target_replicas", request.TargetReplicas), attribute.Int("scaling.wave_size", waveSize), attribute.String("scaling.operation_id", operation.ID), ) log.Info(). Str("operation_id", operation.ID). Str("service_name", request.ServiceName). Int("current_replicas", currentReplicas). Int("target_replicas", request.TargetReplicas). Int("wave_size", waveSize). Msg("Started scaling operation") return operation, nil } // executeScaling executes the scaling operation with wave-based approach func (sc *ScalingController) executeScaling(ctx context.Context, operation *ScalingOperation, force bool) { ctx, span := tracing.Tracer.Start(ctx, "scaling_controller.execute_scaling") defer span.End() defer func() { sc.mu.Lock() // Keep completed operations for a while for monitoring if operation.Status == ScalingStatusCompleted || operation.Status == ScalingStatusFailed { // Clean up after 1 hour go func() { time.Sleep(1 * time.Hour) sc.mu.Lock() delete(sc.currentOperations, operation.ServiceName) sc.mu.Unlock() }() } sc.mu.Unlock() }() operation.Status = ScalingStatusRunning for operation.CurrentReplicas < operation.TargetReplicas { // Check if we should wait for backoff if !operation.NextWaveAt.IsZero() && time.Now().Before(operation.NextWaveAt) { operation.Status = ScalingStatusBackoff waitTime := time.Until(operation.NextWaveAt) log.Info(). Str("operation_id", operation.ID). Dur("wait_time", waitTime). Msg("Waiting for backoff period") select { case <-ctx.Done(): operation.Status = ScalingStatusCancelled return case <-time.After(waitTime): // Continue after backoff } } operation.Status = ScalingStatusRunning // Check health gates (unless forced) if !force { if err := sc.waitForHealthGates(ctx, operation); err != nil { operation.LastError = err.Error() operation.ConsecutiveFailures++ sc.applyBackoff(operation) continue } } // Execute scaling wave waveResult, err := sc.executeWave(ctx, operation) if err != nil { log.Error(). Str("operation_id", operation.ID). Err(err). Msg("Scaling wave failed") operation.LastError = err.Error() operation.ConsecutiveFailures++ sc.applyBackoff(operation) continue } // Update operation state operation.CurrentReplicas += waveResult.SuccessfulJoins operation.WavesCompleted++ operation.LastWaveAt = time.Now() operation.ConsecutiveFailures = 0 // Reset on success operation.NextWaveAt = time.Time{} // Clear backoff // Update scaling metrics sc.updateScalingMetrics(operation.ServiceName, waveResult) log.Info(). Str("operation_id", operation.ID). Int("wave", operation.CurrentWave). Int("successful_joins", waveResult.SuccessfulJoins). Int("failed_joins", waveResult.FailedJoins). Int("current_replicas", operation.CurrentReplicas). Int("target_replicas", operation.TargetReplicas). Msg("Scaling wave completed") // Move to next wave operation.CurrentWave++ // Wait between waves if operation.CurrentReplicas < operation.TargetReplicas { select { case <-ctx.Done(): operation.Status = ScalingStatusCancelled return case <-time.After(sc.config.WaveInterval): // Continue to next wave } } } // Scaling completed successfully operation.Status = ScalingStatusCompleted operation.EstimatedCompletion = time.Now() log.Info(). Str("operation_id", operation.ID). Str("service_name", operation.ServiceName). Int("final_replicas", operation.CurrentReplicas). Int("waves_completed", operation.WavesCompleted). Dur("total_duration", time.Since(operation.StartedAt)). Msg("Scaling operation completed successfully") } // waitForHealthGates waits for health gates to be satisfied func (sc *ScalingController) waitForHealthGates(ctx context.Context, operation *ScalingOperation) error { operation.Status = ScalingStatusWaiting ctx, cancel := context.WithTimeout(ctx, sc.config.HealthCheckTimeout) defer cancel() // Get recent scaling metrics for this service var recentMetrics *ScalingMetrics if metrics, exists := sc.scalingMetrics[operation.ServiceName]; exists { recentMetrics = metrics } healthStatus, err := sc.healthGates.CheckHealth(ctx, recentMetrics) if err != nil { return fmt.Errorf("health gate check failed: %w", err) } if !healthStatus.Healthy { return fmt.Errorf("health gates not satisfied: %s", healthStatus.OverallReason) } return nil } // executeWave executes a single scaling wave func (sc *ScalingController) executeWave(ctx context.Context, operation *ScalingOperation) (*WaveResult, error) { startTime := time.Now() // Calculate how many replicas to add in this wave remaining := operation.TargetReplicas - operation.CurrentReplicas waveSize := operation.WaveSize if remaining < waveSize { waveSize = remaining } // Create assignments for new replicas var assignments []*Assignment for i := 0; i < waveSize; i++ { assignReq := AssignmentRequest{ ClusterID: "production", // TODO: Make configurable Template: operation.Template, } assignment, err := sc.assignmentBroker.CreateAssignment(ctx, assignReq) if err != nil { return nil, fmt.Errorf("failed to create assignment: %w", err) } assignments = append(assignments, assignment) } // Deploy new replicas newReplicaCount := operation.CurrentReplicas + waveSize err := sc.swarmManager.ScaleService(ctx, operation.ServiceName, newReplicaCount) if err != nil { return nil, fmt.Errorf("failed to scale service: %w", err) } // Wait for replicas to come online and join successfully successfulJoins, failedJoins := sc.waitForReplicaJoins(ctx, operation.ServiceName, waveSize) result := &WaveResult{ WaveNumber: operation.CurrentWave, RequestedCount: waveSize, SuccessfulJoins: successfulJoins, FailedJoins: failedJoins, Duration: time.Since(startTime), CompletedAt: time.Now(), } return result, nil } // waitForReplicaJoins waits for new replicas to join the cluster func (sc *ScalingController) waitForReplicaJoins(ctx context.Context, serviceName string, expectedJoins int) (successful, failed int) { // Wait up to 2 minutes for replicas to join ctx, cancel := context.WithTimeout(ctx, 2*time.Minute) defer cancel() ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() startTime := time.Now() for { select { case <-ctx.Done(): // Timeout reached, return current counts return successful, expectedJoins - successful case <-ticker.C: // Check service status running, err := sc.swarmManager.GetRunningReplicas(ctx, serviceName) if err != nil { log.Warn().Err(err).Msg("Failed to get running replicas") continue } // For now, assume all running replicas are successful joins // In a real implementation, this would check P2P network membership if running >= expectedJoins { successful = expectedJoins failed = 0 return } // If we've been waiting too long with no progress, consider some failed if time.Since(startTime) > 90*time.Second { successful = running failed = expectedJoins - running return } } } } // calculateWaveSize calculates the appropriate wave size for scaling func (sc *ScalingController) calculateWaveSize(current, target int) int { totalNodes := 10 // TODO: Get actual node count from swarm // Wave size formula: min(max(3, floor(total_nodes/10)), 8) waveSize := int(math.Max(3, math.Floor(float64(totalNodes)/10))) if waveSize > sc.config.MaxWaveSize { waveSize = sc.config.MaxWaveSize } // Don't exceed remaining replicas needed remaining := target - current if waveSize > remaining { waveSize = remaining } return waveSize } // applyBackoff applies exponential backoff to the operation func (sc *ScalingController) applyBackoff(operation *ScalingOperation) { // Calculate backoff delay with exponential increase backoff := time.Duration(float64(operation.BackoffDelay) * math.Pow(sc.config.BackoffMultiplier, float64(operation.ConsecutiveFailures-1))) // Cap at maximum backoff if backoff > sc.config.MaxBackoff { backoff = sc.config.MaxBackoff } // Add jitter jitter := time.Duration(float64(backoff) * sc.config.JitterPercentage * (rand.Float64() - 0.5)) backoff += jitter operation.BackoffDelay = backoff operation.NextWaveAt = time.Now().Add(backoff) log.Warn(). Str("operation_id", operation.ID). Int("consecutive_failures", operation.ConsecutiveFailures). Dur("backoff_delay", backoff). Time("next_wave_at", operation.NextWaveAt). Msg("Applied exponential backoff") } // updateScalingMetrics updates scaling metrics for success rate tracking func (sc *ScalingController) updateScalingMetrics(serviceName string, result *WaveResult) { sc.mu.Lock() defer sc.mu.Unlock() metrics, exists := sc.scalingMetrics[serviceName] if !exists { metrics = &ScalingMetrics{ LastWaveSize: result.RequestedCount, LastWaveStarted: result.CompletedAt.Add(-result.Duration), LastWaveCompleted: result.CompletedAt, } sc.scalingMetrics[serviceName] = metrics } // Update metrics metrics.LastWaveSize = result.RequestedCount metrics.LastWaveCompleted = result.CompletedAt metrics.SuccessfulJoins += result.SuccessfulJoins metrics.FailedJoins += result.FailedJoins // Calculate success rate total := metrics.SuccessfulJoins + metrics.FailedJoins if total > 0 { metrics.JoinSuccessRate = float64(metrics.SuccessfulJoins) / float64(total) } } // GetOperation returns a scaling operation by service name func (sc *ScalingController) GetOperation(serviceName string) (*ScalingOperation, bool) { sc.mu.RLock() defer sc.mu.RUnlock() op, exists := sc.currentOperations[serviceName] return op, exists } // GetAllOperations returns all current scaling operations func (sc *ScalingController) GetAllOperations() map[string]*ScalingOperation { sc.mu.RLock() defer sc.mu.RUnlock() operations := make(map[string]*ScalingOperation) for k, v := range sc.currentOperations { operations[k] = v } return operations } // CancelOperation cancels a scaling operation func (sc *ScalingController) CancelOperation(serviceName string) error { sc.mu.Lock() defer sc.mu.Unlock() operation, exists := sc.currentOperations[serviceName] if !exists { return fmt.Errorf("no scaling operation found for service %s", serviceName) } if operation.Status == ScalingStatusCompleted || operation.Status == ScalingStatusFailed { return fmt.Errorf("scaling operation already finished") } operation.Status = ScalingStatusCancelled log.Info().Str("operation_id", operation.ID).Msg("Scaling operation cancelled") // Complete metrics tracking if sc.metricsCollector != nil { currentReplicas, _ := sc.swarmManager.GetServiceReplicas(context.Background(), serviceName) sc.metricsCollector.CompleteWave(context.Background(), false, currentReplicas, "Operation cancelled", operation.ConsecutiveFailures) } return nil } // StopScaling stops all active scaling operations func (sc *ScalingController) StopScaling(ctx context.Context) { ctx, span := tracing.Tracer.Start(ctx, "scaling_controller.stop_scaling") defer span.End() sc.mu.Lock() defer sc.mu.Unlock() cancelledCount := 0 for serviceName, operation := range sc.currentOperations { if operation.Status == ScalingStatusRunning || operation.Status == ScalingStatusWaiting || operation.Status == ScalingStatusBackoff { operation.Status = ScalingStatusCancelled cancelledCount++ // Complete metrics tracking for cancelled operations if sc.metricsCollector != nil { currentReplicas, _ := sc.swarmManager.GetServiceReplicas(ctx, serviceName) sc.metricsCollector.CompleteWave(ctx, false, currentReplicas, "Scaling stopped", operation.ConsecutiveFailures) } log.Info().Str("operation_id", operation.ID).Str("service_name", serviceName).Msg("Scaling operation stopped") } } // Signal stop to running operations select { case sc.stopChan <- struct{}{}: default: } span.SetAttributes(attribute.Int("stopped_operations", cancelledCount)) log.Info().Int("cancelled_operations", cancelledCount).Msg("Stopped all scaling operations") } // Close shuts down the scaling controller func (sc *ScalingController) Close() error { sc.cancel() sc.StopScaling(sc.ctx) return nil }