diff --git a/go.mod b/go.mod index dc192e4..b4474c1 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/docker/go-connections v0.6.0 github.com/docker/go-units v0.5.0 github.com/go-redis/redis/v8 v8.11.5 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.0 diff --git a/go.sum b/go.sum index 60986e7..0384137 100644 --- a/go.sum +++ b/go.sum @@ -147,6 +147,8 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= diff --git a/pkg/seqthink/policy/jwt.go b/pkg/seqthink/policy/jwt.go new file mode 100644 index 0000000..025c3eb --- /dev/null +++ b/pkg/seqthink/policy/jwt.go @@ -0,0 +1,313 @@ +package policy + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/rs/zerolog/log" +) + +// Claims represents the JWT claims structure +type Claims struct { + Subject string `json:"sub"` + Scopes []string `json:"scopes,omitempty"` + Scope string `json:"scope,omitempty"` // Space-separated scopes + jwt.RegisteredClaims +} + +// JWKS represents a JSON Web Key Set +type JWKS struct { + Keys []JWK `json:"keys"` +} + +// JWK represents a JSON Web Key +type JWK struct { + Kid string `json:"kid"` + Kty string `json:"kty"` + Alg string `json:"alg"` + Use string `json:"use"` + N string `json:"n"` + E string `json:"e"` +} + +// Validator validates JWT tokens +type Validator struct { + jwksURL string + requiredScope string + httpClient *http.Client + keys map[string]*rsa.PublicKey + keysMutex sync.RWMutex + lastFetch time.Time + cacheDuration time.Duration +} + +// NewValidator creates a new JWT validator +func NewValidator(jwksURL, requiredScope string) *Validator { + return &Validator{ + jwksURL: jwksURL, + requiredScope: requiredScope, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + keys: make(map[string]*rsa.PublicKey), + cacheDuration: 1 * time.Hour, // Cache JWKS for 1 hour + } +} + +// ValidateToken validates a JWT token and checks required scopes +func (v *Validator) ValidateToken(tokenString string) (*Claims, error) { + // Parse token + token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + // Verify signing algorithm + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + // Get key ID from header + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("no kid in token header") + } + + // Get public key for this kid + publicKey, err := v.getPublicKey(kid) + if err != nil { + return nil, fmt.Errorf("get public key: %w", err) + } + + return publicKey, nil + }) + + if err != nil { + return nil, fmt.Errorf("parse token: %w", err) + } + + // Extract claims + claims, ok := token.Claims.(*Claims) + if !ok || !token.Valid { + return nil, fmt.Errorf("invalid token claims") + } + + // Validate expiration + if claims.ExpiresAt != nil && claims.ExpiresAt.Before(time.Now()) { + return nil, fmt.Errorf("token expired") + } + + // Validate not before + if claims.NotBefore != nil && claims.NotBefore.After(time.Now()) { + return nil, fmt.Errorf("token not yet valid") + } + + // Check required scope + if v.requiredScope != "" { + if !v.hasRequiredScope(claims) { + return nil, fmt.Errorf("missing required scope: %s", v.requiredScope) + } + } + + return claims, nil +} + +// hasRequiredScope checks if claims contain the required scope +func (v *Validator) hasRequiredScope(claims *Claims) bool { + // Check scopes array + for _, scope := range claims.Scopes { + if scope == v.requiredScope { + return true + } + } + + // Check space-separated scope string (OAuth2 style) + if claims.Scope != "" { + for _, scope := range parseScopes(claims.Scope) { + if scope == v.requiredScope { + return true + } + } + } + + return false +} + +// getPublicKey retrieves a public key by kid, fetching JWKS if needed +func (v *Validator) getPublicKey(kid string) (*rsa.PublicKey, error) { + // Check if cache is expired + v.keysMutex.RLock() + cacheExpired := time.Since(v.lastFetch) > v.cacheDuration + key, keyExists := v.keys[kid] + v.keysMutex.RUnlock() + + // If key exists and cache is not expired, return it + if keyExists && !cacheExpired { + return key, nil + } + + // Need to fetch JWKS (either key not found or cache expired) + if err := v.fetchJWKS(); err != nil { + return nil, fmt.Errorf("fetch JWKS: %w", err) + } + + // Try again after fetch + v.keysMutex.RLock() + defer v.keysMutex.RUnlock() + + if key, ok := v.keys[kid]; ok { + return key, nil + } + + return nil, fmt.Errorf("key not found: %s", kid) +} + +// fetchJWKS fetches and caches the JWKS from the server +func (v *Validator) fetchJWKS() error { + log.Info().Str("url", v.jwksURL).Msg("Fetching JWKS") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", v.jwksURL, nil) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + + resp, err := v.httpClient.Do(req) + if err != nil { + return fmt.Errorf("http request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("JWKS fetch failed: status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response: %w", err) + } + + var jwks JWKS + if err := json.Unmarshal(body, &jwks); err != nil { + return fmt.Errorf("unmarshal JWKS: %w", err) + } + + // Parse and cache all keys + newKeys := make(map[string]*rsa.PublicKey) + for _, jwk := range jwks.Keys { + if jwk.Kty != "RSA" { + log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping non-RSA key") + continue + } + + publicKey, err := jwk.toRSAPublicKey() + if err != nil { + log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse JWK") + continue + } + + newKeys[jwk.Kid] = publicKey + } + + if len(newKeys) == 0 { + return fmt.Errorf("no valid keys found in JWKS") + } + + // Update cache + v.keysMutex.Lock() + v.keys = newKeys + v.lastFetch = time.Now() + v.keysMutex.Unlock() + + log.Info().Int("key_count", len(newKeys)).Msg("JWKS cached successfully") + + return nil +} + +// toRSAPublicKey converts a JWK to an RSA public key +func (jwk *JWK) toRSAPublicKey() (*rsa.PublicKey, error) { + // Decode N (modulus) - use base64 URL encoding without padding + nBytes, err := base64URLDecode(jwk.N) + if err != nil { + return nil, fmt.Errorf("decode N: %w", err) + } + + // Decode E (exponent) + eBytes, err := base64URLDecode(jwk.E) + if err != nil { + return nil, fmt.Errorf("decode E: %w", err) + } + + // Convert E bytes to int + var e int + for _, b := range eBytes { + e = e<<8 | int(b) + } + + // Create RSA public key + publicKey := &rsa.PublicKey{ + N: new(big.Int).SetBytes(nBytes), + E: e, + } + + return publicKey, nil +} + +// parseScopes splits a space-separated scope string +func parseScopes(scopeString string) []string { + if scopeString == "" { + return nil + } + + var scopes []string + current := "" + for _, ch := range scopeString { + if ch == ' ' { + if current != "" { + scopes = append(scopes, current) + current = "" + } + } else { + current += string(ch) + } + } + if current != "" { + scopes = append(scopes, current) + } + + return scopes +} + +// RefreshJWKS forces a refresh of the JWKS cache +func (v *Validator) RefreshJWKS() error { + return v.fetchJWKS() +} + +// GetCachedKeyCount returns the number of cached keys +func (v *Validator) GetCachedKeyCount() int { + v.keysMutex.RLock() + defer v.keysMutex.RUnlock() + return len(v.keys) +} + +// base64URLDecode decodes a base64 URL-encoded string (with or without padding) +func base64URLDecode(s string) ([]byte, error) { + // Add padding if needed + if l := len(s) % 4; l > 0 { + s += strings.Repeat("=", 4-l) + } + return base64.URLEncoding.DecodeString(s) +} + +// base64URLEncode encodes bytes to base64 URL encoding without padding +func base64URLEncode(data []byte) string { + return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=") +} diff --git a/pkg/seqthink/policy/jwt_test.go b/pkg/seqthink/policy/jwt_test.go new file mode 100644 index 0000000..ed34b3f --- /dev/null +++ b/pkg/seqthink/policy/jwt_test.go @@ -0,0 +1,354 @@ +package policy + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// generateTestKeyPair generates an RSA key pair for testing +func generateTestKeyPair() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, 2048) +} + +// createTestJWKS creates a test JWKS server +func createTestJWKS(t *testing.T, privateKey *rsa.PrivateKey) *httptest.Server { + publicKey := &privateKey.PublicKey + + // Create JWK from public key + jwk := JWK{ + Kid: "test-key-1", + Kty: "RSA", + Alg: "RS256", + Use: "sig", + N: base64URLEncode(publicKey.N.Bytes()), + E: base64URLEncode([]byte{1, 0, 1}), // 65537 + } + + jwks := JWKS{ + Keys: []JWK{jwk}, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(jwks) + })) + + return server +} + +// createTestToken creates a test JWT token +func createTestToken(privateKey *rsa.PrivateKey, claims *Claims) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "test-key-1" + return token.SignedString(privateKey) +} + +func TestValidateToken(t *testing.T) { + // Generate test key pair + privateKey, err := generateTestKeyPair() + if err != nil { + t.Fatalf("generate key pair: %v", err) + } + + // Create test JWKS server + jwksServer := createTestJWKS(t, privateKey) + defer jwksServer.Close() + + // Create validator + validator := NewValidator(jwksServer.URL, "sequentialthinking.run") + + // Test valid token + t.Run("valid_token", func(t *testing.T) { + claims := &Claims{ + Subject: "test-user", + Scopes: []string{"sequentialthinking.run"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + + tokenString, err := createTestToken(privateKey, claims) + if err != nil { + t.Fatalf("create token: %v", err) + } + + validatedClaims, err := validator.ValidateToken(tokenString) + if err != nil { + t.Fatalf("validate token: %v", err) + } + + if validatedClaims.Subject != "test-user" { + t.Errorf("wrong subject: got %s, want test-user", validatedClaims.Subject) + } + }) + + // Test expired token + t.Run("expired_token", func(t *testing.T) { + claims := &Claims{ + Subject: "test-user", + Scopes: []string{"sequentialthinking.run"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)), + }, + } + + tokenString, err := createTestToken(privateKey, claims) + if err != nil { + t.Fatalf("create token: %v", err) + } + + _, err = validator.ValidateToken(tokenString) + if err == nil { + t.Fatal("expected error for expired token") + } + }) + + // Test missing scope + t.Run("missing_scope", func(t *testing.T) { + claims := &Claims{ + Subject: "test-user", + Scopes: []string{"other.scope"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + + tokenString, err := createTestToken(privateKey, claims) + if err != nil { + t.Fatalf("create token: %v", err) + } + + _, err = validator.ValidateToken(tokenString) + if err == nil { + t.Fatal("expected error for missing scope") + } + }) + + // Test space-separated scopes + t.Run("space_separated_scopes", func(t *testing.T) { + claims := &Claims{ + Subject: "test-user", + Scope: "read write sequentialthinking.run admin", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + + tokenString, err := createTestToken(privateKey, claims) + if err != nil { + t.Fatalf("create token: %v", err) + } + + validatedClaims, err := validator.ValidateToken(tokenString) + if err != nil { + t.Fatalf("validate token: %v", err) + } + + if validatedClaims.Subject != "test-user" { + t.Errorf("wrong subject: got %s, want test-user", validatedClaims.Subject) + } + }) + + // Test not before + t.Run("not_yet_valid", func(t *testing.T) { + claims := &Claims{ + Subject: "test-user", + Scopes: []string{"sequentialthinking.run"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), + NotBefore: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + + tokenString, err := createTestToken(privateKey, claims) + if err != nil { + t.Fatalf("create token: %v", err) + } + + _, err = validator.ValidateToken(tokenString) + if err == nil { + t.Fatal("expected error for not-yet-valid token") + } + }) +} + +func TestJWKSCaching(t *testing.T) { + privateKey, err := generateTestKeyPair() + if err != nil { + t.Fatalf("generate key pair: %v", err) + } + + fetchCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fetchCount++ + publicKey := &privateKey.PublicKey + + jwk := JWK{ + Kid: "test-key-1", + Kty: "RSA", + Alg: "RS256", + Use: "sig", + N: base64URLEncode(publicKey.N.Bytes()), + E: base64URLEncode([]byte{1, 0, 1}), + } + + jwks := JWKS{Keys: []JWK{jwk}} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(jwks) + })) + defer server.Close() + + validator := NewValidator(server.URL, "sequentialthinking.run") + validator.cacheDuration = 100 * time.Millisecond // Short cache for testing + + claims := &Claims{ + Subject: "test-user", + Scopes: []string{"sequentialthinking.run"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + + tokenString, err := createTestToken(privateKey, claims) + if err != nil { + t.Fatalf("create token: %v", err) + } + + // First validation - should fetch JWKS + _, err = validator.ValidateToken(tokenString) + if err != nil { + t.Fatalf("validate token: %v", err) + } + + if fetchCount != 1 { + t.Errorf("expected 1 fetch, got %d", fetchCount) + } + + // Second validation - should use cache + _, err = validator.ValidateToken(tokenString) + if err != nil { + t.Fatalf("validate token: %v", err) + } + + if fetchCount != 1 { + t.Errorf("expected 1 fetch (cached), got %d", fetchCount) + } + + // Wait for cache to expire + time.Sleep(150 * time.Millisecond) + + // Third validation - should fetch again + _, err = validator.ValidateToken(tokenString) + if err != nil { + t.Fatalf("validate token: %v", err) + } + + if fetchCount != 2 { + t.Errorf("expected 2 fetches (cache expired), got %d", fetchCount) + } +} + +func TestParseScopes(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "single_scope", + input: "read", + expected: []string{"read"}, + }, + { + name: "multiple_scopes", + input: "read write admin", + expected: []string{"read", "write", "admin"}, + }, + { + name: "extra_spaces", + input: "read write admin", + expected: []string{"read", "write", "admin"}, + }, + { + name: "empty_string", + input: "", + expected: nil, + }, + { + name: "spaces_only", + input: " ", + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseScopes(tt.input) + + if len(result) != len(tt.expected) { + t.Errorf("wrong length: got %d, want %d", len(result), len(tt.expected)) + return + } + + for i, expected := range tt.expected { + if result[i] != expected { + t.Errorf("scope %d: got %s, want %s", i, result[i], expected) + } + } + }) + } +} + +func TestInvalidJWKS(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + validator := NewValidator(server.URL, "sequentialthinking.run") + + err := validator.RefreshJWKS() + if err == nil { + t.Fatal("expected error for invalid JWKS server") + } +} + +func TestGetCachedKeyCount(t *testing.T) { + privateKey, err := generateTestKeyPair() + if err != nil { + t.Fatalf("generate key pair: %v", err) + } + + jwksServer := createTestJWKS(t, privateKey) + defer jwksServer.Close() + + validator := NewValidator(jwksServer.URL, "sequentialthinking.run") + + // Initially no keys + if count := validator.GetCachedKeyCount(); count != 0 { + t.Errorf("expected 0 cached keys initially, got %d", count) + } + + // Refresh JWKS + if err := validator.RefreshJWKS(); err != nil { + t.Fatalf("refresh JWKS: %v", err) + } + + // Should have 1 key + if count := validator.GetCachedKeyCount(); count != 1 { + t.Errorf("expected 1 cached key after refresh, got %d", count) + } +} diff --git a/pkg/seqthink/policy/middleware.go b/pkg/seqthink/policy/middleware.go new file mode 100644 index 0000000..94238da --- /dev/null +++ b/pkg/seqthink/policy/middleware.go @@ -0,0 +1,80 @@ +package policy + +import ( + "net/http" + "strings" + + "github.com/rs/zerolog/log" +) + +// AuthMiddleware creates HTTP middleware for JWT authentication +type AuthMiddleware struct { + validator *Validator + policyDenials func() // Metrics callback for policy denials + enforcementEnabled bool +} + +// NewAuthMiddleware creates a new authentication middleware +func NewAuthMiddleware(validator *Validator, policyDenials func()) *AuthMiddleware { + return &AuthMiddleware{ + validator: validator, + policyDenials: policyDenials, + enforcementEnabled: validator != nil, + } +} + +// Wrap wraps an HTTP handler with JWT authentication +func (m *AuthMiddleware) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // If enforcement is disabled, pass through + if !m.enforcementEnabled { + log.Warn().Msg("Policy enforcement disabled - allowing request") + next.ServeHTTP(w, r) + return + } + + // Extract token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + log.Error().Msg("Missing Authorization header") + m.policyDenials() + http.Error(w, "Unauthorized: missing authorization header", http.StatusUnauthorized) + return + } + + // Check Bearer scheme + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || parts[0] != "Bearer" { + log.Error().Str("auth_header", authHeader).Msg("Invalid Authorization header format") + m.policyDenials() + http.Error(w, "Unauthorized: invalid authorization format", http.StatusUnauthorized) + return + } + + tokenString := parts[1] + + // Validate token + claims, err := m.validator.ValidateToken(tokenString) + if err != nil { + log.Error().Err(err).Msg("Token validation failed") + m.policyDenials() + http.Error(w, "Unauthorized: "+err.Error(), http.StatusUnauthorized) + return + } + + log.Info(). + Str("subject", claims.Subject). + Strs("scopes", claims.Scopes). + Msg("Request authorized") + + // Token is valid, pass to next handler + next.ServeHTTP(w, r) + }) +} + +// WrapFunc wraps an HTTP handler function with JWT authentication +func (m *AuthMiddleware) WrapFunc(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + m.Wrap(next).ServeHTTP(w, r) + } +} diff --git a/pkg/seqthink/proxy/server.go b/pkg/seqthink/proxy/server.go index 7da3939..23b6703 100644 --- a/pkg/seqthink/proxy/server.go +++ b/pkg/seqthink/proxy/server.go @@ -10,6 +10,7 @@ import ( "chorus/pkg/seqthink/mcpclient" "chorus/pkg/seqthink/observability" + "chorus/pkg/seqthink/policy" "github.com/gorilla/mux" "github.com/rs/zerolog/log" ) @@ -27,8 +28,9 @@ type ServerConfig struct { // Server is the proxy server handling requests type Server struct { - config ServerConfig - router *mux.Router + config ServerConfig + router *mux.Router + authMiddleware *policy.AuthMiddleware } // NewServer creates a new proxy server @@ -38,6 +40,26 @@ func NewServer(cfg ServerConfig) (*Server, error) { router: mux.NewRouter(), } + // Setup policy enforcement if configured + if cfg.KachingJWKSURL != "" && cfg.RequiredScope != "" { + log.Info(). + Str("jwks_url", cfg.KachingJWKSURL). + Str("required_scope", cfg.RequiredScope). + Msg("Policy enforcement enabled") + + validator := policy.NewValidator(cfg.KachingJWKSURL, cfg.RequiredScope) + + // Pre-fetch JWKS + if err := validator.RefreshJWKS(); err != nil { + log.Warn().Err(err).Msg("Failed to pre-fetch JWKS, will retry on first request") + } + + s.authMiddleware = policy.NewAuthMiddleware(validator, cfg.Metrics.IncrementPolicyDenials) + } else { + log.Warn().Msg("Policy enforcement disabled - no JWKS URL or required scope configured") + s.authMiddleware = policy.NewAuthMiddleware(nil, cfg.Metrics.IncrementPolicyDenials) + } + // Setup routes s.setupRoutes() @@ -51,27 +73,31 @@ func (s *Server) Handler() http.Handler { // setupRoutes configures the HTTP routes func (s *Server) setupRoutes() { - // Health checks + // Health checks (no auth required) s.router.HandleFunc("/health", s.handleHealth).Methods("GET") s.router.HandleFunc("/ready", s.handleReady).Methods("GET") - // MCP tool endpoint - route based on encryption config + // MCP tool endpoint - route based on encryption config, with auth if s.isEncryptionEnabled() { log.Info().Msg("Encryption enabled - using encrypted endpoint") - s.router.HandleFunc("/mcp/tool", s.handleToolCallEncrypted).Methods("POST") + s.router.Handle("/mcp/tool", + s.authMiddleware.Wrap(http.HandlerFunc(s.handleToolCallEncrypted))).Methods("POST") } else { log.Warn().Msg("Encryption disabled - using plaintext endpoint") - s.router.HandleFunc("/mcp/tool", s.handleToolCall).Methods("POST") + s.router.Handle("/mcp/tool", + s.authMiddleware.Wrap(http.HandlerFunc(s.handleToolCall))).Methods("POST") } - // SSE endpoint - route based on encryption config + // SSE endpoint - route based on encryption config, with auth if s.isEncryptionEnabled() { - s.router.HandleFunc("/mcp/sse", s.handleSSEEncrypted).Methods("GET") + s.router.Handle("/mcp/sse", + s.authMiddleware.Wrap(http.HandlerFunc(s.handleSSEEncrypted))).Methods("GET") } else { - s.router.HandleFunc("/mcp/sse", s.handleSSEPlaintext).Methods("GET") + s.router.Handle("/mcp/sse", + s.authMiddleware.Wrap(http.HandlerFunc(s.handleSSEPlaintext))).Methods("GET") } - // Metrics endpoint + // Metrics endpoint (no auth required for internal monitoring) s.router.Handle("/metrics", s.config.Metrics.Handler()) }