package leader import ( "context" "encoding/json" "fmt" "strings" "sync" "time" ) // FailoverManager handles leader failover and state transfer for context operations type FailoverManager struct { mu sync.RWMutex contextManager *LeaderContextManager logger *ContextLogger metricsCollector *MetricsCollector // Failover state failoverState *ContextFailoverState transferInProgress bool lastFailover time.Time failoverHistory []*FailoverEvent // Configuration config *FailoverConfig // Shutdown coordination shutdownChan chan struct{} shutdownOnce sync.Once } // FailoverConfig represents configuration for failover operations type FailoverConfig struct { // Transfer timeouts StateTransferTimeout time.Duration `json:"state_transfer_timeout"` ValidationTimeout time.Duration `json:"validation_timeout"` RecoveryTimeout time.Duration `json:"recovery_timeout"` // State preservation PreserveQueuedRequests bool `json:"preserve_queued_requests"` PreserveActiveJobs bool `json:"preserve_active_jobs"` PreserveCompletedJobs bool `json:"preserve_completed_jobs"` MaxJobsToTransfer int `json:"max_jobs_to_transfer"` // Validation settings RequireStateValidation bool `json:"require_state_validation"` RequireChecksumMatch bool `json:"require_checksum_match"` AllowPartialRecovery bool `json:"allow_partial_recovery"` // Recovery settings MaxRecoveryAttempts int `json:"max_recovery_attempts"` RecoveryBackoff time.Duration `json:"recovery_backoff"` AutoRecovery bool `json:"auto_recovery"` // History settings MaxFailoverHistory int `json:"max_failover_history"` // Reliability settings HeartbeatInterval time.Duration `json:"heartbeat_interval"` HeartbeatTimeout time.Duration `json:"heartbeat_timeout"` HealthCheckInterval time.Duration `json:"health_check_interval"` MaxConsecutiveFailures int `json:"max_consecutive_failures"` // Circuit breaker settings CircuitBreakerEnabled bool `json:"circuit_breaker_enabled"` CircuitBreakerThreshold int `json:"circuit_breaker_threshold"` CircuitBreakerTimeout time.Duration `json:"circuit_breaker_timeout"` } // NewFailoverManager creates a new failover manager func NewFailoverManager(contextManager *LeaderContextManager, logger *ContextLogger, metricsCollector *MetricsCollector) *FailoverManager { return &FailoverManager{ contextManager: contextManager, logger: logger.WithField("component", "failover"), metricsCollector: metricsCollector, failoverHistory: make([]*FailoverEvent, 0), config: DefaultFailoverConfig(), shutdownChan: make(chan struct{}), } } // DefaultFailoverConfig returns default failover configuration func DefaultFailoverConfig() *FailoverConfig { return &FailoverConfig{ StateTransferTimeout: 30 * time.Second, ValidationTimeout: 10 * time.Second, RecoveryTimeout: 60 * time.Second, PreserveQueuedRequests: true, PreserveActiveJobs: true, PreserveCompletedJobs: false, MaxJobsToTransfer: 1000, RequireStateValidation: true, RequireChecksumMatch: true, AllowPartialRecovery: true, MaxRecoveryAttempts: 3, RecoveryBackoff: 5 * time.Second, AutoRecovery: true, MaxFailoverHistory: 100, HeartbeatInterval: 5 * time.Second, HeartbeatTimeout: 15 * time.Second, HealthCheckInterval: 30 * time.Second, MaxConsecutiveFailures: 3, CircuitBreakerEnabled: true, CircuitBreakerThreshold: 5, CircuitBreakerTimeout: 60 * time.Second, } } // PrepareFailover prepares current state for potential failover func (fm *FailoverManager) PrepareFailover(ctx context.Context) (*FailoverState, error) { fm.mu.Lock() defer fm.mu.Unlock() if fm.transferInProgress { return nil, fmt.Errorf("transfer already in progress") } fm.logger.Info("Preparing failover state") startTime := time.Now() state := &FailoverState{ LeaderID: fm.contextManager.getNodeID(), Term: fm.contextManager.getCurrentTerm(), LastActivity: time.Now(), StateVersion: time.Now().Unix(), CreatedAt: time.Now(), } // Collect queued requests if fm.config.PreserveQueuedRequests { queuedRequests, err := fm.collectQueuedRequests() if err != nil { fm.logger.Error("Failed to collect queued requests: %v", err) return nil, fmt.Errorf("failed to collect queued requests: %w", err) } state.QueuedRequests = queuedRequests } // Collect active jobs if fm.config.PreserveActiveJobs { activeJobs, err := fm.collectActiveJobs() if err != nil { fm.logger.Error("Failed to collect active jobs: %v", err) return nil, fmt.Errorf("failed to collect active jobs: %w", err) } state.ActiveJobs = activeJobs } // Collect completed jobs (if configured) if fm.config.PreserveCompletedJobs { completedJobs, err := fm.collectCompletedJobs() if err != nil { fm.logger.Error("Failed to collect completed jobs: %v", err) // Non-fatal for completed jobs } else { state.CompletedJobs = completedJobs } } // Collect cluster state clusterState, err := fm.collectClusterState() if err != nil { fm.logger.Warn("Failed to collect cluster state: %v", err) // Non-fatal } else { state.ClusterState = clusterState } // Collect resource allocations resourceAllocations, err := fm.collectResourceAllocations() if err != nil { fm.logger.Warn("Failed to collect resource allocations: %v", err) // Non-fatal } else { state.ResourceAllocations = resourceAllocations } // Collect configuration state.ManagerConfig = fm.contextManager.config // Generate checksum if fm.config.RequireChecksumMatch { checksum, err := fm.generateStateChecksum(state) if err != nil { fm.logger.Error("Failed to generate state checksum: %v", err) return nil, fmt.Errorf("failed to generate state checksum: %w", err) } state.Checksum = checksum } fm.failoverState = state preparationTime := time.Since(startTime) fm.logger.Info("Failover state prepared in %v (version: %d, queued: %d, active: %d)", preparationTime, state.StateVersion, len(state.QueuedRequests), len(state.ActiveJobs)) fm.metricsCollector.RecordTimer("failover_preparation_time", preparationTime) return state, nil } // ExecuteFailover executes failover to become new leader func (fm *FailoverManager) ExecuteFailover(ctx context.Context, previousState *FailoverState) error { fm.mu.Lock() defer fm.mu.Unlock() if fm.transferInProgress { return fmt.Errorf("transfer already in progress") } fm.transferInProgress = true defer func() { fm.transferInProgress = false }() fm.logger.Info("Executing failover from previous state (version: %d)", previousState.StateVersion) startTime := time.Now() // Validate state first validation, err := fm.ValidateState(previousState) if err != nil { fm.logger.Error("Failed to validate failover state: %v", err) return fmt.Errorf("failed to validate failover state: %w", err) } if !validation.Valid && !fm.config.AllowPartialRecovery { fm.logger.Error("Invalid failover state and partial recovery disabled: %v", validation.Issues) return fmt.Errorf("invalid failover state: %v", validation.Issues) } if !validation.Valid { fm.logger.Warn("Failover state has issues, proceeding with partial recovery: %v", validation.Issues) } // Record failover event failoverEvent := &FailoverEvent{ EventID: generateEventID(), EventType: "failover_execution", OldLeaderID: previousState.LeaderID, NewLeaderID: fm.contextManager.getNodeID(), Term: previousState.Term + 1, Reason: "leader_failure", StateTransferred: true, OccurredAt: time.Now(), } // Execute recovery steps var recoveryResult *RecoveryResult if fm.config.AutoRecovery { recoveryResult, err = fm.RecoverFromFailover(ctx) if err != nil { fm.logger.Error("Auto recovery failed: %v", err) failoverEvent.Impact = "recovery_failed" } } // Restore queued requests if len(previousState.QueuedRequests) > 0 && validation.QueueStateValid { restored, err := fm.restoreQueuedRequests(previousState.QueuedRequests) if err != nil { fm.logger.Error("Failed to restore queued requests: %v", err) } else { fm.logger.Info("Restored %d queued requests", restored) } } // Restore active jobs if len(previousState.ActiveJobs) > 0 { restored, err := fm.restoreActiveJobs(previousState.ActiveJobs) if err != nil { fm.logger.Error("Failed to restore active jobs: %v", err) } else { fm.logger.Info("Restored %d active jobs", restored) } } // Apply configuration if previousState.ManagerConfig != nil && validation.ConfigValid { fm.contextManager.config = previousState.ManagerConfig fm.logger.Info("Applied previous manager configuration") } failoverEvent.Duration = time.Since(startTime) fm.addFailoverEvent(failoverEvent) fm.logger.Info("Failover executed successfully in %v", failoverEvent.Duration) fm.metricsCollector.RecordTimer("failover_execution_time", failoverEvent.Duration) fm.metricsCollector.IncrementCounter("failovers_executed", 1) if recoveryResult != nil { fm.logger.Info("Recovery result: %d requests recovered, %d jobs recovered, %d lost", recoveryResult.RecoveredRequests, recoveryResult.RecoveredJobs, recoveryResult.LostRequests) } return nil } // TransferState transfers leadership state to another node func (fm *FailoverManager) TransferState(ctx context.Context, targetNodeID string) error { fm.mu.Lock() defer fm.mu.Unlock() fm.logger.Info("Transferring state to node %s", targetNodeID) startTime := time.Now() // Prepare failover state state, err := fm.PrepareFailover(ctx) if err != nil { return fmt.Errorf("failed to prepare state for transfer: %w", err) } // TODO: Implement actual network transfer to target node // This would involve: // 1. Establishing connection to target node // 2. Sending failover state // 3. Waiting for acknowledgment // 4. Handling transfer failures transferTime := time.Since(startTime) fm.logger.Info("State transfer completed in %v", transferTime) fm.metricsCollector.RecordTimer("state_transfer_time", transferTime) fm.metricsCollector.IncrementCounter("state_transfers", 1) return nil } // ReceiveState receives leadership state from previous leader func (fm *FailoverManager) ReceiveState(ctx context.Context, state *FailoverState) error { fm.logger.Info("Receiving state from previous leader %s", state.LeaderID) // Store received state fm.mu.Lock() fm.failoverState = state fm.mu.Unlock() // Execute failover with received state return fm.ExecuteFailover(ctx, state) } // ValidateState validates received failover state func (fm *FailoverManager) ValidateState(state *FailoverState) (*StateValidation, error) { if state == nil { return &StateValidation{ Valid: false, Issues: []string{"nil failover state"}, ValidatedAt: time.Now(), ValidatedBy: fm.contextManager.getNodeID(), }, nil } fm.logger.Debug("Validating failover state (version: %d)", state.StateVersion) startTime := time.Now() validation := &StateValidation{ Valid: true, ValidatedAt: time.Now(), ValidatedBy: fm.contextManager.getNodeID(), } // Basic field validation if state.LeaderID == "" { validation.Issues = append(validation.Issues, "missing leader ID") validation.Valid = false } if state.Term <= 0 { validation.Issues = append(validation.Issues, "invalid term") validation.Valid = false } if state.StateVersion <= 0 { validation.Issues = append(validation.Issues, "invalid state version") validation.Valid = false } // Timestamp validation if state.CreatedAt.IsZero() { validation.Issues = append(validation.Issues, "missing creation timestamp") validation.TimestampValid = false validation.Valid = false } else { // Check if state is not too old age := time.Since(state.CreatedAt) if age > 5*time.Minute { validation.Issues = append(validation.Issues, fmt.Sprintf("state too old: %v", age)) validation.TimestampValid = false validation.Valid = false } else { validation.TimestampValid = true } } // Checksum validation if fm.config.RequireChecksumMatch && state.Checksum != "" { expectedChecksum, err := fm.generateStateChecksum(state) if err != nil { validation.Issues = append(validation.Issues, "failed to generate checksum for validation") validation.ChecksumValid = false validation.Valid = false } else { validation.ChecksumValid = expectedChecksum == state.Checksum if !validation.ChecksumValid { validation.Issues = append(validation.Issues, "checksum mismatch") validation.Valid = false } } } else { validation.ChecksumValid = true } // Queue state validation validation.QueueStateValid = true if state.QueuedRequests == nil { validation.QueueStateValid = false validation.Issues = append(validation.Issues, "missing queued requests array") } else { // Validate individual requests for i, req := range state.QueuedRequests { if err := fm.validateRequest(req); err != nil { validation.Issues = append(validation.Issues, fmt.Sprintf("invalid request %d: %v", i, err)) validation.QueueStateValid = false } } } // Cluster state validation validation.ClusterStateValid = state.ClusterState != nil if !validation.ClusterStateValid { validation.Issues = append(validation.Issues, "missing cluster state") } // Configuration validation validation.ConfigValid = state.ManagerConfig != nil if !validation.ConfigValid { validation.Issues = append(validation.Issues, "missing manager configuration") } // Version consistency validation.VersionConsistent = true // TODO: Implement actual version checking // Set recovery requirements if len(validation.Issues) > 0 { validation.RequiresRecovery = true validation.RecoverySteps = fm.generateRecoverySteps(validation.Issues) } validation.ValidationDuration = time.Since(startTime) fm.logger.Debug("State validation completed in %v (valid: %t, issues: %d)", validation.ValidationDuration, validation.Valid, len(validation.Issues)) return validation, nil } // RecoverFromFailover recovers operations after failover func (fm *FailoverManager) RecoverFromFailover(ctx context.Context) (*RecoveryResult, error) { fm.logger.Info("Starting recovery from failover") startTime := time.Now() result := &RecoveryResult{ RecoveredAt: time.Now(), } // TODO: Implement actual recovery logic // This would involve: // 1. Checking for orphaned jobs // 2. Restarting failed operations // 3. Cleaning up inconsistent state // 4. Validating system health result.RecoveryTime = time.Since(startTime) fm.logger.Info("Recovery completed in %v", result.RecoveryTime) fm.metricsCollector.RecordTimer("recovery_time", result.RecoveryTime) fm.metricsCollector.IncrementCounter("recoveries_executed", 1) return result, nil } // GetFailoverHistory returns history of failover events func (fm *FailoverManager) GetFailoverHistory() ([]*FailoverEvent, error) { fm.mu.RLock() defer fm.mu.RUnlock() // Return copy of failover history history := make([]*FailoverEvent, len(fm.failoverHistory)) copy(history, fm.failoverHistory) return history, nil } // GetFailoverStats returns failover statistics func (fm *FailoverManager) GetFailoverStats() (*FailoverStatistics, error) { fm.mu.RLock() defer fm.mu.RUnlock() stats := &FailoverStatistics{ TotalFailovers: int64(len(fm.failoverHistory)), LastFailover: fm.lastFailover, } // Calculate statistics from history var totalDuration time.Duration var maxDuration time.Duration var successfulFailovers int64 for _, event := range fm.failoverHistory { if event.EventType == "failover_execution" { totalDuration += event.Duration if event.Duration > maxDuration { maxDuration = event.Duration } if event.Impact != "recovery_failed" { successfulFailovers++ } } } stats.SuccessfulFailovers = successfulFailovers stats.FailedFailovers = stats.TotalFailovers - successfulFailovers stats.MaxFailoverTime = maxDuration if stats.TotalFailovers > 0 { stats.AverageFailoverTime = totalDuration / time.Duration(stats.TotalFailovers) } // Calculate MTBF (Mean Time Between Failures) if len(fm.failoverHistory) > 1 { firstFailover := fm.failoverHistory[0].OccurredAt lastFailover := fm.failoverHistory[len(fm.failoverHistory)-1].OccurredAt totalTime := lastFailover.Sub(firstFailover) stats.MeanTimeBetweenFailovers = totalTime / time.Duration(len(fm.failoverHistory)-1) } return stats, nil } // Helper methods func (fm *FailoverManager) collectQueuedRequests() ([]*ContextGenerationRequest, error) { // TODO: Implement actual queue collection from context manager return []*ContextGenerationRequest{}, nil } func (fm *FailoverManager) collectActiveJobs() (map[string]*ContextGenerationJob, error) { // TODO: Implement actual active jobs collection from context manager return make(map[string]*ContextGenerationJob), nil } func (fm *FailoverManager) collectCompletedJobs() ([]*ContextGenerationJob, error) { // TODO: Implement actual completed jobs collection from context manager return []*ContextGenerationJob{}, nil } func (fm *FailoverManager) collectClusterState() (*ClusterState, error) { // TODO: Implement actual cluster state collection return &ClusterState{}, nil } func (fm *FailoverManager) collectResourceAllocations() (map[string]*ResourceAllocation, error) { // TODO: Implement actual resource allocation collection return make(map[string]*ResourceAllocation), nil } func (fm *FailoverManager) generateStateChecksum(state *FailoverState) (string, error) { // Create a copy without checksum for hashing tempState := *state tempState.Checksum = "" data, err := json.Marshal(tempState) if err != nil { return "", err } // TODO: Use proper cryptographic hash return fmt.Sprintf("%x", data[:32]), nil } func (fm *FailoverManager) restoreQueuedRequests(requests []*ContextGenerationRequest) (int, error) { // TODO: Implement actual queue restoration return len(requests), nil } func (fm *FailoverManager) restoreActiveJobs(jobs map[string]*ContextGenerationJob) (int, error) { // TODO: Implement actual active jobs restoration return len(jobs), nil } func (fm *FailoverManager) validateRequest(req *ContextGenerationRequest) error { if req == nil { return fmt.Errorf("nil request") } if req.ID == "" { return fmt.Errorf("missing request ID") } if req.FilePath == "" { return fmt.Errorf("missing file path") } if req.Role == "" { return fmt.Errorf("missing role") } return nil } func (fm *FailoverManager) generateRecoverySteps(issues []string) []string { steps := []string{ "Validate system health", "Check resource availability", "Restart failed operations", } // Add specific steps based on issues for _, issue := range issues { if strings.Contains(issue, "checksum") { steps = append(steps, "Perform state integrity check") } if strings.Contains(issue, "queue") { steps = append(steps, "Rebuild generation queue") } if strings.Contains(issue, "cluster") { steps = append(steps, "Refresh cluster state") } } return steps } func (fm *FailoverManager) addFailoverEvent(event *FailoverEvent) { fm.failoverHistory = append(fm.failoverHistory, event) fm.lastFailover = event.OccurredAt // Trim history if too long if len(fm.failoverHistory) > fm.config.MaxFailoverHistory { fm.failoverHistory = fm.failoverHistory[1:] } } func (fm *FailoverManager) getNodeID() string { return fm.contextManager.getNodeID() } func (fm *FailoverManager) getCurrentTerm() int64 { return fm.contextManager.getCurrentTerm() } func generateEventID() string { return fmt.Sprintf("failover-%d-%x", time.Now().Unix(), time.Now().UnixNano()&0xFFFFFF) } // Add required methods to LeaderContextManager func (cm *LeaderContextManager) getNodeID() string { // TODO: Get actual node ID from configuration or election system return "node-" + fmt.Sprintf("%d", time.Now().Unix()) } func (cm *LeaderContextManager) getCurrentTerm() int64 { // TODO: Get actual term from election system return 1 }