package security import ( "fmt" "net" "regexp" "strconv" "strings" "unicode" ) // ValidationError represents a security validation error type ValidationError struct { Field string Message string } func (e ValidationError) Error() string { return fmt.Sprintf("%s: %s", e.Field, e.Message) } // SecurityValidator provides zero-trust input validation type SecurityValidator struct { maxStringLength int maxIPLength int maxUsernameLength int maxPasswordLength int } // NewSecurityValidator creates a new validator with safe defaults func NewSecurityValidator() *SecurityValidator { return &SecurityValidator{ maxStringLength: 1024, // Maximum string length maxIPLength: 45, // IPv6 max length maxUsernameLength: 32, // Standard Unix username limit maxPasswordLength: 128, // Reasonable password limit } } // ValidateIP validates IP addresses with zero-trust approach func (v *SecurityValidator) ValidateIP(ip string) error { if ip == "" { return ValidationError{"ip", "IP address is required"} } if len(ip) > v.maxIPLength { return ValidationError{"ip", "IP address too long"} } // Check for dangerous characters that could be used in command injection if containsUnsafeChars(ip, []rune{'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r'}) { return ValidationError{"ip", "IP address contains invalid characters"} } // Validate IP format if net.ParseIP(ip) == nil { return ValidationError{"ip", "Invalid IP address format"} } return nil } // ValidateUsername validates SSH usernames func (v *SecurityValidator) ValidateUsername(username string) error { if username == "" { return ValidationError{"username", "Username is required"} } if len(username) > v.maxUsernameLength { return ValidationError{"username", fmt.Sprintf("Username too long (max %d characters)", v.maxUsernameLength)} } // Check for command injection characters if containsUnsafeChars(username, []rune{'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r', ' ', '"', '\'', '\\', '/'}) { return ValidationError{"username", "Username contains invalid characters"} } // Validate Unix username format (alphanumeric, underscore, dash, starting with letter/underscore) matched, err := regexp.MatchString("^[a-zA-Z_][a-zA-Z0-9_-]*$", username) if err != nil || !matched { return ValidationError{"username", "Username must start with letter/underscore and contain only alphanumeric characters, underscores, and dashes"} } return nil } // ValidatePassword validates SSH passwords func (v *SecurityValidator) ValidatePassword(password string) error { // Password can be empty if SSH keys are used if password == "" { return nil } if len(password) > v.maxPasswordLength { return ValidationError{"password", fmt.Sprintf("Password too long (max %d characters)", v.maxPasswordLength)} } // Check for shell metacharacters that could break command execution if containsUnsafeChars(password, []rune{'`', '$', '\n', '\r', '\'', ';', '|', '&'}) { return ValidationError{"password", "Password contains characters that could cause security issues"} } return nil } // ValidateSSHKey validates SSH private keys func (v *SecurityValidator) ValidateSSHKey(key string) error { // SSH key can be empty if password auth is used if key == "" { return nil } // Increased limit to accommodate large RSA keys (8192-bit RSA can be ~6.5KB) if len(key) > 16384 { // 16KB should handle even very large keys return ValidationError{"ssh_key", "SSH key too long (max 16KB)"} } // Check for basic SSH key format if strings.Contains(key, "-----BEGIN") { // Private key format - check for proper termination if !strings.Contains(key, "-----END") { return ValidationError{"ssh_key", "SSH private key appears malformed - missing END marker"} } // Check for common private key types validKeyTypes := []string{ "-----BEGIN RSA PRIVATE KEY-----", "-----BEGIN DSA PRIVATE KEY-----", "-----BEGIN EC PRIVATE KEY-----", "-----BEGIN OPENSSH PRIVATE KEY-----", "-----BEGIN PRIVATE KEY-----", // PKCS#8 format } hasValidType := false for _, keyType := range validKeyTypes { if strings.Contains(key, keyType) { hasValidType = true break } } if !hasValidType { return ValidationError{"ssh_key", "SSH private key type not recognized"} } } else if strings.HasPrefix(key, "ssh-") { // Public key format - shouldn't be used for private key field return ValidationError{"ssh_key", "Public key provided where private key expected"} } else { return ValidationError{"ssh_key", "SSH key format not recognized - must be PEM-encoded private key"} } // Check for suspicious content that could indicate injection attempts suspiciousPatterns := []string{ "$(", "`", ";", "|", "&", "<", ">", "\n\n\n", // command injection patterns } for _, pattern := range suspiciousPatterns { if strings.Contains(key, pattern) && !strings.Contains(pattern, "\n") { // newlines are normal in keys return ValidationError{"ssh_key", "SSH key contains suspicious content"} } } return nil } // ValidatePort validates port numbers func (v *SecurityValidator) ValidatePort(port int) error { if port <= 0 || port > 65535 { return ValidationError{"port", "Port must be between 1 and 65535"} } // Warn about privileged ports if port < 1024 && port != 22 && port != 80 && port != 443 { return ValidationError{"port", "Avoid using privileged ports (< 1024) unless necessary"} } return nil } // ValidateHostname validates hostnames func (v *SecurityValidator) ValidateHostname(hostname string) error { if hostname == "" { return ValidationError{"hostname", "Hostname is required"} } if len(hostname) > 253 { return ValidationError{"hostname", "Hostname too long (max 253 characters)"} } // Check for command injection characters if containsUnsafeChars(hostname, []rune{'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r', ' ', '"', '\''}) { return ValidationError{"hostname", "Hostname contains invalid characters"} } // Validate hostname format (RFC 1123) matched, err := regexp.MatchString("^[a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?(\\.([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?))*$", hostname) if err != nil || !matched { return ValidationError{"hostname", "Invalid hostname format"} } return nil } // ValidateClusterSecret validates cluster secrets func (v *SecurityValidator) ValidateClusterSecret(secret string) error { if secret == "" { return ValidationError{"cluster_secret", "Cluster secret is required"} } if len(secret) < 32 { return ValidationError{"cluster_secret", "Cluster secret too short (minimum 32 characters)"} } if len(secret) > 128 { return ValidationError{"cluster_secret", "Cluster secret too long (maximum 128 characters)"} } // Ensure it's hexadecimal (common for generated secrets) matched, err := regexp.MatchString("^[a-fA-F0-9]+$", secret) if err != nil || !matched { // If not hex, ensure it's at least alphanumeric if !isAlphanumeric(secret) { return ValidationError{"cluster_secret", "Cluster secret must be alphanumeric or hexadecimal"} } } return nil } // ValidateFilePath validates file paths func (v *SecurityValidator) ValidateFilePath(path string) error { if path == "" { return ValidationError{"file_path", "File path is required"} } if len(path) > 4096 { return ValidationError{"file_path", "File path too long"} } // Check for command injection and directory traversal if containsUnsafeChars(path, []rune{'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r'}) { return ValidationError{"file_path", "File path contains unsafe characters"} } // Check for directory traversal attempts if strings.Contains(path, "..") { return ValidationError{"file_path", "Directory traversal detected in file path"} } // Ensure absolute paths if !strings.HasPrefix(path, "/") { return ValidationError{"file_path", "File path must be absolute"} } return nil } // SanitizeForCommand sanitizes strings for use in shell commands func (v *SecurityValidator) SanitizeForCommand(input string) string { // Remove dangerous characters and control characters result := strings.Map(func(r rune) rune { if r < 32 || r == 127 { return -1 // Remove control characters } switch r { case '`', '$', ';', '&', '|', '<', '>', '(', ')', '"', '\'', '\\', '*', '?', '[', ']', '{', '}': return -1 // Remove shell metacharacters and globbing chars } return r }, input) // Trim whitespace and collapse multiple spaces result = strings.TrimSpace(result) // Replace multiple spaces with single space for strings.Contains(result, " ") { result = strings.ReplaceAll(result, " ", " ") } return result } // Helper function to check for unsafe characters func containsUnsafeChars(s string, unsafeChars []rune) bool { for _, char := range s { for _, unsafe := range unsafeChars { if char == unsafe { return true } } } return false } // Helper function to check if string is alphanumeric func isAlphanumeric(s string) bool { for _, char := range s { if !unicode.IsLetter(char) && !unicode.IsDigit(char) { return false } } return true } // ValidateSSHConnectionRequest validates an SSH connection request func (v *SecurityValidator) ValidateSSHConnectionRequest(ip, username, password, sshKey string, port int) error { if err := v.ValidateIP(ip); err != nil { return err } if err := v.ValidateUsername(username); err != nil { return err } if err := v.ValidatePassword(password); err != nil { return err } if err := v.ValidateSSHKey(sshKey); err != nil { return err } if err := v.ValidatePort(port); err != nil { return err } // Ensure at least one authentication method is provided if password == "" && sshKey == "" { return ValidationError{"auth", "Either password or SSH key must be provided"} } return nil } // ValidatePortList validates a list of port numbers func (v *SecurityValidator) ValidatePortList(ports []string) error { if len(ports) > 50 { // Reasonable limit return ValidationError{"ports", "Too many ports specified (max 50)"} } for i, portStr := range ports { port, err := strconv.Atoi(portStr) if err != nil { return ValidationError{"ports", fmt.Sprintf("Port %d is not a valid number: %s", i+1, portStr)} } if err := v.ValidatePort(port); err != nil { return ValidationError{"ports", fmt.Sprintf("Port %d invalid: %s", i+1, err.Error())} } } return nil } // ValidateIPList validates a list of IP addresses func (v *SecurityValidator) ValidateIPList(ips []string) error { if len(ips) > 100 { // Reasonable limit return ValidationError{"ip_list", "Too many IPs specified (max 100)"} } for i, ip := range ips { if err := v.ValidateIP(ip); err != nil { return ValidationError{"ip_list", fmt.Sprintf("IP %d invalid: %s", i+1, err.Error())} } } return nil }