backbeat: add module sources
This commit is contained in:
573
pkg/sdk/client_test.go
Normal file
573
pkg/sdk/client_test.go
Normal file
@@ -0,0 +1,573 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
var testCounter int
|
||||
|
||||
// generateUniqueAgentID generates unique agent IDs for tests to avoid expvar conflicts
|
||||
func generateUniqueAgentID(prefix string) string {
|
||||
testCounter++
|
||||
return fmt.Sprintf("%s-%d", prefix, testCounter)
|
||||
}
|
||||
|
||||
// TestClient tests basic client creation and configuration
|
||||
func TestClient(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "test-cluster"
|
||||
config.AgentID = generateUniqueAgentID("test-agent")
|
||||
config.NATSUrl = "nats://localhost:4222"
|
||||
|
||||
client := NewClient(config)
|
||||
if client == nil {
|
||||
t.Fatal("Expected client to be created")
|
||||
}
|
||||
|
||||
// Test health before start
|
||||
health := client.Health()
|
||||
if health.Connected {
|
||||
t.Error("Expected client to be disconnected before start")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBeatCallbacks tests beat and downbeat callback registration
|
||||
func TestBeatCallbacks(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "test-cluster"
|
||||
config.AgentID = generateUniqueAgentID("test-agent-callbacks")
|
||||
|
||||
client := NewClient(config)
|
||||
|
||||
var beatCalled, downbeatCalled bool
|
||||
|
||||
// Register callbacks
|
||||
err := client.OnBeat(func(beat BeatFrame) {
|
||||
beatCalled = true
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register beat callback: %v", err)
|
||||
}
|
||||
|
||||
err = client.OnDownbeat(func(beat BeatFrame) {
|
||||
downbeatCalled = true
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register downbeat callback: %v", err)
|
||||
}
|
||||
|
||||
// Test nil callback rejection
|
||||
err = client.OnBeat(nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when registering nil beat callback")
|
||||
}
|
||||
|
||||
err = client.OnDownbeat(nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when registering nil downbeat callback")
|
||||
}
|
||||
|
||||
// Use variables to prevent unused warnings
|
||||
_ = beatCalled
|
||||
_ = downbeatCalled
|
||||
}
|
||||
|
||||
// TestStatusClaim tests status claim validation and emission
|
||||
func TestStatusClaim(t *testing.T) {
|
||||
_, signingKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate signing key: %v", err)
|
||||
}
|
||||
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "test-cluster"
|
||||
config.AgentID = generateUniqueAgentID("test-agent")
|
||||
config.SigningKey = signingKey
|
||||
|
||||
client := NewClient(config).(*client)
|
||||
|
||||
// Test valid status claim
|
||||
claim := StatusClaim{
|
||||
State: "executing",
|
||||
BeatsLeft: 5,
|
||||
Progress: 0.5,
|
||||
Notes: "Test status",
|
||||
}
|
||||
|
||||
// Test validation without connection (should work for validation)
|
||||
client.currentBeat = 1
|
||||
client.currentHLC = "test-hlc"
|
||||
|
||||
// Test auto-population
|
||||
if claim.AgentID != "" {
|
||||
t.Error("Expected AgentID to be empty before emission")
|
||||
}
|
||||
|
||||
// Since we can't actually emit without NATS connection, test validation directly
|
||||
claim.Type = "backbeat.statusclaim.v1"
|
||||
claim.AgentID = config.AgentID
|
||||
claim.TaskID = "test-task"
|
||||
claim.BeatIndex = 1
|
||||
claim.HLC = "test-hlc"
|
||||
|
||||
err = client.validateStatusClaim(&claim)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid status claim to pass validation: %v", err)
|
||||
}
|
||||
|
||||
// Test invalid states
|
||||
invalidClaim := claim
|
||||
invalidClaim.State = "invalid-state"
|
||||
err = client.validateStatusClaim(&invalidClaim)
|
||||
if err == nil {
|
||||
t.Error("Expected invalid state to fail validation")
|
||||
}
|
||||
|
||||
// Test invalid progress
|
||||
invalidClaim = claim
|
||||
invalidClaim.Progress = 1.5
|
||||
err = client.validateStatusClaim(&invalidClaim)
|
||||
if err == nil {
|
||||
t.Error("Expected invalid progress to fail validation")
|
||||
}
|
||||
|
||||
// Test negative beats left
|
||||
invalidClaim = claim
|
||||
invalidClaim.BeatsLeft = -1
|
||||
err = client.validateStatusClaim(&invalidClaim)
|
||||
if err == nil {
|
||||
t.Error("Expected negative beats_left to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBeatBudget tests beat budget functionality
|
||||
func TestBeatBudget(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "test-cluster"
|
||||
config.AgentID = generateUniqueAgentID("test-agent")
|
||||
|
||||
client := NewClient(config).(*client)
|
||||
client.currentTempo = 120 // 120 BPM = 0.5 seconds per beat
|
||||
|
||||
ctx := context.Background()
|
||||
client.ctx = ctx
|
||||
|
||||
// Test successful execution within budget
|
||||
executed := false
|
||||
err := client.WithBeatBudget(2, func() error {
|
||||
executed = true
|
||||
time.Sleep(100 * time.Millisecond) // Much less than 2 beats (1 second)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected function to complete successfully: %v", err)
|
||||
}
|
||||
if !executed {
|
||||
t.Error("Expected function to be executed")
|
||||
}
|
||||
|
||||
// Test timeout (need to be careful with timing)
|
||||
timeoutErr := client.WithBeatBudget(1, func() error {
|
||||
time.Sleep(2 * time.Second) // More than 1 beat at 120 BPM (0.5s)
|
||||
return nil
|
||||
})
|
||||
|
||||
if timeoutErr == nil {
|
||||
t.Error("Expected function to timeout")
|
||||
}
|
||||
if timeoutErr.Error() != "beat budget of 1 beats exceeded" {
|
||||
t.Errorf("Expected timeout error message, got: %v", timeoutErr)
|
||||
}
|
||||
|
||||
// Test invalid budget
|
||||
err = client.WithBeatBudget(0, func() error { return nil })
|
||||
if err == nil {
|
||||
t.Error("Expected error for zero beat budget")
|
||||
}
|
||||
|
||||
err = client.WithBeatBudget(-1, func() error { return nil })
|
||||
if err == nil {
|
||||
t.Error("Expected error for negative beat budget")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTempoTracking tests tempo tracking and drift calculation
|
||||
func TestTempoTracking(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "test-cluster"
|
||||
config.AgentID = generateUniqueAgentID("test-agent")
|
||||
|
||||
client := NewClient(config).(*client)
|
||||
|
||||
// Test initial values
|
||||
if client.GetCurrentTempo() != 60 {
|
||||
t.Errorf("Expected default tempo to be 60, got %d", client.GetCurrentTempo())
|
||||
}
|
||||
|
||||
if client.GetTempoDrift() != 0 {
|
||||
t.Errorf("Expected initial tempo drift to be 0, got %v", client.GetTempoDrift())
|
||||
}
|
||||
|
||||
// Simulate tempo changes
|
||||
client.beatMutex.Lock()
|
||||
client.currentTempo = 120
|
||||
client.tempoHistory = append(client.tempoHistory, tempoSample{
|
||||
BeatIndex: 1,
|
||||
Tempo: 120,
|
||||
MeasuredTime: time.Now(),
|
||||
ActualBPM: 118.0, // Slightly slower than expected
|
||||
})
|
||||
client.tempoHistory = append(client.tempoHistory, tempoSample{
|
||||
BeatIndex: 2,
|
||||
Tempo: 120,
|
||||
MeasuredTime: time.Now().Add(500 * time.Millisecond),
|
||||
ActualBPM: 119.0, // Still slightly slower
|
||||
})
|
||||
client.beatMutex.Unlock()
|
||||
|
||||
if client.GetCurrentTempo() != 120 {
|
||||
t.Errorf("Expected current tempo to remain at 120 BPM, got %d", client.GetCurrentTempo())
|
||||
}
|
||||
|
||||
// Test drift calculation (should be non-zero due to difference between 120 and measured BPM)
|
||||
drift := client.GetTempoDrift()
|
||||
if drift == 0 {
|
||||
t.Error("Expected non-zero tempo drift")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLegacyCompatibility tests legacy beat conversion
|
||||
func TestLegacyCompatibility(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "test-cluster"
|
||||
config.AgentID = generateUniqueAgentID("test-agent")
|
||||
|
||||
client := NewClient(config).(*client)
|
||||
|
||||
// Test legacy beat conversion
|
||||
beatIndex := client.ConvertLegacyBeat(2, 3) // Bar 2, Beat 3
|
||||
expectedBeatIndex := int64(7) // (2-1)*4 + 3 = 7
|
||||
if beatIndex != expectedBeatIndex {
|
||||
t.Errorf("Expected beat index %d, got %d", expectedBeatIndex, beatIndex)
|
||||
}
|
||||
|
||||
// Test reverse conversion
|
||||
client.beatMutex.Lock()
|
||||
client.currentBeat = 7
|
||||
client.beatMutex.Unlock()
|
||||
|
||||
legacyInfo := client.GetLegacyBeatInfo()
|
||||
if legacyInfo.Bar != 2 || legacyInfo.Beat != 3 {
|
||||
t.Errorf("Expected bar=2, beat=3, got bar=%d, beat=%d", legacyInfo.Bar, legacyInfo.Beat)
|
||||
}
|
||||
|
||||
// Test edge cases
|
||||
beatIndex = client.ConvertLegacyBeat(1, 1) // First beat
|
||||
if beatIndex != 1 {
|
||||
t.Errorf("Expected beat index 1 for first beat, got %d", beatIndex)
|
||||
}
|
||||
|
||||
client.beatMutex.Lock()
|
||||
client.currentBeat = 0 // Edge case
|
||||
client.beatMutex.Unlock()
|
||||
|
||||
legacyInfo = client.GetLegacyBeatInfo()
|
||||
if legacyInfo.Bar != 1 || legacyInfo.Beat != 1 {
|
||||
t.Errorf("Expected bar=1, beat=1 for zero beat, got bar=%d, beat=%d", legacyInfo.Bar, legacyInfo.Beat)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHealthStatus tests health status reporting
|
||||
func TestHealthStatus(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "test-cluster"
|
||||
config.AgentID = generateUniqueAgentID("test-agent")
|
||||
|
||||
client := NewClient(config).(*client)
|
||||
|
||||
// Test initial health
|
||||
health := client.Health()
|
||||
if health.Connected {
|
||||
t.Error("Expected client to be disconnected initially")
|
||||
}
|
||||
if health.LastBeat != 0 {
|
||||
t.Error("Expected last beat to be 0 initially")
|
||||
}
|
||||
if health.CurrentTempo != 60 {
|
||||
t.Errorf("Expected default tempo 60, got %d", health.CurrentTempo)
|
||||
}
|
||||
|
||||
// Simulate some activity
|
||||
client.beatMutex.Lock()
|
||||
client.currentBeat = 10
|
||||
client.currentTempo = 90
|
||||
client.lastBeatTime = time.Now().Add(-100 * time.Millisecond)
|
||||
client.beatMutex.Unlock()
|
||||
|
||||
client.addError("test error")
|
||||
|
||||
health = client.Health()
|
||||
if health.LastBeat != 10 {
|
||||
t.Errorf("Expected last beat to be 10, got %d", health.LastBeat)
|
||||
}
|
||||
if health.CurrentTempo != 90 {
|
||||
t.Errorf("Expected current tempo to be 90, got %d", health.CurrentTempo)
|
||||
}
|
||||
if len(health.Errors) != 1 {
|
||||
t.Errorf("Expected 1 error, got %d", len(health.Errors))
|
||||
}
|
||||
if health.TimeDrift <= 0 {
|
||||
t.Error("Expected positive time drift")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetrics tests metrics integration
|
||||
func TestMetrics(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "test-cluster"
|
||||
config.AgentID = generateUniqueAgentID("test-agent")
|
||||
|
||||
client := NewClient(config).(*client)
|
||||
|
||||
if client.metrics == nil {
|
||||
t.Fatal("Expected metrics to be initialized")
|
||||
}
|
||||
|
||||
// Test metrics snapshot
|
||||
snapshot := client.metrics.GetMetricsSnapshot()
|
||||
if snapshot == nil {
|
||||
t.Error("Expected metrics snapshot to be available")
|
||||
}
|
||||
|
||||
// Check for expected metric keys
|
||||
expectedKeys := []string{
|
||||
"connection_status",
|
||||
"reconnect_count",
|
||||
"beats_received",
|
||||
"status_claims_emitted",
|
||||
"budgets_created",
|
||||
"total_errors",
|
||||
}
|
||||
|
||||
for _, key := range expectedKeys {
|
||||
if _, exists := snapshot[key]; !exists {
|
||||
t.Errorf("Expected metric key '%s' to exist in snapshot", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfig tests configuration validation and defaults
|
||||
func TestConfig(t *testing.T) {
|
||||
// Test default config
|
||||
config := DefaultConfig()
|
||||
if config.JitterTolerance != 50*time.Millisecond {
|
||||
t.Errorf("Expected default jitter tolerance 50ms, got %v", config.JitterTolerance)
|
||||
}
|
||||
if config.ReconnectDelay != 1*time.Second {
|
||||
t.Errorf("Expected default reconnect delay 1s, got %v", config.ReconnectDelay)
|
||||
}
|
||||
if config.MaxReconnects != -1 {
|
||||
t.Errorf("Expected default max reconnects -1, got %d", config.MaxReconnects)
|
||||
}
|
||||
|
||||
// Test logger initialization
|
||||
config.Logger = nil
|
||||
client := NewClient(config)
|
||||
if client == nil {
|
||||
t.Error("Expected client to be created even with nil logger")
|
||||
}
|
||||
|
||||
// Test with custom config
|
||||
_, signingKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate signing key: %v", err)
|
||||
}
|
||||
|
||||
config.ClusterID = "custom-cluster"
|
||||
config.AgentID = "custom-agent"
|
||||
config.SigningKey = signingKey
|
||||
config.JitterTolerance = 100 * time.Millisecond
|
||||
config.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
|
||||
client = NewClient(config)
|
||||
if client == nil {
|
||||
t.Error("Expected client to be created with custom config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBeatDurationCalculation tests beat duration calculation
|
||||
func TestBeatDurationCalculation(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "test-cluster"
|
||||
config.AgentID = generateUniqueAgentID("test-agent")
|
||||
|
||||
client := NewClient(config).(*client)
|
||||
|
||||
// Test default 60 BPM (1 second per beat)
|
||||
duration := client.getBeatDuration()
|
||||
expected := 1000 * time.Millisecond
|
||||
if duration != expected {
|
||||
t.Errorf("Expected beat duration %v for 60 BPM, got %v", expected, duration)
|
||||
}
|
||||
|
||||
// Test 120 BPM (0.5 seconds per beat)
|
||||
client.beatMutex.Lock()
|
||||
client.currentTempo = 120
|
||||
client.beatMutex.Unlock()
|
||||
|
||||
duration = client.getBeatDuration()
|
||||
expected = 500 * time.Millisecond
|
||||
if duration != expected {
|
||||
t.Errorf("Expected beat duration %v for 120 BPM, got %v", expected, duration)
|
||||
}
|
||||
|
||||
// Test 30 BPM (2 seconds per beat)
|
||||
client.beatMutex.Lock()
|
||||
client.currentTempo = 30
|
||||
client.beatMutex.Unlock()
|
||||
|
||||
duration = client.getBeatDuration()
|
||||
expected = 2000 * time.Millisecond
|
||||
if duration != expected {
|
||||
t.Errorf("Expected beat duration %v for 30 BPM, got %v", expected, duration)
|
||||
}
|
||||
|
||||
// Test edge case: zero tempo (should default to 60 BPM)
|
||||
client.beatMutex.Lock()
|
||||
client.currentTempo = 0
|
||||
client.beatMutex.Unlock()
|
||||
|
||||
duration = client.getBeatDuration()
|
||||
expected = 1000 * time.Millisecond
|
||||
if duration != expected {
|
||||
t.Errorf("Expected beat duration %v for 0 BPM (default 60), got %v", expected, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkBeatCallback benchmarks beat callback execution
|
||||
func BenchmarkBeatCallback(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "bench-cluster"
|
||||
config.AgentID = "bench-agent"
|
||||
|
||||
client := NewClient(config).(*client)
|
||||
|
||||
beatFrame := BeatFrame{
|
||||
Type: "backbeat.beatframe.v1",
|
||||
ClusterID: "bench-cluster",
|
||||
BeatIndex: 1,
|
||||
Downbeat: false,
|
||||
Phase: "test",
|
||||
HLC: "test-hlc",
|
||||
DeadlineAt: time.Now().Add(time.Second),
|
||||
TempoBPM: 60,
|
||||
WindowID: "test-window",
|
||||
}
|
||||
|
||||
callbackCount := 0
|
||||
client.OnBeat(func(beat BeatFrame) {
|
||||
callbackCount++
|
||||
})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
client.safeExecuteCallback(client.beatCallbacks[0], beatFrame, "beat")
|
||||
}
|
||||
|
||||
if callbackCount != b.N {
|
||||
b.Errorf("Expected callback to be called %d times, got %d", b.N, callbackCount)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStatusClaimValidation benchmarks status claim validation
|
||||
func BenchmarkStatusClaimValidation(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "bench-cluster"
|
||||
config.AgentID = "bench-agent"
|
||||
|
||||
client := NewClient(config).(*client)
|
||||
|
||||
claim := StatusClaim{
|
||||
Type: "backbeat.statusclaim.v1",
|
||||
AgentID: "bench-agent",
|
||||
TaskID: "bench-task",
|
||||
BeatIndex: 1,
|
||||
State: "executing",
|
||||
BeatsLeft: 5,
|
||||
Progress: 0.5,
|
||||
Notes: "Benchmark test",
|
||||
HLC: "bench-hlc",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := client.validateStatusClaim(&claim)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mock NATS server for integration tests (if needed)
|
||||
func setupTestNATSServer(t *testing.T) *nats.Conn {
|
||||
// This would start an embedded NATS server for testing
|
||||
// For now, we'll skip tests that require NATS if it's not available
|
||||
nc, err := nats.Connect(nats.DefaultURL)
|
||||
if err != nil {
|
||||
t.Skipf("NATS server not available: %v", err)
|
||||
return nil
|
||||
}
|
||||
return nc
|
||||
}
|
||||
|
||||
func TestIntegrationWithNATS(t *testing.T) {
|
||||
nc := setupTestNATSServer(t)
|
||||
if nc == nil {
|
||||
return // Skipped
|
||||
}
|
||||
defer nc.Close()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.ClusterID = "integration-test"
|
||||
config.AgentID = generateUniqueAgentID("test-agent")
|
||||
config.NATSUrl = nats.DefaultURL
|
||||
|
||||
client := NewClient(config)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Test start/stop cycle
|
||||
err := client.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start client: %v", err)
|
||||
}
|
||||
|
||||
// Check health after start
|
||||
health := client.Health()
|
||||
if !health.Connected {
|
||||
t.Error("Expected client to be connected after start")
|
||||
}
|
||||
|
||||
// Test stop
|
||||
err = client.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to stop client: %v", err)
|
||||
}
|
||||
|
||||
// Check health after stop
|
||||
health = client.Health()
|
||||
if health.Connected {
|
||||
t.Error("Expected client to be disconnected after stop")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user