package policy import ( "crypto/rand" "crypto/rsa" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "github.com/golang-jwt/jwt/v5" ) // generateTestKeyPair generates an RSA key pair for testing func generateTestKeyPair() (*rsa.PrivateKey, error) { return rsa.GenerateKey(rand.Reader, 2048) } // createTestJWKS creates a test JWKS server func createTestJWKS(t *testing.T, privateKey *rsa.PrivateKey) *httptest.Server { publicKey := &privateKey.PublicKey // Create JWK from public key jwk := JWK{ Kid: "test-key-1", Kty: "RSA", Alg: "RS256", Use: "sig", N: base64URLEncode(publicKey.N.Bytes()), E: base64URLEncode([]byte{1, 0, 1}), // 65537 } jwks := JWKS{ Keys: []JWK{jwk}, } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(jwks) })) return server } // createTestToken creates a test JWT token func createTestToken(privateKey *rsa.PrivateKey, claims *Claims) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token.Header["kid"] = "test-key-1" return token.SignedString(privateKey) } func TestValidateToken(t *testing.T) { // Generate test key pair privateKey, err := generateTestKeyPair() if err != nil { t.Fatalf("generate key pair: %v", err) } // Create test JWKS server jwksServer := createTestJWKS(t, privateKey) defer jwksServer.Close() // Create validator validator := NewValidator(jwksServer.URL, "sequentialthinking.run") // Test valid token t.Run("valid_token", func(t *testing.T) { claims := &Claims{ Subject: "test-user", Scopes: []string{"sequentialthinking.run"}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), }, } tokenString, err := createTestToken(privateKey, claims) if err != nil { t.Fatalf("create token: %v", err) } validatedClaims, err := validator.ValidateToken(tokenString) if err != nil { t.Fatalf("validate token: %v", err) } if validatedClaims.Subject != "test-user" { t.Errorf("wrong subject: got %s, want test-user", validatedClaims.Subject) } }) // Test expired token t.Run("expired_token", func(t *testing.T) { claims := &Claims{ Subject: "test-user", Scopes: []string{"sequentialthinking.run"}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)), }, } tokenString, err := createTestToken(privateKey, claims) if err != nil { t.Fatalf("create token: %v", err) } _, err = validator.ValidateToken(tokenString) if err == nil { t.Fatal("expected error for expired token") } }) // Test missing scope t.Run("missing_scope", func(t *testing.T) { claims := &Claims{ Subject: "test-user", Scopes: []string{"other.scope"}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), }, } tokenString, err := createTestToken(privateKey, claims) if err != nil { t.Fatalf("create token: %v", err) } _, err = validator.ValidateToken(tokenString) if err == nil { t.Fatal("expected error for missing scope") } }) // Test space-separated scopes t.Run("space_separated_scopes", func(t *testing.T) { claims := &Claims{ Subject: "test-user", Scope: "read write sequentialthinking.run admin", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), }, } tokenString, err := createTestToken(privateKey, claims) if err != nil { t.Fatalf("create token: %v", err) } validatedClaims, err := validator.ValidateToken(tokenString) if err != nil { t.Fatalf("validate token: %v", err) } if validatedClaims.Subject != "test-user" { t.Errorf("wrong subject: got %s, want test-user", validatedClaims.Subject) } }) // Test not before t.Run("not_yet_valid", func(t *testing.T) { claims := &Claims{ Subject: "test-user", Scopes: []string{"sequentialthinking.run"}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)), NotBefore: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), }, } tokenString, err := createTestToken(privateKey, claims) if err != nil { t.Fatalf("create token: %v", err) } _, err = validator.ValidateToken(tokenString) if err == nil { t.Fatal("expected error for not-yet-valid token") } }) } func TestJWKSCaching(t *testing.T) { privateKey, err := generateTestKeyPair() if err != nil { t.Fatalf("generate key pair: %v", err) } fetchCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fetchCount++ publicKey := &privateKey.PublicKey jwk := JWK{ Kid: "test-key-1", Kty: "RSA", Alg: "RS256", Use: "sig", N: base64URLEncode(publicKey.N.Bytes()), E: base64URLEncode([]byte{1, 0, 1}), } jwks := JWKS{Keys: []JWK{jwk}} w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(jwks) })) defer server.Close() validator := NewValidator(server.URL, "sequentialthinking.run") validator.cacheDuration = 100 * time.Millisecond // Short cache for testing claims := &Claims{ Subject: "test-user", Scopes: []string{"sequentialthinking.run"}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), }, } tokenString, err := createTestToken(privateKey, claims) if err != nil { t.Fatalf("create token: %v", err) } // First validation - should fetch JWKS _, err = validator.ValidateToken(tokenString) if err != nil { t.Fatalf("validate token: %v", err) } if fetchCount != 1 { t.Errorf("expected 1 fetch, got %d", fetchCount) } // Second validation - should use cache _, err = validator.ValidateToken(tokenString) if err != nil { t.Fatalf("validate token: %v", err) } if fetchCount != 1 { t.Errorf("expected 1 fetch (cached), got %d", fetchCount) } // Wait for cache to expire time.Sleep(150 * time.Millisecond) // Third validation - should fetch again _, err = validator.ValidateToken(tokenString) if err != nil { t.Fatalf("validate token: %v", err) } if fetchCount != 2 { t.Errorf("expected 2 fetches (cache expired), got %d", fetchCount) } } func TestParseScopes(t *testing.T) { tests := []struct { name string input string expected []string }{ { name: "single_scope", input: "read", expected: []string{"read"}, }, { name: "multiple_scopes", input: "read write admin", expected: []string{"read", "write", "admin"}, }, { name: "extra_spaces", input: "read write admin", expected: []string{"read", "write", "admin"}, }, { name: "empty_string", input: "", expected: nil, }, { name: "spaces_only", input: " ", expected: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := parseScopes(tt.input) if len(result) != len(tt.expected) { t.Errorf("wrong length: got %d, want %d", len(result), len(tt.expected)) return } for i, expected := range tt.expected { if result[i] != expected { t.Errorf("scope %d: got %s, want %s", i, result[i], expected) } } }) } } func TestInvalidJWKS(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() validator := NewValidator(server.URL, "sequentialthinking.run") err := validator.RefreshJWKS() if err == nil { t.Fatal("expected error for invalid JWKS server") } } func TestGetCachedKeyCount(t *testing.T) { privateKey, err := generateTestKeyPair() if err != nil { t.Fatalf("generate key pair: %v", err) } jwksServer := createTestJWKS(t, privateKey) defer jwksServer.Close() validator := NewValidator(jwksServer.URL, "sequentialthinking.run") // Initially no keys if count := validator.GetCachedKeyCount(); count != 0 { t.Errorf("expected 0 cached keys initially, got %d", count) } // Refresh JWKS if err := validator.RefreshJWKS(); err != nil { t.Fatalf("refresh JWKS: %v", err) } // Should have 1 key if count := validator.GetCachedKeyCount(); count != 1 { t.Errorf("expected 1 cached key after refresh, got %d", count) } }