355 lines
8.4 KiB
Go
355 lines
8.4 KiB
Go
package policy
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// Claims represents the JWT claims structure
|
|
type Claims struct {
|
|
Subject string `json:"sub"`
|
|
Scopes []string `json:"scopes,omitempty"`
|
|
Scope string `json:"scope,omitempty"` // Space-separated scopes
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
// JWKS represents a JSON Web Key Set
|
|
type JWKS struct {
|
|
Keys []JWK `json:"keys"`
|
|
}
|
|
|
|
// JWK represents a JSON Web Key
|
|
type JWK struct {
|
|
Kid string `json:"kid"`
|
|
Kty string `json:"kty"`
|
|
Alg string `json:"alg"`
|
|
Use string `json:"use"`
|
|
N string `json:"n"`
|
|
E string `json:"e"`
|
|
X string `json:"x"`
|
|
Crv string `json:"crv"`
|
|
}
|
|
|
|
// Validator validates JWT tokens
|
|
type Validator struct {
|
|
jwksURL string
|
|
requiredScope string
|
|
httpClient *http.Client
|
|
keys map[string]interface{}
|
|
keysMutex sync.RWMutex
|
|
lastFetch time.Time
|
|
cacheDuration time.Duration
|
|
}
|
|
|
|
// NewValidator creates a new JWT validator
|
|
func NewValidator(jwksURL, requiredScope string) *Validator {
|
|
return &Validator{
|
|
jwksURL: jwksURL,
|
|
requiredScope: requiredScope,
|
|
httpClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
keys: make(map[string]interface{}),
|
|
cacheDuration: 1 * time.Hour, // Cache JWKS for 1 hour
|
|
}
|
|
}
|
|
|
|
// ValidateToken validates a JWT token and checks required scopes
|
|
func (v *Validator) ValidateToken(tokenString string) (*Claims, error) {
|
|
// Parse token
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
// Get key ID from header
|
|
kid, ok := token.Header["kid"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("no kid in token header")
|
|
}
|
|
|
|
// Get public key for this kid
|
|
publicKey, err := v.getPublicKey(kid)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get public key: %w", err)
|
|
}
|
|
|
|
switch token.Method.(type) {
|
|
case *jwt.SigningMethodRSA, *jwt.SigningMethodRSAPSS:
|
|
rsaKey, ok := publicKey.(*rsa.PublicKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected RSA public key for kid %s", kid)
|
|
}
|
|
return rsaKey, nil
|
|
case *jwt.SigningMethodEd25519:
|
|
edKey, ok := publicKey.(ed25519.PublicKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected Ed25519 public key for kid %s", kid)
|
|
}
|
|
return edKey, nil
|
|
default:
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse token: %w", err)
|
|
}
|
|
|
|
// Extract claims
|
|
claims, ok := token.Claims.(*Claims)
|
|
if !ok || !token.Valid {
|
|
return nil, fmt.Errorf("invalid token claims")
|
|
}
|
|
|
|
// Validate expiration
|
|
if claims.ExpiresAt != nil && claims.ExpiresAt.Before(time.Now()) {
|
|
return nil, fmt.Errorf("token expired")
|
|
}
|
|
|
|
// Validate not before
|
|
if claims.NotBefore != nil && claims.NotBefore.After(time.Now()) {
|
|
return nil, fmt.Errorf("token not yet valid")
|
|
}
|
|
|
|
// Check required scope
|
|
if v.requiredScope != "" {
|
|
if !v.hasRequiredScope(claims) {
|
|
return nil, fmt.Errorf("missing required scope: %s", v.requiredScope)
|
|
}
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// hasRequiredScope checks if claims contain the required scope
|
|
func (v *Validator) hasRequiredScope(claims *Claims) bool {
|
|
// Check scopes array
|
|
for _, scope := range claims.Scopes {
|
|
if scope == v.requiredScope {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Check space-separated scope string (OAuth2 style)
|
|
if claims.Scope != "" {
|
|
for _, scope := range parseScopes(claims.Scope) {
|
|
if scope == v.requiredScope {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// getPublicKey retrieves a public key by kid, fetching JWKS if needed
|
|
func (v *Validator) getPublicKey(kid string) (interface{}, error) {
|
|
// Check if cache is expired
|
|
v.keysMutex.RLock()
|
|
cacheExpired := time.Since(v.lastFetch) > v.cacheDuration
|
|
key, keyExists := v.keys[kid]
|
|
v.keysMutex.RUnlock()
|
|
|
|
// If key exists and cache is not expired, return it
|
|
if keyExists && !cacheExpired {
|
|
return key, nil
|
|
}
|
|
|
|
// Need to fetch JWKS (either key not found or cache expired)
|
|
if err := v.fetchJWKS(); err != nil {
|
|
return nil, fmt.Errorf("fetch JWKS: %w", err)
|
|
}
|
|
|
|
// Try again after fetch
|
|
v.keysMutex.RLock()
|
|
defer v.keysMutex.RUnlock()
|
|
|
|
if key, ok := v.keys[kid]; ok {
|
|
return key, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("key not found: %s", kid)
|
|
}
|
|
|
|
// fetchJWKS fetches and caches the JWKS from the server
|
|
func (v *Validator) fetchJWKS() error {
|
|
log.Info().Str("url", v.jwksURL).Msg("Fetching JWKS")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", v.jwksURL, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
resp, err := v.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("http request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("JWKS fetch failed: status %d", resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("read response: %w", err)
|
|
}
|
|
|
|
var jwks JWKS
|
|
if err := json.Unmarshal(body, &jwks); err != nil {
|
|
return fmt.Errorf("unmarshal JWKS: %w", err)
|
|
}
|
|
|
|
// Parse and cache all keys
|
|
newKeys := make(map[string]interface{})
|
|
for _, jwk := range jwks.Keys {
|
|
switch jwk.Kty {
|
|
case "RSA":
|
|
publicKey, err := jwk.toRSAPublicKey()
|
|
if err != nil {
|
|
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse RSA JWK")
|
|
continue
|
|
}
|
|
newKeys[jwk.Kid] = publicKey
|
|
case "OKP":
|
|
if strings.EqualFold(jwk.Crv, "Ed25519") {
|
|
publicKey, err := jwk.toEd25519PublicKey()
|
|
if err != nil {
|
|
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse Ed25519 JWK")
|
|
continue
|
|
}
|
|
newKeys[jwk.Kid] = publicKey
|
|
} else {
|
|
log.Warn().Str("kid", jwk.Kid).Str("crv", jwk.Crv).Msg("Skipping unsupported OKP curve")
|
|
}
|
|
default:
|
|
log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping unsupported key type")
|
|
}
|
|
}
|
|
|
|
if len(newKeys) == 0 {
|
|
return fmt.Errorf("no valid keys found in JWKS")
|
|
}
|
|
|
|
// Update cache
|
|
v.keysMutex.Lock()
|
|
v.keys = newKeys
|
|
v.lastFetch = time.Now()
|
|
v.keysMutex.Unlock()
|
|
|
|
log.Info().Int("key_count", len(newKeys)).Msg("JWKS cached successfully")
|
|
|
|
return nil
|
|
}
|
|
|
|
// toRSAPublicKey converts a JWK to an RSA public key
|
|
func (jwk *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
|
|
// Decode N (modulus) - use base64 URL encoding without padding
|
|
nBytes, err := base64URLDecode(jwk.N)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode N: %w", err)
|
|
}
|
|
|
|
// Decode E (exponent)
|
|
eBytes, err := base64URLDecode(jwk.E)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode E: %w", err)
|
|
}
|
|
|
|
// Convert E bytes to int
|
|
var e int
|
|
for _, b := range eBytes {
|
|
e = e<<8 | int(b)
|
|
}
|
|
|
|
// Create RSA public key
|
|
publicKey := &rsa.PublicKey{
|
|
N: new(big.Int).SetBytes(nBytes),
|
|
E: e,
|
|
}
|
|
|
|
return publicKey, nil
|
|
}
|
|
|
|
// toEd25519PublicKey converts a JWK to an Ed25519 public key
|
|
func (jwk *JWK) toEd25519PublicKey() (ed25519.PublicKey, error) {
|
|
if jwk.X == "" {
|
|
return nil, fmt.Errorf("missing x coordinate for Ed25519 key")
|
|
}
|
|
|
|
xBytes, err := base64URLDecode(jwk.X)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode x: %w", err)
|
|
}
|
|
|
|
if len(xBytes) != ed25519.PublicKeySize {
|
|
return nil, fmt.Errorf("invalid Ed25519 public key length: expected %d, got %d", ed25519.PublicKeySize, len(xBytes))
|
|
}
|
|
|
|
return ed25519.PublicKey(xBytes), nil
|
|
}
|
|
|
|
// parseScopes splits a space-separated scope string
|
|
func parseScopes(scopeString string) []string {
|
|
if scopeString == "" {
|
|
return nil
|
|
}
|
|
|
|
var scopes []string
|
|
current := ""
|
|
for _, ch := range scopeString {
|
|
if ch == ' ' {
|
|
if current != "" {
|
|
scopes = append(scopes, current)
|
|
current = ""
|
|
}
|
|
} else {
|
|
current += string(ch)
|
|
}
|
|
}
|
|
if current != "" {
|
|
scopes = append(scopes, current)
|
|
}
|
|
|
|
return scopes
|
|
}
|
|
|
|
// RefreshJWKS forces a refresh of the JWKS cache
|
|
func (v *Validator) RefreshJWKS() error {
|
|
return v.fetchJWKS()
|
|
}
|
|
|
|
// GetCachedKeyCount returns the number of cached keys
|
|
func (v *Validator) GetCachedKeyCount() int {
|
|
v.keysMutex.RLock()
|
|
defer v.keysMutex.RUnlock()
|
|
return len(v.keys)
|
|
}
|
|
|
|
// base64URLDecode decodes a base64 URL-encoded string (with or without padding)
|
|
func base64URLDecode(s string) ([]byte, error) {
|
|
// Add padding if needed
|
|
if l := len(s) % 4; l > 0 {
|
|
s += strings.Repeat("=", 4-l)
|
|
}
|
|
return base64.URLEncoding.DecodeString(s)
|
|
}
|
|
|
|
// base64URLEncode encodes bytes to base64 URL encoding without padding
|
|
func base64URLEncode(data []byte) string {
|
|
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
|
|
}
|