package policy import ( "net/http" "strings" "github.com/rs/zerolog/log" ) // AuthMiddleware creates HTTP middleware for JWT authentication type AuthMiddleware struct { validator *Validator policyDenials func() // Metrics callback for policy denials enforcementEnabled bool } // NewAuthMiddleware creates a new authentication middleware func NewAuthMiddleware(validator *Validator, policyDenials func()) *AuthMiddleware { return &AuthMiddleware{ validator: validator, policyDenials: policyDenials, enforcementEnabled: validator != nil, } } // Wrap wraps an HTTP handler with JWT authentication func (m *AuthMiddleware) Wrap(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // If enforcement is disabled, pass through if !m.enforcementEnabled { log.Warn().Msg("Policy enforcement disabled - allowing request") next.ServeHTTP(w, r) return } // Extract token from Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { log.Error().Msg("Missing Authorization header") m.policyDenials() http.Error(w, "Unauthorized: missing authorization header", http.StatusUnauthorized) return } // Check Bearer scheme parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { log.Error().Str("auth_header", authHeader).Msg("Invalid Authorization header format") m.policyDenials() http.Error(w, "Unauthorized: invalid authorization format", http.StatusUnauthorized) return } tokenString := parts[1] // Validate token claims, err := m.validator.ValidateToken(tokenString) if err != nil { log.Error().Err(err).Msg("Token validation failed") m.policyDenials() http.Error(w, "Unauthorized: "+err.Error(), http.StatusUnauthorized) return } log.Info(). Str("subject", claims.Subject). Strs("scopes", claims.Scopes). Msg("Request authorized") // Token is valid, pass to next handler next.ServeHTTP(w, r) }) } // WrapFunc wraps an HTTP handler function with JWT authentication func (m *AuthMiddleware) WrapFunc(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { m.Wrap(next).ServeHTTP(w, r) } }