package ageio import ( "bytes" "os" "path/filepath" "testing" ) // TestGoldenEncryptionRoundTrip validates encryption/decryption with golden test data func TestGoldenEncryptionRoundTrip(t *testing.T) { // Generate test key pair once tmpDir := t.TempDir() identityPath, recipientPath, err := GenerateTestKeyPair(tmpDir) if err != nil { t.Fatalf("generate test key pair: %v", err) } // Create encryptor and decryptor enc, err := NewEncryptor(recipientPath) if err != nil { t.Fatalf("create encryptor: %v", err) } dec, err := NewDecryptor(identityPath) if err != nil { t.Fatalf("create decryptor: %v", err) } // Golden test cases representing real MCP payloads goldenTests := []struct { name string payload []byte description string }{ { name: "sequential_thinking_request", payload: []byte(`{ "tool": "mcp__sequential-thinking__sequentialthinking", "payload": { "thought": "First, I need to analyze the problem by breaking it down into smaller components.", "thoughtNumber": 1, "totalThoughts": 5, "nextThoughtNeeded": true, "isRevision": false } }`), description: "Initial sequential thinking request", }, { name: "sequential_thinking_revision", payload: []byte(`{ "tool": "mcp__sequential-thinking__sequentialthinking", "payload": { "thought": "Wait, I need to revise my previous thought - I missed considering edge cases.", "thoughtNumber": 3, "totalThoughts": 6, "nextThoughtNeeded": true, "isRevision": true, "revisesThought": 2 } }`), description: "Revision of previous thought", }, { name: "sequential_thinking_branching", payload: []byte(`{ "tool": "mcp__sequential-thinking__sequentialthinking", "payload": { "thought": "Let me explore an alternative approach using event sourcing instead.", "thoughtNumber": 4, "totalThoughts": 8, "nextThoughtNeeded": true, "branchFromThought": 2, "branchId": "alternative-approach-1" } }`), description: "Branching to explore alternative", }, { name: "sequential_thinking_final", payload: []byte(`{ "tool": "mcp__sequential-thinking__sequentialthinking", "payload": { "thought": "Based on all previous analysis, I recommend implementing the event sourcing pattern with CQRS for optimal scalability.", "thoughtNumber": 8, "totalThoughts": 8, "nextThoughtNeeded": false, "confidence": 0.85 } }`), description: "Final thought with conclusion", }, { name: "large_context_payload", payload: bytes.Repeat([]byte(`{"key": "value", "data": "ABCDEFGHIJ"}`), 100), description: "Large payload testing encryption of substantial data", }, { name: "unicode_payload", payload: []byte(`{ "tool": "mcp__sequential-thinking__sequentialthinking", "payload": { "thought": "分析日本語でのデータ処理 🌸🎌 and mixed language content: 你好世界", "thoughtNumber": 1, "totalThoughts": 1, "nextThoughtNeeded": false } }`), description: "Unicode and emoji content", }, { name: "special_characters", payload: []byte(`{ "tool": "test", "payload": { "special": "Testing: \n\t\r\b\"'\\\/\u0000\u001f", "symbols": "!@#$%^&*()_+-=[]{}|;:,.<>?~" } }`), description: "Special characters and escape sequences", }, } for _, gt := range goldenTests { t.Run(gt.name, func(t *testing.T) { t.Logf("Testing: %s", gt.description) t.Logf("Original size: %d bytes", len(gt.payload)) // Encrypt ciphertext, err := enc.Encrypt(gt.payload) if err != nil { t.Fatalf("encrypt failed: %v", err) } t.Logf("Encrypted size: %d bytes (%.1f%% overhead)", len(ciphertext), float64(len(ciphertext)-len(gt.payload))/float64(len(gt.payload))*100) // Verify ciphertext is different from plaintext if bytes.Equal(ciphertext, gt.payload) { t.Fatal("ciphertext equals plaintext - encryption failed") } // Verify ciphertext doesn't contain plaintext patterns // (basic sanity check - not cryptographically rigorous) if bytes.Contains(ciphertext, []byte("mcp__sequential-thinking")) { t.Error("ciphertext contains plaintext patterns - weak encryption") } // Decrypt decrypted, err := dec.Decrypt(ciphertext) if err != nil { t.Fatalf("decrypt failed: %v", err) } // Verify perfect round-trip if !bytes.Equal(decrypted, gt.payload) { t.Errorf("decrypted data doesn't match original\nOriginal: %s\nDecrypted: %s", string(gt.payload), string(decrypted)) } // Optional: Save golden files for inspection if os.Getenv("SAVE_GOLDEN") == "1" { goldenDir := filepath.Join(tmpDir, "golden") os.MkdirAll(goldenDir, 0755) plainPath := filepath.Join(goldenDir, gt.name+".plain.json") encPath := filepath.Join(goldenDir, gt.name+".encrypted.age") os.WriteFile(plainPath, gt.payload, 0644) os.WriteFile(encPath, ciphertext, 0644) t.Logf("Golden files saved to: %s", goldenDir) } }) } } // TestGoldenDecryptionFailures validates proper error handling func TestGoldenDecryptionFailures(t *testing.T) { tmpDir := t.TempDir() identityPath, recipientPath, err := GenerateTestKeyPair(tmpDir) if err != nil { t.Fatalf("generate test key pair: %v", err) } dec, err := NewDecryptor(identityPath) if err != nil { t.Fatalf("create decryptor: %v", err) } enc, err := NewEncryptor(recipientPath) if err != nil { t.Fatalf("create encryptor: %v", err) } failureTests := []struct { name string ciphertext []byte expectError string }{ { name: "empty_ciphertext", ciphertext: []byte{}, expectError: "ciphertext is empty", }, { name: "invalid_age_format", ciphertext: []byte("not a valid age ciphertext"), expectError: "create decryptor", }, { name: "corrupted_header", ciphertext: []byte("-----BEGIN AGE ENCRYPTED FILE-----\ngarbage\n-----END AGE ENCRYPTED FILE-----"), expectError: "create decryptor", }, } for _, ft := range failureTests { t.Run(ft.name, func(t *testing.T) { _, err := dec.Decrypt(ft.ciphertext) if err == nil { t.Fatal("expected error but got none") } // Just verify we got an error - specific error messages may vary t.Logf("Got expected error: %v", err) }) } // Test truncated ciphertext t.Run("truncated_ciphertext", func(t *testing.T) { // Create valid ciphertext validPlaintext := []byte("test message") validCiphertext, err := enc.Encrypt(validPlaintext) if err != nil { t.Fatalf("encrypt: %v", err) } // Truncate it truncated := validCiphertext[:len(validCiphertext)/2] // Try to decrypt _, err = dec.Decrypt(truncated) if err == nil { t.Fatal("expected error decrypting truncated ciphertext") } t.Logf("Got expected error for truncated ciphertext: %v", err) }) // Test modified ciphertext t.Run("modified_ciphertext", func(t *testing.T) { // Create valid ciphertext validPlaintext := []byte("test message") validCiphertext, err := enc.Encrypt(validPlaintext) if err != nil { t.Fatalf("encrypt: %v", err) } // Flip a bit in the middle modified := make([]byte, len(validCiphertext)) copy(modified, validCiphertext) modified[len(modified)/2] ^= 0x01 // Try to decrypt _, err = dec.Decrypt(modified) if err == nil { t.Fatal("expected error decrypting modified ciphertext") } t.Logf("Got expected error for modified ciphertext: %v", err) }) } // BenchmarkEncryption benchmarks encryption performance func BenchmarkEncryption(b *testing.B) { tmpDir := b.TempDir() _, recipientPath, err := GenerateTestKeyPair(tmpDir) if err != nil { b.Fatalf("generate test key pair: %v", err) } enc, err := NewEncryptor(recipientPath) if err != nil { b.Fatalf("create encryptor: %v", err) } payloads := map[string][]byte{ "small_1KB": bytes.Repeat([]byte("A"), 1024), "medium_10KB": bytes.Repeat([]byte("A"), 10*1024), "large_100KB": bytes.Repeat([]byte("A"), 100*1024), } for name, payload := range payloads { b.Run(name, func(b *testing.B) { b.SetBytes(int64(len(payload))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := enc.Encrypt(payload) if err != nil { b.Fatalf("encrypt: %v", err) } } }) } } // BenchmarkDecryption benchmarks decryption performance func BenchmarkDecryption(b *testing.B) { tmpDir := b.TempDir() identityPath, recipientPath, err := GenerateTestKeyPair(tmpDir) if err != nil { b.Fatalf("generate test key pair: %v", err) } enc, err := NewEncryptor(recipientPath) if err != nil { b.Fatalf("create encryptor: %v", err) } dec, err := NewDecryptor(identityPath) if err != nil { b.Fatalf("create decryptor: %v", err) } payloads := map[string][]byte{ "small_1KB": bytes.Repeat([]byte("A"), 1024), "medium_10KB": bytes.Repeat([]byte("A"), 10*1024), "large_100KB": bytes.Repeat([]byte("A"), 100*1024), } for name, payload := range payloads { // Pre-encrypt ciphertext, err := enc.Encrypt(payload) if err != nil { b.Fatalf("encrypt: %v", err) } b.Run(name, func(b *testing.B) { b.SetBytes(int64(len(payload))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := dec.Decrypt(ciphertext) if err != nil { b.Fatalf("decrypt: %v", err) } } }) } }