package crypto import ( "crypto/rand" "encoding/base64" "fmt" "math/big" "github.com/anthonyrawlins/bzzz/pkg/config" ) // ShamirSecretSharing implements Shamir's Secret Sharing algorithm for Age keys type ShamirSecretSharing struct { threshold int totalShares int } // NewShamirSecretSharing creates a new Shamir secret sharing instance func NewShamirSecretSharing(threshold, totalShares int) (*ShamirSecretSharing, error) { if threshold <= 0 || totalShares <= 0 { return nil, fmt.Errorf("threshold and total shares must be positive") } if threshold > totalShares { return nil, fmt.Errorf("threshold cannot be greater than total shares") } if totalShares > 255 { return nil, fmt.Errorf("total shares cannot exceed 255") } return &ShamirSecretSharing{ threshold: threshold, totalShares: totalShares, }, nil } // Share represents a single share of a secret type Share struct { Index int `json:"index"` Value string `json:"value"` // Base64 encoded } // SplitSecret splits an Age private key into shares using Shamir's Secret Sharing func (sss *ShamirSecretSharing) SplitSecret(secret string) ([]Share, error) { if secret == "" { return nil, fmt.Errorf("secret cannot be empty") } secretBytes := []byte(secret) shares := make([]Share, sss.totalShares) // Create polynomial coefficients (random except first one which is the secret) coefficients := make([]*big.Int, sss.threshold) // The constant term is the secret (split into chunks if needed) // For simplicity, we'll work with the secret as a single big integer secretInt := new(big.Int).SetBytes(secretBytes) coefficients[0] = secretInt // Generate random coefficients for the polynomial prime := getPrime257() // Use 257-bit prime for security for i := 1; i < sss.threshold; i++ { coeff, err := rand.Int(rand.Reader, prime) if err != nil { return nil, fmt.Errorf("failed to generate random coefficient: %w", err) } coefficients[i] = coeff } // Generate shares by evaluating polynomial at different points for i := 0; i < sss.totalShares; i++ { x := big.NewInt(int64(i + 1)) // x values from 1 to totalShares y := evaluatePolynomial(coefficients, x, prime) // Encode the share shareData := encodeShare(x, y) shareValue := base64.StdEncoding.EncodeToString(shareData) shares[i] = Share{ Index: i + 1, Value: shareValue, } } return shares, nil } // ReconstructSecret reconstructs the original secret from threshold number of shares func (sss *ShamirSecretSharing) ReconstructSecret(shares []Share) (string, error) { if len(shares) < sss.threshold { return "", fmt.Errorf("need at least %d shares to reconstruct secret, got %d", sss.threshold, len(shares)) } // Use only the first threshold number of shares useShares := shares[:sss.threshold] points := make([]Point, len(useShares)) prime := getPrime257() // Decode shares for i, share := range useShares { shareData, err := base64.StdEncoding.DecodeString(share.Value) if err != nil { return "", fmt.Errorf("failed to decode share %d: %w", share.Index, err) } x, y, err := decodeShare(shareData) if err != nil { return "", fmt.Errorf("failed to parse share %d: %w", share.Index, err) } points[i] = Point{X: x, Y: y} } // Use Lagrange interpolation to reconstruct the secret (polynomial at x=0) secret := lagrangeInterpolation(points, big.NewInt(0), prime) // Convert back to string secretBytes := secret.Bytes() return string(secretBytes), nil } // Point represents a point on the polynomial type Point struct { X, Y *big.Int } // evaluatePolynomial evaluates polynomial at given x func evaluatePolynomial(coefficients []*big.Int, x, prime *big.Int) *big.Int { result := big.NewInt(0) xPower := big.NewInt(1) // x^0 = 1 for _, coeff := range coefficients { // result += coeff * x^power term := new(big.Int).Mul(coeff, xPower) result.Add(result, term) result.Mod(result, prime) // Update x^power for next iteration xPower.Mul(xPower, x) xPower.Mod(xPower, prime) } return result } // lagrangeInterpolation reconstructs the polynomial value at target x using Lagrange interpolation func lagrangeInterpolation(points []Point, targetX, prime *big.Int) *big.Int { result := big.NewInt(0) for i := 0; i < len(points); i++ { // Calculate Lagrange basis polynomial L_i(targetX) numerator := big.NewInt(1) denominator := big.NewInt(1) for j := 0; j < len(points); j++ { if i != j { // numerator *= (targetX - points[j].X) temp := new(big.Int).Sub(targetX, points[j].X) numerator.Mul(numerator, temp) numerator.Mod(numerator, prime) // denominator *= (points[i].X - points[j].X) temp = new(big.Int).Sub(points[i].X, points[j].X) denominator.Mul(denominator, temp) denominator.Mod(denominator, prime) } } // Calculate modular inverse of denominator denominatorInv := modularInverse(denominator, prime) // L_i(targetX) = numerator / denominator = numerator * denominatorInv lagrangeBasis := new(big.Int).Mul(numerator, denominatorInv) lagrangeBasis.Mod(lagrangeBasis, prime) // Add points[i].Y * L_i(targetX) to result term := new(big.Int).Mul(points[i].Y, lagrangeBasis) result.Add(result, term) result.Mod(result, prime) } return result } // modularInverse calculates the modular multiplicative inverse func modularInverse(a, m *big.Int) *big.Int { return new(big.Int).ModInverse(a, m) } // encodeShare encodes x,y coordinates into bytes func encodeShare(x, y *big.Int) []byte { xBytes := x.Bytes() yBytes := y.Bytes() // Simple encoding: [x_length][x_bytes][y_bytes] result := make([]byte, 0, 1+len(xBytes)+len(yBytes)) result = append(result, byte(len(xBytes))) result = append(result, xBytes...) result = append(result, yBytes...) return result } // decodeShare decodes bytes back into x,y coordinates func decodeShare(data []byte) (*big.Int, *big.Int, error) { if len(data) < 2 { return nil, nil, fmt.Errorf("share data too short") } xLength := int(data[0]) if len(data) < 1+xLength { return nil, nil, fmt.Errorf("invalid share data") } xBytes := data[1 : 1+xLength] yBytes := data[1+xLength:] x := new(big.Int).SetBytes(xBytes) y := new(big.Int).SetBytes(yBytes) return x, y, nil } // getPrime257 returns a large prime number for the finite field func getPrime257() *big.Int { // Using a well-known 257-bit prime primeStr := "208351617316091241234326746312124448251235562226470491514186331217050270460481" prime, _ := new(big.Int).SetString(primeStr, 10) return prime } // AdminKeyManager manages admin key reconstruction using Shamir shares type AdminKeyManager struct { config *config.Config nodeID string nodeShare *config.ShamirShare } // NewAdminKeyManager creates a new admin key manager func NewAdminKeyManager(cfg *config.Config, nodeID string) *AdminKeyManager { return &AdminKeyManager{ config: cfg, nodeID: nodeID, } } // SetNodeShare sets this node's Shamir share func (akm *AdminKeyManager) SetNodeShare(share *config.ShamirShare) { akm.nodeShare = share } // GetNodeShare returns this node's Shamir share func (akm *AdminKeyManager) GetNodeShare() *config.ShamirShare { return akm.nodeShare } // ReconstructAdminKey reconstructs the admin private key from collected shares func (akm *AdminKeyManager) ReconstructAdminKey(shares []config.ShamirShare) (string, error) { if len(shares) < akm.config.Security.AdminKeyShares.Threshold { return "", fmt.Errorf("insufficient shares: need %d, have %d", akm.config.Security.AdminKeyShares.Threshold, len(shares)) } // Convert config shares to crypto shares cryptoShares := make([]Share, len(shares)) for i, share := range shares { cryptoShares[i] = Share{ Index: share.Index, Value: share.Share, } } // Create Shamir instance with config parameters sss, err := NewShamirSecretSharing( akm.config.Security.AdminKeyShares.Threshold, akm.config.Security.AdminKeyShares.TotalShares, ) if err != nil { return "", fmt.Errorf("failed to create Shamir instance: %w", err) } // Reconstruct the secret return sss.ReconstructSecret(cryptoShares) } // SplitAdminKey splits an admin private key into Shamir shares func (akm *AdminKeyManager) SplitAdminKey(adminPrivateKey string) ([]config.ShamirShare, error) { // Create Shamir instance with config parameters sss, err := NewShamirSecretSharing( akm.config.Security.AdminKeyShares.Threshold, akm.config.Security.AdminKeyShares.TotalShares, ) if err != nil { return nil, fmt.Errorf("failed to create Shamir instance: %w", err) } // Split the secret shares, err := sss.SplitSecret(adminPrivateKey) if err != nil { return nil, fmt.Errorf("failed to split admin key: %w", err) } // Convert to config shares configShares := make([]config.ShamirShare, len(shares)) for i, share := range shares { configShares[i] = config.ShamirShare{ Index: share.Index, Share: share.Value, Threshold: akm.config.Security.AdminKeyShares.Threshold, TotalShares: akm.config.Security.AdminKeyShares.TotalShares, } } return configShares, nil } // ValidateShare validates a Shamir share func (akm *AdminKeyManager) ValidateShare(share *config.ShamirShare) error { if share.Index < 1 || share.Index > share.TotalShares { return fmt.Errorf("invalid share index: %d (must be 1-%d)", share.Index, share.TotalShares) } if share.Threshold != akm.config.Security.AdminKeyShares.Threshold { return fmt.Errorf("share threshold mismatch: expected %d, got %d", akm.config.Security.AdminKeyShares.Threshold, share.Threshold) } if share.TotalShares != akm.config.Security.AdminKeyShares.TotalShares { return fmt.Errorf("share total mismatch: expected %d, got %d", akm.config.Security.AdminKeyShares.TotalShares, share.TotalShares) } // Try to decode the share value _, err := base64.StdEncoding.DecodeString(share.Share) if err != nil { return fmt.Errorf("invalid share encoding: %w", err) } return nil } // TestShamirSecretSharing tests the Shamir secret sharing implementation func TestShamirSecretSharing() error { // Test parameters threshold := 3 totalShares := 5 testSecret := "AGE-SECRET-KEY-1ABCDEF1234567890ABCDEF1234567890ABCDEF1234567890" // Create Shamir instance sss, err := NewShamirSecretSharing(threshold, totalShares) if err != nil { return fmt.Errorf("failed to create Shamir instance: %w", err) } // Split the secret shares, err := sss.SplitSecret(testSecret) if err != nil { return fmt.Errorf("failed to split secret: %w", err) } if len(shares) != totalShares { return fmt.Errorf("expected %d shares, got %d", totalShares, len(shares)) } // Test reconstruction with minimum threshold minShares := shares[:threshold] reconstructed, err := sss.ReconstructSecret(minShares) if err != nil { return fmt.Errorf("failed to reconstruct secret: %w", err) } if reconstructed != testSecret { return fmt.Errorf("reconstructed secret doesn't match original") } // Test reconstruction with more than threshold extraShares := shares[:threshold+1] reconstructed2, err := sss.ReconstructSecret(extraShares) if err != nil { return fmt.Errorf("failed to reconstruct secret with extra shares: %w", err) } if reconstructed2 != testSecret { return fmt.Errorf("reconstructed secret with extra shares doesn't match original") } // Test that insufficient shares fail insufficientShares := shares[:threshold-1] _, err = sss.ReconstructSecret(insufficientShares) if err == nil { return fmt.Errorf("expected error with insufficient shares, but got none") } return nil }