Files
BACKBEAT/pkg/sdk/client_test.go
2025-10-17 08:56:25 +11:00

574 lines
15 KiB
Go

package sdk
import (
"context"
"crypto/ed25519"
"crypto/rand"
"fmt"
"testing"
"time"
"log/slog"
"os"
"github.com/nats-io/nats.go"
)
var testCounter int
// generateUniqueAgentID generates unique agent IDs for tests to avoid expvar conflicts
func generateUniqueAgentID(prefix string) string {
testCounter++
return fmt.Sprintf("%s-%d", prefix, testCounter)
}
// TestClient tests basic client creation and configuration
func TestClient(t *testing.T) {
config := DefaultConfig()
config.ClusterID = "test-cluster"
config.AgentID = generateUniqueAgentID("test-agent")
config.NATSUrl = "nats://localhost:4222"
client := NewClient(config)
if client == nil {
t.Fatal("Expected client to be created")
}
// Test health before start
health := client.Health()
if health.Connected {
t.Error("Expected client to be disconnected before start")
}
}
// TestBeatCallbacks tests beat and downbeat callback registration
func TestBeatCallbacks(t *testing.T) {
config := DefaultConfig()
config.ClusterID = "test-cluster"
config.AgentID = generateUniqueAgentID("test-agent-callbacks")
client := NewClient(config)
var beatCalled, downbeatCalled bool
// Register callbacks
err := client.OnBeat(func(beat BeatFrame) {
beatCalled = true
})
if err != nil {
t.Fatalf("Failed to register beat callback: %v", err)
}
err = client.OnDownbeat(func(beat BeatFrame) {
downbeatCalled = true
})
if err != nil {
t.Fatalf("Failed to register downbeat callback: %v", err)
}
// Test nil callback rejection
err = client.OnBeat(nil)
if err == nil {
t.Error("Expected error when registering nil beat callback")
}
err = client.OnDownbeat(nil)
if err == nil {
t.Error("Expected error when registering nil downbeat callback")
}
// Use variables to prevent unused warnings
_ = beatCalled
_ = downbeatCalled
}
// TestStatusClaim tests status claim validation and emission
func TestStatusClaim(t *testing.T) {
_, signingKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("Failed to generate signing key: %v", err)
}
config := DefaultConfig()
config.ClusterID = "test-cluster"
config.AgentID = generateUniqueAgentID("test-agent")
config.SigningKey = signingKey
client := NewClient(config).(*client)
// Test valid status claim
claim := StatusClaim{
State: "executing",
BeatsLeft: 5,
Progress: 0.5,
Notes: "Test status",
}
// Test validation without connection (should work for validation)
client.currentBeat = 1
client.currentHLC = "test-hlc"
// Test auto-population
if claim.AgentID != "" {
t.Error("Expected AgentID to be empty before emission")
}
// Since we can't actually emit without NATS connection, test validation directly
claim.Type = "backbeat.statusclaim.v1"
claim.AgentID = config.AgentID
claim.TaskID = "test-task"
claim.BeatIndex = 1
claim.HLC = "test-hlc"
err = client.validateStatusClaim(&claim)
if err != nil {
t.Errorf("Expected valid status claim to pass validation: %v", err)
}
// Test invalid states
invalidClaim := claim
invalidClaim.State = "invalid-state"
err = client.validateStatusClaim(&invalidClaim)
if err == nil {
t.Error("Expected invalid state to fail validation")
}
// Test invalid progress
invalidClaim = claim
invalidClaim.Progress = 1.5
err = client.validateStatusClaim(&invalidClaim)
if err == nil {
t.Error("Expected invalid progress to fail validation")
}
// Test negative beats left
invalidClaim = claim
invalidClaim.BeatsLeft = -1
err = client.validateStatusClaim(&invalidClaim)
if err == nil {
t.Error("Expected negative beats_left to fail validation")
}
}
// TestBeatBudget tests beat budget functionality
func TestBeatBudget(t *testing.T) {
config := DefaultConfig()
config.ClusterID = "test-cluster"
config.AgentID = generateUniqueAgentID("test-agent")
client := NewClient(config).(*client)
client.currentTempo = 120 // 120 BPM = 0.5 seconds per beat
ctx := context.Background()
client.ctx = ctx
// Test successful execution within budget
executed := false
err := client.WithBeatBudget(2, func() error {
executed = true
time.Sleep(100 * time.Millisecond) // Much less than 2 beats (1 second)
return nil
})
if err != nil {
t.Errorf("Expected function to complete successfully: %v", err)
}
if !executed {
t.Error("Expected function to be executed")
}
// Test timeout (need to be careful with timing)
timeoutErr := client.WithBeatBudget(1, func() error {
time.Sleep(2 * time.Second) // More than 1 beat at 120 BPM (0.5s)
return nil
})
if timeoutErr == nil {
t.Error("Expected function to timeout")
}
if timeoutErr.Error() != "beat budget of 1 beats exceeded" {
t.Errorf("Expected timeout error message, got: %v", timeoutErr)
}
// Test invalid budget
err = client.WithBeatBudget(0, func() error { return nil })
if err == nil {
t.Error("Expected error for zero beat budget")
}
err = client.WithBeatBudget(-1, func() error { return nil })
if err == nil {
t.Error("Expected error for negative beat budget")
}
}
// TestTempoTracking tests tempo tracking and drift calculation
func TestTempoTracking(t *testing.T) {
config := DefaultConfig()
config.ClusterID = "test-cluster"
config.AgentID = generateUniqueAgentID("test-agent")
client := NewClient(config).(*client)
// Test initial values
if client.GetCurrentTempo() != 60 {
t.Errorf("Expected default tempo to be 60, got %d", client.GetCurrentTempo())
}
if client.GetTempoDrift() != 0 {
t.Errorf("Expected initial tempo drift to be 0, got %v", client.GetTempoDrift())
}
// Simulate tempo changes
client.beatMutex.Lock()
client.currentTempo = 120
client.tempoHistory = append(client.tempoHistory, tempoSample{
BeatIndex: 1,
Tempo: 120,
MeasuredTime: time.Now(),
ActualBPM: 118.0, // Slightly slower than expected
})
client.tempoHistory = append(client.tempoHistory, tempoSample{
BeatIndex: 2,
Tempo: 120,
MeasuredTime: time.Now().Add(500 * time.Millisecond),
ActualBPM: 119.0, // Still slightly slower
})
client.beatMutex.Unlock()
if client.GetCurrentTempo() != 120 {
t.Errorf("Expected current tempo to remain at 120 BPM, got %d", client.GetCurrentTempo())
}
// Test drift calculation (should be non-zero due to difference between 120 and measured BPM)
drift := client.GetTempoDrift()
if drift == 0 {
t.Error("Expected non-zero tempo drift")
}
}
// TestLegacyCompatibility tests legacy beat conversion
func TestLegacyCompatibility(t *testing.T) {
config := DefaultConfig()
config.ClusterID = "test-cluster"
config.AgentID = generateUniqueAgentID("test-agent")
client := NewClient(config).(*client)
// Test legacy beat conversion
beatIndex := client.ConvertLegacyBeat(2, 3) // Bar 2, Beat 3
expectedBeatIndex := int64(7) // (2-1)*4 + 3 = 7
if beatIndex != expectedBeatIndex {
t.Errorf("Expected beat index %d, got %d", expectedBeatIndex, beatIndex)
}
// Test reverse conversion
client.beatMutex.Lock()
client.currentBeat = 7
client.beatMutex.Unlock()
legacyInfo := client.GetLegacyBeatInfo()
if legacyInfo.Bar != 2 || legacyInfo.Beat != 3 {
t.Errorf("Expected bar=2, beat=3, got bar=%d, beat=%d", legacyInfo.Bar, legacyInfo.Beat)
}
// Test edge cases
beatIndex = client.ConvertLegacyBeat(1, 1) // First beat
if beatIndex != 1 {
t.Errorf("Expected beat index 1 for first beat, got %d", beatIndex)
}
client.beatMutex.Lock()
client.currentBeat = 0 // Edge case
client.beatMutex.Unlock()
legacyInfo = client.GetLegacyBeatInfo()
if legacyInfo.Bar != 1 || legacyInfo.Beat != 1 {
t.Errorf("Expected bar=1, beat=1 for zero beat, got bar=%d, beat=%d", legacyInfo.Bar, legacyInfo.Beat)
}
}
// TestHealthStatus tests health status reporting
func TestHealthStatus(t *testing.T) {
config := DefaultConfig()
config.ClusterID = "test-cluster"
config.AgentID = generateUniqueAgentID("test-agent")
client := NewClient(config).(*client)
// Test initial health
health := client.Health()
if health.Connected {
t.Error("Expected client to be disconnected initially")
}
if health.LastBeat != 0 {
t.Error("Expected last beat to be 0 initially")
}
if health.CurrentTempo != 60 {
t.Errorf("Expected default tempo 60, got %d", health.CurrentTempo)
}
// Simulate some activity
client.beatMutex.Lock()
client.currentBeat = 10
client.currentTempo = 90
client.lastBeatTime = time.Now().Add(-100 * time.Millisecond)
client.beatMutex.Unlock()
client.addError("test error")
health = client.Health()
if health.LastBeat != 10 {
t.Errorf("Expected last beat to be 10, got %d", health.LastBeat)
}
if health.CurrentTempo != 90 {
t.Errorf("Expected current tempo to be 90, got %d", health.CurrentTempo)
}
if len(health.Errors) != 1 {
t.Errorf("Expected 1 error, got %d", len(health.Errors))
}
if health.TimeDrift <= 0 {
t.Error("Expected positive time drift")
}
}
// TestMetrics tests metrics integration
func TestMetrics(t *testing.T) {
config := DefaultConfig()
config.ClusterID = "test-cluster"
config.AgentID = generateUniqueAgentID("test-agent")
client := NewClient(config).(*client)
if client.metrics == nil {
t.Fatal("Expected metrics to be initialized")
}
// Test metrics snapshot
snapshot := client.metrics.GetMetricsSnapshot()
if snapshot == nil {
t.Error("Expected metrics snapshot to be available")
}
// Check for expected metric keys
expectedKeys := []string{
"connection_status",
"reconnect_count",
"beats_received",
"status_claims_emitted",
"budgets_created",
"total_errors",
}
for _, key := range expectedKeys {
if _, exists := snapshot[key]; !exists {
t.Errorf("Expected metric key '%s' to exist in snapshot", key)
}
}
}
// TestConfig tests configuration validation and defaults
func TestConfig(t *testing.T) {
// Test default config
config := DefaultConfig()
if config.JitterTolerance != 50*time.Millisecond {
t.Errorf("Expected default jitter tolerance 50ms, got %v", config.JitterTolerance)
}
if config.ReconnectDelay != 1*time.Second {
t.Errorf("Expected default reconnect delay 1s, got %v", config.ReconnectDelay)
}
if config.MaxReconnects != -1 {
t.Errorf("Expected default max reconnects -1, got %d", config.MaxReconnects)
}
// Test logger initialization
config.Logger = nil
client := NewClient(config)
if client == nil {
t.Error("Expected client to be created even with nil logger")
}
// Test with custom config
_, signingKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("Failed to generate signing key: %v", err)
}
config.ClusterID = "custom-cluster"
config.AgentID = "custom-agent"
config.SigningKey = signingKey
config.JitterTolerance = 100 * time.Millisecond
config.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
client = NewClient(config)
if client == nil {
t.Error("Expected client to be created with custom config")
}
}
// TestBeatDurationCalculation tests beat duration calculation
func TestBeatDurationCalculation(t *testing.T) {
config := DefaultConfig()
config.ClusterID = "test-cluster"
config.AgentID = generateUniqueAgentID("test-agent")
client := NewClient(config).(*client)
// Test default 60 BPM (1 second per beat)
duration := client.getBeatDuration()
expected := 1000 * time.Millisecond
if duration != expected {
t.Errorf("Expected beat duration %v for 60 BPM, got %v", expected, duration)
}
// Test 120 BPM (0.5 seconds per beat)
client.beatMutex.Lock()
client.currentTempo = 120
client.beatMutex.Unlock()
duration = client.getBeatDuration()
expected = 500 * time.Millisecond
if duration != expected {
t.Errorf("Expected beat duration %v for 120 BPM, got %v", expected, duration)
}
// Test 30 BPM (2 seconds per beat)
client.beatMutex.Lock()
client.currentTempo = 30
client.beatMutex.Unlock()
duration = client.getBeatDuration()
expected = 2000 * time.Millisecond
if duration != expected {
t.Errorf("Expected beat duration %v for 30 BPM, got %v", expected, duration)
}
// Test edge case: zero tempo (should default to 60 BPM)
client.beatMutex.Lock()
client.currentTempo = 0
client.beatMutex.Unlock()
duration = client.getBeatDuration()
expected = 1000 * time.Millisecond
if duration != expected {
t.Errorf("Expected beat duration %v for 0 BPM (default 60), got %v", expected, duration)
}
}
// BenchmarkBeatCallback benchmarks beat callback execution
func BenchmarkBeatCallback(b *testing.B) {
config := DefaultConfig()
config.ClusterID = "bench-cluster"
config.AgentID = "bench-agent"
client := NewClient(config).(*client)
beatFrame := BeatFrame{
Type: "backbeat.beatframe.v1",
ClusterID: "bench-cluster",
BeatIndex: 1,
Downbeat: false,
Phase: "test",
HLC: "test-hlc",
DeadlineAt: time.Now().Add(time.Second),
TempoBPM: 60,
WindowID: "test-window",
}
callbackCount := 0
client.OnBeat(func(beat BeatFrame) {
callbackCount++
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
client.safeExecuteCallback(client.beatCallbacks[0], beatFrame, "beat")
}
if callbackCount != b.N {
b.Errorf("Expected callback to be called %d times, got %d", b.N, callbackCount)
}
}
// BenchmarkStatusClaimValidation benchmarks status claim validation
func BenchmarkStatusClaimValidation(b *testing.B) {
config := DefaultConfig()
config.ClusterID = "bench-cluster"
config.AgentID = "bench-agent"
client := NewClient(config).(*client)
claim := StatusClaim{
Type: "backbeat.statusclaim.v1",
AgentID: "bench-agent",
TaskID: "bench-task",
BeatIndex: 1,
State: "executing",
BeatsLeft: 5,
Progress: 0.5,
Notes: "Benchmark test",
HLC: "bench-hlc",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := client.validateStatusClaim(&claim)
if err != nil {
b.Fatal(err)
}
}
}
// Mock NATS server for integration tests (if needed)
func setupTestNATSServer(t *testing.T) *nats.Conn {
// This would start an embedded NATS server for testing
// For now, we'll skip tests that require NATS if it's not available
nc, err := nats.Connect(nats.DefaultURL)
if err != nil {
t.Skipf("NATS server not available: %v", err)
return nil
}
return nc
}
func TestIntegrationWithNATS(t *testing.T) {
nc := setupTestNATSServer(t)
if nc == nil {
return // Skipped
}
defer nc.Close()
config := DefaultConfig()
config.ClusterID = "integration-test"
config.AgentID = generateUniqueAgentID("test-agent")
config.NATSUrl = nats.DefaultURL
client := NewClient(config)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Test start/stop cycle
err := client.Start(ctx)
if err != nil {
t.Fatalf("Failed to start client: %v", err)
}
// Check health after start
health := client.Health()
if !health.Connected {
t.Error("Expected client to be connected after start")
}
// Test stop
err = client.Stop()
if err != nil {
t.Errorf("Failed to stop client: %v", err)
}
// Check health after stop
health = client.Health()
if health.Connected {
t.Error("Expected client to be disconnected after stop")
}
}