Files
BACKBEAT/pkg/sdk/client.go
2025-10-17 08:56:25 +11:00

480 lines
13 KiB
Go

// Package sdk provides the BACKBEAT Go SDK for enabling CHORUS services
// to become BACKBEAT-aware with beat synchronization and status emission.
package sdk
import (
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"log/slog"
"sync"
"time"
"github.com/google/uuid"
"github.com/nats-io/nats.go"
)
// Client interface defines the core BACKBEAT SDK functionality
// Implements BACKBEAT-REQ-040, 041, 042, 043, 044
type Client interface {
// Beat subscription (BACKBEAT-REQ-040)
OnBeat(callback func(BeatFrame)) error
OnDownbeat(callback func(BeatFrame)) error
// Status emission (BACKBEAT-REQ-041)
EmitStatusClaim(claim StatusClaim) error
// Beat budgets (BACKBEAT-REQ-042)
WithBeatBudget(n int, fn func() error) error
// Utilities
GetCurrentBeat() int64
GetCurrentWindow() string
IsInWindow(windowID string) bool
GetCurrentTempo() int
GetTempoDrift() time.Duration
// Lifecycle management
Start(ctx context.Context) error
Stop() error
Health() HealthStatus
}
// Config represents the SDK configuration
type Config struct {
ClusterID string // BACKBEAT cluster identifier
AgentID string // Unique agent identifier
NATSUrl string // NATS connection URL
SigningKey ed25519.PrivateKey // Ed25519 private key for signing (BACKBEAT-REQ-044)
Logger *slog.Logger // Structured logger
JitterTolerance time.Duration // Maximum jitter tolerance (default: 50ms)
ReconnectDelay time.Duration // NATS reconnection delay (default: 1s)
MaxReconnects int // Maximum reconnection attempts (default: -1 for infinite)
}
// DefaultConfig returns a Config with sensible defaults
func DefaultConfig() *Config {
return &Config{
JitterTolerance: 50 * time.Millisecond,
ReconnectDelay: 1 * time.Second,
MaxReconnects: -1, // Infinite reconnects
Logger: slog.Default(),
}
}
// BeatFrame represents a beat frame with timing information
type BeatFrame struct {
Type string `json:"type"`
ClusterID string `json:"cluster_id"`
BeatIndex int64 `json:"beat_index"`
Downbeat bool `json:"downbeat"`
Phase string `json:"phase"`
HLC string `json:"hlc"`
DeadlineAt time.Time `json:"deadline_at"`
TempoBPM int `json:"tempo_bpm"`
WindowID string `json:"window_id"`
}
// StatusClaim represents a status claim emission
type StatusClaim struct {
// Auto-populated by SDK
Type string `json:"type"` // Always "backbeat.statusclaim.v1"
AgentID string `json:"agent_id"` // Auto-populated from config
TaskID string `json:"task_id"` // Auto-generated if not provided
BeatIndex int64 `json:"beat_index"` // Auto-populated from current beat
HLC string `json:"hlc"` // Auto-populated from current HLC
// User-provided
State string `json:"state"` // executing|planning|waiting|review|done|failed
WaitFor []string `json:"wait_for,omitempty"` // refs (e.g., hmmm://thread/...)
BeatsLeft int `json:"beats_left"` // estimated beats remaining
Progress float64 `json:"progress"` // progress ratio (0.0-1.0)
Notes string `json:"notes"` // status description
}
// HealthStatus represents the current health of the SDK client
type HealthStatus struct {
Connected bool `json:"connected"`
LastBeat int64 `json:"last_beat"`
LastBeatTime time.Time `json:"last_beat_time"`
TimeDrift time.Duration `json:"time_drift"`
ReconnectCount int `json:"reconnect_count"`
LocalDegradation bool `json:"local_degradation"`
CurrentTempo int `json:"current_tempo"`
TempoDrift time.Duration `json:"tempo_drift"`
MeasuredBPM float64 `json:"measured_bpm"`
Errors []string `json:"errors,omitempty"`
}
// LegacyBeatInfo represents legacy {bar,beat} information
// For BACKBEAT-REQ-043 compatibility
type LegacyBeatInfo struct {
Bar int `json:"bar"`
Beat int `json:"beat"`
}
// tempoSample represents a tempo measurement for drift calculation
type tempoSample struct {
BeatIndex int64
Tempo int
MeasuredTime time.Time
ActualBPM float64 // Measured BPM based on inter-beat timing
}
// client implements the Client interface
type client struct {
config *Config
nc *nats.Conn
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
// Beat tracking
currentBeat int64
currentWindow string
currentHLC string
lastBeatTime time.Time
currentTempo int // Current tempo in BPM
lastTempo int // Last known tempo for drift calculation
tempoHistory []tempoSample // History for drift calculation
beatMutex sync.RWMutex
// Callbacks
beatCallbacks []func(BeatFrame)
downbeatCallbacks []func(BeatFrame)
callbackMutex sync.RWMutex
// Health and metrics
reconnectCount int
localDegradation bool
errors []string
errorMutex sync.RWMutex
metrics *Metrics
// Beat budget tracking
budgetContexts map[string]context.CancelFunc
budgetMutex sync.Mutex
// Legacy compatibility
legacyWarned bool
legacyMutex sync.Mutex
}
// NewClient creates a new BACKBEAT SDK client
func NewClient(config *Config) Client {
if config.Logger == nil {
config.Logger = slog.Default()
}
c := &client{
config: config,
beatCallbacks: make([]func(BeatFrame), 0),
downbeatCallbacks: make([]func(BeatFrame), 0),
budgetContexts: make(map[string]context.CancelFunc),
errors: make([]string, 0),
tempoHistory: make([]tempoSample, 0, 100),
currentTempo: 60, // Default to 60 BPM
}
// Initialize metrics
prefix := fmt.Sprintf("backbeat.sdk.%s", config.AgentID)
c.metrics = NewMetrics(prefix)
return c
}
// Start initializes the client and begins beat synchronization
func (c *client) Start(ctx context.Context) error {
c.ctx, c.cancel = context.WithCancel(ctx)
if err := c.connect(); err != nil {
return fmt.Errorf("failed to connect to NATS: %w", err)
}
c.wg.Add(1)
go c.beatSubscriptionLoop()
c.config.Logger.Info("BACKBEAT SDK client started",
slog.String("cluster_id", c.config.ClusterID),
slog.String("agent_id", c.config.AgentID))
return nil
}
// Stop gracefully stops the client
func (c *client) Stop() error {
if c.cancel != nil {
c.cancel()
}
// Cancel all active beat budgets
c.budgetMutex.Lock()
for id, cancel := range c.budgetContexts {
cancel()
delete(c.budgetContexts, id)
}
c.budgetMutex.Unlock()
if c.nc != nil {
c.nc.Close()
}
c.wg.Wait()
c.config.Logger.Info("BACKBEAT SDK client stopped")
return nil
}
// OnBeat registers a callback for beat events (BACKBEAT-REQ-040)
func (c *client) OnBeat(callback func(BeatFrame)) error {
if callback == nil {
return fmt.Errorf("callback cannot be nil")
}
c.callbackMutex.Lock()
defer c.callbackMutex.Unlock()
c.beatCallbacks = append(c.beatCallbacks, callback)
return nil
}
// OnDownbeat registers a callback for downbeat events (BACKBEAT-REQ-040)
func (c *client) OnDownbeat(callback func(BeatFrame)) error {
if callback == nil {
return fmt.Errorf("callback cannot be nil")
}
c.callbackMutex.Lock()
defer c.callbackMutex.Unlock()
c.downbeatCallbacks = append(c.downbeatCallbacks, callback)
return nil
}
// EmitStatusClaim emits a status claim (BACKBEAT-REQ-041)
func (c *client) EmitStatusClaim(claim StatusClaim) error {
// Auto-populate required fields
claim.Type = "backbeat.statusclaim.v1"
claim.AgentID = c.config.AgentID
claim.BeatIndex = c.GetCurrentBeat()
claim.HLC = c.getCurrentHLC()
// Auto-generate task ID if not provided
if claim.TaskID == "" {
claim.TaskID = fmt.Sprintf("task:%s", uuid.New().String()[:8])
}
// Validate the claim
if err := c.validateStatusClaim(&claim); err != nil {
return fmt.Errorf("invalid status claim: %w", err)
}
// Sign the claim if signing key is available (BACKBEAT-REQ-044)
if c.config.SigningKey != nil {
if err := c.signStatusClaim(&claim); err != nil {
return fmt.Errorf("failed to sign status claim: %w", err)
}
}
// Publish to NATS
data, err := json.Marshal(claim)
if err != nil {
return fmt.Errorf("failed to marshal status claim: %w", err)
}
subject := fmt.Sprintf("backbeat.status.%s", c.config.ClusterID)
headers := c.createHeaders()
msg := &nats.Msg{
Subject: subject,
Data: data,
Header: headers,
}
if err := c.nc.PublishMsg(msg); err != nil {
c.addError(fmt.Sprintf("failed to publish status claim: %v", err))
c.metrics.RecordStatusClaim(false)
return fmt.Errorf("failed to publish status claim: %w", err)
}
c.metrics.RecordStatusClaim(true)
c.config.Logger.Debug("Status claim emitted",
slog.String("agent_id", claim.AgentID),
slog.String("task_id", claim.TaskID),
slog.String("state", claim.State),
slog.Int64("beat_index", claim.BeatIndex))
return nil
}
// WithBeatBudget executes a function with a beat-based timeout (BACKBEAT-REQ-042)
func (c *client) WithBeatBudget(n int, fn func() error) error {
if n <= 0 {
return fmt.Errorf("beat budget must be positive, got %d", n)
}
// Calculate timeout based on current tempo
currentBeat := c.GetCurrentBeat()
beatDuration := c.getBeatDuration()
timeout := time.Duration(n) * beatDuration
// Use background context if client context is not set (for testing)
baseCtx := c.ctx
if baseCtx == nil {
baseCtx = context.Background()
}
ctx, cancel := context.WithTimeout(baseCtx, timeout)
defer cancel()
// Track the budget context for cancellation
budgetID := uuid.New().String()
c.budgetMutex.Lock()
c.budgetContexts[budgetID] = cancel
c.budgetMutex.Unlock()
// Record budget creation
c.metrics.RecordBudgetCreated()
defer func() {
c.budgetMutex.Lock()
delete(c.budgetContexts, budgetID)
c.budgetMutex.Unlock()
}()
// Execute function with timeout
done := make(chan error, 1)
go func() {
done <- fn()
}()
select {
case err := <-done:
c.metrics.RecordBudgetCompleted(false) // Not timed out
if err != nil {
c.config.Logger.Debug("Beat budget function completed with error",
slog.Int("budget", n),
slog.Int64("start_beat", currentBeat),
slog.String("error", err.Error()))
} else {
c.config.Logger.Debug("Beat budget function completed successfully",
slog.Int("budget", n),
slog.Int64("start_beat", currentBeat))
}
return err
case <-ctx.Done():
c.metrics.RecordBudgetCompleted(true) // Timed out
c.config.Logger.Warn("Beat budget exceeded",
slog.Int("budget", n),
slog.Int64("start_beat", currentBeat),
slog.Duration("timeout", timeout))
return fmt.Errorf("beat budget of %d beats exceeded", n)
}
}
// GetCurrentBeat returns the current beat index
func (c *client) GetCurrentBeat() int64 {
c.beatMutex.RLock()
defer c.beatMutex.RUnlock()
return c.currentBeat
}
// GetCurrentWindow returns the current window ID
func (c *client) GetCurrentWindow() string {
c.beatMutex.RLock()
defer c.beatMutex.RUnlock()
return c.currentWindow
}
// IsInWindow checks if we're currently in the specified window
func (c *client) IsInWindow(windowID string) bool {
return c.GetCurrentWindow() == windowID
}
// GetCurrentTempo returns the current tempo in BPM
func (c *client) GetCurrentTempo() int {
c.beatMutex.RLock()
defer c.beatMutex.RUnlock()
return c.currentTempo
}
// GetTempoDrift calculates the drift between expected and actual tempo
func (c *client) GetTempoDrift() time.Duration {
c.beatMutex.RLock()
defer c.beatMutex.RUnlock()
if len(c.tempoHistory) < 2 {
return 0
}
// Calculate average measured BPM from recent samples
historyLen := len(c.tempoHistory)
recentCount := 10
if historyLen < recentCount {
recentCount = historyLen
}
recent := c.tempoHistory[historyLen-recentCount:]
if len(recent) < 2 {
recent = c.tempoHistory
}
totalBPM := 0.0
for _, sample := range recent {
totalBPM += sample.ActualBPM
}
avgMeasuredBPM := totalBPM / float64(len(recent))
// Calculate drift
expectedBeatDuration := 60.0 / float64(c.currentTempo)
actualBeatDuration := 60.0 / avgMeasuredBPM
drift := actualBeatDuration - expectedBeatDuration
return time.Duration(drift * float64(time.Second))
}
// Health returns the current health status
func (c *client) Health() HealthStatus {
c.errorMutex.RLock()
errors := make([]string, len(c.errors))
copy(errors, c.errors)
c.errorMutex.RUnlock()
c.beatMutex.RLock()
timeDrift := time.Since(c.lastBeatTime)
currentTempo := c.currentTempo
// Calculate measured BPM from recent tempo history
measuredBPM := 60.0 // Default
if len(c.tempoHistory) > 0 {
historyLen := len(c.tempoHistory)
recentCount := 5
if historyLen < recentCount {
recentCount = historyLen
}
recent := c.tempoHistory[historyLen-recentCount:]
totalBPM := 0.0
for _, sample := range recent {
totalBPM += sample.ActualBPM
}
measuredBPM = totalBPM / float64(len(recent))
}
c.beatMutex.RUnlock()
tempoDrift := c.GetTempoDrift()
return HealthStatus{
Connected: c.nc != nil && c.nc.IsConnected(),
LastBeat: c.GetCurrentBeat(),
LastBeatTime: c.lastBeatTime,
TimeDrift: timeDrift,
ReconnectCount: c.reconnectCount,
LocalDegradation: c.localDegradation,
CurrentTempo: currentTempo,
TempoDrift: tempoDrift,
MeasuredBPM: measuredBPM,
Errors: errors,
}
}