package policy import ( "context" "crypto/rsa" "encoding/base64" "encoding/json" "fmt" "io" "math/big" "net/http" "strings" "sync" "time" "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog/log" ) // Claims represents the JWT claims structure type Claims struct { Subject string `json:"sub"` Scopes []string `json:"scopes,omitempty"` Scope string `json:"scope,omitempty"` // Space-separated scopes jwt.RegisteredClaims } // JWKS represents a JSON Web Key Set type JWKS struct { Keys []JWK `json:"keys"` } // JWK represents a JSON Web Key type JWK struct { Kid string `json:"kid"` Kty string `json:"kty"` Alg string `json:"alg"` Use string `json:"use"` N string `json:"n"` E string `json:"e"` } // Validator validates JWT tokens type Validator struct { jwksURL string requiredScope string httpClient *http.Client keys map[string]*rsa.PublicKey keysMutex sync.RWMutex lastFetch time.Time cacheDuration time.Duration } // NewValidator creates a new JWT validator func NewValidator(jwksURL, requiredScope string) *Validator { return &Validator{ jwksURL: jwksURL, requiredScope: requiredScope, httpClient: &http.Client{ Timeout: 10 * time.Second, }, keys: make(map[string]*rsa.PublicKey), cacheDuration: 1 * time.Hour, // Cache JWKS for 1 hour } } // ValidateToken validates a JWT token and checks required scopes func (v *Validator) ValidateToken(tokenString string) (*Claims, error) { // Parse token token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { // Verify signing algorithm if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } // Get key ID from header kid, ok := token.Header["kid"].(string) if !ok { return nil, fmt.Errorf("no kid in token header") } // Get public key for this kid publicKey, err := v.getPublicKey(kid) if err != nil { return nil, fmt.Errorf("get public key: %w", err) } return publicKey, nil }) if err != nil { return nil, fmt.Errorf("parse token: %w", err) } // Extract claims claims, ok := token.Claims.(*Claims) if !ok || !token.Valid { return nil, fmt.Errorf("invalid token claims") } // Validate expiration if claims.ExpiresAt != nil && claims.ExpiresAt.Before(time.Now()) { return nil, fmt.Errorf("token expired") } // Validate not before if claims.NotBefore != nil && claims.NotBefore.After(time.Now()) { return nil, fmt.Errorf("token not yet valid") } // Check required scope if v.requiredScope != "" { if !v.hasRequiredScope(claims) { return nil, fmt.Errorf("missing required scope: %s", v.requiredScope) } } return claims, nil } // hasRequiredScope checks if claims contain the required scope func (v *Validator) hasRequiredScope(claims *Claims) bool { // Check scopes array for _, scope := range claims.Scopes { if scope == v.requiredScope { return true } } // Check space-separated scope string (OAuth2 style) if claims.Scope != "" { for _, scope := range parseScopes(claims.Scope) { if scope == v.requiredScope { return true } } } return false } // getPublicKey retrieves a public key by kid, fetching JWKS if needed func (v *Validator) getPublicKey(kid string) (*rsa.PublicKey, error) { // Check if cache is expired v.keysMutex.RLock() cacheExpired := time.Since(v.lastFetch) > v.cacheDuration key, keyExists := v.keys[kid] v.keysMutex.RUnlock() // If key exists and cache is not expired, return it if keyExists && !cacheExpired { return key, nil } // Need to fetch JWKS (either key not found or cache expired) if err := v.fetchJWKS(); err != nil { return nil, fmt.Errorf("fetch JWKS: %w", err) } // Try again after fetch v.keysMutex.RLock() defer v.keysMutex.RUnlock() if key, ok := v.keys[kid]; ok { return key, nil } return nil, fmt.Errorf("key not found: %s", kid) } // fetchJWKS fetches and caches the JWKS from the server func (v *Validator) fetchJWKS() error { log.Info().Str("url", v.jwksURL).Msg("Fetching JWKS") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, "GET", v.jwksURL, nil) if err != nil { return fmt.Errorf("create request: %w", err) } resp, err := v.httpClient.Do(req) if err != nil { return fmt.Errorf("http request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("JWKS fetch failed: status %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("read response: %w", err) } var jwks JWKS if err := json.Unmarshal(body, &jwks); err != nil { return fmt.Errorf("unmarshal JWKS: %w", err) } // Parse and cache all keys newKeys := make(map[string]*rsa.PublicKey) for _, jwk := range jwks.Keys { if jwk.Kty != "RSA" { log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping non-RSA key") continue } publicKey, err := jwk.toRSAPublicKey() if err != nil { log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse JWK") continue } newKeys[jwk.Kid] = publicKey } if len(newKeys) == 0 { return fmt.Errorf("no valid keys found in JWKS") } // Update cache v.keysMutex.Lock() v.keys = newKeys v.lastFetch = time.Now() v.keysMutex.Unlock() log.Info().Int("key_count", len(newKeys)).Msg("JWKS cached successfully") return nil } // toRSAPublicKey converts a JWK to an RSA public key func (jwk *JWK) toRSAPublicKey() (*rsa.PublicKey, error) { // Decode N (modulus) - use base64 URL encoding without padding nBytes, err := base64URLDecode(jwk.N) if err != nil { return nil, fmt.Errorf("decode N: %w", err) } // Decode E (exponent) eBytes, err := base64URLDecode(jwk.E) if err != nil { return nil, fmt.Errorf("decode E: %w", err) } // Convert E bytes to int var e int for _, b := range eBytes { e = e<<8 | int(b) } // Create RSA public key publicKey := &rsa.PublicKey{ N: new(big.Int).SetBytes(nBytes), E: e, } return publicKey, nil } // parseScopes splits a space-separated scope string func parseScopes(scopeString string) []string { if scopeString == "" { return nil } var scopes []string current := "" for _, ch := range scopeString { if ch == ' ' { if current != "" { scopes = append(scopes, current) current = "" } } else { current += string(ch) } } if current != "" { scopes = append(scopes, current) } return scopes } // RefreshJWKS forces a refresh of the JWKS cache func (v *Validator) RefreshJWKS() error { return v.fetchJWKS() } // GetCachedKeyCount returns the number of cached keys func (v *Validator) GetCachedKeyCount() int { v.keysMutex.RLock() defer v.keysMutex.RUnlock() return len(v.keys) } // base64URLDecode decodes a base64 URL-encoded string (with or without padding) func base64URLDecode(s string) ([]byte, error) { // Add padding if needed if l := len(s) % 4; l > 0 { s += strings.Repeat("=", 4-l) } return base64.URLEncoding.DecodeString(s) } // base64URLEncode encodes bytes to base64 URL encoding without padding func base64URLEncode(data []byte) string { return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=") }