Add Sequential Thinking compatibility server and JWKS support
This commit is contained in:
@@ -2,6 +2,7 @@ package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -38,6 +39,8 @@ type JWK struct {
|
||||
Use string `json:"use"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
X string `json:"x"`
|
||||
Crv string `json:"crv"`
|
||||
}
|
||||
|
||||
// Validator validates JWT tokens
|
||||
@@ -45,7 +48,7 @@ type Validator struct {
|
||||
jwksURL string
|
||||
requiredScope string
|
||||
httpClient *http.Client
|
||||
keys map[string]*rsa.PublicKey
|
||||
keys map[string]interface{}
|
||||
keysMutex sync.RWMutex
|
||||
lastFetch time.Time
|
||||
cacheDuration time.Duration
|
||||
@@ -59,7 +62,7 @@ func NewValidator(jwksURL, requiredScope string) *Validator {
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
keys: make(map[string]*rsa.PublicKey),
|
||||
keys: make(map[string]interface{}),
|
||||
cacheDuration: 1 * time.Hour, // Cache JWKS for 1 hour
|
||||
}
|
||||
}
|
||||
@@ -68,11 +71,6 @@ func NewValidator(jwksURL, requiredScope string) *Validator {
|
||||
func (v *Validator) ValidateToken(tokenString string) (*Claims, error) {
|
||||
// Parse token
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify signing algorithm
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// Get key ID from header
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
@@ -85,7 +83,22 @@ func (v *Validator) ValidateToken(tokenString string) (*Claims, error) {
|
||||
return nil, fmt.Errorf("get public key: %w", err)
|
||||
}
|
||||
|
||||
return publicKey, nil
|
||||
switch token.Method.(type) {
|
||||
case *jwt.SigningMethodRSA, *jwt.SigningMethodRSAPSS:
|
||||
rsaKey, ok := publicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected RSA public key for kid %s", kid)
|
||||
}
|
||||
return rsaKey, nil
|
||||
case *jwt.SigningMethodEd25519:
|
||||
edKey, ok := publicKey.(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected Ed25519 public key for kid %s", kid)
|
||||
}
|
||||
return edKey, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -140,7 +153,7 @@ func (v *Validator) hasRequiredScope(claims *Claims) bool {
|
||||
}
|
||||
|
||||
// getPublicKey retrieves a public key by kid, fetching JWKS if needed
|
||||
func (v *Validator) getPublicKey(kid string) (*rsa.PublicKey, error) {
|
||||
func (v *Validator) getPublicKey(kid string) (interface{}, error) {
|
||||
// Check if cache is expired
|
||||
v.keysMutex.RLock()
|
||||
cacheExpired := time.Since(v.lastFetch) > v.cacheDuration
|
||||
@@ -201,20 +214,30 @@ func (v *Validator) fetchJWKS() error {
|
||||
}
|
||||
|
||||
// Parse and cache all keys
|
||||
newKeys := make(map[string]*rsa.PublicKey)
|
||||
newKeys := make(map[string]interface{})
|
||||
for _, jwk := range jwks.Keys {
|
||||
if jwk.Kty != "RSA" {
|
||||
log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping non-RSA key")
|
||||
continue
|
||||
switch jwk.Kty {
|
||||
case "RSA":
|
||||
publicKey, err := jwk.toRSAPublicKey()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse RSA JWK")
|
||||
continue
|
||||
}
|
||||
newKeys[jwk.Kid] = publicKey
|
||||
case "OKP":
|
||||
if strings.EqualFold(jwk.Crv, "Ed25519") {
|
||||
publicKey, err := jwk.toEd25519PublicKey()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse Ed25519 JWK")
|
||||
continue
|
||||
}
|
||||
newKeys[jwk.Kid] = publicKey
|
||||
} else {
|
||||
log.Warn().Str("kid", jwk.Kid).Str("crv", jwk.Crv).Msg("Skipping unsupported OKP curve")
|
||||
}
|
||||
default:
|
||||
log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping unsupported key type")
|
||||
}
|
||||
|
||||
publicKey, err := jwk.toRSAPublicKey()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse JWK")
|
||||
continue
|
||||
}
|
||||
|
||||
newKeys[jwk.Kid] = publicKey
|
||||
}
|
||||
|
||||
if len(newKeys) == 0 {
|
||||
@@ -261,6 +284,24 @@ func (jwk *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
// toEd25519PublicKey converts a JWK to an Ed25519 public key
|
||||
func (jwk *JWK) toEd25519PublicKey() (ed25519.PublicKey, error) {
|
||||
if jwk.X == "" {
|
||||
return nil, fmt.Errorf("missing x coordinate for Ed25519 key")
|
||||
}
|
||||
|
||||
xBytes, err := base64URLDecode(jwk.X)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode x: %w", err)
|
||||
}
|
||||
|
||||
if len(xBytes) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("invalid Ed25519 public key length: expected %d, got %d", ed25519.PublicKeySize, len(xBytes))
|
||||
}
|
||||
|
||||
return ed25519.PublicKey(xBytes), nil
|
||||
}
|
||||
|
||||
// parseScopes splits a space-separated scope string
|
||||
func parseScopes(scopeString string) []string {
|
||||
if scopeString == "" {
|
||||
|
||||
Reference in New Issue
Block a user