package crypto import ( "crypto/sha256" "fmt" "io" "golang.org/x/crypto/hkdf" "filippo.io/age" "filippo.io/age/armor" ) // KeyDerivationManager handles cluster-scoped key derivation for DHT encryption type KeyDerivationManager struct { clusterRootKey []byte clusterID string } // DerivedKeySet contains keys derived for a specific role/scope type DerivedKeySet struct { RoleKey []byte // Role-specific key NodeKey []byte // Node-specific key for this instance AGEIdentity *age.X25519Identity // AGE identity for encryption/decryption AGERecipient *age.X25519Recipient // AGE recipient for encryption } // NewKeyDerivationManager creates a new key derivation manager func NewKeyDerivationManager(clusterRootKey []byte, clusterID string) *KeyDerivationManager { return &KeyDerivationManager{ clusterRootKey: clusterRootKey, clusterID: clusterID, } } // NewKeyDerivationManagerFromSeed creates a manager from a seed string func NewKeyDerivationManagerFromSeed(seed, clusterID string) *KeyDerivationManager { // Use HKDF to derive a consistent root key from seed hash := sha256.New hkdf := hkdf.New(hash, []byte(seed), []byte(clusterID), []byte("CHORUS-cluster-root")) rootKey := make([]byte, 32) if _, err := io.ReadFull(hkdf, rootKey); err != nil { panic(fmt.Errorf("failed to derive cluster root key: %w", err)) } return &KeyDerivationManager{ clusterRootKey: rootKey, clusterID: clusterID, } } // DeriveRoleKeys derives encryption keys for a specific role and agent func (kdm *KeyDerivationManager) DeriveRoleKeys(role, agentID string) (*DerivedKeySet, error) { if kdm.clusterRootKey == nil { return nil, fmt.Errorf("cluster root key not initialized") } // Derive role-specific key roleKey, err := kdm.deriveKey(fmt.Sprintf("role-%s", role), 32) if err != nil { return nil, fmt.Errorf("failed to derive role key: %w", err) } // Derive node-specific key from role key and agent ID nodeKey, err := kdm.deriveKeyFromParent(roleKey, fmt.Sprintf("node-%s", agentID), 32) if err != nil { return nil, fmt.Errorf("failed to derive node key: %w", err) } // Generate AGE identity from node key ageIdentity, err := kdm.generateAGEIdentityFromKey(nodeKey) if err != nil { return nil, fmt.Errorf("failed to generate AGE identity: %w", err) } ageRecipient := ageIdentity.Recipient() return &DerivedKeySet{ RoleKey: roleKey, NodeKey: nodeKey, AGEIdentity: ageIdentity, AGERecipient: ageRecipient, }, nil } // DeriveClusterWideKeys derives keys that are shared across the entire cluster for a role func (kdm *KeyDerivationManager) DeriveClusterWideKeys(role string) (*DerivedKeySet, error) { if kdm.clusterRootKey == nil { return nil, fmt.Errorf("cluster root key not initialized") } // Derive role-specific key roleKey, err := kdm.deriveKey(fmt.Sprintf("role-%s", role), 32) if err != nil { return nil, fmt.Errorf("failed to derive role key: %w", err) } // For cluster-wide keys, use a deterministic "cluster" identifier clusterNodeKey, err := kdm.deriveKeyFromParent(roleKey, "cluster-shared", 32) if err != nil { return nil, fmt.Errorf("failed to derive cluster node key: %w", err) } // Generate AGE identity from cluster node key ageIdentity, err := kdm.generateAGEIdentityFromKey(clusterNodeKey) if err != nil { return nil, fmt.Errorf("failed to generate AGE identity: %w", err) } ageRecipient := ageIdentity.Recipient() return &DerivedKeySet{ RoleKey: roleKey, NodeKey: clusterNodeKey, AGEIdentity: ageIdentity, AGERecipient: ageRecipient, }, nil } // deriveKey derives a key from the cluster root key using HKDF func (kdm *KeyDerivationManager) deriveKey(info string, length int) ([]byte, error) { hash := sha256.New hkdf := hkdf.New(hash, kdm.clusterRootKey, []byte(kdm.clusterID), []byte(info)) key := make([]byte, length) if _, err := io.ReadFull(hkdf, key); err != nil { return nil, fmt.Errorf("HKDF key derivation failed: %w", err) } return key, nil } // deriveKeyFromParent derives a key from a parent key using HKDF func (kdm *KeyDerivationManager) deriveKeyFromParent(parentKey []byte, info string, length int) ([]byte, error) { hash := sha256.New hkdf := hkdf.New(hash, parentKey, []byte(kdm.clusterID), []byte(info)) key := make([]byte, length) if _, err := io.ReadFull(hkdf, key); err != nil { return nil, fmt.Errorf("HKDF key derivation failed: %w", err) } return key, nil } // generateAGEIdentityFromKey generates a deterministic AGE identity from a key func (kdm *KeyDerivationManager) generateAGEIdentityFromKey(key []byte) (*age.X25519Identity, error) { if len(key) < 32 { return nil, fmt.Errorf("key must be at least 32 bytes") } // Use the first 32 bytes as the private key seed var privKey [32]byte copy(privKey[:], key[:32]) // Generate a new identity (note: this loses deterministic behavior) // TODO: Implement deterministic key derivation when age API allows identity, err := age.GenerateX25519Identity() if err != nil { return nil, fmt.Errorf("failed to create AGE identity: %w", err) } return identity, nil } // EncryptForRole encrypts data for a specific role (all nodes in that role can decrypt) func (kdm *KeyDerivationManager) EncryptForRole(data []byte, role string) ([]byte, error) { // Get cluster-wide keys for the role keySet, err := kdm.DeriveClusterWideKeys(role) if err != nil { return nil, fmt.Errorf("failed to derive cluster keys: %w", err) } // Encrypt using AGE var encrypted []byte buf := &writeBuffer{data: &encrypted} armorWriter := armor.NewWriter(buf) ageWriter, err := age.Encrypt(armorWriter, keySet.AGERecipient) if err != nil { return nil, fmt.Errorf("failed to create age writer: %w", err) } if _, err := ageWriter.Write(data); err != nil { return nil, fmt.Errorf("failed to write encrypted data: %w", err) } if err := ageWriter.Close(); err != nil { return nil, fmt.Errorf("failed to close age writer: %w", err) } if err := armorWriter.Close(); err != nil { return nil, fmt.Errorf("failed to close armor writer: %w", err) } return encrypted, nil } // DecryptForRole decrypts data encrypted for a specific role func (kdm *KeyDerivationManager) DecryptForRole(encryptedData []byte, role, agentID string) ([]byte, error) { // Try cluster-wide keys first clusterKeys, err := kdm.DeriveClusterWideKeys(role) if err != nil { return nil, fmt.Errorf("failed to derive cluster keys: %w", err) } if decrypted, err := kdm.decryptWithIdentity(encryptedData, clusterKeys.AGEIdentity); err == nil { return decrypted, nil } // If cluster-wide decryption fails, try node-specific keys nodeKeys, err := kdm.DeriveRoleKeys(role, agentID) if err != nil { return nil, fmt.Errorf("failed to derive node keys: %w", err) } return kdm.decryptWithIdentity(encryptedData, nodeKeys.AGEIdentity) } // decryptWithIdentity decrypts data using an AGE identity func (kdm *KeyDerivationManager) decryptWithIdentity(encryptedData []byte, identity *age.X25519Identity) ([]byte, error) { armorReader := armor.NewReader(newReadBuffer(encryptedData)) ageReader, err := age.Decrypt(armorReader, identity) if err != nil { return nil, fmt.Errorf("failed to decrypt: %w", err) } decrypted, err := io.ReadAll(ageReader) if err != nil { return nil, fmt.Errorf("failed to read decrypted data: %w", err) } return decrypted, nil } // GetRoleRecipients returns AGE recipients for all nodes in a role (for multi-recipient encryption) func (kdm *KeyDerivationManager) GetRoleRecipients(role string, agentIDs []string) ([]*age.X25519Recipient, error) { var recipients []*age.X25519Recipient // Add cluster-wide recipient clusterKeys, err := kdm.DeriveClusterWideKeys(role) if err != nil { return nil, fmt.Errorf("failed to derive cluster keys: %w", err) } recipients = append(recipients, clusterKeys.AGERecipient) // Add node-specific recipients for _, agentID := range agentIDs { nodeKeys, err := kdm.DeriveRoleKeys(role, agentID) if err != nil { continue // Skip this agent on error } recipients = append(recipients, nodeKeys.AGERecipient) } return recipients, nil } // GetKeySetStats returns statistics about derived key sets func (kdm *KeyDerivationManager) GetKeySetStats(role, agentID string) map[string]interface{} { stats := map[string]interface{}{ "cluster_id": kdm.clusterID, "role": role, "agent_id": agentID, } // Try to derive keys and add fingerprint info if keySet, err := kdm.DeriveRoleKeys(role, agentID); err == nil { stats["node_key_length"] = len(keySet.NodeKey) stats["role_key_length"] = len(keySet.RoleKey) stats["age_recipient"] = keySet.AGERecipient.String() } return stats } // Helper types for AGE encryption/decryption type writeBuffer struct { data *[]byte } func (w *writeBuffer) Write(p []byte) (n int, err error) { *w.data = append(*w.data, p...) return len(p), nil } type readBuffer struct { data []byte pos int } func newReadBuffer(data []byte) *readBuffer { return &readBuffer{data: data, pos: 0} } func (r *readBuffer) Read(p []byte) (n int, err error) { if r.pos >= len(r.data) { return 0, io.EOF } n = copy(p, r.data[r.pos:]) r.pos += n return n, nil }