574 lines
15 KiB
Go
574 lines
15 KiB
Go
package sdk
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"log/slog"
|
|
"os"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
)
|
|
|
|
var testCounter int
|
|
|
|
// generateUniqueAgentID generates unique agent IDs for tests to avoid expvar conflicts
|
|
func generateUniqueAgentID(prefix string) string {
|
|
testCounter++
|
|
return fmt.Sprintf("%s-%d", prefix, testCounter)
|
|
}
|
|
|
|
// TestClient tests basic client creation and configuration
|
|
func TestClient(t *testing.T) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "test-cluster"
|
|
config.AgentID = generateUniqueAgentID("test-agent")
|
|
config.NATSUrl = "nats://localhost:4222"
|
|
|
|
client := NewClient(config)
|
|
if client == nil {
|
|
t.Fatal("Expected client to be created")
|
|
}
|
|
|
|
// Test health before start
|
|
health := client.Health()
|
|
if health.Connected {
|
|
t.Error("Expected client to be disconnected before start")
|
|
}
|
|
}
|
|
|
|
// TestBeatCallbacks tests beat and downbeat callback registration
|
|
func TestBeatCallbacks(t *testing.T) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "test-cluster"
|
|
config.AgentID = generateUniqueAgentID("test-agent-callbacks")
|
|
|
|
client := NewClient(config)
|
|
|
|
var beatCalled, downbeatCalled bool
|
|
|
|
// Register callbacks
|
|
err := client.OnBeat(func(beat BeatFrame) {
|
|
beatCalled = true
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to register beat callback: %v", err)
|
|
}
|
|
|
|
err = client.OnDownbeat(func(beat BeatFrame) {
|
|
downbeatCalled = true
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to register downbeat callback: %v", err)
|
|
}
|
|
|
|
// Test nil callback rejection
|
|
err = client.OnBeat(nil)
|
|
if err == nil {
|
|
t.Error("Expected error when registering nil beat callback")
|
|
}
|
|
|
|
err = client.OnDownbeat(nil)
|
|
if err == nil {
|
|
t.Error("Expected error when registering nil downbeat callback")
|
|
}
|
|
|
|
// Use variables to prevent unused warnings
|
|
_ = beatCalled
|
|
_ = downbeatCalled
|
|
}
|
|
|
|
// TestStatusClaim tests status claim validation and emission
|
|
func TestStatusClaim(t *testing.T) {
|
|
_, signingKey, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate signing key: %v", err)
|
|
}
|
|
|
|
config := DefaultConfig()
|
|
config.ClusterID = "test-cluster"
|
|
config.AgentID = generateUniqueAgentID("test-agent")
|
|
config.SigningKey = signingKey
|
|
|
|
client := NewClient(config).(*client)
|
|
|
|
// Test valid status claim
|
|
claim := StatusClaim{
|
|
State: "executing",
|
|
BeatsLeft: 5,
|
|
Progress: 0.5,
|
|
Notes: "Test status",
|
|
}
|
|
|
|
// Test validation without connection (should work for validation)
|
|
client.currentBeat = 1
|
|
client.currentHLC = "test-hlc"
|
|
|
|
// Test auto-population
|
|
if claim.AgentID != "" {
|
|
t.Error("Expected AgentID to be empty before emission")
|
|
}
|
|
|
|
// Since we can't actually emit without NATS connection, test validation directly
|
|
claim.Type = "backbeat.statusclaim.v1"
|
|
claim.AgentID = config.AgentID
|
|
claim.TaskID = "test-task"
|
|
claim.BeatIndex = 1
|
|
claim.HLC = "test-hlc"
|
|
|
|
err = client.validateStatusClaim(&claim)
|
|
if err != nil {
|
|
t.Errorf("Expected valid status claim to pass validation: %v", err)
|
|
}
|
|
|
|
// Test invalid states
|
|
invalidClaim := claim
|
|
invalidClaim.State = "invalid-state"
|
|
err = client.validateStatusClaim(&invalidClaim)
|
|
if err == nil {
|
|
t.Error("Expected invalid state to fail validation")
|
|
}
|
|
|
|
// Test invalid progress
|
|
invalidClaim = claim
|
|
invalidClaim.Progress = 1.5
|
|
err = client.validateStatusClaim(&invalidClaim)
|
|
if err == nil {
|
|
t.Error("Expected invalid progress to fail validation")
|
|
}
|
|
|
|
// Test negative beats left
|
|
invalidClaim = claim
|
|
invalidClaim.BeatsLeft = -1
|
|
err = client.validateStatusClaim(&invalidClaim)
|
|
if err == nil {
|
|
t.Error("Expected negative beats_left to fail validation")
|
|
}
|
|
}
|
|
|
|
// TestBeatBudget tests beat budget functionality
|
|
func TestBeatBudget(t *testing.T) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "test-cluster"
|
|
config.AgentID = generateUniqueAgentID("test-agent")
|
|
|
|
client := NewClient(config).(*client)
|
|
client.currentTempo = 120 // 120 BPM = 0.5 seconds per beat
|
|
|
|
ctx := context.Background()
|
|
client.ctx = ctx
|
|
|
|
// Test successful execution within budget
|
|
executed := false
|
|
err := client.WithBeatBudget(2, func() error {
|
|
executed = true
|
|
time.Sleep(100 * time.Millisecond) // Much less than 2 beats (1 second)
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
t.Errorf("Expected function to complete successfully: %v", err)
|
|
}
|
|
if !executed {
|
|
t.Error("Expected function to be executed")
|
|
}
|
|
|
|
// Test timeout (need to be careful with timing)
|
|
timeoutErr := client.WithBeatBudget(1, func() error {
|
|
time.Sleep(2 * time.Second) // More than 1 beat at 120 BPM (0.5s)
|
|
return nil
|
|
})
|
|
|
|
if timeoutErr == nil {
|
|
t.Error("Expected function to timeout")
|
|
}
|
|
if timeoutErr.Error() != "beat budget of 1 beats exceeded" {
|
|
t.Errorf("Expected timeout error message, got: %v", timeoutErr)
|
|
}
|
|
|
|
// Test invalid budget
|
|
err = client.WithBeatBudget(0, func() error { return nil })
|
|
if err == nil {
|
|
t.Error("Expected error for zero beat budget")
|
|
}
|
|
|
|
err = client.WithBeatBudget(-1, func() error { return nil })
|
|
if err == nil {
|
|
t.Error("Expected error for negative beat budget")
|
|
}
|
|
}
|
|
|
|
// TestTempoTracking tests tempo tracking and drift calculation
|
|
func TestTempoTracking(t *testing.T) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "test-cluster"
|
|
config.AgentID = generateUniqueAgentID("test-agent")
|
|
|
|
client := NewClient(config).(*client)
|
|
|
|
// Test initial values
|
|
if client.GetCurrentTempo() != 60 {
|
|
t.Errorf("Expected default tempo to be 60, got %d", client.GetCurrentTempo())
|
|
}
|
|
|
|
if client.GetTempoDrift() != 0 {
|
|
t.Errorf("Expected initial tempo drift to be 0, got %v", client.GetTempoDrift())
|
|
}
|
|
|
|
// Simulate tempo changes
|
|
client.beatMutex.Lock()
|
|
client.currentTempo = 120
|
|
client.tempoHistory = append(client.tempoHistory, tempoSample{
|
|
BeatIndex: 1,
|
|
Tempo: 120,
|
|
MeasuredTime: time.Now(),
|
|
ActualBPM: 118.0, // Slightly slower than expected
|
|
})
|
|
client.tempoHistory = append(client.tempoHistory, tempoSample{
|
|
BeatIndex: 2,
|
|
Tempo: 120,
|
|
MeasuredTime: time.Now().Add(500 * time.Millisecond),
|
|
ActualBPM: 119.0, // Still slightly slower
|
|
})
|
|
client.beatMutex.Unlock()
|
|
|
|
if client.GetCurrentTempo() != 120 {
|
|
t.Errorf("Expected current tempo to remain at 120 BPM, got %d", client.GetCurrentTempo())
|
|
}
|
|
|
|
// Test drift calculation (should be non-zero due to difference between 120 and measured BPM)
|
|
drift := client.GetTempoDrift()
|
|
if drift == 0 {
|
|
t.Error("Expected non-zero tempo drift")
|
|
}
|
|
}
|
|
|
|
// TestLegacyCompatibility tests legacy beat conversion
|
|
func TestLegacyCompatibility(t *testing.T) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "test-cluster"
|
|
config.AgentID = generateUniqueAgentID("test-agent")
|
|
|
|
client := NewClient(config).(*client)
|
|
|
|
// Test legacy beat conversion
|
|
beatIndex := client.ConvertLegacyBeat(2, 3) // Bar 2, Beat 3
|
|
expectedBeatIndex := int64(7) // (2-1)*4 + 3 = 7
|
|
if beatIndex != expectedBeatIndex {
|
|
t.Errorf("Expected beat index %d, got %d", expectedBeatIndex, beatIndex)
|
|
}
|
|
|
|
// Test reverse conversion
|
|
client.beatMutex.Lock()
|
|
client.currentBeat = 7
|
|
client.beatMutex.Unlock()
|
|
|
|
legacyInfo := client.GetLegacyBeatInfo()
|
|
if legacyInfo.Bar != 2 || legacyInfo.Beat != 3 {
|
|
t.Errorf("Expected bar=2, beat=3, got bar=%d, beat=%d", legacyInfo.Bar, legacyInfo.Beat)
|
|
}
|
|
|
|
// Test edge cases
|
|
beatIndex = client.ConvertLegacyBeat(1, 1) // First beat
|
|
if beatIndex != 1 {
|
|
t.Errorf("Expected beat index 1 for first beat, got %d", beatIndex)
|
|
}
|
|
|
|
client.beatMutex.Lock()
|
|
client.currentBeat = 0 // Edge case
|
|
client.beatMutex.Unlock()
|
|
|
|
legacyInfo = client.GetLegacyBeatInfo()
|
|
if legacyInfo.Bar != 1 || legacyInfo.Beat != 1 {
|
|
t.Errorf("Expected bar=1, beat=1 for zero beat, got bar=%d, beat=%d", legacyInfo.Bar, legacyInfo.Beat)
|
|
}
|
|
}
|
|
|
|
// TestHealthStatus tests health status reporting
|
|
func TestHealthStatus(t *testing.T) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "test-cluster"
|
|
config.AgentID = generateUniqueAgentID("test-agent")
|
|
|
|
client := NewClient(config).(*client)
|
|
|
|
// Test initial health
|
|
health := client.Health()
|
|
if health.Connected {
|
|
t.Error("Expected client to be disconnected initially")
|
|
}
|
|
if health.LastBeat != 0 {
|
|
t.Error("Expected last beat to be 0 initially")
|
|
}
|
|
if health.CurrentTempo != 60 {
|
|
t.Errorf("Expected default tempo 60, got %d", health.CurrentTempo)
|
|
}
|
|
|
|
// Simulate some activity
|
|
client.beatMutex.Lock()
|
|
client.currentBeat = 10
|
|
client.currentTempo = 90
|
|
client.lastBeatTime = time.Now().Add(-100 * time.Millisecond)
|
|
client.beatMutex.Unlock()
|
|
|
|
client.addError("test error")
|
|
|
|
health = client.Health()
|
|
if health.LastBeat != 10 {
|
|
t.Errorf("Expected last beat to be 10, got %d", health.LastBeat)
|
|
}
|
|
if health.CurrentTempo != 90 {
|
|
t.Errorf("Expected current tempo to be 90, got %d", health.CurrentTempo)
|
|
}
|
|
if len(health.Errors) != 1 {
|
|
t.Errorf("Expected 1 error, got %d", len(health.Errors))
|
|
}
|
|
if health.TimeDrift <= 0 {
|
|
t.Error("Expected positive time drift")
|
|
}
|
|
}
|
|
|
|
// TestMetrics tests metrics integration
|
|
func TestMetrics(t *testing.T) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "test-cluster"
|
|
config.AgentID = generateUniqueAgentID("test-agent")
|
|
|
|
client := NewClient(config).(*client)
|
|
|
|
if client.metrics == nil {
|
|
t.Fatal("Expected metrics to be initialized")
|
|
}
|
|
|
|
// Test metrics snapshot
|
|
snapshot := client.metrics.GetMetricsSnapshot()
|
|
if snapshot == nil {
|
|
t.Error("Expected metrics snapshot to be available")
|
|
}
|
|
|
|
// Check for expected metric keys
|
|
expectedKeys := []string{
|
|
"connection_status",
|
|
"reconnect_count",
|
|
"beats_received",
|
|
"status_claims_emitted",
|
|
"budgets_created",
|
|
"total_errors",
|
|
}
|
|
|
|
for _, key := range expectedKeys {
|
|
if _, exists := snapshot[key]; !exists {
|
|
t.Errorf("Expected metric key '%s' to exist in snapshot", key)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestConfig tests configuration validation and defaults
|
|
func TestConfig(t *testing.T) {
|
|
// Test default config
|
|
config := DefaultConfig()
|
|
if config.JitterTolerance != 50*time.Millisecond {
|
|
t.Errorf("Expected default jitter tolerance 50ms, got %v", config.JitterTolerance)
|
|
}
|
|
if config.ReconnectDelay != 1*time.Second {
|
|
t.Errorf("Expected default reconnect delay 1s, got %v", config.ReconnectDelay)
|
|
}
|
|
if config.MaxReconnects != -1 {
|
|
t.Errorf("Expected default max reconnects -1, got %d", config.MaxReconnects)
|
|
}
|
|
|
|
// Test logger initialization
|
|
config.Logger = nil
|
|
client := NewClient(config)
|
|
if client == nil {
|
|
t.Error("Expected client to be created even with nil logger")
|
|
}
|
|
|
|
// Test with custom config
|
|
_, signingKey, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate signing key: %v", err)
|
|
}
|
|
|
|
config.ClusterID = "custom-cluster"
|
|
config.AgentID = "custom-agent"
|
|
config.SigningKey = signingKey
|
|
config.JitterTolerance = 100 * time.Millisecond
|
|
config.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
|
|
client = NewClient(config)
|
|
if client == nil {
|
|
t.Error("Expected client to be created with custom config")
|
|
}
|
|
}
|
|
|
|
// TestBeatDurationCalculation tests beat duration calculation
|
|
func TestBeatDurationCalculation(t *testing.T) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "test-cluster"
|
|
config.AgentID = generateUniqueAgentID("test-agent")
|
|
|
|
client := NewClient(config).(*client)
|
|
|
|
// Test default 60 BPM (1 second per beat)
|
|
duration := client.getBeatDuration()
|
|
expected := 1000 * time.Millisecond
|
|
if duration != expected {
|
|
t.Errorf("Expected beat duration %v for 60 BPM, got %v", expected, duration)
|
|
}
|
|
|
|
// Test 120 BPM (0.5 seconds per beat)
|
|
client.beatMutex.Lock()
|
|
client.currentTempo = 120
|
|
client.beatMutex.Unlock()
|
|
|
|
duration = client.getBeatDuration()
|
|
expected = 500 * time.Millisecond
|
|
if duration != expected {
|
|
t.Errorf("Expected beat duration %v for 120 BPM, got %v", expected, duration)
|
|
}
|
|
|
|
// Test 30 BPM (2 seconds per beat)
|
|
client.beatMutex.Lock()
|
|
client.currentTempo = 30
|
|
client.beatMutex.Unlock()
|
|
|
|
duration = client.getBeatDuration()
|
|
expected = 2000 * time.Millisecond
|
|
if duration != expected {
|
|
t.Errorf("Expected beat duration %v for 30 BPM, got %v", expected, duration)
|
|
}
|
|
|
|
// Test edge case: zero tempo (should default to 60 BPM)
|
|
client.beatMutex.Lock()
|
|
client.currentTempo = 0
|
|
client.beatMutex.Unlock()
|
|
|
|
duration = client.getBeatDuration()
|
|
expected = 1000 * time.Millisecond
|
|
if duration != expected {
|
|
t.Errorf("Expected beat duration %v for 0 BPM (default 60), got %v", expected, duration)
|
|
}
|
|
}
|
|
|
|
// BenchmarkBeatCallback benchmarks beat callback execution
|
|
func BenchmarkBeatCallback(b *testing.B) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "bench-cluster"
|
|
config.AgentID = "bench-agent"
|
|
|
|
client := NewClient(config).(*client)
|
|
|
|
beatFrame := BeatFrame{
|
|
Type: "backbeat.beatframe.v1",
|
|
ClusterID: "bench-cluster",
|
|
BeatIndex: 1,
|
|
Downbeat: false,
|
|
Phase: "test",
|
|
HLC: "test-hlc",
|
|
DeadlineAt: time.Now().Add(time.Second),
|
|
TempoBPM: 60,
|
|
WindowID: "test-window",
|
|
}
|
|
|
|
callbackCount := 0
|
|
client.OnBeat(func(beat BeatFrame) {
|
|
callbackCount++
|
|
})
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
client.safeExecuteCallback(client.beatCallbacks[0], beatFrame, "beat")
|
|
}
|
|
|
|
if callbackCount != b.N {
|
|
b.Errorf("Expected callback to be called %d times, got %d", b.N, callbackCount)
|
|
}
|
|
}
|
|
|
|
// BenchmarkStatusClaimValidation benchmarks status claim validation
|
|
func BenchmarkStatusClaimValidation(b *testing.B) {
|
|
config := DefaultConfig()
|
|
config.ClusterID = "bench-cluster"
|
|
config.AgentID = "bench-agent"
|
|
|
|
client := NewClient(config).(*client)
|
|
|
|
claim := StatusClaim{
|
|
Type: "backbeat.statusclaim.v1",
|
|
AgentID: "bench-agent",
|
|
TaskID: "bench-task",
|
|
BeatIndex: 1,
|
|
State: "executing",
|
|
BeatsLeft: 5,
|
|
Progress: 0.5,
|
|
Notes: "Benchmark test",
|
|
HLC: "bench-hlc",
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
err := client.validateStatusClaim(&claim)
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Mock NATS server for integration tests (if needed)
|
|
func setupTestNATSServer(t *testing.T) *nats.Conn {
|
|
// This would start an embedded NATS server for testing
|
|
// For now, we'll skip tests that require NATS if it's not available
|
|
nc, err := nats.Connect(nats.DefaultURL)
|
|
if err != nil {
|
|
t.Skipf("NATS server not available: %v", err)
|
|
return nil
|
|
}
|
|
return nc
|
|
}
|
|
|
|
func TestIntegrationWithNATS(t *testing.T) {
|
|
nc := setupTestNATSServer(t)
|
|
if nc == nil {
|
|
return // Skipped
|
|
}
|
|
defer nc.Close()
|
|
|
|
config := DefaultConfig()
|
|
config.ClusterID = "integration-test"
|
|
config.AgentID = generateUniqueAgentID("test-agent")
|
|
config.NATSUrl = nats.DefaultURL
|
|
|
|
client := NewClient(config)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
// Test start/stop cycle
|
|
err := client.Start(ctx)
|
|
if err != nil {
|
|
t.Fatalf("Failed to start client: %v", err)
|
|
}
|
|
|
|
// Check health after start
|
|
health := client.Health()
|
|
if !health.Connected {
|
|
t.Error("Expected client to be connected after start")
|
|
}
|
|
|
|
// Test stop
|
|
err = client.Stop()
|
|
if err != nil {
|
|
t.Errorf("Failed to stop client: %v", err)
|
|
}
|
|
|
|
// Check health after stop
|
|
health = client.Health()
|
|
if health.Connected {
|
|
t.Error("Expected client to be disconnected after stop")
|
|
}
|
|
}
|