package policy import ( "context" "crypto/ed25519" "crypto/rsa" "encoding/base64" "encoding/json" "fmt" "io" "math/big" "net/http" "strings" "sync" "time" "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog/log" ) // Claims represents the JWT claims structure type Claims struct { Subject string `json:"sub"` Scopes []string `json:"scopes,omitempty"` Scope string `json:"scope,omitempty"` // Space-separated scopes jwt.RegisteredClaims } // JWKS represents a JSON Web Key Set type JWKS struct { Keys []JWK `json:"keys"` } // JWK represents a JSON Web Key type JWK struct { Kid string `json:"kid"` Kty string `json:"kty"` Alg string `json:"alg"` Use string `json:"use"` N string `json:"n"` E string `json:"e"` X string `json:"x"` Crv string `json:"crv"` } // Validator validates JWT tokens type Validator struct { jwksURL string requiredScope string httpClient *http.Client keys map[string]interface{} keysMutex sync.RWMutex lastFetch time.Time cacheDuration time.Duration } // NewValidator creates a new JWT validator func NewValidator(jwksURL, requiredScope string) *Validator { return &Validator{ jwksURL: jwksURL, requiredScope: requiredScope, httpClient: &http.Client{ Timeout: 10 * time.Second, }, keys: make(map[string]interface{}), cacheDuration: 1 * time.Hour, // Cache JWKS for 1 hour } } // ValidateToken validates a JWT token and checks required scopes func (v *Validator) ValidateToken(tokenString string) (*Claims, error) { // Parse token token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { // Get key ID from header kid, ok := token.Header["kid"].(string) if !ok { return nil, fmt.Errorf("no kid in token header") } // Get public key for this kid publicKey, err := v.getPublicKey(kid) if err != nil { return nil, fmt.Errorf("get public key: %w", err) } switch token.Method.(type) { case *jwt.SigningMethodRSA, *jwt.SigningMethodRSAPSS: rsaKey, ok := publicKey.(*rsa.PublicKey) if !ok { return nil, fmt.Errorf("expected RSA public key for kid %s", kid) } return rsaKey, nil case *jwt.SigningMethodEd25519: edKey, ok := publicKey.(ed25519.PublicKey) if !ok { return nil, fmt.Errorf("expected Ed25519 public key for kid %s", kid) } return edKey, nil default: return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } }) if err != nil { return nil, fmt.Errorf("parse token: %w", err) } // Extract claims claims, ok := token.Claims.(*Claims) if !ok || !token.Valid { return nil, fmt.Errorf("invalid token claims") } // Validate expiration if claims.ExpiresAt != nil && claims.ExpiresAt.Before(time.Now()) { return nil, fmt.Errorf("token expired") } // Validate not before if claims.NotBefore != nil && claims.NotBefore.After(time.Now()) { return nil, fmt.Errorf("token not yet valid") } // Check required scope if v.requiredScope != "" { if !v.hasRequiredScope(claims) { return nil, fmt.Errorf("missing required scope: %s", v.requiredScope) } } return claims, nil } // hasRequiredScope checks if claims contain the required scope func (v *Validator) hasRequiredScope(claims *Claims) bool { // Check scopes array for _, scope := range claims.Scopes { if scope == v.requiredScope { return true } } // Check space-separated scope string (OAuth2 style) if claims.Scope != "" { for _, scope := range parseScopes(claims.Scope) { if scope == v.requiredScope { return true } } } return false } // getPublicKey retrieves a public key by kid, fetching JWKS if needed func (v *Validator) getPublicKey(kid string) (interface{}, error) { // Check if cache is expired v.keysMutex.RLock() cacheExpired := time.Since(v.lastFetch) > v.cacheDuration key, keyExists := v.keys[kid] v.keysMutex.RUnlock() // If key exists and cache is not expired, return it if keyExists && !cacheExpired { return key, nil } // Need to fetch JWKS (either key not found or cache expired) if err := v.fetchJWKS(); err != nil { return nil, fmt.Errorf("fetch JWKS: %w", err) } // Try again after fetch v.keysMutex.RLock() defer v.keysMutex.RUnlock() if key, ok := v.keys[kid]; ok { return key, nil } return nil, fmt.Errorf("key not found: %s", kid) } // fetchJWKS fetches and caches the JWKS from the server func (v *Validator) fetchJWKS() error { log.Info().Str("url", v.jwksURL).Msg("Fetching JWKS") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, "GET", v.jwksURL, nil) if err != nil { return fmt.Errorf("create request: %w", err) } resp, err := v.httpClient.Do(req) if err != nil { return fmt.Errorf("http request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("JWKS fetch failed: status %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("read response: %w", err) } var jwks JWKS if err := json.Unmarshal(body, &jwks); err != nil { return fmt.Errorf("unmarshal JWKS: %w", err) } // Parse and cache all keys newKeys := make(map[string]interface{}) for _, jwk := range jwks.Keys { switch jwk.Kty { case "RSA": publicKey, err := jwk.toRSAPublicKey() if err != nil { log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse RSA JWK") continue } newKeys[jwk.Kid] = publicKey case "OKP": if strings.EqualFold(jwk.Crv, "Ed25519") { publicKey, err := jwk.toEd25519PublicKey() if err != nil { log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse Ed25519 JWK") continue } newKeys[jwk.Kid] = publicKey } else { log.Warn().Str("kid", jwk.Kid).Str("crv", jwk.Crv).Msg("Skipping unsupported OKP curve") } default: log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping unsupported key type") } } if len(newKeys) == 0 { return fmt.Errorf("no valid keys found in JWKS") } // Update cache v.keysMutex.Lock() v.keys = newKeys v.lastFetch = time.Now() v.keysMutex.Unlock() log.Info().Int("key_count", len(newKeys)).Msg("JWKS cached successfully") return nil } // toRSAPublicKey converts a JWK to an RSA public key func (jwk *JWK) toRSAPublicKey() (*rsa.PublicKey, error) { // Decode N (modulus) - use base64 URL encoding without padding nBytes, err := base64URLDecode(jwk.N) if err != nil { return nil, fmt.Errorf("decode N: %w", err) } // Decode E (exponent) eBytes, err := base64URLDecode(jwk.E) if err != nil { return nil, fmt.Errorf("decode E: %w", err) } // Convert E bytes to int var e int for _, b := range eBytes { e = e<<8 | int(b) } // Create RSA public key publicKey := &rsa.PublicKey{ N: new(big.Int).SetBytes(nBytes), E: e, } return publicKey, nil } // toEd25519PublicKey converts a JWK to an Ed25519 public key func (jwk *JWK) toEd25519PublicKey() (ed25519.PublicKey, error) { if jwk.X == "" { return nil, fmt.Errorf("missing x coordinate for Ed25519 key") } xBytes, err := base64URLDecode(jwk.X) if err != nil { return nil, fmt.Errorf("decode x: %w", err) } if len(xBytes) != ed25519.PublicKeySize { return nil, fmt.Errorf("invalid Ed25519 public key length: expected %d, got %d", ed25519.PublicKeySize, len(xBytes)) } return ed25519.PublicKey(xBytes), nil } // parseScopes splits a space-separated scope string func parseScopes(scopeString string) []string { if scopeString == "" { return nil } var scopes []string current := "" for _, ch := range scopeString { if ch == ' ' { if current != "" { scopes = append(scopes, current) current = "" } } else { current += string(ch) } } if current != "" { scopes = append(scopes, current) } return scopes } // RefreshJWKS forces a refresh of the JWKS cache func (v *Validator) RefreshJWKS() error { return v.fetchJWKS() } // GetCachedKeyCount returns the number of cached keys func (v *Validator) GetCachedKeyCount() int { v.keysMutex.RLock() defer v.keysMutex.RUnlock() return len(v.keys) } // base64URLDecode decodes a base64 URL-encoded string (with or without padding) func base64URLDecode(s string) ([]byte, error) { // Add padding if needed if l := len(s) % 4; l > 0 { s += strings.Repeat("=", 4-l) } return base64.URLEncoding.DecodeString(s) } // base64URLEncode encodes bytes to base64 URL encoding without padding func base64URLEncode(data []byte) string { return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=") }