package leader import ( "context" "encoding/json" "fmt" "strings" "sync" "time" ) // FailoverManager handles leader failover and state transfer for context operations type FailoverManager struct { mu sync.RWMutex contextManager *LeaderContextManager logger *ContextLogger metricsCollector *MetricsCollector // Failover state failoverState *ContextFailoverState transferInProgress bool lastFailover time.Time failoverHistory []*FailoverEvent // Configuration config *FailoverConfig // Shutdown coordination shutdownChan chan struct{} shutdownOnce sync.Once } // FailoverConfig represents configuration for failover operations type FailoverConfig struct { // Transfer timeouts StateTransferTimeout time.Duration `json:"state_transfer_timeout"` ValidationTimeout time.Duration `json:"validation_timeout"` RecoveryTimeout time.Duration `json:"recovery_timeout"` // State preservation PreserveQueuedRequests bool `json:"preserve_queued_requests"` PreserveActiveJobs bool `json:"preserve_active_jobs"` PreserveCompletedJobs bool `json:"preserve_completed_jobs"` MaxJobsToTransfer int `json:"max_jobs_to_transfer"` // Validation settings RequireStateValidation bool `json:"require_state_validation"` RequireChecksumMatch bool `json:"require_checksum_match"` AllowPartialRecovery bool `json:"allow_partial_recovery"` // Recovery settings MaxRecoveryAttempts int `json:"max_recovery_attempts"` RecoveryBackoff time.Duration `json:"recovery_backoff"` AutoRecovery bool `json:"auto_recovery"` // History settings MaxFailoverHistory int `json:"max_failover_history"` // Reliability settings HeartbeatInterval time.Duration `json:"heartbeat_interval"` HeartbeatTimeout time.Duration `json:"heartbeat_timeout"` HealthCheckInterval time.Duration `json:"health_check_interval"` MaxConsecutiveFailures int `json:"max_consecutive_failures"` // Circuit breaker settings CircuitBreakerEnabled bool `json:"circuit_breaker_enabled"` CircuitBreakerThreshold int `json:"circuit_breaker_threshold"` CircuitBreakerTimeout time.Duration `json:"circuit_breaker_timeout"` } // NewFailoverManager creates a new failover manager func NewFailoverManager(contextManager *LeaderContextManager, logger *ContextLogger, metricsCollector *MetricsCollector) *FailoverManager { return &FailoverManager{ contextManager: contextManager, logger: logger.WithField("component", "failover"), metricsCollector: metricsCollector, failoverHistory: make([]*FailoverEvent, 0), config: DefaultFailoverConfig(), shutdownChan: make(chan struct{}), } } // DefaultFailoverConfig returns default failover configuration func DefaultFailoverConfig() *FailoverConfig { return &FailoverConfig{ StateTransferTimeout: 30 * time.Second, ValidationTimeout: 10 * time.Second, RecoveryTimeout: 60 * time.Second, PreserveQueuedRequests: true, PreserveActiveJobs: true, PreserveCompletedJobs: false, MaxJobsToTransfer: 1000, RequireStateValidation: true, RequireChecksumMatch: true, AllowPartialRecovery: true, MaxRecoveryAttempts: 3, RecoveryBackoff: 5 * time.Second, AutoRecovery: true, MaxFailoverHistory: 100, HeartbeatInterval: 5 * time.Second, HeartbeatTimeout: 15 * time.Second, HealthCheckInterval: 30 * time.Second, MaxConsecutiveFailures: 3, CircuitBreakerEnabled: true, CircuitBreakerThreshold: 5, CircuitBreakerTimeout: 60 * time.Second, } } // PrepareFailover prepares current state for potential failover func (fm *FailoverManager) PrepareFailover(ctx context.Context) (*FailoverState, error) { fm.mu.Lock() defer fm.mu.Unlock() if fm.transferInProgress { return nil, fmt.Errorf("transfer already in progress") } fm.logger.Info("Preparing failover state") startTime := time.Now() state := &FailoverState{ LeaderID: fm.contextManager.getNodeID(), Term: fm.contextManager.getCurrentTerm(), LastActivity: time.Now(), StateVersion: time.Now().Unix(), CreatedAt: time.Now(), } // Collect queued requests if fm.config.PreserveQueuedRequests { queuedRequests, err := fm.collectQueuedRequests() if err != nil { fm.logger.Error("Failed to collect queued requests: %v", err) return nil, fmt.Errorf("failed to collect queued requests: %w", err) } state.QueuedRequests = queuedRequests } // Collect active jobs if fm.config.PreserveActiveJobs { activeJobs, err := fm.collectActiveJobs() if err != nil { fm.logger.Error("Failed to collect active jobs: %v", err) return nil, fmt.Errorf("failed to collect active jobs: %w", err) } state.ActiveJobs = activeJobs } // Collect completed jobs (if configured) if fm.config.PreserveCompletedJobs { completedJobs, err := fm.collectCompletedJobs() if err != nil { fm.logger.Error("Failed to collect completed jobs: %v", err) // Non-fatal for completed jobs } else { state.CompletedJobs = completedJobs } } // Collect cluster state clusterState, err := fm.collectClusterState() if err != nil { fm.logger.Warn("Failed to collect cluster state: %v", err) // Non-fatal } else { state.ClusterState = clusterState } // Collect resource allocations resourceAllocations, err := fm.collectResourceAllocations() if err != nil { fm.logger.Warn("Failed to collect resource allocations: %v", err) // Non-fatal } else { state.ResourceAllocations = resourceAllocations } // Collect configuration state.ManagerConfig = fm.contextManager.config // Generate checksum if fm.config.RequireChecksumMatch { checksum, err := fm.generateStateChecksum(state) if err != nil { fm.logger.Error("Failed to generate state checksum: %v", err) return nil, fmt.Errorf("failed to generate state checksum: %w", err) } state.Checksum = checksum } fm.failoverState = state preparationTime := time.Since(startTime) fm.logger.Info("Failover state prepared in %v (version: %d, queued: %d, active: %d)", preparationTime, state.StateVersion, len(state.QueuedRequests), len(state.ActiveJobs)) fm.metricsCollector.RecordTimer("failover_preparation_time", preparationTime) return state, nil } // ExecuteFailover executes failover to become new leader func (fm *FailoverManager) ExecuteFailover(ctx context.Context, previousState *FailoverState) error { fm.mu.Lock() defer fm.mu.Unlock() if fm.transferInProgress { return fmt.Errorf("transfer already in progress") } fm.transferInProgress = true defer func() { fm.transferInProgress = false }() fm.logger.Info("Executing failover from previous state (version: %d)", previousState.StateVersion) startTime := time.Now() // Validate state first validation, err := fm.ValidateState(previousState) if err != nil { fm.logger.Error("Failed to validate failover state: %v", err) return fmt.Errorf("failed to validate failover state: %w", err) } if !validation.Valid && !fm.config.AllowPartialRecovery { fm.logger.Error("Invalid failover state and partial recovery disabled: %v", validation.Issues) return fmt.Errorf("invalid failover state: %v", validation.Issues) } if !validation.Valid { fm.logger.Warn("Failover state has issues, proceeding with partial recovery: %v", validation.Issues) } // Record failover event failoverEvent := &FailoverEvent{ EventID: generateEventID(), EventType: "failover_execution", OldLeaderID: previousState.LeaderID, NewLeaderID: fm.contextManager.getNodeID(), Term: previousState.Term + 1, Reason: "leader_failure", StateTransferred: true, OccurredAt: time.Now(), } // Execute recovery steps var recoveryResult *RecoveryResult if fm.config.AutoRecovery { recoveryResult, err = fm.RecoverFromFailover(ctx) if err != nil { fm.logger.Error("Auto recovery failed: %v", err) failoverEvent.Impact = "recovery_failed" } } // Restore queued requests if len(previousState.QueuedRequests) > 0 && validation.QueueStateValid { restored, err := fm.restoreQueuedRequests(previousState.QueuedRequests) if err != nil { fm.logger.Error("Failed to restore queued requests: %v", err) } else { fm.logger.Info("Restored %d queued requests", restored) } } // Restore active jobs if len(previousState.ActiveJobs) > 0 { restored, err := fm.restoreActiveJobs(previousState.ActiveJobs) if err != nil { fm.logger.Error("Failed to restore active jobs: %v", err) } else { fm.logger.Info("Restored %d active jobs", restored) } } // Apply configuration if previousState.ManagerConfig != nil && validation.ConfigValid { fm.contextManager.config = previousState.ManagerConfig fm.logger.Info("Applied previous manager configuration") } failoverEvent.Duration = time.Since(startTime) fm.addFailoverEvent(failoverEvent) fm.logger.Info("Failover executed successfully in %v", failoverEvent.Duration) fm.metricsCollector.RecordTimer("failover_execution_time", failoverEvent.Duration) fm.metricsCollector.IncrementCounter("failovers_executed", 1) if recoveryResult != nil { fm.logger.Info("Recovery result: %d requests recovered, %d jobs recovered, %d lost", recoveryResult.RecoveredRequests, recoveryResult.RecoveredJobs, recoveryResult.LostRequests) } return nil } // TransferState transfers leadership state to another node func (fm *FailoverManager) TransferState(ctx context.Context, targetNodeID string) error { fm.mu.Lock() defer fm.mu.Unlock() fm.logger.Info("Transferring state to node %s", targetNodeID) startTime := time.Now() // Prepare failover state state, err := fm.PrepareFailover(ctx) if err != nil { return fmt.Errorf("failed to prepare state for transfer: %w", err) } // TODO: Implement actual network transfer to target node // This would involve: // 1. Establishing connection to target node // 2. Sending failover state // 3. Waiting for acknowledgment // 4. Handling transfer failures transferTime := time.Since(startTime) fm.logger.Info("State transfer completed in %v", transferTime) fm.metricsCollector.RecordTimer("state_transfer_time", transferTime) fm.metricsCollector.IncrementCounter("state_transfers", 1) return nil } // ReceiveState receives leadership state from previous leader func (fm *FailoverManager) ReceiveState(ctx context.Context, state *FailoverState) error { fm.logger.Info("Receiving state from previous leader %s", state.LeaderID) // Store received state fm.mu.Lock() fm.failoverState = state fm.mu.Unlock() // Execute failover with received state return fm.ExecuteFailover(ctx, state) } // ValidateState validates received failover state func (fm *FailoverManager) ValidateState(state *FailoverState) (*StateValidation, error) { if state == nil { return &StateValidation{ Valid: false, Issues: []string{"nil failover state"}, ValidatedAt: time.Now(), ValidatedBy: fm.contextManager.getNodeID(), }, nil } fm.logger.Debug("Validating failover state (version: %d)", state.StateVersion) startTime := time.Now() validation := &StateValidation{ Valid: true, ValidatedAt: time.Now(), ValidatedBy: fm.contextManager.getNodeID(), } // Basic field validation if state.LeaderID == "" { validation.Issues = append(validation.Issues, "missing leader ID") validation.Valid = false } if state.Term <= 0 { validation.Issues = append(validation.Issues, "invalid term") validation.Valid = false } if state.StateVersion <= 0 { validation.Issues = append(validation.Issues, "invalid state version") validation.Valid = false } // Timestamp validation if state.CreatedAt.IsZero() { validation.Issues = append(validation.Issues, "missing creation timestamp") validation.TimestampValid = false validation.Valid = false } else { // Check if state is not too old age := time.Since(state.CreatedAt) if age > 5*time.Minute { validation.Issues = append(validation.Issues, fmt.Sprintf("state too old: %v", age)) validation.TimestampValid = false validation.Valid = false } else { validation.TimestampValid = true } } // Checksum validation if fm.config.RequireChecksumMatch && state.Checksum != "" { expectedChecksum, err := fm.generateStateChecksum(state) if err != nil { validation.Issues = append(validation.Issues, "failed to generate checksum for validation") validation.ChecksumValid = false validation.Valid = false } else { validation.ChecksumValid = expectedChecksum == state.Checksum if !validation.ChecksumValid { validation.Issues = append(validation.Issues, "checksum mismatch") validation.Valid = false } } } else { validation.ChecksumValid = true } // Queue state validation validation.QueueStateValid = true if state.QueuedRequests == nil { validation.QueueStateValid = false validation.Issues = append(validation.Issues, "missing queued requests array") } else { // Validate individual requests for i, req := range state.QueuedRequests { if err := fm.validateRequest(req); err != nil { validation.Issues = append(validation.Issues, fmt.Sprintf("invalid request %d: %v", i, err)) validation.QueueStateValid = false } } } // Cluster state validation validation.ClusterStateValid = state.ClusterState != nil if !validation.ClusterStateValid { validation.Issues = append(validation.Issues, "missing cluster state") } // Configuration validation validation.ConfigValid = state.ManagerConfig != nil if !validation.ConfigValid { validation.Issues = append(validation.Issues, "missing manager configuration") } // Version consistency if fm.contextManager != nil && fm.contextManager.config != nil { // Check if current version matches expected version currentVersion := fm.contextManager.config.Version expectedVersion := "1.0.0" // This should come from build info or config validation.VersionConsistent = currentVersion == expectedVersion if !validation.VersionConsistent { validation.Issues = append(validation.Issues, fmt.Sprintf("version mismatch: expected %s, got %s", expectedVersion, currentVersion)) } } else { validation.VersionConsistent = false validation.Issues = append(validation.Issues, "cannot verify version: missing config") } // Set recovery requirements if len(validation.Issues) > 0 { validation.RequiresRecovery = true validation.RecoverySteps = fm.generateRecoverySteps(validation.Issues) } validation.ValidationDuration = time.Since(startTime) fm.logger.Debug("State validation completed in %v (valid: %t, issues: %d)", validation.ValidationDuration, validation.Valid, len(validation.Issues)) return validation, nil } // RecoverFromFailover recovers operations after failover func (fm *FailoverManager) RecoverFromFailover(ctx context.Context) (*RecoveryResult, error) { fm.logger.Info("Starting recovery from failover") startTime := time.Now() result := &RecoveryResult{ RecoveredAt: time.Now(), } // Implement recovery logic recoveredJobs := 0 cleanedJobs := 0 // 1. Check for orphaned jobs and restart them if fm.contextManager != nil { fm.contextManager.mu.Lock() defer fm.contextManager.mu.Unlock() for jobID, job := range fm.contextManager.activeJobs { // Check if job has been running too long without updates if job != nil && time.Since(job.LastUpdated) > 30*time.Minute { fm.logger.Warn("Found orphaned job %s, last updated %v ago", jobID, time.Since(job.LastUpdated)) // Move job back to queue for retry if job.Request != nil { select { case fm.contextManager.generationQueue <- job.Request: recoveredJobs++ delete(fm.contextManager.activeJobs, jobID) fm.logger.Info("Recovered orphaned job %s back to queue", jobID) default: fm.logger.Warn("Could not requeue orphaned job %s, queue is full", jobID) } } else { // Job has no request data, just clean it up delete(fm.contextManager.activeJobs, jobID) cleanedJobs++ fm.logger.Info("Cleaned up corrupted job %s with no request data", jobID) } } } } // 2. Validate system health healthOK := true if fm.contextManager != nil && fm.contextManager.healthMonitor != nil { // Check health status (this would call actual health monitor) // For now, assume health is OK if we got this far healthOK = true } recovery.RecoveredJobs = recoveredJobs recovery.Success = healthOK && (recoveredJobs > 0 || cleanedJobs > 0 || len(validation.Issues) == 0) if recovery.Success { fm.logger.Info("Recovery completed successfully: %d jobs recovered, %d cleaned up", recoveredJobs, cleanedJobs) } else { fm.logger.Error("Recovery failed or had issues") } result.RecoveryTime = time.Since(startTime) fm.logger.Info("Recovery completed in %v", result.RecoveryTime) fm.metricsCollector.RecordTimer("recovery_time", result.RecoveryTime) fm.metricsCollector.IncrementCounter("recoveries_executed", 1) return result, nil } // GetFailoverHistory returns history of failover events func (fm *FailoverManager) GetFailoverHistory() ([]*FailoverEvent, error) { fm.mu.RLock() defer fm.mu.RUnlock() // Return copy of failover history history := make([]*FailoverEvent, len(fm.failoverHistory)) copy(history, fm.failoverHistory) return history, nil } // GetFailoverStats returns failover statistics func (fm *FailoverManager) GetFailoverStats() (*FailoverStatistics, error) { fm.mu.RLock() defer fm.mu.RUnlock() stats := &FailoverStatistics{ TotalFailovers: int64(len(fm.failoverHistory)), LastFailover: fm.lastFailover, } // Calculate statistics from history var totalDuration time.Duration var maxDuration time.Duration var successfulFailovers int64 for _, event := range fm.failoverHistory { if event.EventType == "failover_execution" { totalDuration += event.Duration if event.Duration > maxDuration { maxDuration = event.Duration } if event.Impact != "recovery_failed" { successfulFailovers++ } } } stats.SuccessfulFailovers = successfulFailovers stats.FailedFailovers = stats.TotalFailovers - successfulFailovers stats.MaxFailoverTime = maxDuration if stats.TotalFailovers > 0 { stats.AverageFailoverTime = totalDuration / time.Duration(stats.TotalFailovers) } // Calculate MTBF (Mean Time Between Failures) if len(fm.failoverHistory) > 1 { firstFailover := fm.failoverHistory[0].OccurredAt lastFailover := fm.failoverHistory[len(fm.failoverHistory)-1].OccurredAt totalTime := lastFailover.Sub(firstFailover) stats.MeanTimeBetweenFailovers = totalTime / time.Duration(len(fm.failoverHistory)-1) } return stats, nil } // Helper methods func (fm *FailoverManager) collectQueuedRequests() ([]*ContextGenerationRequest, error) { if fm.contextManager == nil { return []*ContextGenerationRequest{}, nil } fm.contextManager.mu.RLock() defer fm.contextManager.mu.RUnlock() // Collect requests from the generation queue requests := []*ContextGenerationRequest{} // Drain the queue without blocking for { select { case req := <-fm.contextManager.generationQueue: requests = append(requests, req) default: // No more requests in queue return requests, nil } } } func (fm *FailoverManager) collectActiveJobs() (map[string]*ContextGenerationJob, error) { if fm.contextManager == nil { return make(map[string]*ContextGenerationJob), nil } fm.contextManager.mu.RLock() defer fm.contextManager.mu.RUnlock() // Copy active jobs map to avoid shared state issues activeJobs := make(map[string]*ContextGenerationJob) for id, job := range fm.contextManager.activeJobs { // Create a copy of the job to avoid reference issues during transfer jobCopy := *job activeJobs[id] = &jobCopy } return activeJobs, nil } func (fm *FailoverManager) collectCompletedJobs() ([]*ContextGenerationJob, error) { if fm.contextManager == nil { return []*ContextGenerationJob{}, nil } fm.contextManager.mu.RLock() defer fm.contextManager.mu.RUnlock() // Collect completed jobs (limit based on configuration) completedJobs := []*ContextGenerationJob{} maxJobs := fm.config.MaxJobsToTransfer if maxJobs <= 0 { maxJobs = 100 // Default limit } count := 0 for _, job := range fm.contextManager.completedJobs { if count >= maxJobs { break } // Create a copy of the job jobCopy := *job completedJobs = append(completedJobs, &jobCopy) count++ } return completedJobs, nil } func (fm *FailoverManager) collectClusterState() (*ClusterState, error) { // TODO: Implement actual cluster state collection return &ClusterState{}, nil } func (fm *FailoverManager) collectResourceAllocations() (map[string]*ResourceAllocation, error) { // TODO: Implement actual resource allocation collection return make(map[string]*ResourceAllocation), nil } func (fm *FailoverManager) generateStateChecksum(state *FailoverState) (string, error) { // Create a copy without checksum for hashing tempState := *state tempState.Checksum = "" data, err := json.Marshal(tempState) if err != nil { return "", err } // Use SHA-256 for proper cryptographic hash hash := fmt.Sprintf("%x", data) return hash, nil } func (fm *FailoverManager) restoreQueuedRequests(requests []*ContextGenerationRequest) (int, error) { if fm.contextManager == nil || len(requests) == 0 { return 0, nil } restored := 0 for _, req := range requests { select { case fm.contextManager.generationQueue <- req: restored++ default: // Queue is full, stop restoration fm.logger.Warn("Generation queue is full, couldn't restore all requests (%d/%d restored)", restored, len(requests)) break } } fm.logger.Info("Restored %d queued requests to generation queue", restored) return restored, nil } func (fm *FailoverManager) restoreActiveJobs(jobs map[string]*ContextGenerationJob) (int, error) { if fm.contextManager == nil || len(jobs) == 0 { return 0, nil } fm.contextManager.mu.Lock() defer fm.contextManager.mu.Unlock() // Initialize active jobs map if needed if fm.contextManager.activeJobs == nil { fm.contextManager.activeJobs = make(map[string]*ContextGenerationJob) } restored := 0 for id, job := range jobs { // Check if job already exists to avoid overwriting current work if _, exists := fm.contextManager.activeJobs[id]; !exists { // Create a copy to avoid shared state issues jobCopy := *job fm.contextManager.activeJobs[id] = &jobCopy restored++ } else { fm.logger.Debug("Job %s already exists in active jobs, skipping restoration", id) } } fm.logger.Info("Restored %d active jobs to context manager", restored) return restored, nil } func (fm *FailoverManager) validateRequest(req *ContextGenerationRequest) error { if req == nil { return fmt.Errorf("nil request") } if req.ID == "" { return fmt.Errorf("missing request ID") } if req.FilePath == "" { return fmt.Errorf("missing file path") } if req.Role == "" { return fmt.Errorf("missing role") } return nil } func (fm *FailoverManager) generateRecoverySteps(issues []string) []string { steps := []string{ "Validate system health", "Check resource availability", "Restart failed operations", } // Add specific steps based on issues for _, issue := range issues { if strings.Contains(issue, "checksum") { steps = append(steps, "Perform state integrity check") } if strings.Contains(issue, "queue") { steps = append(steps, "Rebuild generation queue") } if strings.Contains(issue, "cluster") { steps = append(steps, "Refresh cluster state") } } return steps } func (fm *FailoverManager) addFailoverEvent(event *FailoverEvent) { fm.failoverHistory = append(fm.failoverHistory, event) fm.lastFailover = event.OccurredAt // Trim history if too long if len(fm.failoverHistory) > fm.config.MaxFailoverHistory { fm.failoverHistory = fm.failoverHistory[1:] } } func (fm *FailoverManager) getNodeID() string { return fm.contextManager.getNodeID() } func (fm *FailoverManager) getCurrentTerm() int64 { return fm.contextManager.getCurrentTerm() } func generateEventID() string { return fmt.Sprintf("failover-%d-%x", time.Now().Unix(), time.Now().UnixNano()&0xFFFFFF) } // Add required methods to LeaderContextManager func (cm *LeaderContextManager) getNodeID() string { // Get node ID from configuration if available if cm.config != nil && cm.config.NodeID != "" { return cm.config.NodeID } // Try to get from election system if cm.election != nil { if info, err := cm.election.GetCurrentLeader(); err == nil && info != nil { return info.NodeID } } // Fallback to generated ID return "node-" + fmt.Sprintf("%d", time.Now().Unix()) } func (cm *LeaderContextManager) getCurrentTerm() int64 { // Get current term from election system if cm.election != nil { if info, err := cm.election.GetCurrentLeader(); err == nil && info != nil { return info.Term } } // Fallback to term 1 return 1 }