Compare commits
22 Commits
e820770409
...
feature/ph
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17673c38a6 | ||
|
|
9dbd361caf | ||
|
|
859e5e1e02 | ||
|
|
f010a0c8a2 | ||
|
|
d0973b2adf | ||
|
|
8d9b62daf3 | ||
|
|
d1252ade69 | ||
|
|
9fc9a2e3a2 | ||
|
|
14b5125c12 | ||
|
|
ea04378962 | ||
| 237e8699eb | |||
| 1de8695736 | |||
| c30c6dc480 | |||
|
|
e523c4b543 | ||
|
|
26e4ef7d8b | ||
|
|
eb2e05ff84 | ||
| ef4bf1efe0 | |||
|
|
2578876eeb | ||
|
|
95784822ce | ||
|
|
1bb736c09a | ||
|
|
57751f277a | ||
| 966225c3e2 |
44
Dockerfile.simple
Normal file
44
Dockerfile.simple
Normal file
@@ -0,0 +1,44 @@
|
||||
# CHORUS - Simple Docker image using pre-built binary
|
||||
FROM alpine:3.18
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk --no-cache add \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl
|
||||
|
||||
# Create non-root user for security
|
||||
RUN addgroup -g 1000 chorus && \
|
||||
adduser -u 1000 -G chorus -s /bin/sh -D chorus
|
||||
|
||||
# Create application directories
|
||||
RUN mkdir -p /app/data && \
|
||||
chown -R chorus:chorus /app
|
||||
|
||||
# Copy pre-built binary from build directory (ensure it exists and is the correct one)
|
||||
COPY build/chorus-agent /app/chorus-agent
|
||||
RUN chmod +x /app/chorus-agent && chown chorus:chorus /app/chorus-agent
|
||||
|
||||
# Switch to non-root user
|
||||
USER chorus
|
||||
WORKDIR /app
|
||||
|
||||
# Note: Using correct chorus-agent binary built with 'make build-agent'
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 8080 8081 9000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8081/health || exit 1
|
||||
|
||||
# Set default environment variables
|
||||
ENV LOG_LEVEL=info \
|
||||
LOG_FORMAT=structured \
|
||||
CHORUS_BIND_ADDRESS=0.0.0.0 \
|
||||
CHORUS_API_PORT=8080 \
|
||||
CHORUS_HEALTH_PORT=8081 \
|
||||
CHORUS_P2P_PORT=9000
|
||||
|
||||
# Start CHORUS
|
||||
ENTRYPOINT ["/app/chorus-agent"]
|
||||
43
Dockerfile.ubuntu
Normal file
43
Dockerfile.ubuntu
Normal file
@@ -0,0 +1,43 @@
|
||||
# CHORUS - Ubuntu-based Docker image for glibc compatibility
|
||||
FROM ubuntu:22.04
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user for security
|
||||
RUN groupadd -g 1000 chorus && \
|
||||
useradd -u 1000 -g chorus -s /bin/bash -d /home/chorus -m chorus
|
||||
|
||||
# Create application directories
|
||||
RUN mkdir -p /app/data && \
|
||||
chown -R chorus:chorus /app
|
||||
|
||||
# Copy pre-built binary from build directory
|
||||
COPY build/chorus-agent /app/chorus-agent
|
||||
RUN chmod +x /app/chorus-agent && chown chorus:chorus /app/chorus-agent
|
||||
|
||||
# Switch to non-root user
|
||||
USER chorus
|
||||
WORKDIR /app
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 8080 8081 9000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8081/health || exit 1
|
||||
|
||||
# Set default environment variables
|
||||
ENV LOG_LEVEL=info \
|
||||
LOG_FORMAT=structured \
|
||||
CHORUS_BIND_ADDRESS=0.0.0.0 \
|
||||
CHORUS_API_PORT=8080 \
|
||||
CHORUS_HEALTH_PORT=8081 \
|
||||
CHORUS_P2P_PORT=9000
|
||||
|
||||
# Start CHORUS
|
||||
ENTRYPOINT ["/app/chorus-agent"]
|
||||
2
Makefile
2
Makefile
@@ -5,7 +5,7 @@
|
||||
BINARY_NAME_AGENT = chorus-agent
|
||||
BINARY_NAME_HAP = chorus-hap
|
||||
BINARY_NAME_COMPAT = chorus
|
||||
VERSION ?= 0.1.0-dev
|
||||
VERSION ?= 0.5.5
|
||||
COMMIT_HASH ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
BUILD_DATE ?= $(shell date -u '+%Y-%m-%d_%H:%M:%S')
|
||||
|
||||
|
||||
142
README.md
142
README.md
@@ -1,99 +1,87 @@
|
||||
# CHORUS - Container-First P2P Task Coordination System
|
||||
# CHORUS – Container-First Context Platform (Alpha)
|
||||
|
||||
CHORUS is a next-generation P2P task coordination and collaborative AI system designed from the ground up for containerized deployments. It takes the best lessons learned from CHORUS and reimagines them for Docker Swarm, Kubernetes, and modern container orchestration platforms.
|
||||
CHORUS is the runtime that ties the CHORUS ecosystem together: libp2p mesh, DHT-backed storage, council/task coordination, and (eventually) SLURP contextual intelligence. The repository you are looking at is the in-progress container-first refactor. Several core systems boot today, but higher-level services (SLURP, SHHH, full HMMM routing) are still landing.
|
||||
|
||||
## Vision
|
||||
## Current Status
|
||||
|
||||
CHORUS enables distributed AI agents to coordinate, collaborate, and execute tasks across container clusters, supporting deployments from single containers to hundreds of instances in enterprise environments.
|
||||
| Area | Status | Notes |
|
||||
| --- | --- | --- |
|
||||
| libp2p node + PubSub | ✅ Running | `internal/runtime/shared.go` spins up the mesh, hypercore logging, availability broadcasts. |
|
||||
| DHT + DecisionPublisher | ✅ Running | Encrypted storage wired through `pkg/dht`; decisions written via `ucxl.DecisionPublisher`. |
|
||||
| **Leader Election System** | ✅ **FULLY FUNCTIONAL** | **🎉 MILESTONE: Complete admin election with consensus, discovery protocol, heartbeats, and SLURP activation!** |
|
||||
| SLURP (context intelligence) | 🚧 Stubbed | `pkg/slurp/slurp.go` contains TODOs for resolver, temporal graphs, intelligence. Leader integration scaffolding exists but uses placeholder IDs/request forwarding. |
|
||||
| SHHH (secrets sentinel) | 🚧 Sentinel live | `pkg/shhh` redacts hypercore + PubSub payloads with audit + metrics hooks (policy replay TBD). |
|
||||
| HMMM routing | 🚧 Partial | PubSub topics join, but capability/role announcements and HMMM router wiring are placeholders (`internal/runtime/agent_support.go`). |
|
||||
|
||||
## Key Design Principles
|
||||
See `docs/progress/CHORUS-WHOOSH-development-plan.md` for the detailed build plan and `docs/progress/CHORUS-WHOOSH-roadmap.md` for sequencing.
|
||||
|
||||
- **Container-First**: Designed specifically for Docker/Kubernetes deployments
|
||||
- **License-Controlled**: Simple environment variable-based licensing
|
||||
- **Cloud-Native Logging**: Structured logging to stdout/stderr for container runtime collection
|
||||
- **Swarm-Ready P2P**: P2P protocols optimized for container networking
|
||||
- **Scalable Agent IDs**: Agent identification system that works across distributed deployments
|
||||
- **Zero-Config**: Minimal configuration requirements via environment variables
|
||||
## Quick Start (Alpha)
|
||||
|
||||
## Architecture
|
||||
The container-first workflows are still evolving; expect frequent changes.
|
||||
|
||||
CHORUS follows a microservices architecture where each container runs a single agent instance:
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ CHORUS Agent │ │ CHORUS Agent │ │ CHORUS Agent │
|
||||
│ Container 1 │◄─┤ Container 2 │─►│ Container N │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
│ │ │
|
||||
└──────────────────────┼──────────────────────┘
|
||||
│
|
||||
┌─────────────────┐
|
||||
│ Container │
|
||||
│ Network │
|
||||
│ (P2P Mesh) │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker & Docker Compose
|
||||
- Valid CHORUS license key
|
||||
- Access to Ollama endpoints for AI functionality
|
||||
|
||||
### Basic Deployment
|
||||
|
||||
1. Clone and configure:
|
||||
```bash
|
||||
git clone https://gitea.chorus.services/tony/CHORUS.git
|
||||
cd CHORUS
|
||||
cp docker/chorus.env.example docker/chorus.env
|
||||
# Edit docker/chorus.env with your license key and configuration
|
||||
# adjust env vars (KACHING license, bootstrap peers, etc.)
|
||||
docker compose -f docker/docker-compose.yml up --build
|
||||
```
|
||||
|
||||
2. Deploy:
|
||||
You’ll get a single agent container with:
|
||||
- libp2p networking (mDNS + configured bootstrap peers)
|
||||
- election heartbeat
|
||||
- DHT storage (AGE-encrypted)
|
||||
- HTTP API + health endpoints
|
||||
|
||||
**Missing today:** SLURP context resolution, advanced SHHH policy replay, HMMM per-issue routing. Expect log warnings/TODOs for those paths.
|
||||
|
||||
## 🎉 Leader Election System (NEW!)
|
||||
|
||||
CHORUS now features a complete, production-ready leader election system:
|
||||
|
||||
### Core Features
|
||||
- **Consensus-based election** with weighted scoring (uptime, capabilities, resources)
|
||||
- **Admin discovery protocol** for network-wide leader identification
|
||||
- **Heartbeat system** with automatic failover (15-second intervals)
|
||||
- **Concurrent election prevention** with randomized delays
|
||||
- **SLURP activation** on elected admin nodes
|
||||
|
||||
### How It Works
|
||||
1. **Bootstrap**: Nodes start in idle state, no admin known
|
||||
2. **Discovery**: Nodes send discovery requests to find existing admin
|
||||
3. **Election trigger**: If no admin found after grace period, trigger election
|
||||
4. **Candidacy**: Eligible nodes announce themselves with capability scores
|
||||
5. **Consensus**: Network selects winner based on highest score
|
||||
6. **Leadership**: Winner starts heartbeats, activates SLURP functionality
|
||||
7. **Monitoring**: Nodes continuously verify admin health via heartbeats
|
||||
|
||||
### Debugging
|
||||
Use these log patterns to monitor election health:
|
||||
```bash
|
||||
docker-compose -f docker/docker-compose.yml up -d
|
||||
# Monitor WHOAMI messages and leader identification
|
||||
docker service logs CHORUS_chorus | grep "🤖 WHOAMI\|👑\|📡.*Discovered"
|
||||
|
||||
# Track election cycles
|
||||
docker service logs CHORUS_chorus | grep "🗳️\|📢.*candidacy\|🏆.*winner"
|
||||
|
||||
# Watch discovery protocol
|
||||
docker service logs CHORUS_chorus | grep "📩\|📤\|📥"
|
||||
```
|
||||
|
||||
3. Scale (Docker Swarm):
|
||||
```bash
|
||||
docker service scale chorus_agent=10
|
||||
```
|
||||
## Roadmap Highlights
|
||||
|
||||
## Licensing
|
||||
1. **Security substrate** – land SHHH sentinel, finish SLURP leader-only operations, validate COOEE enrolment (see roadmap Phase 1).
|
||||
2. **Autonomous teams** – coordinate with WHOOSH for deployment telemetry + SLURP context export.
|
||||
3. **UCXL + KACHING** – hook runtime telemetry into KACHING and enforce UCXL validator.
|
||||
|
||||
CHORUS requires a valid license key to operate. Set your license key in the environment:
|
||||
Track progress via the shared roadmap and weekly burndown dashboards.
|
||||
|
||||
```env
|
||||
CHORUS_LICENSE_KEY=your-license-key-here
|
||||
CHORUS_LICENSE_EMAIL=your-email@example.com
|
||||
```
|
||||
|
||||
**No license = No operation.** CHORUS will not start without valid licensing.
|
||||
|
||||
## Differences from CHORUS
|
||||
|
||||
| Aspect | CHORUS | CHORUS |
|
||||
|--------|------|--------|
|
||||
| Deployment | systemd service (1 per host) | Container (N per cluster) |
|
||||
| Configuration | Web UI setup | Environment variables |
|
||||
| Logging | Journal/files | stdout/stderr (structured) |
|
||||
| Licensing | Setup-time validation | Runtime environment variable |
|
||||
| Agent IDs | Host-based | Container/cluster-based |
|
||||
| P2P Discovery | mDNS local network | Container network + service discovery |
|
||||
|
||||
## Development Status
|
||||
|
||||
🚧 **Early Development** - CHORUS is being designed and built. Not yet ready for production use.
|
||||
|
||||
Current Phase: Architecture design and core foundation development.
|
||||
|
||||
## License
|
||||
|
||||
CHORUS is a commercial product. Contact chorus.services for licensing information.
|
||||
## Related Projects
|
||||
- [WHOOSH](https://gitea.chorus.services/tony/WHOOSH) – council/team orchestration
|
||||
- [KACHING](https://gitea.chorus.services/tony/KACHING) – telemetry/licensing
|
||||
- [SLURP](https://gitea.chorus.services/tony/SLURP) – contextual intelligence prototypes
|
||||
- [HMMM](https://gitea.chorus.services/tony/hmmm) – meta-discussion layer
|
||||
|
||||
## Contributing
|
||||
|
||||
CHORUS is developed by the chorus.services team. For contributions or feedback, please use the issue tracker on our GITEA instance.
|
||||
This repo is still alpha. Please coordinate via the roadmap tickets before landing changes. Major security/runtime decisions should include a Decision Record with a UCXL address so SLURP/BUBBLE can ingest it later.
|
||||
|
||||
@@ -9,10 +9,11 @@ import (
|
||||
|
||||
"chorus/internal/logging"
|
||||
"chorus/pubsub"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// HTTPServer provides HTTP API endpoints for Bzzz
|
||||
// HTTPServer provides HTTP API endpoints for CHORUS
|
||||
type HTTPServer struct {
|
||||
port int
|
||||
hypercoreLog *logging.HypercoreLog
|
||||
@@ -20,7 +21,7 @@ type HTTPServer struct {
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
// NewHTTPServer creates a new HTTP server for Bzzz API
|
||||
// NewHTTPServer creates a new HTTP server for CHORUS API
|
||||
func NewHTTPServer(port int, hlog *logging.HypercoreLog, ps *pubsub.PubSub) *HTTPServer {
|
||||
return &HTTPServer{
|
||||
port: port,
|
||||
@@ -32,38 +33,38 @@ func NewHTTPServer(port int, hlog *logging.HypercoreLog, ps *pubsub.PubSub) *HTT
|
||||
// Start starts the HTTP server
|
||||
func (h *HTTPServer) Start() error {
|
||||
router := mux.NewRouter()
|
||||
|
||||
|
||||
// Enable CORS for all routes
|
||||
router.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
// API routes
|
||||
api := router.PathPrefix("/api").Subrouter()
|
||||
|
||||
|
||||
// Hypercore log endpoints
|
||||
api.HandleFunc("/hypercore/logs", h.handleGetLogs).Methods("GET")
|
||||
api.HandleFunc("/hypercore/logs/recent", h.handleGetRecentLogs).Methods("GET")
|
||||
api.HandleFunc("/hypercore/logs/stats", h.handleGetLogStats).Methods("GET")
|
||||
api.HandleFunc("/hypercore/logs/since/{index}", h.handleGetLogsSince).Methods("GET")
|
||||
|
||||
|
||||
// Health check
|
||||
api.HandleFunc("/health", h.handleHealth).Methods("GET")
|
||||
|
||||
|
||||
// Status endpoint
|
||||
api.HandleFunc("/status", h.handleStatus).Methods("GET")
|
||||
|
||||
|
||||
h.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", h.port),
|
||||
Handler: router,
|
||||
@@ -71,7 +72,7 @@ func (h *HTTPServer) Start() error {
|
||||
WriteTimeout: 15 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
|
||||
fmt.Printf("🌐 Starting HTTP API server on port %d\n", h.port)
|
||||
return h.server.ListenAndServe()
|
||||
}
|
||||
@@ -87,16 +88,16 @@ func (h *HTTPServer) Stop() error {
|
||||
// handleGetLogs returns hypercore log entries
|
||||
func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
||||
// Parse query parameters
|
||||
query := r.URL.Query()
|
||||
startStr := query.Get("start")
|
||||
endStr := query.Get("end")
|
||||
limitStr := query.Get("limit")
|
||||
|
||||
|
||||
var start, end uint64
|
||||
var err error
|
||||
|
||||
|
||||
if startStr != "" {
|
||||
start, err = strconv.ParseUint(startStr, 10, 64)
|
||||
if err != nil {
|
||||
@@ -104,7 +105,7 @@ func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if endStr != "" {
|
||||
end, err = strconv.ParseUint(endStr, 10, 64)
|
||||
if err != nil {
|
||||
@@ -114,7 +115,7 @@ func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
end = h.hypercoreLog.Length()
|
||||
}
|
||||
|
||||
|
||||
var limit int = 100 // Default limit
|
||||
if limitStr != "" {
|
||||
limit, err = strconv.Atoi(limitStr)
|
||||
@@ -122,7 +123,7 @@ func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
||||
limit = 100
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Get log entries
|
||||
var entries []logging.LogEntry
|
||||
if endStr != "" || startStr != "" {
|
||||
@@ -130,87 +131,87 @@ func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
entries, err = h.hypercoreLog.GetRecentEntries(limit)
|
||||
}
|
||||
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to get log entries: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
response := map[string]interface{}{
|
||||
"entries": entries,
|
||||
"count": len(entries),
|
||||
"timestamp": time.Now().Unix(),
|
||||
"total": h.hypercoreLog.Length(),
|
||||
}
|
||||
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleGetRecentLogs returns the most recent log entries
|
||||
func (h *HTTPServer) handleGetRecentLogs(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
||||
// Parse limit parameter
|
||||
query := r.URL.Query()
|
||||
limitStr := query.Get("limit")
|
||||
|
||||
|
||||
limit := 50 // Default
|
||||
if limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 1000 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
entries, err := h.hypercoreLog.GetRecentEntries(limit)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to get recent entries: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
response := map[string]interface{}{
|
||||
"entries": entries,
|
||||
"count": len(entries),
|
||||
"timestamp": time.Now().Unix(),
|
||||
"total": h.hypercoreLog.Length(),
|
||||
}
|
||||
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleGetLogsSince returns log entries since a given index
|
||||
func (h *HTTPServer) handleGetLogsSince(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
||||
vars := mux.Vars(r)
|
||||
indexStr := vars["index"]
|
||||
|
||||
|
||||
index, err := strconv.ParseUint(indexStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid index parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
entries, err := h.hypercoreLog.GetEntriesSince(index)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to get entries since index: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
response := map[string]interface{}{
|
||||
"entries": entries,
|
||||
"count": len(entries),
|
||||
"entries": entries,
|
||||
"count": len(entries),
|
||||
"since_index": index,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"total": h.hypercoreLog.Length(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
"total": h.hypercoreLog.Length(),
|
||||
}
|
||||
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleGetLogStats returns statistics about the hypercore log
|
||||
func (h *HTTPServer) handleGetLogStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
||||
stats := h.hypercoreLog.GetStats()
|
||||
json.NewEncoder(w).Encode(stats)
|
||||
}
|
||||
@@ -218,26 +219,26 @@ func (h *HTTPServer) handleGetLogStats(w http.ResponseWriter, r *http.Request) {
|
||||
// handleHealth returns health status
|
||||
func (h *HTTPServer) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
||||
health := map[string]interface{}{
|
||||
"status": "healthy",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"status": "healthy",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"log_entries": h.hypercoreLog.Length(),
|
||||
}
|
||||
|
||||
|
||||
json.NewEncoder(w).Encode(health)
|
||||
}
|
||||
|
||||
// handleStatus returns detailed status information
|
||||
func (h *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
||||
status := map[string]interface{}{
|
||||
"status": "running",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"hypercore": h.hypercoreLog.GetStats(),
|
||||
"api_version": "1.0.0",
|
||||
"status": "running",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"hypercore": h.hypercoreLog.GetStats(),
|
||||
"api_version": "1.0.0",
|
||||
}
|
||||
|
||||
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
}
|
||||
|
||||
BIN
chorus-agent
Executable file
BIN
chorus-agent
Executable file
Binary file not shown.
@@ -8,12 +8,19 @@ import (
|
||||
"chorus/internal/runtime"
|
||||
)
|
||||
|
||||
// Build-time variables set by ldflags
|
||||
var (
|
||||
version = "0.5.0-dev"
|
||||
commitHash = "unknown"
|
||||
buildDate = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Early CLI handling: print help/version without requiring env/config
|
||||
for _, a := range os.Args[1:] {
|
||||
switch a {
|
||||
case "--help", "-h", "help":
|
||||
fmt.Printf("%s-agent %s\n\n", runtime.AppName, runtime.AppVersion)
|
||||
fmt.Printf("%s-agent %s (build: %s, %s)\n\n", runtime.AppName, version, commitHash, buildDate)
|
||||
fmt.Println("Usage:")
|
||||
fmt.Printf(" %s [--help] [--version]\n\n", filepath.Base(os.Args[0]))
|
||||
fmt.Println("CHORUS Autonomous Agent - P2P Task Coordination")
|
||||
@@ -46,11 +53,16 @@ func main() {
|
||||
fmt.Println(" - Health monitoring")
|
||||
return
|
||||
case "--version", "-v":
|
||||
fmt.Printf("%s-agent %s\n", runtime.AppName, runtime.AppVersion)
|
||||
fmt.Printf("%s-agent %s (build: %s, %s)\n", runtime.AppName, version, commitHash, buildDate)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set dynamic build information
|
||||
runtime.AppVersion = version
|
||||
runtime.AppCommitHash = commitHash
|
||||
runtime.AppBuildDate = buildDate
|
||||
|
||||
// Initialize shared P2P runtime
|
||||
sharedRuntime, err := runtime.Initialize("agent")
|
||||
if err != nil {
|
||||
|
||||
372
configs/models.yaml
Normal file
372
configs/models.yaml
Normal file
@@ -0,0 +1,372 @@
|
||||
# CHORUS AI Provider and Model Configuration
|
||||
# This file defines how different agent roles map to AI models and providers
|
||||
|
||||
# Global provider settings
|
||||
providers:
|
||||
# Local Ollama instance (default for most roles)
|
||||
ollama:
|
||||
type: ollama
|
||||
endpoint: http://localhost:11434
|
||||
default_model: llama3.1:8b
|
||||
temperature: 0.7
|
||||
max_tokens: 4096
|
||||
timeout: 300s
|
||||
retry_attempts: 3
|
||||
retry_delay: 2s
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
mcp_servers: []
|
||||
|
||||
# Ollama cluster nodes (for load balancing)
|
||||
ollama_cluster:
|
||||
type: ollama
|
||||
endpoint: http://192.168.1.72:11434 # Primary node
|
||||
default_model: llama3.1:8b
|
||||
temperature: 0.7
|
||||
max_tokens: 4096
|
||||
timeout: 300s
|
||||
retry_attempts: 3
|
||||
retry_delay: 2s
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
|
||||
# OpenAI API (for advanced models)
|
||||
openai:
|
||||
type: openai
|
||||
endpoint: https://api.openai.com/v1
|
||||
api_key: ${OPENAI_API_KEY}
|
||||
default_model: gpt-4o
|
||||
temperature: 0.7
|
||||
max_tokens: 4096
|
||||
timeout: 120s
|
||||
retry_attempts: 3
|
||||
retry_delay: 5s
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
|
||||
# ResetData LaaS (fallback/testing)
|
||||
resetdata:
|
||||
type: resetdata
|
||||
endpoint: ${RESETDATA_ENDPOINT}
|
||||
api_key: ${RESETDATA_API_KEY}
|
||||
default_model: llama3.1:8b
|
||||
temperature: 0.7
|
||||
max_tokens: 4096
|
||||
timeout: 300s
|
||||
retry_attempts: 3
|
||||
retry_delay: 2s
|
||||
enable_tools: false
|
||||
enable_mcp: false
|
||||
|
||||
# Global fallback settings
|
||||
default_provider: ollama
|
||||
fallback_provider: resetdata
|
||||
|
||||
# Role-based model mappings
|
||||
roles:
|
||||
# Software Developer Agent
|
||||
developer:
|
||||
provider: ollama
|
||||
model: codellama:13b
|
||||
temperature: 0.3 # Lower temperature for more consistent code
|
||||
max_tokens: 8192 # Larger context for code generation
|
||||
system_prompt: |
|
||||
You are an expert software developer agent in the CHORUS autonomous development system.
|
||||
|
||||
Your expertise includes:
|
||||
- Writing clean, maintainable, and well-documented code
|
||||
- Following language-specific best practices and conventions
|
||||
- Implementing proper error handling and validation
|
||||
- Creating comprehensive tests for your code
|
||||
- Considering performance, security, and scalability
|
||||
|
||||
Always provide specific, actionable implementation steps with code examples.
|
||||
Focus on delivering production-ready solutions that follow industry best practices.
|
||||
fallback_provider: resetdata
|
||||
fallback_model: codellama:7b
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
allowed_tools:
|
||||
- file_operation
|
||||
- execute_command
|
||||
- git_operations
|
||||
- code_analysis
|
||||
mcp_servers:
|
||||
- file-server
|
||||
- git-server
|
||||
- code-tools
|
||||
|
||||
# Code Reviewer Agent
|
||||
reviewer:
|
||||
provider: ollama
|
||||
model: llama3.1:8b
|
||||
temperature: 0.2 # Very low temperature for consistent analysis
|
||||
max_tokens: 6144
|
||||
system_prompt: |
|
||||
You are a thorough code reviewer agent in the CHORUS autonomous development system.
|
||||
|
||||
Your responsibilities include:
|
||||
- Analyzing code quality, readability, and maintainability
|
||||
- Identifying bugs, security vulnerabilities, and performance issues
|
||||
- Checking test coverage and test quality
|
||||
- Verifying documentation completeness and accuracy
|
||||
- Suggesting improvements and refactoring opportunities
|
||||
- Ensuring compliance with coding standards and best practices
|
||||
|
||||
Always provide constructive feedback with specific examples and suggestions for improvement.
|
||||
Focus on both technical correctness and long-term maintainability.
|
||||
fallback_provider: resetdata
|
||||
fallback_model: llama3.1:8b
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
allowed_tools:
|
||||
- code_analysis
|
||||
- security_scan
|
||||
- test_coverage
|
||||
- documentation_check
|
||||
mcp_servers:
|
||||
- code-analysis-server
|
||||
- security-tools
|
||||
|
||||
# Software Architect Agent
|
||||
architect:
|
||||
provider: openai # Use OpenAI for complex architectural decisions
|
||||
model: gpt-4o
|
||||
temperature: 0.5 # Balanced creativity and consistency
|
||||
max_tokens: 8192 # Large context for architectural discussions
|
||||
system_prompt: |
|
||||
You are a senior software architect agent in the CHORUS autonomous development system.
|
||||
|
||||
Your expertise includes:
|
||||
- Designing scalable and maintainable system architectures
|
||||
- Making informed decisions about technologies and frameworks
|
||||
- Defining clear interfaces and API contracts
|
||||
- Considering scalability, performance, and security requirements
|
||||
- Creating architectural documentation and diagrams
|
||||
- Evaluating trade-offs between different architectural approaches
|
||||
|
||||
Always provide well-reasoned architectural decisions with clear justifications.
|
||||
Consider both immediate requirements and long-term evolution of the system.
|
||||
fallback_provider: ollama
|
||||
fallback_model: llama3.1:13b
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
allowed_tools:
|
||||
- architecture_analysis
|
||||
- diagram_generation
|
||||
- technology_research
|
||||
- api_design
|
||||
mcp_servers:
|
||||
- architecture-tools
|
||||
- diagram-server
|
||||
|
||||
# QA/Testing Agent
|
||||
tester:
|
||||
provider: ollama
|
||||
model: codellama:7b # Smaller model, focused on test generation
|
||||
temperature: 0.3
|
||||
max_tokens: 6144
|
||||
system_prompt: |
|
||||
You are a quality assurance engineer agent in the CHORUS autonomous development system.
|
||||
|
||||
Your responsibilities include:
|
||||
- Creating comprehensive test plans and test cases
|
||||
- Implementing unit, integration, and end-to-end tests
|
||||
- Identifying edge cases and potential failure scenarios
|
||||
- Setting up test automation and continuous integration
|
||||
- Validating functionality against requirements
|
||||
- Performing security and performance testing
|
||||
|
||||
Always focus on thorough test coverage and quality assurance practices.
|
||||
Ensure tests are maintainable, reliable, and provide meaningful feedback.
|
||||
fallback_provider: resetdata
|
||||
fallback_model: llama3.1:8b
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
allowed_tools:
|
||||
- test_generation
|
||||
- test_execution
|
||||
- coverage_analysis
|
||||
- performance_testing
|
||||
mcp_servers:
|
||||
- testing-framework
|
||||
- coverage-tools
|
||||
|
||||
# DevOps/Infrastructure Agent
|
||||
devops:
|
||||
provider: ollama_cluster
|
||||
model: llama3.1:8b
|
||||
temperature: 0.4
|
||||
max_tokens: 6144
|
||||
system_prompt: |
|
||||
You are a DevOps engineer agent in the CHORUS autonomous development system.
|
||||
|
||||
Your expertise includes:
|
||||
- Automating deployment processes and CI/CD pipelines
|
||||
- Managing containerization with Docker and orchestration with Kubernetes
|
||||
- Implementing infrastructure as code (IaC)
|
||||
- Monitoring, logging, and observability setup
|
||||
- Security hardening and compliance management
|
||||
- Performance optimization and scaling strategies
|
||||
|
||||
Always focus on automation, reliability, and security in your solutions.
|
||||
Ensure infrastructure is scalable, maintainable, and follows best practices.
|
||||
fallback_provider: resetdata
|
||||
fallback_model: llama3.1:8b
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
allowed_tools:
|
||||
- docker_operations
|
||||
- kubernetes_management
|
||||
- ci_cd_tools
|
||||
- monitoring_setup
|
||||
- security_hardening
|
||||
mcp_servers:
|
||||
- docker-server
|
||||
- k8s-tools
|
||||
- monitoring-server
|
||||
|
||||
# Security Specialist Agent
|
||||
security:
|
||||
provider: openai
|
||||
model: gpt-4o # Use advanced model for security analysis
|
||||
temperature: 0.1 # Very conservative for security
|
||||
max_tokens: 8192
|
||||
system_prompt: |
|
||||
You are a security specialist agent in the CHORUS autonomous development system.
|
||||
|
||||
Your expertise includes:
|
||||
- Conducting security audits and vulnerability assessments
|
||||
- Implementing security best practices and controls
|
||||
- Analyzing code for security vulnerabilities
|
||||
- Setting up security monitoring and incident response
|
||||
- Ensuring compliance with security standards
|
||||
- Designing secure architectures and data flows
|
||||
|
||||
Always prioritize security over convenience and thoroughly analyze potential threats.
|
||||
Provide specific, actionable security recommendations with risk assessments.
|
||||
fallback_provider: ollama
|
||||
fallback_model: llama3.1:8b
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
allowed_tools:
|
||||
- security_scan
|
||||
- vulnerability_assessment
|
||||
- compliance_check
|
||||
- threat_modeling
|
||||
mcp_servers:
|
||||
- security-tools
|
||||
- compliance-server
|
||||
|
||||
# Documentation Agent
|
||||
documentation:
|
||||
provider: ollama
|
||||
model: llama3.1:8b
|
||||
temperature: 0.6 # Slightly higher for creative writing
|
||||
max_tokens: 8192
|
||||
system_prompt: |
|
||||
You are a technical documentation specialist agent in the CHORUS autonomous development system.
|
||||
|
||||
Your expertise includes:
|
||||
- Creating clear, comprehensive technical documentation
|
||||
- Writing user guides, API documentation, and tutorials
|
||||
- Maintaining README files and project wikis
|
||||
- Creating architectural decision records (ADRs)
|
||||
- Developing onboarding materials and runbooks
|
||||
- Ensuring documentation accuracy and completeness
|
||||
|
||||
Always write documentation that is clear, actionable, and accessible to your target audience.
|
||||
Focus on providing practical information that helps users accomplish their goals.
|
||||
fallback_provider: resetdata
|
||||
fallback_model: llama3.1:8b
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
allowed_tools:
|
||||
- documentation_generation
|
||||
- markdown_processing
|
||||
- diagram_creation
|
||||
- content_validation
|
||||
mcp_servers:
|
||||
- docs-server
|
||||
- markdown-tools
|
||||
|
||||
# General Purpose Agent (fallback)
|
||||
general:
|
||||
provider: ollama
|
||||
model: llama3.1:8b
|
||||
temperature: 0.7
|
||||
max_tokens: 4096
|
||||
system_prompt: |
|
||||
You are a general-purpose AI agent in the CHORUS autonomous development system.
|
||||
|
||||
Your capabilities include:
|
||||
- Analyzing and understanding various types of development tasks
|
||||
- Providing guidance on software development best practices
|
||||
- Assisting with problem-solving and decision-making
|
||||
- Coordinating with other specialized agents when needed
|
||||
|
||||
Always provide helpful, accurate information and know when to defer to specialized agents.
|
||||
Focus on understanding the task requirements and providing appropriate guidance.
|
||||
fallback_provider: resetdata
|
||||
fallback_model: llama3.1:8b
|
||||
enable_tools: true
|
||||
enable_mcp: true
|
||||
|
||||
# Environment-specific overrides
|
||||
environments:
|
||||
development:
|
||||
# Use local models for development to reduce costs
|
||||
default_provider: ollama
|
||||
fallback_provider: resetdata
|
||||
|
||||
staging:
|
||||
# Mix of local and cloud models for realistic testing
|
||||
default_provider: ollama_cluster
|
||||
fallback_provider: openai
|
||||
|
||||
production:
|
||||
# Prefer reliable cloud providers with fallback to local
|
||||
default_provider: openai
|
||||
fallback_provider: ollama_cluster
|
||||
|
||||
# Model performance preferences (for auto-selection)
|
||||
model_preferences:
|
||||
# Code generation tasks
|
||||
code_generation:
|
||||
preferred_models:
|
||||
- codellama:13b
|
||||
- gpt-4o
|
||||
- codellama:34b
|
||||
min_context_tokens: 8192
|
||||
|
||||
# Code review tasks
|
||||
code_review:
|
||||
preferred_models:
|
||||
- llama3.1:8b
|
||||
- gpt-4o
|
||||
- llama3.1:13b
|
||||
min_context_tokens: 6144
|
||||
|
||||
# Architecture and design
|
||||
architecture:
|
||||
preferred_models:
|
||||
- gpt-4o
|
||||
- llama3.1:13b
|
||||
- llama3.1:70b
|
||||
min_context_tokens: 8192
|
||||
|
||||
# Testing and QA
|
||||
testing:
|
||||
preferred_models:
|
||||
- codellama:7b
|
||||
- llama3.1:8b
|
||||
- codellama:13b
|
||||
min_context_tokens: 6144
|
||||
|
||||
# Documentation
|
||||
documentation:
|
||||
preferred_models:
|
||||
- llama3.1:8b
|
||||
- gpt-4o
|
||||
- mistral:7b
|
||||
min_context_tokens: 8192
|
||||
@@ -8,51 +8,63 @@ import (
|
||||
"time"
|
||||
|
||||
"chorus/internal/logging"
|
||||
"chorus/pkg/ai"
|
||||
"chorus/pkg/config"
|
||||
"chorus/pubsub"
|
||||
"chorus/pkg/repository"
|
||||
"chorus/pkg/execution"
|
||||
"chorus/pkg/hmmm"
|
||||
"chorus/pkg/repository"
|
||||
"chorus/pubsub"
|
||||
"github.com/google/uuid"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
// TaskProgressTracker is notified when tasks start and complete so availability broadcasts stay accurate.
|
||||
type TaskProgressTracker interface {
|
||||
AddTask(taskID string)
|
||||
RemoveTask(taskID string)
|
||||
}
|
||||
|
||||
// TaskCoordinator manages task discovery, assignment, and execution across multiple repositories
|
||||
type TaskCoordinator struct {
|
||||
pubsub *pubsub.PubSub
|
||||
hlog *logging.HypercoreLog
|
||||
ctx context.Context
|
||||
config *config.Config
|
||||
hmmmRouter *hmmm.Router
|
||||
|
||||
pubsub *pubsub.PubSub
|
||||
hlog *logging.HypercoreLog
|
||||
ctx context.Context
|
||||
config *config.Config
|
||||
hmmmRouter *hmmm.Router
|
||||
|
||||
// Repository management
|
||||
providers map[int]repository.TaskProvider // projectID -> provider
|
||||
providerLock sync.RWMutex
|
||||
factory repository.ProviderFactory
|
||||
|
||||
providers map[int]repository.TaskProvider // projectID -> provider
|
||||
providerLock sync.RWMutex
|
||||
factory repository.ProviderFactory
|
||||
|
||||
// Task management
|
||||
activeTasks map[string]*ActiveTask // taskKey -> active task
|
||||
taskLock sync.RWMutex
|
||||
taskMatcher repository.TaskMatcher
|
||||
|
||||
activeTasks map[string]*ActiveTask // taskKey -> active task
|
||||
taskLock sync.RWMutex
|
||||
taskMatcher repository.TaskMatcher
|
||||
taskTracker TaskProgressTracker
|
||||
|
||||
// Task execution
|
||||
executionEngine execution.TaskExecutionEngine
|
||||
|
||||
// Agent tracking
|
||||
nodeID string
|
||||
agentInfo *repository.AgentInfo
|
||||
|
||||
nodeID string
|
||||
agentInfo *repository.AgentInfo
|
||||
|
||||
// Sync settings
|
||||
syncInterval time.Duration
|
||||
lastSync map[int]time.Time
|
||||
syncLock sync.RWMutex
|
||||
syncInterval time.Duration
|
||||
lastSync map[int]time.Time
|
||||
syncLock sync.RWMutex
|
||||
}
|
||||
|
||||
// ActiveTask represents a task currently being worked on
|
||||
type ActiveTask struct {
|
||||
Task *repository.Task
|
||||
Provider repository.TaskProvider
|
||||
ProjectID int
|
||||
ClaimedAt time.Time
|
||||
Status string // claimed, working, completed, failed
|
||||
AgentID string
|
||||
Results map[string]interface{}
|
||||
Task *repository.Task
|
||||
Provider repository.TaskProvider
|
||||
ProjectID int
|
||||
ClaimedAt time.Time
|
||||
Status string // claimed, working, completed, failed
|
||||
AgentID string
|
||||
Results map[string]interface{}
|
||||
}
|
||||
|
||||
// NewTaskCoordinator creates a new task coordinator
|
||||
@@ -63,7 +75,9 @@ func NewTaskCoordinator(
|
||||
cfg *config.Config,
|
||||
nodeID string,
|
||||
hmmmRouter *hmmm.Router,
|
||||
tracker TaskProgressTracker,
|
||||
) *TaskCoordinator {
|
||||
|
||||
coordinator := &TaskCoordinator{
|
||||
pubsub: ps,
|
||||
hlog: hlog,
|
||||
@@ -75,10 +89,11 @@ func NewTaskCoordinator(
|
||||
lastSync: make(map[int]time.Time),
|
||||
factory: &repository.DefaultProviderFactory{},
|
||||
taskMatcher: &repository.DefaultTaskMatcher{},
|
||||
taskTracker: tracker,
|
||||
nodeID: nodeID,
|
||||
syncInterval: 30 * time.Second,
|
||||
}
|
||||
|
||||
|
||||
// Create agent info from config
|
||||
coordinator.agentInfo = &repository.AgentInfo{
|
||||
ID: cfg.Agent.ID,
|
||||
@@ -91,23 +106,30 @@ func NewTaskCoordinator(
|
||||
Performance: map[string]interface{}{"score": 0.8}, // Default performance score
|
||||
Availability: "available",
|
||||
}
|
||||
|
||||
|
||||
return coordinator
|
||||
}
|
||||
|
||||
// Start begins the task coordination process
|
||||
func (tc *TaskCoordinator) Start() {
|
||||
fmt.Printf("🎯 Starting task coordinator for agent %s (%s)\n", tc.agentInfo.ID, tc.agentInfo.Role)
|
||||
|
||||
|
||||
// Initialize task execution engine
|
||||
err := tc.initializeExecutionEngine()
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Failed to initialize task execution engine: %v\n", err)
|
||||
fmt.Println("Task execution will fall back to mock implementation")
|
||||
}
|
||||
|
||||
// Announce role and capabilities
|
||||
tc.announceAgentRole()
|
||||
|
||||
|
||||
// Start periodic task discovery and sync
|
||||
go tc.taskDiscoveryLoop()
|
||||
|
||||
|
||||
// Start role-based message handling
|
||||
tc.pubsub.SetAntennaeMessageHandler(tc.handleRoleMessage)
|
||||
|
||||
|
||||
fmt.Printf("✅ Task coordinator started\n")
|
||||
}
|
||||
|
||||
@@ -185,13 +207,17 @@ func (tc *TaskCoordinator) processTask(task *repository.Task, provider repositor
|
||||
tc.agentInfo.CurrentTasks = len(tc.activeTasks)
|
||||
tc.taskLock.Unlock()
|
||||
|
||||
if tc.taskTracker != nil {
|
||||
tc.taskTracker.AddTask(taskKey)
|
||||
}
|
||||
|
||||
// Log task claim
|
||||
tc.hlog.Append(logging.TaskClaimed, map[string]interface{}{
|
||||
"task_number": task.Number,
|
||||
"repository": task.Repository,
|
||||
"title": task.Title,
|
||||
"task_number": task.Number,
|
||||
"repository": task.Repository,
|
||||
"title": task.Title,
|
||||
"required_role": task.RequiredRole,
|
||||
"priority": task.Priority,
|
||||
"priority": task.Priority,
|
||||
})
|
||||
|
||||
// Announce task claim
|
||||
@@ -212,11 +238,11 @@ func (tc *TaskCoordinator) processTask(task *repository.Task, provider repositor
|
||||
}
|
||||
if err := tc.hmmmRouter.Publish(tc.ctx, seedMsg); err != nil {
|
||||
fmt.Printf("⚠️ Failed to seed HMMM room for task %d: %v\n", task.Number, err)
|
||||
tc.hlog.AppendString("system_error", map[string]interface{}{
|
||||
"error": "hmmm_seed_failed",
|
||||
"task_number": task.Number,
|
||||
"repository": task.Repository,
|
||||
"message": err.Error(),
|
||||
tc.hlog.AppendString("system_error", map[string]interface{}{
|
||||
"error": "hmmm_seed_failed",
|
||||
"task_number": task.Number,
|
||||
"repository": task.Repository,
|
||||
"message": err.Error(),
|
||||
})
|
||||
} else {
|
||||
fmt.Printf("🐜 Seeded HMMM room for task %d\n", task.Number)
|
||||
@@ -259,14 +285,14 @@ func (tc *TaskCoordinator) shouldRequestCollaboration(task *repository.Task) boo
|
||||
// requestTaskCollaboration requests collaboration for a task
|
||||
func (tc *TaskCoordinator) requestTaskCollaboration(task *repository.Task) {
|
||||
data := map[string]interface{}{
|
||||
"task_number": task.Number,
|
||||
"repository": task.Repository,
|
||||
"title": task.Title,
|
||||
"required_role": task.RequiredRole,
|
||||
"task_number": task.Number,
|
||||
"repository": task.Repository,
|
||||
"title": task.Title,
|
||||
"required_role": task.RequiredRole,
|
||||
"required_expertise": task.RequiredExpertise,
|
||||
"priority": task.Priority,
|
||||
"requester_role": tc.agentInfo.Role,
|
||||
"reason": "expertise_gap",
|
||||
"priority": task.Priority,
|
||||
"requester_role": tc.agentInfo.Role,
|
||||
"reason": "expertise_gap",
|
||||
}
|
||||
|
||||
opts := pubsub.MessageOptions{
|
||||
@@ -285,10 +311,69 @@ func (tc *TaskCoordinator) requestTaskCollaboration(task *repository.Task) {
|
||||
}
|
||||
}
|
||||
|
||||
// initializeExecutionEngine sets up the AI-powered task execution engine
|
||||
func (tc *TaskCoordinator) initializeExecutionEngine() error {
|
||||
// Create AI provider factory
|
||||
aiFactory := ai.NewProviderFactory()
|
||||
|
||||
// Load AI configuration from config file
|
||||
configPath := "configs/models.yaml"
|
||||
configLoader := ai.NewConfigLoader(configPath, "production")
|
||||
_, err := configLoader.LoadConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load AI config: %w", err)
|
||||
}
|
||||
|
||||
// Initialize the factory with the loaded configuration
|
||||
// For now, we'll use a simplified initialization
|
||||
// In a complete implementation, the factory would have an Initialize method
|
||||
|
||||
// Create task execution engine
|
||||
tc.executionEngine = execution.NewTaskExecutionEngine()
|
||||
|
||||
// Configure execution engine
|
||||
engineConfig := &execution.EngineConfig{
|
||||
AIProviderFactory: aiFactory,
|
||||
DefaultTimeout: 5 * time.Minute,
|
||||
MaxConcurrentTasks: tc.agentInfo.MaxTasks,
|
||||
EnableMetrics: true,
|
||||
LogLevel: "info",
|
||||
SandboxDefaults: &execution.SandboxConfig{
|
||||
Type: "docker",
|
||||
Image: "alpine:latest",
|
||||
Architecture: "amd64",
|
||||
Resources: execution.ResourceLimits{
|
||||
MemoryLimit: 512 * 1024 * 1024, // 512MB
|
||||
CPULimit: 1.0,
|
||||
ProcessLimit: 50,
|
||||
FileLimit: 1024,
|
||||
},
|
||||
Security: execution.SecurityPolicy{
|
||||
ReadOnlyRoot: false,
|
||||
NoNewPrivileges: true,
|
||||
AllowNetworking: true,
|
||||
IsolateNetwork: false,
|
||||
IsolateProcess: true,
|
||||
DropCapabilities: []string{"NET_ADMIN", "SYS_ADMIN"},
|
||||
},
|
||||
WorkingDir: "/workspace",
|
||||
Timeout: 5 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
err = tc.executionEngine.Initialize(tc.ctx, engineConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize execution engine: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("✅ Task execution engine initialized successfully\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeTask executes a claimed task
|
||||
func (tc *TaskCoordinator) executeTask(activeTask *ActiveTask) {
|
||||
taskKey := fmt.Sprintf("%s:%d", activeTask.Task.Repository, activeTask.Task.Number)
|
||||
|
||||
|
||||
// Update status
|
||||
tc.taskLock.Lock()
|
||||
activeTask.Status = "working"
|
||||
@@ -297,49 +382,59 @@ func (tc *TaskCoordinator) executeTask(activeTask *ActiveTask) {
|
||||
// Announce work start
|
||||
tc.announceTaskProgress(activeTask.Task, "started")
|
||||
|
||||
// Simulate task execution (in real implementation, this would call actual execution logic)
|
||||
time.Sleep(10 * time.Second) // Simulate work
|
||||
// Execute task using AI-powered execution engine
|
||||
var taskResult *repository.TaskResult
|
||||
|
||||
// Complete the task
|
||||
results := map[string]interface{}{
|
||||
"status": "completed",
|
||||
"completion_time": time.Now().Format(time.RFC3339),
|
||||
"agent_id": tc.agentInfo.ID,
|
||||
"agent_role": tc.agentInfo.Role,
|
||||
}
|
||||
if tc.executionEngine != nil {
|
||||
// Use real AI-powered execution
|
||||
executionResult, err := tc.executeTaskWithAI(activeTask)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ AI execution failed for task %s #%d: %v\n",
|
||||
activeTask.Task.Repository, activeTask.Task.Number, err)
|
||||
|
||||
taskResult := &repository.TaskResult{
|
||||
Success: true,
|
||||
Message: "Task completed successfully",
|
||||
Metadata: results,
|
||||
// Fall back to mock execution
|
||||
taskResult = tc.executeMockTask(activeTask)
|
||||
} else {
|
||||
// Convert execution result to task result
|
||||
taskResult = tc.convertExecutionResult(activeTask, executionResult)
|
||||
}
|
||||
} else {
|
||||
// Fall back to mock execution
|
||||
fmt.Printf("📝 Using mock execution for task %s #%d (engine not available)\n",
|
||||
activeTask.Task.Repository, activeTask.Task.Number)
|
||||
taskResult = tc.executeMockTask(activeTask)
|
||||
}
|
||||
err := activeTask.Provider.CompleteTask(activeTask.Task, taskResult)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ Failed to complete task %s #%d: %v\n", activeTask.Task.Repository, activeTask.Task.Number, err)
|
||||
|
||||
|
||||
// Update status to failed
|
||||
tc.taskLock.Lock()
|
||||
activeTask.Status = "failed"
|
||||
activeTask.Results = map[string]interface{}{"error": err.Error()}
|
||||
tc.taskLock.Unlock()
|
||||
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Update status and remove from active tasks
|
||||
tc.taskLock.Lock()
|
||||
activeTask.Status = "completed"
|
||||
activeTask.Results = results
|
||||
activeTask.Results = taskResult.Metadata
|
||||
delete(tc.activeTasks, taskKey)
|
||||
tc.agentInfo.CurrentTasks = len(tc.activeTasks)
|
||||
tc.taskLock.Unlock()
|
||||
|
||||
if tc.taskTracker != nil {
|
||||
tc.taskTracker.RemoveTask(taskKey)
|
||||
}
|
||||
|
||||
// Log completion
|
||||
tc.hlog.Append(logging.TaskCompleted, map[string]interface{}{
|
||||
"task_number": activeTask.Task.Number,
|
||||
"repository": activeTask.Task.Repository,
|
||||
"duration": time.Since(activeTask.ClaimedAt).Seconds(),
|
||||
"results": results,
|
||||
"results": taskResult.Metadata,
|
||||
})
|
||||
|
||||
// Announce completion
|
||||
@@ -348,6 +443,200 @@ func (tc *TaskCoordinator) executeTask(activeTask *ActiveTask) {
|
||||
fmt.Printf("✅ Completed task %s #%d\n", activeTask.Task.Repository, activeTask.Task.Number)
|
||||
}
|
||||
|
||||
// executeTaskWithAI executes a task using the AI-powered execution engine
|
||||
func (tc *TaskCoordinator) executeTaskWithAI(activeTask *ActiveTask) (*execution.TaskExecutionResult, error) {
|
||||
// Convert repository task to execution request
|
||||
executionRequest := &execution.TaskExecutionRequest{
|
||||
ID: fmt.Sprintf("%s:%d", activeTask.Task.Repository, activeTask.Task.Number),
|
||||
Type: tc.determineTaskType(activeTask.Task),
|
||||
Description: tc.buildTaskDescription(activeTask.Task),
|
||||
Context: tc.buildTaskContext(activeTask.Task),
|
||||
Requirements: &execution.TaskRequirements{
|
||||
AIModel: "", // Let the engine choose based on role
|
||||
SandboxType: "docker",
|
||||
RequiredTools: []string{"git", "curl"},
|
||||
EnvironmentVars: map[string]string{
|
||||
"TASK_ID": fmt.Sprintf("%d", activeTask.Task.Number),
|
||||
"REPOSITORY": activeTask.Task.Repository,
|
||||
"AGENT_ID": tc.agentInfo.ID,
|
||||
"AGENT_ROLE": tc.agentInfo.Role,
|
||||
},
|
||||
},
|
||||
Timeout: 10 * time.Minute, // Allow longer timeout for complex tasks
|
||||
}
|
||||
|
||||
// Execute the task
|
||||
return tc.executionEngine.ExecuteTask(tc.ctx, executionRequest)
|
||||
}
|
||||
|
||||
// executeMockTask provides fallback mock execution
|
||||
func (tc *TaskCoordinator) executeMockTask(activeTask *ActiveTask) *repository.TaskResult {
|
||||
// Simulate work time based on task complexity
|
||||
workTime := 5 * time.Second
|
||||
if strings.Contains(strings.ToLower(activeTask.Task.Title), "complex") {
|
||||
workTime = 15 * time.Second
|
||||
}
|
||||
|
||||
fmt.Printf("🕐 Mock execution for task %s #%d (simulating %v)\n",
|
||||
activeTask.Task.Repository, activeTask.Task.Number, workTime)
|
||||
|
||||
time.Sleep(workTime)
|
||||
|
||||
results := map[string]interface{}{
|
||||
"status": "completed",
|
||||
"execution_type": "mock",
|
||||
"completion_time": time.Now().Format(time.RFC3339),
|
||||
"agent_id": tc.agentInfo.ID,
|
||||
"agent_role": tc.agentInfo.Role,
|
||||
"simulated_work": workTime.String(),
|
||||
}
|
||||
|
||||
return &repository.TaskResult{
|
||||
Success: true,
|
||||
Message: "Task completed successfully (mock execution)",
|
||||
Metadata: results,
|
||||
}
|
||||
}
|
||||
|
||||
// convertExecutionResult converts an execution result to a task result
|
||||
func (tc *TaskCoordinator) convertExecutionResult(activeTask *ActiveTask, result *execution.TaskExecutionResult) *repository.TaskResult {
|
||||
// Build result metadata
|
||||
metadata := map[string]interface{}{
|
||||
"status": "completed",
|
||||
"execution_type": "ai_powered",
|
||||
"completion_time": time.Now().Format(time.RFC3339),
|
||||
"agent_id": tc.agentInfo.ID,
|
||||
"agent_role": tc.agentInfo.Role,
|
||||
"task_id": result.TaskID,
|
||||
"duration": result.Metrics.Duration.String(),
|
||||
"ai_provider_time": result.Metrics.AIProviderTime.String(),
|
||||
"sandbox_time": result.Metrics.SandboxTime.String(),
|
||||
"commands_executed": result.Metrics.CommandsExecuted,
|
||||
"files_generated": result.Metrics.FilesGenerated,
|
||||
}
|
||||
|
||||
// Add execution metadata if available
|
||||
if result.Metadata != nil {
|
||||
metadata["ai_metadata"] = result.Metadata
|
||||
}
|
||||
|
||||
// Add resource usage if available
|
||||
if result.Metrics.ResourceUsage != nil {
|
||||
metadata["resource_usage"] = map[string]interface{}{
|
||||
"cpu_usage": result.Metrics.ResourceUsage.CPUUsage,
|
||||
"memory_usage": result.Metrics.ResourceUsage.MemoryUsage,
|
||||
"memory_percent": result.Metrics.ResourceUsage.MemoryPercent,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle artifacts
|
||||
if len(result.Artifacts) > 0 {
|
||||
artifactsList := make([]map[string]interface{}, len(result.Artifacts))
|
||||
for i, artifact := range result.Artifacts {
|
||||
artifactsList[i] = map[string]interface{}{
|
||||
"name": artifact.Name,
|
||||
"type": artifact.Type,
|
||||
"size": artifact.Size,
|
||||
"created_at": artifact.CreatedAt.Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
metadata["artifacts"] = artifactsList
|
||||
}
|
||||
|
||||
// Determine success based on execution result
|
||||
success := result.Success
|
||||
message := "Task completed successfully with AI execution"
|
||||
|
||||
if !success {
|
||||
message = fmt.Sprintf("Task failed: %s", result.ErrorMessage)
|
||||
}
|
||||
|
||||
return &repository.TaskResult{
|
||||
Success: success,
|
||||
Message: message,
|
||||
Metadata: metadata,
|
||||
}
|
||||
}
|
||||
|
||||
// determineTaskType analyzes a task to determine its execution type
|
||||
func (tc *TaskCoordinator) determineTaskType(task *repository.Task) string {
|
||||
title := strings.ToLower(task.Title)
|
||||
description := strings.ToLower(task.Body)
|
||||
|
||||
// Check for common task type keywords
|
||||
if strings.Contains(title, "bug") || strings.Contains(title, "fix") {
|
||||
return "bug_fix"
|
||||
}
|
||||
if strings.Contains(title, "feature") || strings.Contains(title, "implement") {
|
||||
return "feature_development"
|
||||
}
|
||||
if strings.Contains(title, "test") || strings.Contains(description, "test") {
|
||||
return "testing"
|
||||
}
|
||||
if strings.Contains(title, "doc") || strings.Contains(description, "documentation") {
|
||||
return "documentation"
|
||||
}
|
||||
if strings.Contains(title, "refactor") || strings.Contains(description, "refactor") {
|
||||
return "refactoring"
|
||||
}
|
||||
if strings.Contains(title, "review") || strings.Contains(description, "review") {
|
||||
return "code_review"
|
||||
}
|
||||
|
||||
// Default to general development task
|
||||
return "development"
|
||||
}
|
||||
|
||||
// buildTaskDescription creates a comprehensive description for AI execution
|
||||
func (tc *TaskCoordinator) buildTaskDescription(task *repository.Task) string {
|
||||
var description strings.Builder
|
||||
|
||||
description.WriteString(fmt.Sprintf("Task: %s\n\n", task.Title))
|
||||
|
||||
if task.Body != "" {
|
||||
description.WriteString(fmt.Sprintf("Description:\n%s\n\n", task.Body))
|
||||
}
|
||||
|
||||
description.WriteString(fmt.Sprintf("Repository: %s\n", task.Repository))
|
||||
description.WriteString(fmt.Sprintf("Task Number: %d\n", task.Number))
|
||||
|
||||
if len(task.RequiredExpertise) > 0 {
|
||||
description.WriteString(fmt.Sprintf("Required Expertise: %v\n", task.RequiredExpertise))
|
||||
}
|
||||
|
||||
if len(task.Labels) > 0 {
|
||||
description.WriteString(fmt.Sprintf("Labels: %v\n", task.Labels))
|
||||
}
|
||||
|
||||
description.WriteString("\nPlease analyze this task and provide appropriate commands or code to complete it.")
|
||||
|
||||
return description.String()
|
||||
}
|
||||
|
||||
// buildTaskContext creates context information for AI execution
|
||||
func (tc *TaskCoordinator) buildTaskContext(task *repository.Task) map[string]interface{} {
|
||||
context := map[string]interface{}{
|
||||
"repository": task.Repository,
|
||||
"task_number": task.Number,
|
||||
"task_title": task.Title,
|
||||
"required_role": task.RequiredRole,
|
||||
"required_expertise": task.RequiredExpertise,
|
||||
"labels": task.Labels,
|
||||
"agent_info": map[string]interface{}{
|
||||
"id": tc.agentInfo.ID,
|
||||
"role": tc.agentInfo.Role,
|
||||
"expertise": tc.agentInfo.Expertise,
|
||||
},
|
||||
}
|
||||
|
||||
// Add any additional metadata from the task
|
||||
if task.Metadata != nil {
|
||||
context["task_metadata"] = task.Metadata
|
||||
}
|
||||
|
||||
return context
|
||||
}
|
||||
|
||||
// announceAgentRole announces this agent's role and capabilities
|
||||
func (tc *TaskCoordinator) announceAgentRole() {
|
||||
data := map[string]interface{}{
|
||||
@@ -378,19 +667,19 @@ func (tc *TaskCoordinator) announceAgentRole() {
|
||||
// announceTaskClaim announces that this agent has claimed a task
|
||||
func (tc *TaskCoordinator) announceTaskClaim(task *repository.Task) {
|
||||
data := map[string]interface{}{
|
||||
"task_number": task.Number,
|
||||
"repository": task.Repository,
|
||||
"title": task.Title,
|
||||
"agent_id": tc.agentInfo.ID,
|
||||
"agent_role": tc.agentInfo.Role,
|
||||
"claim_time": time.Now().Format(time.RFC3339),
|
||||
"task_number": task.Number,
|
||||
"repository": task.Repository,
|
||||
"title": task.Title,
|
||||
"agent_id": tc.agentInfo.ID,
|
||||
"agent_role": tc.agentInfo.Role,
|
||||
"claim_time": time.Now().Format(time.RFC3339),
|
||||
"estimated_completion": time.Now().Add(time.Hour).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
opts := pubsub.MessageOptions{
|
||||
FromRole: tc.agentInfo.Role,
|
||||
Priority: "medium",
|
||||
ThreadID: fmt.Sprintf("task-%s-%d", task.Repository, task.Number),
|
||||
FromRole: tc.agentInfo.Role,
|
||||
Priority: "medium",
|
||||
ThreadID: fmt.Sprintf("task-%s-%d", task.Repository, task.Number),
|
||||
}
|
||||
|
||||
err := tc.pubsub.PublishRoleBasedMessage(pubsub.TaskProgress, data, opts)
|
||||
@@ -463,15 +752,15 @@ func (tc *TaskCoordinator) handleTaskHelpRequest(msg pubsub.Message, from peer.I
|
||||
}
|
||||
}
|
||||
|
||||
if canHelp && tc.agentInfo.CurrentTasks < tc.agentInfo.MaxTasks {
|
||||
if canHelp && tc.agentInfo.CurrentTasks < tc.agentInfo.MaxTasks {
|
||||
// Offer help
|
||||
responseData := map[string]interface{}{
|
||||
"agent_id": tc.agentInfo.ID,
|
||||
"agent_role": tc.agentInfo.Role,
|
||||
"expertise": tc.agentInfo.Expertise,
|
||||
"availability": tc.agentInfo.MaxTasks - tc.agentInfo.CurrentTasks,
|
||||
"offer_type": "collaboration",
|
||||
"response_to": msg.Data,
|
||||
"agent_id": tc.agentInfo.ID,
|
||||
"agent_role": tc.agentInfo.Role,
|
||||
"expertise": tc.agentInfo.Expertise,
|
||||
"availability": tc.agentInfo.MaxTasks - tc.agentInfo.CurrentTasks,
|
||||
"offer_type": "collaboration",
|
||||
"response_to": msg.Data,
|
||||
}
|
||||
|
||||
opts := pubsub.MessageOptions{
|
||||
@@ -480,34 +769,34 @@ func (tc *TaskCoordinator) handleTaskHelpRequest(msg pubsub.Message, from peer.I
|
||||
ThreadID: msg.ThreadID,
|
||||
}
|
||||
|
||||
err := tc.pubsub.PublishRoleBasedMessage(pubsub.TaskHelpResponse, responseData, opts)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Failed to offer help: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("🤝 Offered help for task collaboration\n")
|
||||
}
|
||||
err := tc.pubsub.PublishRoleBasedMessage(pubsub.TaskHelpResponse, responseData, opts)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Failed to offer help: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("🤝 Offered help for task collaboration\n")
|
||||
}
|
||||
|
||||
// Also reflect the help offer into the HMMM per-issue room (best-effort)
|
||||
if tc.hmmmRouter != nil {
|
||||
if tn, ok := msg.Data["task_number"].(float64); ok {
|
||||
issueID := int64(tn)
|
||||
hmsg := hmmm.Message{
|
||||
Version: 1,
|
||||
Type: "meta_msg",
|
||||
IssueID: issueID,
|
||||
ThreadID: fmt.Sprintf("issue-%d", issueID),
|
||||
MsgID: uuid.New().String(),
|
||||
NodeID: tc.nodeID,
|
||||
HopCount: 0,
|
||||
Timestamp: time.Now().UTC(),
|
||||
Message: fmt.Sprintf("Help offer from %s (availability %d)", tc.agentInfo.Role, tc.agentInfo.MaxTasks-tc.agentInfo.CurrentTasks),
|
||||
}
|
||||
if err := tc.hmmmRouter.Publish(tc.ctx, hmsg); err != nil {
|
||||
fmt.Printf("⚠️ Failed to reflect help into HMMM: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Also reflect the help offer into the HMMM per-issue room (best-effort)
|
||||
if tc.hmmmRouter != nil {
|
||||
if tn, ok := msg.Data["task_number"].(float64); ok {
|
||||
issueID := int64(tn)
|
||||
hmsg := hmmm.Message{
|
||||
Version: 1,
|
||||
Type: "meta_msg",
|
||||
IssueID: issueID,
|
||||
ThreadID: fmt.Sprintf("issue-%d", issueID),
|
||||
MsgID: uuid.New().String(),
|
||||
NodeID: tc.nodeID,
|
||||
HopCount: 0,
|
||||
Timestamp: time.Now().UTC(),
|
||||
Message: fmt.Sprintf("Help offer from %s (availability %d)", tc.agentInfo.Role, tc.agentInfo.MaxTasks-tc.agentInfo.CurrentTasks),
|
||||
}
|
||||
if err := tc.hmmmRouter.Publish(tc.ctx, hmsg); err != nil {
|
||||
fmt.Printf("⚠️ Failed to reflect help into HMMM: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleExpertiseRequest handles requests for specific expertise
|
||||
|
||||
@@ -11,15 +11,15 @@ WORKDIR /build
|
||||
# Copy go mod files first (for better caching)
|
||||
COPY go.mod go.sum ./
|
||||
|
||||
# Copy vendor directory for local dependencies
|
||||
COPY vendor/ vendor/
|
||||
# Download dependencies
|
||||
RUN go mod download
|
||||
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Build the CHORUS binary with vendor mode
|
||||
# Build the CHORUS binary with mod mode
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-mod=vendor \
|
||||
-mod=mod \
|
||||
-ldflags='-w -s -extldflags "-static"' \
|
||||
-o chorus \
|
||||
./cmd/chorus
|
||||
|
||||
38
docker/bootstrap.json
Normal file
38
docker/bootstrap.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"metadata": {
|
||||
"generated_at": "2024-12-19T10:00:00Z",
|
||||
"cluster_id": "production-cluster",
|
||||
"version": "1.0.0",
|
||||
"notes": "Bootstrap configuration for CHORUS scaling - managed by WHOOSH"
|
||||
},
|
||||
"peers": [
|
||||
{
|
||||
"address": "/ip4/10.0.1.10/tcp/9000/p2p/12D3KooWExample1234567890abcdef",
|
||||
"priority": 100,
|
||||
"region": "us-east-1",
|
||||
"roles": ["admin", "stable"],
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"address": "/ip4/10.0.1.11/tcp/9000/p2p/12D3KooWExample1234567890abcde2",
|
||||
"priority": 90,
|
||||
"region": "us-east-1",
|
||||
"roles": ["worker", "stable"],
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"address": "/ip4/10.0.2.10/tcp/9000/p2p/12D3KooWExample1234567890abcde3",
|
||||
"priority": 80,
|
||||
"region": "us-west-2",
|
||||
"roles": ["worker", "stable"],
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"address": "/ip4/10.0.3.10/tcp/9000/p2p/12D3KooWExample1234567890abcde4",
|
||||
"priority": 70,
|
||||
"region": "eu-central-1",
|
||||
"roles": ["worker"],
|
||||
"enabled": false
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -2,7 +2,7 @@ version: "3.9"
|
||||
|
||||
services:
|
||||
chorus:
|
||||
image: anthonyrawlins/chorus:backbeat-v2.0.1
|
||||
image: anthonyrawlins/chorus:latest
|
||||
|
||||
# REQUIRED: License configuration (CHORUS will not start without this)
|
||||
environment:
|
||||
@@ -15,20 +15,39 @@ services:
|
||||
- CHORUS_AGENT_ID=${CHORUS_AGENT_ID:-} # Auto-generated if not provided
|
||||
- CHORUS_SPECIALIZATION=${CHORUS_SPECIALIZATION:-general_developer}
|
||||
- CHORUS_MAX_TASKS=${CHORUS_MAX_TASKS:-3}
|
||||
- CHORUS_CAPABILITIES=${CHORUS_CAPABILITIES:-general_development,task_coordination}
|
||||
- CHORUS_CAPABILITIES=general_development,task_coordination,admin_election
|
||||
|
||||
# Network configuration
|
||||
- CHORUS_API_PORT=8080
|
||||
- CHORUS_HEALTH_PORT=8081
|
||||
- CHORUS_P2P_PORT=9000
|
||||
- CHORUS_BIND_ADDRESS=0.0.0.0
|
||||
|
||||
# Scaling optimizations (as per WHOOSH issue #7)
|
||||
- CHORUS_MDNS_ENABLED=false # Disabled for container/swarm environments
|
||||
- CHORUS_DIALS_PER_SEC=5 # Rate limit outbound connections to prevent storms
|
||||
- CHORUS_MAX_CONCURRENT_DHT=16 # Limit concurrent DHT queries
|
||||
|
||||
# Election stability windows (Medium-risk fix 2.1)
|
||||
- CHORUS_ELECTION_MIN_TERM=30s # Minimum time between elections to prevent churn
|
||||
- CHORUS_LEADER_MIN_TERM=45s # Minimum time before challenging healthy leader
|
||||
|
||||
# Assignment system for runtime configuration (Medium-risk fix 2.2)
|
||||
- ASSIGN_URL=${ASSIGN_URL:-} # Optional: WHOOSH assignment endpoint
|
||||
- TASK_SLOT=${TASK_SLOT:-} # Optional: Task slot identifier
|
||||
- TASK_ID=${TASK_ID:-} # Optional: Task identifier
|
||||
- NODE_ID=${NODE_ID:-} # Optional: Node identifier
|
||||
|
||||
# Bootstrap pool configuration (supports JSON and CSV)
|
||||
- BOOTSTRAP_JSON=/config/bootstrap.json # Optional: JSON bootstrap config
|
||||
- CHORUS_BOOTSTRAP_PEERS=${CHORUS_BOOTSTRAP_PEERS:-} # CSV fallback
|
||||
|
||||
# AI configuration - Provider selection
|
||||
- CHORUS_AI_PROVIDER=${CHORUS_AI_PROVIDER:-resetdata}
|
||||
|
||||
# ResetData configuration (default provider)
|
||||
- RESETDATA_BASE_URL=${RESETDATA_BASE_URL:-https://models.au-syd.resetdata.ai/v1}
|
||||
- RESETDATA_API_KEY=${RESETDATA_API_KEY:?RESETDATA_API_KEY is required for resetdata provider}
|
||||
- RESETDATA_API_KEY_FILE=/run/secrets/resetdata_api_key
|
||||
- RESETDATA_MODEL=${RESETDATA_MODEL:-meta/llama-3.1-8b-instruct}
|
||||
|
||||
# Ollama configuration (alternative provider)
|
||||
@@ -56,12 +75,18 @@ services:
|
||||
# Docker secrets for sensitive configuration
|
||||
secrets:
|
||||
- chorus_license_id
|
||||
- resetdata_api_key
|
||||
|
||||
# Configuration files
|
||||
configs:
|
||||
- source: chorus_bootstrap
|
||||
target: /config/bootstrap.json
|
||||
|
||||
# Persistent data storage
|
||||
volumes:
|
||||
- chorus_data:/app/data
|
||||
# Mount prompts directory read-only for role YAMLs and defaults.md
|
||||
- ../prompts:/etc/chorus/prompts:ro
|
||||
- /rust/containers/WHOOSH/prompts:/etc/chorus/prompts:ro
|
||||
|
||||
# Network ports
|
||||
ports:
|
||||
@@ -70,7 +95,7 @@ services:
|
||||
# Container resource limits
|
||||
deploy:
|
||||
mode: replicated
|
||||
replicas: ${CHORUS_REPLICAS:-1}
|
||||
replicas: ${CHORUS_REPLICAS:-9}
|
||||
update_config:
|
||||
parallelism: 1
|
||||
delay: 10s
|
||||
@@ -90,7 +115,7 @@ services:
|
||||
memory: 128M
|
||||
placement:
|
||||
constraints:
|
||||
- node.hostname != rosewood
|
||||
- node.hostname != acacia
|
||||
preferences:
|
||||
- spread: node.hostname
|
||||
# CHORUS is internal-only, no Traefik labels needed
|
||||
@@ -120,7 +145,7 @@ services:
|
||||
start_period: 10s
|
||||
|
||||
whoosh:
|
||||
image: anthonyrawlins/whoosh:backbeat-v2.1.0
|
||||
image: anthonyrawlins/whoosh:scaling-v1.0.0
|
||||
ports:
|
||||
- target: 8080
|
||||
published: 8800
|
||||
@@ -163,6 +188,18 @@ services:
|
||||
WHOOSH_REDIS_PORT: 6379
|
||||
WHOOSH_REDIS_PASSWORD_FILE: /run/secrets/redis_password
|
||||
WHOOSH_REDIS_DATABASE: 0
|
||||
|
||||
# Scaling system configuration
|
||||
WHOOSH_SCALING_KACHING_URL: "https://kaching.chorus.services"
|
||||
WHOOSH_SCALING_BACKBEAT_URL: "http://backbeat-pulse:8080"
|
||||
WHOOSH_SCALING_CHORUS_URL: "http://chorus:9000"
|
||||
|
||||
# BACKBEAT integration configuration (temporarily disabled)
|
||||
WHOOSH_BACKBEAT_ENABLED: "false"
|
||||
WHOOSH_BACKBEAT_CLUSTER_ID: "chorus-production"
|
||||
WHOOSH_BACKBEAT_AGENT_ID: "whoosh"
|
||||
WHOOSH_BACKBEAT_NATS_URL: "nats://backbeat-nats:4222"
|
||||
|
||||
secrets:
|
||||
- whoosh_db_password
|
||||
- gitea_token
|
||||
@@ -170,6 +207,8 @@ services:
|
||||
- jwt_secret
|
||||
- service_tokens
|
||||
- redis_password
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
deploy:
|
||||
replicas: 2
|
||||
restart_policy:
|
||||
@@ -190,6 +229,8 @@ services:
|
||||
# monitor: 60s
|
||||
# order: stop-first
|
||||
placement:
|
||||
constraints:
|
||||
- node.hostname != acacia
|
||||
preferences:
|
||||
- spread: node.hostname
|
||||
resources:
|
||||
@@ -201,14 +242,16 @@ services:
|
||||
cpus: '0.25'
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.docker.network=tengig
|
||||
- traefik.http.routers.whoosh.rule=Host(`whoosh.chorus.services`)
|
||||
- traefik.http.routers.whoosh.tls=true
|
||||
- traefik.http.routers.whoosh.tls.certresolver=letsencrypt
|
||||
- traefik.http.routers.whoosh.tls.certresolver=letsencryptresolver
|
||||
- traefik.http.routers.photoprism.entrypoints=web,web-secured
|
||||
- traefik.http.services.whoosh.loadbalancer.server.port=8080
|
||||
- traefik.http.middlewares.whoosh-auth.basicauth.users=admin:$$2y$$10$$example_hash
|
||||
- traefik.http.services.photoprism.loadbalancer.passhostheader=true
|
||||
- traefik.http.middlewares.whoosh-auth.basicauth.users=admin:$2y$10$example_hash
|
||||
networks:
|
||||
- tengig
|
||||
- whoosh-backend
|
||||
- chorus_net
|
||||
healthcheck:
|
||||
test: ["CMD", "/app/whoosh", "--health-check"]
|
||||
@@ -246,14 +289,13 @@ services:
|
||||
memory: 256M
|
||||
cpus: '0.5'
|
||||
networks:
|
||||
- whoosh-backend
|
||||
- chorus_net
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U whoosh"]
|
||||
test: ["CMD-SHELL", "pg_isready -h localhost -p 5432 -U whoosh -d whoosh"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
start_period: 40s
|
||||
|
||||
|
||||
redis:
|
||||
@@ -281,7 +323,6 @@ services:
|
||||
memory: 64M
|
||||
cpus: '0.1'
|
||||
networks:
|
||||
- whoosh-backend
|
||||
- chorus_net
|
||||
healthcheck:
|
||||
test: ["CMD", "sh", "-c", "redis-cli --no-auth-warning -a $$(cat /run/secrets/redis_password) ping"]
|
||||
@@ -299,6 +340,66 @@ services:
|
||||
|
||||
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||
- '--storage.tsdb.path=/prometheus'
|
||||
- '--web.console.libraries=/usr/share/prometheus/console_libraries'
|
||||
- '--web.console.templates=/usr/share/prometheus/consoles'
|
||||
volumes:
|
||||
- /rust/containers/CHORUS/monitoring/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml:ro
|
||||
- /rust/containers/CHORUS/monitoring/prometheus:/prometheus
|
||||
ports:
|
||||
- "9099:9090" # Expose Prometheus UI
|
||||
deploy:
|
||||
replicas: 1
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.prometheus.rule=Host(`prometheus.chorus.services`)
|
||||
- traefik.http.routers.prometheus.entrypoints=web,web-secured
|
||||
- traefik.http.routers.prometheus.tls=true
|
||||
- traefik.http.routers.prometheus.tls.certresolver=letsencryptresolver
|
||||
- traefik.http.services.prometheus.loadbalancer.server.port=9090
|
||||
networks:
|
||||
- chorus_net
|
||||
- tengig
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:9090/-/ready"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
user: "1000:1000"
|
||||
environment:
|
||||
- GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD:-admin} # Use a strong password in production
|
||||
- GF_SERVER_ROOT_URL=https://grafana.chorus.services
|
||||
volumes:
|
||||
- /rust/containers/CHORUS/monitoring/grafana:/var/lib/grafana
|
||||
ports:
|
||||
- "3300:3000" # Expose Grafana UI
|
||||
deploy:
|
||||
replicas: 1
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.grafana.rule=Host(`grafana.chorus.services`)
|
||||
- traefik.http.routers.grafana.entrypoints=web,web-secured
|
||||
- traefik.http.routers.grafana.tls=true
|
||||
- traefik.http.routers.grafana.tls.certresolver=letsencryptresolver
|
||||
- traefik.http.services.grafana.loadbalancer.server.port=3000
|
||||
networks:
|
||||
- chorus_net
|
||||
- tengig
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:3000/api/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
|
||||
# BACKBEAT Pulse Service - Leader-elected tempo broadcaster
|
||||
# REQ: BACKBEAT-REQ-001 - Single BeatFrame publisher per cluster
|
||||
# REQ: BACKBEAT-OPS-001 - One replica prefers leadership
|
||||
@@ -344,8 +445,6 @@ services:
|
||||
placement:
|
||||
preferences:
|
||||
- spread: node.hostname
|
||||
constraints:
|
||||
- node.hostname != rosewood # Avoid intermittent gaming PC
|
||||
resources:
|
||||
limits:
|
||||
memory: 256M
|
||||
@@ -413,8 +512,6 @@ services:
|
||||
placement:
|
||||
preferences:
|
||||
- spread: node.hostname
|
||||
constraints:
|
||||
- node.hostname != rosewood
|
||||
resources:
|
||||
limits:
|
||||
memory: 512M # Larger for window aggregation
|
||||
@@ -447,7 +544,6 @@ services:
|
||||
backbeat-nats:
|
||||
image: nats:2.9-alpine
|
||||
command: ["--jetstream"]
|
||||
|
||||
deploy:
|
||||
replicas: 1
|
||||
restart_policy:
|
||||
@@ -458,8 +554,6 @@ services:
|
||||
placement:
|
||||
preferences:
|
||||
- spread: node.hostname
|
||||
constraints:
|
||||
- node.hostname != rosewood
|
||||
resources:
|
||||
limits:
|
||||
memory: 256M
|
||||
@@ -467,10 +561,8 @@ services:
|
||||
reservations:
|
||||
memory: 128M
|
||||
cpus: '0.25'
|
||||
|
||||
networks:
|
||||
- chorus_net
|
||||
|
||||
# Container logging
|
||||
logging:
|
||||
driver: "json-file"
|
||||
@@ -484,6 +576,24 @@ services:
|
||||
|
||||
# Persistent volumes
|
||||
volumes:
|
||||
prometheus_data:
|
||||
driver: local
|
||||
driver_opts:
|
||||
type: none
|
||||
o: bind
|
||||
device: /rust/containers/CHORUS/monitoring/prometheus
|
||||
prometheus_config:
|
||||
driver: local
|
||||
driver_opts:
|
||||
type: none
|
||||
o: bind
|
||||
device: /rust/containers/CHORUS/monitoring/prometheus
|
||||
grafana_data:
|
||||
driver: local
|
||||
driver_opts:
|
||||
type: none
|
||||
o: bind
|
||||
device: /rust/containers/CHORUS/monitoring/grafana
|
||||
chorus_data:
|
||||
driver: local
|
||||
whoosh_postgres_data:
|
||||
@@ -505,23 +615,22 @@ networks:
|
||||
tengig:
|
||||
external: true
|
||||
|
||||
whoosh-backend:
|
||||
driver: overlay
|
||||
attachable: false
|
||||
|
||||
chorus_net:
|
||||
driver: overlay
|
||||
attachable: true
|
||||
ipam:
|
||||
config:
|
||||
- subnet: 10.201.0.0/24
|
||||
|
||||
|
||||
configs:
|
||||
chorus_bootstrap:
|
||||
file: ./bootstrap.json
|
||||
|
||||
secrets:
|
||||
chorus_license_id:
|
||||
external: true
|
||||
name: chorus_license_id
|
||||
resetdata_api_key:
|
||||
external: true
|
||||
name: resetdata_api_key
|
||||
whoosh_db_password:
|
||||
external: true
|
||||
name: whoosh_db_password
|
||||
|
||||
30
docs/decisions/2025-02-16-shhh-sentinel-foundation.md
Normal file
30
docs/decisions/2025-02-16-shhh-sentinel-foundation.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Decision Record: Establish SHHH Sentinel Foundations
|
||||
|
||||
- **Date:** 2025-02-16
|
||||
- **Status:** Accepted
|
||||
- **Context:** CHORUS roadmap Phase 1 requires a secrets sentinel (`pkg/shhh`) before we wire COOEE/WHOOSH telemetry and audit plumbing. The runtime previously emitted placeholder TODOs and logged sensitive payloads without guard rails.
|
||||
|
||||
## Problem
|
||||
- We lacked a reusable component to detect and redact secrets prior to log/telemetry fan-out.
|
||||
- Without a dedicated sentinel we could not attach audit sinks or surface metrics for redaction events, blocking roadmap item `SEC-SHHH`.
|
||||
|
||||
## Decision
|
||||
- Introduce `pkg/shhh` as the SHHH sentinel with:
|
||||
- Curated default rules (API keys, bearer/OAuth tokens, private key PEM blocks, OpenAI secrets).
|
||||
- Extensible configuration for custom regex rules and per-rule severity/tags.
|
||||
- Optional audit sink and statistics collection for integration with COOEE/WHOOSH pipelines.
|
||||
- Helpers to redact free-form text and `map[string]any` payloads used by our logging pipeline.
|
||||
|
||||
## Rationale
|
||||
- Starting with a focused set of high-signal rules gives immediate coverage for the most damaging leak classes without delaying larger SLURP/SHHH workstreams.
|
||||
- The API mirrors other CHORUS subsystems (options, config structs, stats snapshots) so existing operators can plug metrics/audits without bespoke glue.
|
||||
- Providing deterministic findings/locations simplifies future enforcement (e.g., WHOOSH UI badges, COOEE replay) while keeping implementation lean.
|
||||
|
||||
## Impact
|
||||
- Runtime components can now instantiate SHHH and guarantee `[REDACTED]` placeholders for sensitive fields.
|
||||
- Audit/event plumbing can be wired incrementally—hashes are emitted for replay without storing raw secrets.
|
||||
- Future roadmap tasks (policy driven rules, replay, UCXL evidence) can extend `pkg/shhh` rather than implementing ad-hoc redaction in each subsystem.
|
||||
|
||||
## Related Work
|
||||
- Roadmap: `docs/progress/CHORUS-WHOOSH-roadmap.md` (Phase 1.2 `SEC-SHHH`).
|
||||
- README coverage gap noted in `README.md` table (SHHH not implemented).
|
||||
435
docs/development/task-execution-engine-plan.md
Normal file
435
docs/development/task-execution-engine-plan.md
Normal file
@@ -0,0 +1,435 @@
|
||||
# CHORUS Task Execution Engine Development Plan
|
||||
|
||||
## Overview
|
||||
This plan outlines the development of a comprehensive task execution engine for CHORUS agents, replacing the current mock implementation with a fully functional system that can execute real work according to agent roles and specializations.
|
||||
|
||||
## Current State Analysis
|
||||
|
||||
### What's Implemented ✅
|
||||
- **Task Coordinator Framework** (`coordinator/task_coordinator.go`): Full task management lifecycle with role-based assignment, collaboration requests, and HMMM integration
|
||||
- **Agent Role System**: Role announcements, capability broadcasting, and expertise matching
|
||||
- **P2P Infrastructure**: Nodes can discover each other and communicate via pubsub
|
||||
- **Health Monitoring**: Comprehensive health checks and graceful shutdown
|
||||
|
||||
### Critical Gaps Identified ❌
|
||||
- **Task Execution Engine**: `executeTask()` only has a 10-second sleep simulation - no actual work performed
|
||||
- **Repository Integration**: Mock providers only - no real GitHub/GitLab task pulling
|
||||
- **Agent-to-Task Binding**: Task discovery relies on WHOOSH but agents don't connect to real work
|
||||
- **Role-Based Execution**: Agents announce roles but don't execute tasks according to their specialization
|
||||
- **AI Integration**: No LLM/reasoning integration for task completion
|
||||
|
||||
## Architecture Requirements
|
||||
|
||||
### Model and Provider Abstraction
|
||||
The execution engine must support multiple AI model providers and execution environments:
|
||||
|
||||
**Model Provider Types:**
|
||||
- **Local Ollama**: Default for most roles (llama3.1:8b, codellama, etc.)
|
||||
- **OpenAI API**: For specialized models (chatgpt-5, gpt-4o, etc.)
|
||||
- **ResetData API**: For testing and fallback (llama3.1:8b via LaaS)
|
||||
- **Custom Endpoints**: Support for other provider APIs
|
||||
|
||||
**Role-Model Mapping:**
|
||||
- Each role has a default model configuration
|
||||
- Specialized roles may require specific models/providers
|
||||
- Model selection transparent to execution logic
|
||||
- Support for MCP calls and tool usage regardless of provider
|
||||
|
||||
### Execution Environment Abstraction
|
||||
Tasks must execute in secure, isolated environments while maintaining transparency:
|
||||
|
||||
**Sandbox Types:**
|
||||
- **Docker Containers**: Isolated execution environment per task
|
||||
- **Specialized VMs**: For tasks requiring full OS isolation
|
||||
- **Process Sandboxing**: Lightweight isolation for simple tasks
|
||||
|
||||
**Transparency Requirements:**
|
||||
- Model perceives it's working on a local repository
|
||||
- Development tools available within sandbox
|
||||
- File system operations work normally from model's perspective
|
||||
- Network access controlled but transparent
|
||||
- Resource limits enforced but invisible
|
||||
|
||||
## Development Plan
|
||||
|
||||
### Phase 1: Model Provider Abstraction Layer
|
||||
|
||||
#### 1.1 Create Provider Interface
|
||||
```go
|
||||
// pkg/ai/provider.go
|
||||
type ModelProvider interface {
|
||||
ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error)
|
||||
SupportsMCP() bool
|
||||
SupportsTools() bool
|
||||
GetCapabilities() []string
|
||||
}
|
||||
```
|
||||
|
||||
#### 1.2 Implement Provider Types
|
||||
- **OllamaProvider**: Local model execution
|
||||
- **OpenAIProvider**: OpenAI API integration
|
||||
- **ResetDataProvider**: ResetData LaaS integration
|
||||
- **ProviderFactory**: Creates appropriate provider based on model config
|
||||
|
||||
#### 1.3 Role-Model Configuration
|
||||
```yaml
|
||||
# Config structure for role-model mapping
|
||||
roles:
|
||||
developer:
|
||||
default_model: "codellama:13b"
|
||||
provider: "ollama"
|
||||
fallback_model: "llama3.1:8b"
|
||||
fallback_provider: "resetdata"
|
||||
|
||||
architect:
|
||||
default_model: "gpt-4o"
|
||||
provider: "openai"
|
||||
fallback_model: "llama3.1:8b"
|
||||
fallback_provider: "ollama"
|
||||
```
|
||||
|
||||
### Phase 2: Execution Environment Abstraction
|
||||
|
||||
#### 2.1 Create Sandbox Interface
|
||||
```go
|
||||
// pkg/execution/sandbox.go
|
||||
type ExecutionSandbox interface {
|
||||
Initialize(ctx context.Context, config *SandboxConfig) error
|
||||
ExecuteCommand(ctx context.Context, cmd *Command) (*CommandResult, error)
|
||||
CopyFiles(ctx context.Context, source, dest string) error
|
||||
Cleanup() error
|
||||
}
|
||||
```
|
||||
|
||||
#### 2.2 Implement Sandbox Types
|
||||
- **DockerSandbox**: Container-based isolation
|
||||
- **VMSandbox**: Full VM isolation for sensitive tasks
|
||||
- **ProcessSandbox**: Lightweight process-based isolation
|
||||
|
||||
#### 2.3 Repository Mounting
|
||||
- Clone repository into sandbox environment
|
||||
- Mount as local filesystem from model's perspective
|
||||
- Implement secure file I/O operations
|
||||
- Handle git operations within sandbox
|
||||
|
||||
### Phase 3: Core Task Execution Engine
|
||||
|
||||
#### 3.1 Replace Mock Implementation
|
||||
Replace the current simulation in `coordinator/task_coordinator.go:314`:
|
||||
|
||||
```go
|
||||
// Current mock implementation
|
||||
time.Sleep(10 * time.Second) // Simulate work
|
||||
|
||||
// New implementation
|
||||
result, err := tc.executionEngine.ExecuteTask(ctx, &TaskExecutionRequest{
|
||||
Task: activeTask.Task,
|
||||
Agent: tc.agentInfo,
|
||||
Sandbox: sandboxConfig,
|
||||
ModelProvider: providerConfig,
|
||||
})
|
||||
```
|
||||
|
||||
#### 3.2 Task Execution Strategies
|
||||
Create role-specific execution patterns:
|
||||
|
||||
- **DeveloperStrategy**: Code implementation, bug fixes, feature development
|
||||
- **ReviewerStrategy**: Code review, quality analysis, test coverage assessment
|
||||
- **ArchitectStrategy**: System design, technical decision making
|
||||
- **TesterStrategy**: Test creation, validation, quality assurance
|
||||
|
||||
#### 3.3 Execution Workflow
|
||||
1. **Task Analysis**: Parse task requirements and complexity
|
||||
2. **Environment Setup**: Initialize appropriate sandbox
|
||||
3. **Repository Preparation**: Clone and mount repository
|
||||
4. **Model Selection**: Choose appropriate model/provider
|
||||
5. **Task Execution**: Run role-specific execution strategy
|
||||
6. **Result Validation**: Verify output quality and completeness
|
||||
7. **Cleanup**: Teardown sandbox and collect artifacts
|
||||
|
||||
### Phase 4: Repository Provider Implementation
|
||||
|
||||
#### 4.1 Real Repository Integration
|
||||
Replace `MockTaskProvider` with actual implementations:
|
||||
- **GiteaProvider**: Integration with GITEA API
|
||||
- **GitHubProvider**: GitHub API integration
|
||||
- **GitLabProvider**: GitLab API integration
|
||||
|
||||
#### 4.2 Task Lifecycle Management
|
||||
- Task claiming and status updates
|
||||
- Progress reporting back to repositories
|
||||
- Artifact attachment (patches, documentation, etc.)
|
||||
- Automated PR/MR creation for completed tasks
|
||||
|
||||
### Phase 5: AI Integration and Tool Support
|
||||
|
||||
#### 5.1 LLM Integration
|
||||
- Context-aware task analysis based on repository content
|
||||
- Code generation and problem-solving capabilities
|
||||
- Natural language processing for task descriptions
|
||||
- Multi-step reasoning for complex tasks
|
||||
|
||||
#### 5.2 Tool Integration
|
||||
- MCP server connectivity within sandbox
|
||||
- Development tool access (compilers, linters, formatters)
|
||||
- Testing framework integration
|
||||
- Documentation generation tools
|
||||
|
||||
#### 5.3 Quality Assurance
|
||||
- Automated testing of generated code
|
||||
- Code quality metrics and analysis
|
||||
- Security vulnerability scanning
|
||||
- Performance impact assessment
|
||||
|
||||
### Phase 6: Testing and Validation
|
||||
|
||||
#### 6.1 Unit Testing
|
||||
- Provider abstraction layer testing
|
||||
- Sandbox isolation verification
|
||||
- Task execution strategy validation
|
||||
- Error handling and recovery testing
|
||||
|
||||
#### 6.2 Integration Testing
|
||||
- End-to-end task execution workflows
|
||||
- Agent-to-WHOOSH communication testing
|
||||
- Multi-provider failover scenarios
|
||||
- Concurrent task execution testing
|
||||
|
||||
#### 6.3 Security Testing
|
||||
- Sandbox escape prevention
|
||||
- Resource limit enforcement
|
||||
- Network isolation validation
|
||||
- Secrets and credential protection
|
||||
|
||||
### Phase 7: Production Deployment
|
||||
|
||||
#### 7.1 Configuration Management
|
||||
- Environment-specific model configurations
|
||||
- Sandbox resource limit definitions
|
||||
- Provider API key management
|
||||
- Monitoring and logging setup
|
||||
|
||||
#### 7.2 Monitoring and Observability
|
||||
- Task execution metrics and dashboards
|
||||
- Performance monitoring and alerting
|
||||
- Resource utilization tracking
|
||||
- Error rate and success metrics
|
||||
|
||||
## Implementation Priorities
|
||||
|
||||
### Critical Path (Week 1-2)
|
||||
1. Model Provider Abstraction Layer
|
||||
2. Basic Docker Sandbox Implementation
|
||||
3. Replace Mock Task Execution
|
||||
4. Role-Based Execution Strategies
|
||||
|
||||
### High Priority (Week 3-4)
|
||||
5. Real Repository Provider Implementation
|
||||
6. AI Integration with Ollama/OpenAI
|
||||
7. MCP Tool Integration
|
||||
8. Basic Testing Framework
|
||||
|
||||
### Medium Priority (Week 5-6)
|
||||
9. Advanced Sandbox Types (VM, Process)
|
||||
10. Quality Assurance Pipeline
|
||||
11. Comprehensive Testing Suite
|
||||
12. Performance Optimization
|
||||
|
||||
### Future Enhancements
|
||||
- Multi-language model support
|
||||
- Advanced reasoning capabilities
|
||||
- Distributed task execution
|
||||
- Machine learning model fine-tuning
|
||||
|
||||
## Success Metrics
|
||||
|
||||
- **Task Completion Rate**: >90% of assigned tasks successfully completed
|
||||
- **Code Quality**: Generated code passes all existing tests and linting
|
||||
- **Security**: Zero sandbox escapes or security violations
|
||||
- **Performance**: Task execution time within acceptable bounds
|
||||
- **Reliability**: <5% execution failure rate due to engine issues
|
||||
|
||||
## Risk Mitigation
|
||||
|
||||
### Security Risks
|
||||
- Sandbox escape → Multiple isolation layers, security audits
|
||||
- Credential exposure → Secure credential management, rotation
|
||||
- Resource exhaustion → Resource limits, monitoring, auto-scaling
|
||||
|
||||
### Technical Risks
|
||||
- Model provider outages → Multi-provider failover, local fallbacks
|
||||
- Execution failures → Robust error handling, retry mechanisms
|
||||
- Performance bottlenecks → Profiling, optimization, horizontal scaling
|
||||
|
||||
### Integration Risks
|
||||
- WHOOSH compatibility → Extensive integration testing, versioning
|
||||
- Repository provider changes → Provider abstraction, API versioning
|
||||
- Model compatibility → Provider abstraction, capability detection
|
||||
|
||||
This comprehensive plan addresses the core limitation that CHORUS agents currently lack real task execution capabilities while building a robust, secure, and scalable execution engine suitable for production deployment.
|
||||
|
||||
## Implementation Roadmap
|
||||
|
||||
### Development Standards & Workflow
|
||||
|
||||
**Semantic Versioning Strategy:**
|
||||
- **Patch (0.N.X)**: Bug fixes, small improvements, documentation updates
|
||||
- **Minor (0.N.0)**: New features, phase completions, non-breaking changes
|
||||
- **Major (N.0.0)**: Breaking changes, major architectural shifts
|
||||
|
||||
**Git Workflow:**
|
||||
1. **Branch Creation**: `git checkout -b feature/phase-N-description`
|
||||
2. **Development**: Implement with frequent commits using conventional commit format
|
||||
3. **Testing**: Run full test suite with `make test` before PR
|
||||
4. **Code Review**: Create PR with detailed description and test results
|
||||
5. **Integration**: Squash merge to main after approval
|
||||
6. **Release**: Tag with `git tag v0.N.0` and update Makefile version
|
||||
|
||||
**Quality Gates:**
|
||||
Each phase must meet these criteria before merge:
|
||||
- ✅ Unit tests with >80% coverage
|
||||
- ✅ Integration tests for external dependencies
|
||||
- ✅ Security review for new attack surfaces
|
||||
- ✅ Performance benchmarks within acceptable bounds
|
||||
- ✅ Documentation updates (code comments + README)
|
||||
- ✅ Backward compatibility verification
|
||||
|
||||
### Phase-by-Phase Implementation
|
||||
|
||||
#### Phase 1: Model Provider Abstraction (v0.2.0)
|
||||
**Branch:** `feature/phase-1-model-providers`
|
||||
**Duration:** 3-5 days
|
||||
**Deliverables:**
|
||||
```
|
||||
pkg/ai/
|
||||
├── provider.go # Core provider interface & request/response types
|
||||
├── ollama.go # Local Ollama model integration
|
||||
├── openai.go # OpenAI API client wrapper
|
||||
├── resetdata.go # ResetData LaaS integration
|
||||
├── factory.go # Provider factory with auto-selection
|
||||
└── provider_test.go # Comprehensive provider tests
|
||||
|
||||
configs/
|
||||
└── models.yaml # Role-model mapping configuration
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Abstract AI providers behind unified interface
|
||||
- Support multiple providers with automatic failover
|
||||
- Configuration-driven model selection per agent role
|
||||
- Proper error handling and retry logic
|
||||
|
||||
#### Phase 2: Execution Environment Abstraction (v0.3.0)
|
||||
**Branch:** `feature/phase-2-execution-sandbox`
|
||||
**Duration:** 5-7 days
|
||||
**Deliverables:**
|
||||
```
|
||||
pkg/execution/
|
||||
├── sandbox.go # Core sandbox interface & types
|
||||
├── docker.go # Docker container implementation
|
||||
├── security.go # Security policies & enforcement
|
||||
├── resources.go # Resource monitoring & limits
|
||||
└── sandbox_test.go # Sandbox security & isolation tests
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Docker-based task isolation with transparent repository access
|
||||
- Resource limits (CPU, memory, network, disk) with monitoring
|
||||
- Security boundary enforcement and escape prevention
|
||||
- Clean teardown and artifact collection
|
||||
|
||||
#### Phase 3: Core Task Execution Engine (v0.4.0)
|
||||
**Branch:** `feature/phase-3-task-execution`
|
||||
**Duration:** 7-10 days
|
||||
**Modified Files:**
|
||||
- `coordinator/task_coordinator.go:314` - Replace mock with real execution
|
||||
- `pkg/repository/types.go` - Extend interfaces for execution context
|
||||
|
||||
**New Files:**
|
||||
```
|
||||
pkg/strategies/
|
||||
├── developer.go # Code implementation & bug fixes
|
||||
├── reviewer.go # Code review & quality analysis
|
||||
├── architect.go # System design & tech decisions
|
||||
└── tester.go # Test creation & validation
|
||||
|
||||
pkg/engine/
|
||||
├── executor.go # Main execution orchestrator
|
||||
├── workflow.go # 7-step execution workflow
|
||||
└── validation.go # Result quality verification
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Real task execution replacing 10-second sleep simulation
|
||||
- Role-specific execution strategies with appropriate tooling
|
||||
- Integration between AI providers, sandboxes, and task lifecycle
|
||||
- Comprehensive result validation and quality metrics
|
||||
|
||||
#### Phase 4: Repository Provider Implementation (v0.5.0)
|
||||
**Branch:** `feature/phase-4-real-providers`
|
||||
**Duration:** 10-14 days
|
||||
**Deliverables:**
|
||||
```
|
||||
pkg/providers/
|
||||
├── gitea.go # Gitea API integration (primary)
|
||||
├── github.go # GitHub API integration
|
||||
├── gitlab.go # GitLab API integration
|
||||
└── provider_test.go # API integration tests
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Replace MockTaskProvider with production implementations
|
||||
- Task claiming, status updates, and progress reporting via APIs
|
||||
- Automated PR/MR creation with proper branch management
|
||||
- Repository-specific configuration and credential management
|
||||
|
||||
### Testing Strategy
|
||||
|
||||
**Unit Testing:**
|
||||
- Each provider/sandbox implementation has dedicated test suite
|
||||
- Mock external dependencies (APIs, Docker, etc.) for isolated testing
|
||||
- Property-based testing for core interfaces
|
||||
- Error condition and edge case coverage
|
||||
|
||||
**Integration Testing:**
|
||||
- End-to-end task execution workflows
|
||||
- Multi-provider failover scenarios
|
||||
- Agent-to-WHOOSH communication validation
|
||||
- Concurrent task execution under load
|
||||
|
||||
**Security Testing:**
|
||||
- Sandbox escape prevention validation
|
||||
- Resource exhaustion protection
|
||||
- Network isolation verification
|
||||
- Secrets and credential protection audits
|
||||
|
||||
### Deployment & Monitoring
|
||||
|
||||
**Configuration Management:**
|
||||
- Environment-specific model configurations
|
||||
- Sandbox resource limits per environment
|
||||
- Provider API credentials via secure secret management
|
||||
- Feature flags for gradual rollout
|
||||
|
||||
**Observability:**
|
||||
- Task execution metrics (completion rate, duration, success/failure)
|
||||
- Resource utilization tracking (CPU, memory, network per task)
|
||||
- Error rate monitoring with alerting thresholds
|
||||
- Performance dashboards for capacity planning
|
||||
|
||||
### Risk Mitigation
|
||||
|
||||
**Technical Risks:**
|
||||
- **Provider Outages**: Multi-provider failover with health checks
|
||||
- **Resource Exhaustion**: Strict limits with monitoring and auto-scaling
|
||||
- **Execution Failures**: Retry mechanisms with exponential backoff
|
||||
|
||||
**Security Risks:**
|
||||
- **Sandbox Escapes**: Multiple isolation layers and regular security audits
|
||||
- **Credential Exposure**: Secure rotation and least-privilege access
|
||||
- **Data Exfiltration**: Network isolation and egress monitoring
|
||||
|
||||
**Integration Risks:**
|
||||
- **API Changes**: Provider abstraction with versioning support
|
||||
- **Performance Degradation**: Comprehensive benchmarking at each phase
|
||||
- **Compatibility Issues**: Extensive integration testing with existing systems
|
||||
70
docs/progress/CHORUS-WHOOSH-roadmap.md
Normal file
70
docs/progress/CHORUS-WHOOSH-roadmap.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# CHORUS / WHOOSH Roadmap
|
||||
|
||||
_Last updated: 2025-02-15_
|
||||
|
||||
This roadmap translates the development plan into phased milestones with suggested sequencing and exit criteria. Durations are approximate and assume parallel work streams where practical.
|
||||
|
||||
## Phase 0 – Kick-off & Scoping (Week 0)
|
||||
- Confirm owners and staffing for SLURP, SHHH, COOEE, WHOOSH, UCXL, and KACHING work streams.
|
||||
- Finalize engineering briefs for each deliverable; align with plan in `CHORUS-WHOOSH-development-plan.md`.
|
||||
- Stand up tracking board (Kanban/Sprint) with milestone tags introduced below.
|
||||
|
||||
**Exit Criteria**
|
||||
- Owners assigned and briefs approved.
|
||||
- Roadmap milestones added to tracking tooling.
|
||||
|
||||
## Phase 1 – Security Substrate Foundations (Weeks 1–4)
|
||||
- **1.1 SLURP Core (Weeks 1–3)**
|
||||
- Implement storage/resolver/temporal components and leader integration (ticket group `SEC-SLURP`).
|
||||
- Ship integration tests covering admin-only operations and failover.
|
||||
- **1.2 SHHH Sentinel (Weeks 2–4)**
|
||||
- Build `pkg/shhh`, integrate with COOEE/WHOOSH logging, add audit metrics (`SEC-SHHH`).
|
||||
- **1.3 COOEE Mesh Monitoring (Weeks 3–4)**
|
||||
- Validate enrolment payloads, instrument mesh health, document ops runbook (`SEC-COOEE`).
|
||||
|
||||
**Exit Criteria**
|
||||
- SLURP passes integration suite with real context resolution.
|
||||
- SHHH redaction events visible in metrics/logs; regression tests in place.
|
||||
- COOEE dashboards/reporting operational; runbook published.
|
||||
|
||||
## Phase 2 – WHOOSH Data Path & Telemetry (Weeks 4–8)
|
||||
- **2.1 Persistence & API Hardening (Weeks 4–6)**
|
||||
- Replace mock handlers with Postgres-backed endpoints (`WHOOSH-API`).
|
||||
- **2.2 Analysis Ingestion (Weeks 5–7)**
|
||||
- Pipeline real Gitea/n8n analysis into composer/monitor (`WHOOSH-ANALYSIS`).
|
||||
- **2.3 Deployment Telemetry (Weeks 6–8)**
|
||||
- Persist deployment results, emit telemetry, surface status in UI (`WHOOSH-OBS`).
|
||||
- **2.4 Composer Enhancements (Weeks 7–8)**
|
||||
- Add LLM skill analysis with fallback heuristics; evaluation harness (`WHOOSH-COMP`).
|
||||
|
||||
**Exit Criteria**
|
||||
- WHOOSH API/UI reflects live database state.
|
||||
- Analysis-derived data present in team formation/deployment flows.
|
||||
- Telemetry events available for KACHING integration.
|
||||
|
||||
## Phase 3 – Cross-Cutting Governance & Tooling (Weeks 8–12)
|
||||
- **3.1 UCXL Spec & Validator (Weeks 8–10)**
|
||||
- Publish Spec 1.0, ship validator CLI with CI coverage (`UCXL-SPEC`).
|
||||
- **3.2 KACHING Telemetry (Weeks 9–11)**
|
||||
- Instrument CHORUS runtime & WHOOSH orchestrator, deploy ingestion/aggregation jobs (`KACHING-TELEM`).
|
||||
- **3.3 Governance Tooling (Weeks 10–12)**
|
||||
- Deliver DR templates, signed assertions workflow, scope-aware RUSTLE views (`GOV-TOOLS`).
|
||||
|
||||
**Exit Criteria**
|
||||
- UCXL validator integrated into CI for CHORUS/WHOOSH/RUSTLE.
|
||||
- KACHING receives events and triggers quota/budget alerts.
|
||||
- Governance docs/tooling published; RUSTLE displays redacted context correctly.
|
||||
|
||||
## Phase 4 – Stabilization & Launch Readiness (Weeks 12–14)
|
||||
- Regression testing across CHORUS/WHOOSH/UCXL/KACHING.
|
||||
- Security & compliance review for SHHH and telemetry pipelines.
|
||||
- Rollout plan: staged deployment, rollback procedures, support playbooks.
|
||||
|
||||
**Exit Criteria**
|
||||
- All milestone tickets closed with QA sign-off.
|
||||
- Production readiness review approved; launch window scheduled.
|
||||
|
||||
## Tracking & Reporting
|
||||
- Weekly status sync covering milestone burndown, risks, and cross-team blockers.
|
||||
- Metrics dashboard to include: SLURP leader uptime, SHHH redaction counts, COOEE peer health, WHOOSH deployment success rate, UCXL validation pass rate, KACHING alert volume.
|
||||
- Maintain Decision Records for key architecture/security choices at relevant UCXL addresses.
|
||||
35
go.mod
35
go.mod
@@ -1,6 +1,6 @@
|
||||
module chorus
|
||||
|
||||
go 1.23
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.5
|
||||
|
||||
@@ -8,6 +8,9 @@ require (
|
||||
filippo.io/age v1.2.1
|
||||
github.com/blevesearch/bleve/v2 v2.5.3
|
||||
github.com/chorus-services/backbeat v0.0.0-00010101000000-000000000000
|
||||
github.com/docker/docker v28.4.0+incompatible
|
||||
github.com/docker/go-connections v0.6.0
|
||||
github.com/docker/go-units v0.5.0
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
@@ -21,12 +24,15 @@ require (
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/sashabaranov/go-openai v1.41.1
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/sony/gobreaker v0.5.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/syndtr/goleveldb v1.0.0
|
||||
golang.org/x/crypto v0.24.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/RoaringBitmap/roaring/v2 v2.4.5 // indirect
|
||||
github.com/benbjohnson/clock v1.3.5 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
@@ -50,16 +56,19 @@ require (
|
||||
github.com/blevesearch/zapx/v16 v16.2.4 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/containerd/cgroups v1.1.0 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/elastic/gosigar v0.14.2 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/flynn/noise v1.0.0 // indirect
|
||||
github.com/francoispqt/gojay v1.2.13 // indirect
|
||||
github.com/go-logr/logr v1.2.4 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||
@@ -104,6 +113,7 @@ require (
|
||||
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b // indirect
|
||||
github.com/mikioh/tcpopt v0.0.0-20190314235656-172688c1accc // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/mr-tron/base58 v1.2.0 // indirect
|
||||
@@ -120,6 +130,8 @@ require (
|
||||
github.com/nats-io/nkeys v0.4.7 // indirect
|
||||
github.com/nats-io/nuid v1.0.1 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.13.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/opencontainers/runtime-spec v1.1.0 // indirect
|
||||
github.com/opentracing/opentracing-go v1.2.0 // indirect
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
|
||||
@@ -138,9 +150,11 @@ require (
|
||||
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 // indirect
|
||||
go.etcd.io/bbolt v1.4.0 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/otel v1.16.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.16.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.16.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
|
||||
go.opentelemetry.io/otel v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
go.uber.org/dig v1.17.1 // indirect
|
||||
go.uber.org/fx v1.20.1 // indirect
|
||||
go.uber.org/mock v0.3.0 // indirect
|
||||
@@ -150,13 +164,12 @@ require (
|
||||
golang.org/x/mod v0.18.0 // indirect
|
||||
golang.org/x/net v0.26.0 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/sys v0.29.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/tools v0.22.0 // indirect
|
||||
gonum.org/v1/gonum v0.13.0 // indirect
|
||||
google.golang.org/protobuf v1.33.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
lukechampine.com/blake3 v1.2.1 // indirect
|
||||
)
|
||||
|
||||
replace github.com/chorus-services/backbeat => /home/tony/chorus/project-queues/active/BACKBEAT/backbeat/prototype
|
||||
replace github.com/chorus-services/backbeat => ../BACKBEAT/backbeat/prototype
|
||||
|
||||
40
go.sum
40
go.sum
@@ -12,6 +12,8 @@ filippo.io/age v1.2.1 h1:X0TZjehAZylOIj4DubWYU1vWQxv9bJpo+Uu2/LGhi1o=
|
||||
filippo.io/age v1.2.1/go.mod h1:JL9ew2lTN+Pyft4RiNGguFfOpewKwSHm5ayKD/A4004=
|
||||
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/RoaringBitmap/roaring/v2 v2.4.5 h1:uGrrMreGjvAtTBobc0g5IrW1D5ldxDQYe2JW2gggRdg=
|
||||
github.com/RoaringBitmap/roaring/v2 v2.4.5/go.mod h1:FiJcsfkGje/nZBZgCu0ZxCPOKD/hVXDS2dXi7/eUFE0=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
|
||||
@@ -72,6 +74,10 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX
|
||||
github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE=
|
||||
github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM=
|
||||
github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
|
||||
@@ -89,6 +95,12 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etly
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.4.0+incompatible h1:KVC7bz5zJY/4AZe/78BIvCnPsLaC9T/zh72xnlrTTOk=
|
||||
github.com/docker/docker v28.4.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
@@ -100,6 +112,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
|
||||
github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ=
|
||||
github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
@@ -116,6 +130,8 @@ github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
@@ -307,6 +323,8 @@ github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8Rv
|
||||
github.com/minio/sha256-simd v0.1.1-0.20190913151208-6de447530771/go.mod h1:B5e1o+1/KgNmWrSQK08Y6Z1Vb5pwIktudl0J58iy0KM=
|
||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -361,6 +379,10 @@ github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xl
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI=
|
||||
github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
|
||||
github.com/opencontainers/runtime-spec v1.1.0 h1:HHUyrt9mwHUjtasSbXSMvs4cyFxh+Bll4AjJ9odEGpg=
|
||||
github.com/opencontainers/runtime-spec v1.1.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
|
||||
@@ -437,6 +459,8 @@ github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N
|
||||
github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
|
||||
github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
|
||||
github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM=
|
||||
github.com/sony/gobreaker v0.5.0 h1:dRCvqm0P490vZPmy7ppEk2qCnCieBooFJ+YoXGYB+yg=
|
||||
github.com/sony/gobreaker v0.5.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
|
||||
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE=
|
||||
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
|
||||
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
|
||||
@@ -454,6 +478,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
|
||||
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
|
||||
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
||||
@@ -473,12 +499,22 @@ go.etcd.io/bbolt v1.4.0/go.mod h1:AsD+OCi/qPN1giOX1aiLAha3o1U8rAz65bvN4j0sRuk=
|
||||
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
|
||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg=
|
||||
go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s=
|
||||
go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo=
|
||||
go.opentelemetry.io/otel/metric v1.16.0/go.mod h1:QE47cpOmkwipPiefDwo2wDzwJrlfxxNYodqc4xnGCo4=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs=
|
||||
go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
@@ -588,6 +624,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
|
||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||
@@ -659,6 +697,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
|
||||
340
internal/licensing/license_gate.go
Normal file
340
internal/licensing/license_gate.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package licensing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sony/gobreaker"
|
||||
)
|
||||
|
||||
// LicenseGate provides burst-proof license validation with caching and circuit breaker
|
||||
type LicenseGate struct {
|
||||
config LicenseConfig
|
||||
cache atomic.Value // stores cachedLease
|
||||
breaker *gobreaker.CircuitBreaker
|
||||
graceUntil atomic.Value // stores time.Time
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// cachedLease represents a cached license lease with expiry
|
||||
type cachedLease struct {
|
||||
LeaseToken string `json:"lease_token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
ClusterID string `json:"cluster_id"`
|
||||
Valid bool `json:"valid"`
|
||||
CachedAt time.Time `json:"cached_at"`
|
||||
}
|
||||
|
||||
// LeaseRequest represents a cluster lease request
|
||||
type LeaseRequest struct {
|
||||
ClusterID string `json:"cluster_id"`
|
||||
RequestedReplicas int `json:"requested_replicas"`
|
||||
DurationMinutes int `json:"duration_minutes"`
|
||||
}
|
||||
|
||||
// LeaseResponse represents a cluster lease response
|
||||
type LeaseResponse struct {
|
||||
LeaseToken string `json:"lease_token"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
ClusterID string `json:"cluster_id"`
|
||||
LeaseID string `json:"lease_id"`
|
||||
}
|
||||
|
||||
// LeaseValidationRequest represents a lease validation request
|
||||
type LeaseValidationRequest struct {
|
||||
LeaseToken string `json:"lease_token"`
|
||||
ClusterID string `json:"cluster_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
}
|
||||
|
||||
// LeaseValidationResponse represents a lease validation response
|
||||
type LeaseValidationResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
RemainingReplicas int `json:"remaining_replicas"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// NewLicenseGate creates a new license gate with circuit breaker and caching
|
||||
func NewLicenseGate(config LicenseConfig) *LicenseGate {
|
||||
// Circuit breaker settings optimized for license validation
|
||||
breakerSettings := gobreaker.Settings{
|
||||
Name: "license-validation",
|
||||
MaxRequests: 3, // Allow 3 requests in half-open state
|
||||
Interval: 60 * time.Second, // Reset failure count every minute
|
||||
Timeout: 30 * time.Second, // Stay open for 30 seconds
|
||||
ReadyToTrip: func(counts gobreaker.Counts) bool {
|
||||
// Trip after 3 consecutive failures
|
||||
return counts.ConsecutiveFailures >= 3
|
||||
},
|
||||
OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) {
|
||||
fmt.Printf("🔌 License validation circuit breaker: %s -> %s\n", from, to)
|
||||
},
|
||||
}
|
||||
|
||||
gate := &LicenseGate{
|
||||
config: config,
|
||||
breaker: gobreaker.NewCircuitBreaker(breakerSettings),
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
// Initialize grace period
|
||||
gate.graceUntil.Store(time.Now().Add(90 * time.Second))
|
||||
|
||||
return gate
|
||||
}
|
||||
|
||||
// ValidNow checks if the cached lease is currently valid
|
||||
func (c *cachedLease) ValidNow() bool {
|
||||
if !c.Valid {
|
||||
return false
|
||||
}
|
||||
// Consider lease invalid 2 minutes before actual expiry for safety margin
|
||||
return time.Now().Before(c.ExpiresAt.Add(-2 * time.Minute))
|
||||
}
|
||||
|
||||
// loadCachedLease safely loads the cached lease
|
||||
func (g *LicenseGate) loadCachedLease() *cachedLease {
|
||||
if cached := g.cache.Load(); cached != nil {
|
||||
if lease, ok := cached.(*cachedLease); ok {
|
||||
return lease
|
||||
}
|
||||
}
|
||||
return &cachedLease{Valid: false}
|
||||
}
|
||||
|
||||
// storeLease safely stores a lease in the cache
|
||||
func (g *LicenseGate) storeLease(lease *cachedLease) {
|
||||
lease.CachedAt = time.Now()
|
||||
g.cache.Store(lease)
|
||||
}
|
||||
|
||||
// isInGracePeriod checks if we're still in the grace period
|
||||
func (g *LicenseGate) isInGracePeriod() bool {
|
||||
if graceUntil := g.graceUntil.Load(); graceUntil != nil {
|
||||
if grace, ok := graceUntil.(time.Time); ok {
|
||||
return time.Now().Before(grace)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extendGracePeriod extends the grace period on successful validation
|
||||
func (g *LicenseGate) extendGracePeriod() {
|
||||
g.graceUntil.Store(time.Now().Add(90 * time.Second))
|
||||
}
|
||||
|
||||
// Validate validates the license using cache, lease system, and circuit breaker
|
||||
func (g *LicenseGate) Validate(ctx context.Context, agentID string) error {
|
||||
// Check cached lease first
|
||||
if lease := g.loadCachedLease(); lease.ValidNow() {
|
||||
return g.validateCachedLease(ctx, lease, agentID)
|
||||
}
|
||||
|
||||
// Try to get/renew lease through circuit breaker
|
||||
_, err := g.breaker.Execute(func() (interface{}, error) {
|
||||
lease, err := g.requestOrRenewLease(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate the new lease
|
||||
if err := g.validateLease(ctx, lease, agentID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store successful lease
|
||||
g.storeLease(&cachedLease{
|
||||
LeaseToken: lease.LeaseToken,
|
||||
ExpiresAt: lease.ExpiresAt,
|
||||
ClusterID: lease.ClusterID,
|
||||
Valid: true,
|
||||
})
|
||||
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// If we're in grace period, allow startup but log warning
|
||||
if g.isInGracePeriod() {
|
||||
fmt.Printf("⚠️ License validation failed but in grace period: %v\n", err)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("license validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Extend grace period on successful validation
|
||||
g.extendGracePeriod()
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCachedLease validates using cached lease token
|
||||
func (g *LicenseGate) validateCachedLease(ctx context.Context, lease *cachedLease, agentID string) error {
|
||||
validation := LeaseValidationRequest{
|
||||
LeaseToken: lease.LeaseToken,
|
||||
ClusterID: g.config.ClusterID,
|
||||
AgentID: agentID,
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/licenses/validate-lease", strings.TrimSuffix(g.config.KachingURL, "/"))
|
||||
|
||||
reqBody, err := json.Marshal(validation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal lease validation request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(reqBody)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create lease validation request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := g.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lease validation request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// If validation fails, invalidate cache
|
||||
lease.Valid = false
|
||||
g.storeLease(lease)
|
||||
return fmt.Errorf("lease validation failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var validationResp LeaseValidationResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&validationResp); err != nil {
|
||||
return fmt.Errorf("failed to decode lease validation response: %w", err)
|
||||
}
|
||||
|
||||
if !validationResp.Valid {
|
||||
// If validation fails, invalidate cache
|
||||
lease.Valid = false
|
||||
g.storeLease(lease)
|
||||
return fmt.Errorf("lease token is invalid")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// requestOrRenewLease requests a new cluster lease or renews existing one
|
||||
func (g *LicenseGate) requestOrRenewLease(ctx context.Context) (*LeaseResponse, error) {
|
||||
// For now, request a new lease (TODO: implement renewal logic)
|
||||
leaseReq := LeaseRequest{
|
||||
ClusterID: g.config.ClusterID,
|
||||
RequestedReplicas: 1, // Start with single replica
|
||||
DurationMinutes: 60, // 1 hour lease
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/licenses/%s/cluster-lease",
|
||||
strings.TrimSuffix(g.config.KachingURL, "/"), g.config.LicenseID)
|
||||
|
||||
reqBody, err := json.Marshal(leaseReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal lease request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(reqBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create lease request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := g.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lease request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
return nil, fmt.Errorf("rate limited by KACHING, retry after: %s", resp.Header.Get("Retry-After"))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("lease request failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var leaseResp LeaseResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&leaseResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode lease response: %w", err)
|
||||
}
|
||||
|
||||
return &leaseResp, nil
|
||||
}
|
||||
|
||||
// validateLease validates a lease token
|
||||
func (g *LicenseGate) validateLease(ctx context.Context, lease *LeaseResponse, agentID string) error {
|
||||
validation := LeaseValidationRequest{
|
||||
LeaseToken: lease.LeaseToken,
|
||||
ClusterID: lease.ClusterID,
|
||||
AgentID: agentID,
|
||||
}
|
||||
|
||||
return g.validateLeaseRequest(ctx, validation)
|
||||
}
|
||||
|
||||
// validateLeaseRequest performs the actual lease validation HTTP request
|
||||
func (g *LicenseGate) validateLeaseRequest(ctx context.Context, validation LeaseValidationRequest) error {
|
||||
url := fmt.Sprintf("%s/api/v1/licenses/validate-lease", strings.TrimSuffix(g.config.KachingURL, "/"))
|
||||
|
||||
reqBody, err := json.Marshal(validation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal lease validation request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(reqBody)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create lease validation request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := g.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lease validation request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("lease validation failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var validationResp LeaseValidationResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&validationResp); err != nil {
|
||||
return fmt.Errorf("failed to decode lease validation response: %w", err)
|
||||
}
|
||||
|
||||
if !validationResp.Valid {
|
||||
return fmt.Errorf("lease token is invalid")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCacheStats returns cache statistics for monitoring
|
||||
func (g *LicenseGate) GetCacheStats() map[string]interface{} {
|
||||
lease := g.loadCachedLease()
|
||||
stats := map[string]interface{}{
|
||||
"cache_valid": lease.Valid,
|
||||
"cache_hit": lease.ValidNow(),
|
||||
"expires_at": lease.ExpiresAt,
|
||||
"cached_at": lease.CachedAt,
|
||||
"in_grace_period": g.isInGracePeriod(),
|
||||
"breaker_state": g.breaker.State().String(),
|
||||
}
|
||||
|
||||
if grace := g.graceUntil.Load(); grace != nil {
|
||||
if graceTime, ok := grace.(time.Time); ok {
|
||||
stats["grace_until"] = graceTime
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package licensing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -21,35 +22,60 @@ type LicenseConfig struct {
|
||||
}
|
||||
|
||||
// Validator handles license validation with KACHING
|
||||
// Enhanced with license gate for burst-proof validation
|
||||
type Validator struct {
|
||||
config LicenseConfig
|
||||
kachingURL string
|
||||
client *http.Client
|
||||
gate *LicenseGate // New: License gate for scaling support
|
||||
}
|
||||
|
||||
// NewValidator creates a new license validator
|
||||
// NewValidator creates a new license validator with enhanced scaling support
|
||||
func NewValidator(config LicenseConfig) *Validator {
|
||||
kachingURL := config.KachingURL
|
||||
if kachingURL == "" {
|
||||
kachingURL = DefaultKachingURL
|
||||
}
|
||||
|
||||
return &Validator{
|
||||
|
||||
validator := &Validator{
|
||||
config: config,
|
||||
kachingURL: kachingURL,
|
||||
client: &http.Client{
|
||||
Timeout: LicenseTimeout,
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize license gate for scaling support
|
||||
validator.gate = NewLicenseGate(config)
|
||||
|
||||
return validator
|
||||
}
|
||||
|
||||
// Validate performs license validation with KACHING license authority
|
||||
// CRITICAL: CHORUS will not start without valid license validation
|
||||
// Enhanced with caching, circuit breaker, and lease token support
|
||||
func (v *Validator) Validate() error {
|
||||
return v.ValidateWithContext(context.Background())
|
||||
}
|
||||
|
||||
// ValidateWithContext performs license validation with context and agent ID
|
||||
func (v *Validator) ValidateWithContext(ctx context.Context) error {
|
||||
if v.config.LicenseID == "" || v.config.ClusterID == "" {
|
||||
return fmt.Errorf("license ID and cluster ID are required")
|
||||
}
|
||||
|
||||
// Use enhanced license gate for validation
|
||||
agentID := "default-agent" // TODO: Get from config/environment
|
||||
if err := v.gate.Validate(ctx, agentID); err != nil {
|
||||
// Fallback to legacy validation for backward compatibility
|
||||
fmt.Printf("⚠️ License gate validation failed, trying legacy validation: %v\n", err)
|
||||
return v.validateLegacy()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateLegacy performs the original license validation (for fallback)
|
||||
func (v *Validator) validateLegacy() error {
|
||||
// Prepare validation request
|
||||
request := map[string]interface{}{
|
||||
"license_id": v.config.LicenseID,
|
||||
@@ -66,7 +92,7 @@ func (v *Validator) Validate() error {
|
||||
return fmt.Errorf("failed to marshal license request: %w", err)
|
||||
}
|
||||
|
||||
// Call KACHING license authority
|
||||
// Call KACHING license authority
|
||||
licenseURL := fmt.Sprintf("%s/v1/license/activate", v.kachingURL)
|
||||
resp, err := v.client.Post(licenseURL, "application/json", bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/shhh"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
@@ -23,12 +25,14 @@ type HypercoreLog struct {
|
||||
entries []LogEntry
|
||||
mutex sync.RWMutex
|
||||
peerID peer.ID
|
||||
|
||||
|
||||
// Verification chain
|
||||
headHash string
|
||||
|
||||
|
||||
// Replication
|
||||
replicators map[peer.ID]*Replicator
|
||||
|
||||
redactor *shhh.Sentinel
|
||||
}
|
||||
|
||||
// LogEntry represents a single entry in the distributed log
|
||||
@@ -48,12 +52,12 @@ type LogType string
|
||||
|
||||
const (
|
||||
// Bzzz coordination logs
|
||||
TaskAnnounced LogType = "task_announced"
|
||||
TaskClaimed LogType = "task_claimed"
|
||||
TaskProgress LogType = "task_progress"
|
||||
TaskCompleted LogType = "task_completed"
|
||||
TaskFailed LogType = "task_failed"
|
||||
|
||||
TaskAnnounced LogType = "task_announced"
|
||||
TaskClaimed LogType = "task_claimed"
|
||||
TaskProgress LogType = "task_progress"
|
||||
TaskCompleted LogType = "task_completed"
|
||||
TaskFailed LogType = "task_failed"
|
||||
|
||||
// HMMM meta-discussion logs
|
||||
PlanProposed LogType = "plan_proposed"
|
||||
ObjectionRaised LogType = "objection_raised"
|
||||
@@ -65,17 +69,17 @@ const (
|
||||
TaskHelpReceived LogType = "task_help_received"
|
||||
|
||||
// System logs
|
||||
PeerJoined LogType = "peer_joined"
|
||||
PeerLeft LogType = "peer_left"
|
||||
PeerJoined LogType = "peer_joined"
|
||||
PeerLeft LogType = "peer_left"
|
||||
CapabilityBcast LogType = "capability_broadcast"
|
||||
NetworkEvent LogType = "network_event"
|
||||
NetworkEvent LogType = "network_event"
|
||||
)
|
||||
|
||||
// Replicator handles log replication with other peers
|
||||
type Replicator struct {
|
||||
peerID peer.ID
|
||||
peerID peer.ID
|
||||
lastSyncIndex uint64
|
||||
connected bool
|
||||
connected bool
|
||||
}
|
||||
|
||||
// NewHypercoreLog creates a new distributed log for a peer
|
||||
@@ -88,6 +92,13 @@ func NewHypercoreLog(peerID peer.ID) *HypercoreLog {
|
||||
}
|
||||
}
|
||||
|
||||
// SetRedactor wires the SHHH sentinel so log payloads are sanitized before persistence.
|
||||
func (h *HypercoreLog) SetRedactor(redactor *shhh.Sentinel) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
h.redactor = redactor
|
||||
}
|
||||
|
||||
// AppendString is a convenience method for string log types (to match interface)
|
||||
func (h *HypercoreLog) AppendString(logType string, data map[string]interface{}) error {
|
||||
_, err := h.Append(LogType(logType), data)
|
||||
@@ -98,38 +109,40 @@ func (h *HypercoreLog) AppendString(logType string, data map[string]interface{})
|
||||
func (h *HypercoreLog) Append(logType LogType, data map[string]interface{}) (*LogEntry, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
||||
index := uint64(len(h.entries))
|
||||
|
||||
|
||||
sanitized := h.redactData(logType, data)
|
||||
|
||||
entry := LogEntry{
|
||||
Index: index,
|
||||
Timestamp: time.Now(),
|
||||
Author: h.peerID.String(),
|
||||
Type: logType,
|
||||
Data: data,
|
||||
Data: sanitized,
|
||||
PrevHash: h.headHash,
|
||||
}
|
||||
|
||||
|
||||
// Calculate hash
|
||||
entryHash, err := h.calculateEntryHash(entry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to calculate entry hash: %w", err)
|
||||
}
|
||||
entry.Hash = entryHash
|
||||
|
||||
|
||||
// Add simple signature (in production, use proper cryptographic signatures)
|
||||
entry.Signature = h.createSignature(entry)
|
||||
|
||||
|
||||
// Append to log
|
||||
h.entries = append(h.entries, entry)
|
||||
h.headHash = entryHash
|
||||
|
||||
fmt.Printf("📝 Log entry appended: %s [%d] by %s\n",
|
||||
|
||||
fmt.Printf("📝 Log entry appended: %s [%d] by %s\n",
|
||||
logType, index, h.peerID.ShortString())
|
||||
|
||||
|
||||
// Trigger replication to connected peers
|
||||
go h.replicateEntry(entry)
|
||||
|
||||
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
@@ -137,11 +150,11 @@ func (h *HypercoreLog) Append(logType LogType, data map[string]interface{}) (*Lo
|
||||
func (h *HypercoreLog) Get(index uint64) (*LogEntry, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
|
||||
if index >= uint64(len(h.entries)) {
|
||||
return nil, fmt.Errorf("entry %d not found", index)
|
||||
}
|
||||
|
||||
|
||||
return &h.entries[index], nil
|
||||
}
|
||||
|
||||
@@ -149,7 +162,7 @@ func (h *HypercoreLog) Get(index uint64) (*LogEntry, error) {
|
||||
func (h *HypercoreLog) Length() uint64 {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
|
||||
return uint64(len(h.entries))
|
||||
}
|
||||
|
||||
@@ -157,22 +170,22 @@ func (h *HypercoreLog) Length() uint64 {
|
||||
func (h *HypercoreLog) GetRange(start, end uint64) ([]LogEntry, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
|
||||
if start >= uint64(len(h.entries)) {
|
||||
return nil, fmt.Errorf("start index %d out of range", start)
|
||||
}
|
||||
|
||||
|
||||
if end > uint64(len(h.entries)) {
|
||||
end = uint64(len(h.entries))
|
||||
}
|
||||
|
||||
|
||||
if start > end {
|
||||
return nil, fmt.Errorf("invalid range: start %d > end %d", start, end)
|
||||
}
|
||||
|
||||
|
||||
result := make([]LogEntry, end-start)
|
||||
copy(result, h.entries[start:end])
|
||||
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -180,14 +193,14 @@ func (h *HypercoreLog) GetRange(start, end uint64) ([]LogEntry, error) {
|
||||
func (h *HypercoreLog) GetEntriesByType(logType LogType) ([]LogEntry, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
|
||||
var result []LogEntry
|
||||
for _, entry := range h.entries {
|
||||
if entry.Type == logType {
|
||||
result = append(result, entry)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -195,14 +208,14 @@ func (h *HypercoreLog) GetEntriesByType(logType LogType) ([]LogEntry, error) {
|
||||
func (h *HypercoreLog) GetEntriesByAuthor(author string) ([]LogEntry, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
|
||||
var result []LogEntry
|
||||
for _, entry := range h.entries {
|
||||
if entry.Author == author {
|
||||
result = append(result, entry)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -210,20 +223,20 @@ func (h *HypercoreLog) GetEntriesByAuthor(author string) ([]LogEntry, error) {
|
||||
func (h *HypercoreLog) GetRecentEntries(count int) ([]LogEntry, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
|
||||
totalEntries := len(h.entries)
|
||||
if count <= 0 || totalEntries == 0 {
|
||||
return []LogEntry{}, nil
|
||||
}
|
||||
|
||||
|
||||
start := 0
|
||||
if totalEntries > count {
|
||||
start = totalEntries - count
|
||||
}
|
||||
|
||||
|
||||
result := make([]LogEntry, totalEntries-start)
|
||||
copy(result, h.entries[start:])
|
||||
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -231,14 +244,14 @@ func (h *HypercoreLog) GetRecentEntries(count int) ([]LogEntry, error) {
|
||||
func (h *HypercoreLog) GetEntriesSince(sinceIndex uint64) ([]LogEntry, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
|
||||
if sinceIndex >= uint64(len(h.entries)) {
|
||||
return []LogEntry{}, nil
|
||||
}
|
||||
|
||||
|
||||
result := make([]LogEntry, len(h.entries)-int(sinceIndex))
|
||||
copy(result, h.entries[sinceIndex:])
|
||||
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -246,27 +259,27 @@ func (h *HypercoreLog) GetEntriesSince(sinceIndex uint64) ([]LogEntry, error) {
|
||||
func (h *HypercoreLog) VerifyIntegrity() error {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
|
||||
var prevHash string
|
||||
for i, entry := range h.entries {
|
||||
// Verify previous hash link
|
||||
if entry.PrevHash != prevHash {
|
||||
return fmt.Errorf("integrity error at entry %d: prev_hash mismatch", i)
|
||||
}
|
||||
|
||||
|
||||
// Verify entry hash
|
||||
calculatedHash, err := h.calculateEntryHash(entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate hash for entry %d: %w", i, err)
|
||||
}
|
||||
|
||||
|
||||
if entry.Hash != calculatedHash {
|
||||
return fmt.Errorf("integrity error at entry %d: hash mismatch", i)
|
||||
}
|
||||
|
||||
|
||||
prevHash = entry.Hash
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -274,13 +287,13 @@ func (h *HypercoreLog) VerifyIntegrity() error {
|
||||
func (h *HypercoreLog) AddReplicator(peerID peer.ID) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
||||
h.replicators[peerID] = &Replicator{
|
||||
peerID: peerID,
|
||||
peerID: peerID,
|
||||
lastSyncIndex: 0,
|
||||
connected: true,
|
||||
connected: true,
|
||||
}
|
||||
|
||||
|
||||
fmt.Printf("🔄 Added replicator: %s\n", peerID.ShortString())
|
||||
}
|
||||
|
||||
@@ -288,7 +301,7 @@ func (h *HypercoreLog) AddReplicator(peerID peer.ID) {
|
||||
func (h *HypercoreLog) RemoveReplicator(peerID peer.ID) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
||||
delete(h.replicators, peerID)
|
||||
fmt.Printf("🔄 Removed replicator: %s\n", peerID.ShortString())
|
||||
}
|
||||
@@ -303,10 +316,10 @@ func (h *HypercoreLog) replicateEntry(entry LogEntry) {
|
||||
}
|
||||
}
|
||||
h.mutex.RUnlock()
|
||||
|
||||
|
||||
for _, replicator := range replicators {
|
||||
// In a real implementation, this would send the entry over the network
|
||||
fmt.Printf("🔄 Replicating entry %d to %s\n",
|
||||
fmt.Printf("🔄 Replicating entry %d to %s\n",
|
||||
entry.Index, replicator.peerID.ShortString())
|
||||
}
|
||||
}
|
||||
@@ -322,16 +335,74 @@ func (h *HypercoreLog) calculateEntryHash(entry LogEntry) (string, error) {
|
||||
Data: entry.Data,
|
||||
PrevHash: entry.PrevHash,
|
||||
}
|
||||
|
||||
|
||||
entryBytes, err := json.Marshal(entryForHash)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
||||
hash := sha256.Sum256(entryBytes)
|
||||
return hex.EncodeToString(hash[:]), nil
|
||||
}
|
||||
|
||||
func (h *HypercoreLog) redactData(logType LogType, data map[string]interface{}) map[string]interface{} {
|
||||
cloned := cloneLogMap(data)
|
||||
if cloned == nil {
|
||||
return nil
|
||||
}
|
||||
if h.redactor != nil {
|
||||
labels := map[string]string{
|
||||
"source": "hypercore",
|
||||
"log_type": string(logType),
|
||||
}
|
||||
h.redactor.RedactMapWithLabels(context.Background(), cloned, labels)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneLogMap(in map[string]interface{}) map[string]interface{} {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]interface{}, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = cloneLogValue(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// @goal: CHORUS-REQ-001 - Fix duplicate type case compilation error
|
||||
// WHY: Go 1.18+ treats interface{} and any as identical types, causing duplicate case errors
|
||||
func cloneLogValue(v interface{}) interface{} {
|
||||
switch tv := v.(type) {
|
||||
case map[string]any:
|
||||
// @goal: CHORUS-REQ-001 - Convert any to interface{} for cloneLogMap compatibility
|
||||
converted := make(map[string]interface{}, len(tv))
|
||||
for k, val := range tv {
|
||||
converted[k] = val
|
||||
}
|
||||
return cloneLogMap(converted)
|
||||
case []any:
|
||||
converted := make([]interface{}, len(tv))
|
||||
for i, val := range tv {
|
||||
converted[i] = cloneLogValue(val)
|
||||
}
|
||||
return converted
|
||||
case []string:
|
||||
return append([]string(nil), tv...)
|
||||
default:
|
||||
return tv
|
||||
}
|
||||
}
|
||||
|
||||
func cloneLogSlice(in []interface{}) []interface{} {
|
||||
out := make([]interface{}, len(in))
|
||||
for i, val := range in {
|
||||
out[i] = cloneLogValue(val)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// createSignature creates a simplified signature for the entry
|
||||
func (h *HypercoreLog) createSignature(entry LogEntry) string {
|
||||
// In production, this would use proper cryptographic signatures
|
||||
@@ -345,21 +416,21 @@ func (h *HypercoreLog) createSignature(entry LogEntry) string {
|
||||
func (h *HypercoreLog) GetStats() map[string]interface{} {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
|
||||
typeCount := make(map[LogType]int)
|
||||
authorCount := make(map[string]int)
|
||||
|
||||
|
||||
for _, entry := range h.entries {
|
||||
typeCount[entry.Type]++
|
||||
authorCount[entry.Author]++
|
||||
}
|
||||
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_entries": len(h.entries),
|
||||
"head_hash": h.headHash,
|
||||
"replicators": len(h.replicators),
|
||||
"entries_by_type": typeCount,
|
||||
"total_entries": len(h.entries),
|
||||
"head_hash": h.headHash,
|
||||
"replicators": len(h.replicators),
|
||||
"entries_by_type": typeCount,
|
||||
"entries_by_author": authorCount,
|
||||
"peer_id": h.peerID.String(),
|
||||
"peer_id": h.peerID.String(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package runtime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"chorus/internal/logging"
|
||||
"chorus/pkg/dht"
|
||||
"chorus/pkg/health"
|
||||
"chorus/pkg/shutdown"
|
||||
"chorus/pubsub"
|
||||
@@ -43,37 +45,37 @@ func (r *SharedRuntime) StartAgentMode() error {
|
||||
|
||||
// === Comprehensive Health Monitoring & Graceful Shutdown ===
|
||||
shutdownManager := shutdown.NewManager(30*time.Second, &simpleLogger{logger: r.Logger})
|
||||
|
||||
|
||||
healthManager := health.NewManager(r.Node.ID().ShortString(), AppVersion, &simpleLogger{logger: r.Logger})
|
||||
healthManager.SetShutdownManager(shutdownManager)
|
||||
|
||||
|
||||
// Register health checks
|
||||
r.setupHealthChecks(healthManager)
|
||||
|
||||
|
||||
// Register components for graceful shutdown
|
||||
r.setupGracefulShutdown(shutdownManager, healthManager)
|
||||
|
||||
|
||||
// Start health monitoring
|
||||
if err := healthManager.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
r.HealthManager = healthManager
|
||||
r.Logger.Info("❤️ Health monitoring started")
|
||||
|
||||
|
||||
// Start health HTTP server
|
||||
if err := healthManager.StartHTTPServer(r.Config.Network.HealthPort); err != nil {
|
||||
r.Logger.Error("❌ Failed to start health HTTP server: %v", err)
|
||||
} else {
|
||||
r.Logger.Info("🏥 Health endpoints available at http://localhost:%d/health", r.Config.Network.HealthPort)
|
||||
}
|
||||
|
||||
|
||||
// Start shutdown manager
|
||||
shutdownManager.Start()
|
||||
r.ShutdownManager = shutdownManager
|
||||
r.Logger.Info("🛡️ Graceful shutdown manager started")
|
||||
|
||||
|
||||
r.Logger.Info("✅ CHORUS agent system fully operational with health monitoring")
|
||||
|
||||
|
||||
// Wait for graceful shutdown
|
||||
shutdownManager.Wait()
|
||||
r.Logger.Info("✅ CHORUS agent system shutdown completed")
|
||||
@@ -90,7 +92,7 @@ func (r *SharedRuntime) announceAvailability() {
|
||||
currentTasks := r.TaskTracker.GetActiveTasks()
|
||||
maxTasks := r.TaskTracker.GetMaxTasks()
|
||||
isAvailable := len(currentTasks) < maxTasks
|
||||
|
||||
|
||||
status := "ready"
|
||||
if len(currentTasks) >= maxTasks {
|
||||
status = "busy"
|
||||
@@ -99,13 +101,13 @@ func (r *SharedRuntime) announceAvailability() {
|
||||
}
|
||||
|
||||
availability := map[string]interface{}{
|
||||
"node_id": r.Node.ID().ShortString(),
|
||||
"node_id": r.Node.ID().ShortString(),
|
||||
"available_for_work": isAvailable,
|
||||
"current_tasks": len(currentTasks),
|
||||
"max_tasks": maxTasks,
|
||||
"last_activity": time.Now().Unix(),
|
||||
"status": status,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"current_tasks": len(currentTasks),
|
||||
"max_tasks": maxTasks,
|
||||
"last_activity": time.Now().Unix(),
|
||||
"status": status,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
if err := r.PubSub.PublishBzzzMessage(pubsub.AvailabilityBcast, availability); err != nil {
|
||||
r.Logger.Error("❌ Failed to announce availability: %v", err)
|
||||
@@ -126,16 +128,79 @@ func (r *SharedRuntime) statusReporter() {
|
||||
|
||||
// announceCapabilitiesOnChange announces capabilities when they change
|
||||
func (r *SharedRuntime) announceCapabilitiesOnChange() {
|
||||
// Implementation from CHORUS would go here
|
||||
// For now, just log that capabilities would be announced
|
||||
r.Logger.Info("📢 Agent capabilities announcement enabled")
|
||||
if r.PubSub == nil {
|
||||
r.Logger.Warn("⚠️ Capability broadcast skipped: PubSub not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
r.Logger.Info("📢 Broadcasting agent capabilities to network")
|
||||
|
||||
activeTaskCount := 0
|
||||
if r.TaskTracker != nil {
|
||||
activeTaskCount = len(r.TaskTracker.GetActiveTasks())
|
||||
}
|
||||
|
||||
announcement := map[string]interface{}{
|
||||
"agent_id": r.Config.Agent.ID,
|
||||
"node_id": r.Node.ID().ShortString(),
|
||||
"version": AppVersion,
|
||||
"capabilities": r.Config.Agent.Capabilities,
|
||||
"expertise": r.Config.Agent.Expertise,
|
||||
"models": r.Config.Agent.Models,
|
||||
"specialization": r.Config.Agent.Specialization,
|
||||
"max_tasks": r.Config.Agent.MaxTasks,
|
||||
"current_tasks": activeTaskCount,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"availability": "ready",
|
||||
}
|
||||
|
||||
if err := r.PubSub.PublishBzzzMessage(pubsub.CapabilityBcast, announcement); err != nil {
|
||||
r.Logger.Error("❌ Failed to broadcast capabilities: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r.Logger.Info("✅ Capabilities broadcast published")
|
||||
|
||||
// TODO: Watch for live capability changes (role updates, model changes) and re-broadcast
|
||||
}
|
||||
|
||||
// announceRoleOnStartup announces role when the agent starts
|
||||
func (r *SharedRuntime) announceRoleOnStartup() {
|
||||
// Implementation from CHORUS would go here
|
||||
// For now, just log that role would be announced
|
||||
r.Logger.Info("🎭 Agent role announcement enabled")
|
||||
role := r.Config.Agent.Role
|
||||
if role == "" {
|
||||
r.Logger.Info("🎭 No agent role configured; skipping role announcement")
|
||||
return
|
||||
}
|
||||
if r.PubSub == nil {
|
||||
r.Logger.Warn("⚠️ Role announcement skipped: PubSub not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
r.Logger.Info("🎭 Announcing agent role to collaboration mesh")
|
||||
|
||||
announcement := map[string]interface{}{
|
||||
"agent_id": r.Config.Agent.ID,
|
||||
"node_id": r.Node.ID().ShortString(),
|
||||
"role": role,
|
||||
"expertise": r.Config.Agent.Expertise,
|
||||
"capabilities": r.Config.Agent.Capabilities,
|
||||
"reports_to": r.Config.Agent.ReportsTo,
|
||||
"specialization": r.Config.Agent.Specialization,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
opts := pubsub.MessageOptions{
|
||||
FromRole: role,
|
||||
Priority: "medium",
|
||||
ThreadID: fmt.Sprintf("role:%s", role),
|
||||
}
|
||||
|
||||
if err := r.PubSub.PublishRoleBasedMessage(pubsub.RoleAnnouncement, announcement, opts); err != nil {
|
||||
r.Logger.Error("❌ Failed to announce role: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r.Logger.Info("✅ Role announcement published")
|
||||
}
|
||||
|
||||
func (r *SharedRuntime) setupHealthChecks(healthManager *health.Manager) {
|
||||
@@ -151,31 +216,108 @@ func (r *SharedRuntime) setupHealthChecks(healthManager *health.Manager) {
|
||||
Checker: func(ctx context.Context) health.CheckResult {
|
||||
healthInfo := r.BackbeatIntegration.GetHealth()
|
||||
connected, _ := healthInfo["connected"].(bool)
|
||||
|
||||
|
||||
result := health.CheckResult{
|
||||
Healthy: connected,
|
||||
Details: healthInfo,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
if connected {
|
||||
result.Message = "BACKBEAT integration healthy and connected"
|
||||
} else {
|
||||
result.Message = "BACKBEAT integration not connected"
|
||||
}
|
||||
|
||||
|
||||
return result
|
||||
},
|
||||
}
|
||||
healthManager.RegisterCheck(backbeatCheck)
|
||||
}
|
||||
|
||||
// Add other health checks (P2P, DHT, etc.)
|
||||
// Implementation from CHORUS would go here
|
||||
|
||||
// Register enhanced health instrumentation when core subsystems are available
|
||||
if r.PubSub == nil {
|
||||
r.Logger.Warn("⚠️ Skipping enhanced health checks: PubSub not initialized")
|
||||
return
|
||||
}
|
||||
if r.ElectionManager == nil {
|
||||
r.Logger.Warn("⚠️ Skipping enhanced health checks: election manager not ready")
|
||||
return
|
||||
}
|
||||
|
||||
var replication *dht.ReplicationManager
|
||||
if r.DHTNode != nil {
|
||||
replication = r.DHTNode.ReplicationManager()
|
||||
}
|
||||
|
||||
enhanced := health.NewEnhancedHealthChecks(
|
||||
healthManager,
|
||||
r.ElectionManager,
|
||||
r.DHTNode,
|
||||
r.PubSub,
|
||||
replication,
|
||||
&simpleLogger{logger: r.Logger},
|
||||
)
|
||||
|
||||
r.EnhancedHealth = enhanced
|
||||
r.Logger.Info("🩺 Enhanced health checks registered")
|
||||
}
|
||||
|
||||
func (r *SharedRuntime) setupGracefulShutdown(shutdownManager *shutdown.Manager, healthManager *health.Manager) {
|
||||
// Register components for graceful shutdown
|
||||
// Implementation would register all components that need graceful shutdown
|
||||
if shutdownManager == nil {
|
||||
r.Logger.Warn("⚠️ Shutdown manager not initialized; graceful teardown skipped")
|
||||
return
|
||||
}
|
||||
|
||||
if r.HTTPServer != nil {
|
||||
httpComponent := shutdown.NewGenericComponent("http-api-server", 10, true).
|
||||
SetShutdownFunc(func(ctx context.Context) error {
|
||||
return r.HTTPServer.Stop()
|
||||
})
|
||||
shutdownManager.Register(httpComponent)
|
||||
}
|
||||
|
||||
if healthManager != nil {
|
||||
healthComponent := shutdown.NewGenericComponent("health-manager", 15, true).
|
||||
SetShutdownFunc(func(ctx context.Context) error {
|
||||
return healthManager.Stop()
|
||||
})
|
||||
shutdownManager.Register(healthComponent)
|
||||
}
|
||||
|
||||
if r.UCXIServer != nil {
|
||||
ucxiComponent := shutdown.NewGenericComponent("ucxi-server", 20, true).
|
||||
SetShutdownFunc(func(ctx context.Context) error {
|
||||
return r.UCXIServer.Stop()
|
||||
})
|
||||
shutdownManager.Register(ucxiComponent)
|
||||
}
|
||||
|
||||
if r.PubSub != nil {
|
||||
shutdownManager.Register(shutdown.NewPubSubComponent("pubsub", r.PubSub.Close, 30))
|
||||
}
|
||||
|
||||
if r.DHTNode != nil {
|
||||
dhtComponent := shutdown.NewGenericComponent("dht-node", 35, true).
|
||||
SetCloser(r.DHTNode.Close)
|
||||
shutdownManager.Register(dhtComponent)
|
||||
}
|
||||
|
||||
if r.Node != nil {
|
||||
shutdownManager.Register(shutdown.NewP2PNodeComponent("p2p-node", r.Node.Close, 40))
|
||||
}
|
||||
|
||||
if r.ElectionManager != nil {
|
||||
shutdownManager.Register(shutdown.NewElectionManagerComponent("election-manager", r.ElectionManager.Stop, 45))
|
||||
}
|
||||
|
||||
if r.BackbeatIntegration != nil {
|
||||
backbeatComponent := shutdown.NewGenericComponent("backbeat-integration", 50, true).
|
||||
SetShutdownFunc(func(ctx context.Context) error {
|
||||
return r.BackbeatIntegration.Stop()
|
||||
})
|
||||
shutdownManager.Register(backbeatComponent)
|
||||
}
|
||||
|
||||
r.Logger.Info("🛡️ Graceful shutdown components registered")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,8 +21,10 @@ import (
|
||||
"chorus/pkg/dht"
|
||||
"chorus/pkg/election"
|
||||
"chorus/pkg/health"
|
||||
"chorus/pkg/shutdown"
|
||||
"chorus/pkg/metrics"
|
||||
"chorus/pkg/prompt"
|
||||
"chorus/pkg/shhh"
|
||||
"chorus/pkg/shutdown"
|
||||
"chorus/pkg/ucxi"
|
||||
"chorus/pkg/ucxl"
|
||||
"chorus/pubsub"
|
||||
@@ -31,9 +33,12 @@ import (
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
const (
|
||||
AppName = "CHORUS"
|
||||
AppVersion = "0.1.0-dev"
|
||||
// Build information - set by main package
|
||||
var (
|
||||
AppName = "CHORUS"
|
||||
AppVersion = "0.1.0-dev"
|
||||
AppCommitHash = "unknown"
|
||||
AppBuildDate = "unknown"
|
||||
)
|
||||
|
||||
// SimpleLogger provides basic logging implementation
|
||||
@@ -53,8 +58,8 @@ func (l *SimpleLogger) Error(msg string, args ...interface{}) {
|
||||
|
||||
// SimpleTaskTracker tracks active tasks for availability reporting
|
||||
type SimpleTaskTracker struct {
|
||||
maxTasks int
|
||||
activeTasks map[string]bool
|
||||
maxTasks int
|
||||
activeTasks map[string]bool
|
||||
decisionPublisher *ucxl.DecisionPublisher
|
||||
}
|
||||
|
||||
@@ -80,7 +85,7 @@ func (t *SimpleTaskTracker) AddTask(taskID string) {
|
||||
// RemoveTask marks a task as completed and publishes decision if publisher available
|
||||
func (t *SimpleTaskTracker) RemoveTask(taskID string) {
|
||||
delete(t.activeTasks, taskID)
|
||||
|
||||
|
||||
// Publish task completion decision if publisher is available
|
||||
if t.decisionPublisher != nil {
|
||||
t.publishTaskCompletion(taskID, true, "Task completed successfully", nil)
|
||||
@@ -92,7 +97,7 @@ func (t *SimpleTaskTracker) publishTaskCompletion(taskID string, success bool, s
|
||||
if t.decisionPublisher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
if err := t.decisionPublisher.PublishTaskCompletion(taskID, success, summary, filesModified); err != nil {
|
||||
fmt.Printf("⚠️ Failed to publish task completion for %s: %v\n", taskID, err)
|
||||
} else {
|
||||
@@ -102,37 +107,41 @@ func (t *SimpleTaskTracker) publishTaskCompletion(taskID string, success bool, s
|
||||
|
||||
// SharedRuntime contains all the shared P2P infrastructure components
|
||||
type SharedRuntime struct {
|
||||
Config *config.Config
|
||||
Logger *SimpleLogger
|
||||
Context context.Context
|
||||
Cancel context.CancelFunc
|
||||
Node *p2p.Node
|
||||
PubSub *pubsub.PubSub
|
||||
HypercoreLog *logging.HypercoreLog
|
||||
MDNSDiscovery *discovery.MDNSDiscovery
|
||||
BackbeatIntegration *backbeat.Integration
|
||||
DHTNode *dht.LibP2PDHT
|
||||
EncryptedStorage *dht.EncryptedDHTStorage
|
||||
DecisionPublisher *ucxl.DecisionPublisher
|
||||
ElectionManager *election.ElectionManager
|
||||
TaskCoordinator *coordinator.TaskCoordinator
|
||||
HTTPServer *api.HTTPServer
|
||||
UCXIServer *ucxi.Server
|
||||
HealthManager *health.Manager
|
||||
ShutdownManager *shutdown.Manager
|
||||
TaskTracker *SimpleTaskTracker
|
||||
Config *config.Config
|
||||
RuntimeConfig *config.RuntimeConfig
|
||||
Logger *SimpleLogger
|
||||
Context context.Context
|
||||
Cancel context.CancelFunc
|
||||
Node *p2p.Node
|
||||
PubSub *pubsub.PubSub
|
||||
HypercoreLog *logging.HypercoreLog
|
||||
MDNSDiscovery *discovery.MDNSDiscovery
|
||||
BackbeatIntegration *backbeat.Integration
|
||||
DHTNode *dht.LibP2PDHT
|
||||
EncryptedStorage *dht.EncryptedDHTStorage
|
||||
DecisionPublisher *ucxl.DecisionPublisher
|
||||
ElectionManager *election.ElectionManager
|
||||
TaskCoordinator *coordinator.TaskCoordinator
|
||||
HTTPServer *api.HTTPServer
|
||||
UCXIServer *ucxi.Server
|
||||
HealthManager *health.Manager
|
||||
EnhancedHealth *health.EnhancedHealthChecks
|
||||
ShutdownManager *shutdown.Manager
|
||||
TaskTracker *SimpleTaskTracker
|
||||
Metrics *metrics.CHORUSMetrics
|
||||
Shhh *shhh.Sentinel
|
||||
}
|
||||
|
||||
// Initialize sets up all shared P2P infrastructure components
|
||||
func Initialize(appMode string) (*SharedRuntime, error) {
|
||||
runtime := &SharedRuntime{}
|
||||
runtime.Logger = &SimpleLogger{}
|
||||
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
runtime.Context = ctx
|
||||
runtime.Cancel = cancel
|
||||
|
||||
runtime.Logger.Info("🎭 Starting CHORUS v%s - Container-First P2P Task Coordination", AppVersion)
|
||||
runtime.Logger.Info("🎭 Starting CHORUS v%s (build: %s, %s) - Container-First P2P Task Coordination", AppVersion, AppCommitHash, AppBuildDate)
|
||||
runtime.Logger.Info("📦 Container deployment - Mode: %s", appMode)
|
||||
|
||||
// Load configuration from environment (no config files in containers)
|
||||
@@ -142,8 +151,30 @@ func Initialize(appMode string) (*SharedRuntime, error) {
|
||||
return nil, fmt.Errorf("configuration error: %v", err)
|
||||
}
|
||||
runtime.Config = cfg
|
||||
|
||||
|
||||
runtime.Logger.Info("✅ Configuration loaded successfully")
|
||||
|
||||
// Initialize runtime configuration with assignment support
|
||||
runtime.RuntimeConfig = config.NewRuntimeConfig(cfg)
|
||||
|
||||
// Load assignment if ASSIGN_URL is configured
|
||||
if assignURL := os.Getenv("ASSIGN_URL"); assignURL != "" {
|
||||
runtime.Logger.Info("📡 Loading assignment from WHOOSH: %s", assignURL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(runtime.Context, 10*time.Second)
|
||||
if err := runtime.RuntimeConfig.LoadAssignment(ctx, assignURL); err != nil {
|
||||
runtime.Logger.Warn("⚠️ Failed to load assignment (continuing with base config): %v", err)
|
||||
} else {
|
||||
runtime.Logger.Info("✅ Assignment loaded successfully")
|
||||
}
|
||||
cancel()
|
||||
|
||||
// Start reload handler for SIGHUP
|
||||
runtime.RuntimeConfig.StartReloadHandler(runtime.Context, assignURL)
|
||||
runtime.Logger.Info("📡 SIGHUP reload handler started for assignment updates")
|
||||
} else {
|
||||
runtime.Logger.Info("⚪ No ASSIGN_URL configured, using static configuration")
|
||||
}
|
||||
runtime.Logger.Info("🤖 Agent ID: %s", cfg.Agent.ID)
|
||||
runtime.Logger.Info("🎯 Specialization: %s", cfg.Agent.Specialization)
|
||||
|
||||
@@ -166,6 +197,21 @@ func Initialize(appMode string) (*SharedRuntime, error) {
|
||||
}
|
||||
runtime.Logger.Info("✅ AI provider configured successfully")
|
||||
|
||||
// Initialize metrics collector
|
||||
runtime.Metrics = metrics.NewCHORUSMetrics(nil)
|
||||
|
||||
// Initialize SHHH sentinel
|
||||
sentinel, err := shhh.NewSentinel(
|
||||
shhh.Config{},
|
||||
shhh.WithFindingObserver(runtime.handleShhhFindings),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize SHHH sentinel: %v", err)
|
||||
}
|
||||
sentinel.SetAuditSink(&shhhAuditSink{logger: runtime.Logger})
|
||||
runtime.Shhh = sentinel
|
||||
runtime.Logger.Info("🛡️ SHHH sentinel initialized")
|
||||
|
||||
// Initialize BACKBEAT integration
|
||||
var backbeatIntegration *backbeat.Integration
|
||||
backbeatIntegration, err = backbeat.NewIntegration(cfg, cfg.Agent.ID, runtime.Logger)
|
||||
@@ -198,6 +244,9 @@ func Initialize(appMode string) (*SharedRuntime, error) {
|
||||
|
||||
// Initialize Hypercore-style logger for P2P coordination
|
||||
hlog := logging.NewHypercoreLog(node.ID())
|
||||
if runtime.Shhh != nil {
|
||||
hlog.SetRedactor(runtime.Shhh)
|
||||
}
|
||||
hlog.Append(logging.PeerJoined, map[string]interface{}{"status": "started"})
|
||||
runtime.HypercoreLog = hlog
|
||||
runtime.Logger.Info("📝 Hypercore logger initialized")
|
||||
@@ -214,8 +263,11 @@ func Initialize(appMode string) (*SharedRuntime, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create PubSub: %v", err)
|
||||
}
|
||||
if runtime.Shhh != nil {
|
||||
ps.SetRedactor(runtime.Shhh)
|
||||
}
|
||||
runtime.PubSub = ps
|
||||
|
||||
|
||||
runtime.Logger.Info("📡 PubSub system initialized")
|
||||
|
||||
// Join role-based topics if role is configured
|
||||
@@ -257,6 +309,7 @@ func (r *SharedRuntime) Cleanup() {
|
||||
|
||||
if r.MDNSDiscovery != nil {
|
||||
r.MDNSDiscovery.Close()
|
||||
r.Logger.Info("🔍 mDNS discovery closed")
|
||||
}
|
||||
|
||||
if r.PubSub != nil {
|
||||
@@ -294,12 +347,12 @@ func (r *SharedRuntime) Cleanup() {
|
||||
func (r *SharedRuntime) initializeElectionSystem() error {
|
||||
// === Admin Election System ===
|
||||
electionManager := election.NewElectionManager(r.Context, r.Config, r.Node.Host(), r.PubSub, r.Node.ID().ShortString())
|
||||
|
||||
|
||||
// Set election callbacks with BACKBEAT integration
|
||||
electionManager.SetCallbacks(
|
||||
func(oldAdmin, newAdmin string) {
|
||||
r.Logger.Info("👑 Admin changed: %s -> %s", oldAdmin, newAdmin)
|
||||
|
||||
|
||||
// Track admin change with BACKBEAT if available
|
||||
if r.BackbeatIntegration != nil {
|
||||
operationID := fmt.Sprintf("admin-change-%d", time.Now().Unix())
|
||||
@@ -311,7 +364,7 @@ func (r *SharedRuntime) initializeElectionSystem() error {
|
||||
r.BackbeatIntegration.CompleteP2POperation(operationID, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// If this node becomes admin, enable SLURP functionality
|
||||
if newAdmin == r.Node.ID().ShortString() {
|
||||
r.Logger.Info("🎯 This node is now admin - enabling SLURP functionality")
|
||||
@@ -324,12 +377,12 @@ func (r *SharedRuntime) initializeElectionSystem() error {
|
||||
},
|
||||
func(winner string) {
|
||||
r.Logger.Info("🏆 Election completed, winner: %s", winner)
|
||||
|
||||
|
||||
// Track election completion with BACKBEAT if available
|
||||
if r.BackbeatIntegration != nil {
|
||||
operationID := fmt.Sprintf("election-completed-%d", time.Now().Unix())
|
||||
if err := r.BackbeatIntegration.StartP2POperation(operationID, "election", 1, map[string]interface{}{
|
||||
"winner": winner,
|
||||
"winner": winner,
|
||||
"node_id": r.Node.ID().ShortString(),
|
||||
}); err == nil {
|
||||
r.BackbeatIntegration.CompleteP2POperation(operationID, 1)
|
||||
@@ -337,22 +390,22 @@ func (r *SharedRuntime) initializeElectionSystem() error {
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if err := electionManager.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start election manager: %v", err)
|
||||
}
|
||||
r.ElectionManager = electionManager
|
||||
r.Logger.Info("✅ Election manager started with automated heartbeat management")
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SharedRuntime) initializeDHTStorage() error {
|
||||
// === DHT Storage and Decision Publishing ===
|
||||
var dhtNode *dht.LibP2PDHT
|
||||
var encryptedStorage *dht.EncryptedDHTStorage
|
||||
var encryptedStorage *dht.EncryptedDHTStorage
|
||||
var decisionPublisher *ucxl.DecisionPublisher
|
||||
|
||||
|
||||
if r.Config.V2.DHT.Enabled {
|
||||
// Create DHT
|
||||
var err error
|
||||
@@ -361,14 +414,14 @@ func (r *SharedRuntime) initializeDHTStorage() error {
|
||||
r.Logger.Warn("⚠️ Failed to create DHT: %v", err)
|
||||
} else {
|
||||
r.Logger.Info("🕸️ DHT initialized")
|
||||
|
||||
|
||||
// Bootstrap DHT with BACKBEAT tracking
|
||||
if r.BackbeatIntegration != nil {
|
||||
operationID := fmt.Sprintf("dht-bootstrap-%d", time.Now().Unix())
|
||||
if err := r.BackbeatIntegration.StartP2POperation(operationID, "dht_bootstrap", 4, nil); err == nil {
|
||||
r.BackbeatIntegration.UpdateP2POperationPhase(operationID, backbeat.PhaseConnecting, 0)
|
||||
}
|
||||
|
||||
|
||||
if err := dhtNode.Bootstrap(); err != nil {
|
||||
r.Logger.Warn("⚠️ DHT bootstrap failed: %v", err)
|
||||
r.BackbeatIntegration.FailP2POperation(operationID, err.Error())
|
||||
@@ -380,22 +433,34 @@ func (r *SharedRuntime) initializeDHTStorage() error {
|
||||
r.Logger.Warn("⚠️ DHT bootstrap failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to bootstrap peers if configured
|
||||
for _, addrStr := range r.Config.V2.DHT.BootstrapPeers {
|
||||
|
||||
// Connect to bootstrap peers (with assignment override support)
|
||||
bootstrapPeers := r.RuntimeConfig.GetBootstrapPeers()
|
||||
if len(bootstrapPeers) == 0 {
|
||||
bootstrapPeers = r.Config.V2.DHT.BootstrapPeers
|
||||
}
|
||||
|
||||
// Apply join stagger if configured
|
||||
joinStagger := r.RuntimeConfig.GetJoinStagger()
|
||||
if joinStagger > 0 {
|
||||
r.Logger.Info("⏱️ Applying join stagger delay: %v", joinStagger)
|
||||
time.Sleep(joinStagger)
|
||||
}
|
||||
|
||||
for _, addrStr := range bootstrapPeers {
|
||||
addr, err := multiaddr.NewMultiaddr(addrStr)
|
||||
if err != nil {
|
||||
r.Logger.Warn("⚠️ Invalid bootstrap address %s: %v", addrStr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Extract peer info from multiaddr
|
||||
info, err := peer.AddrInfoFromP2pAddr(addr)
|
||||
if err != nil {
|
||||
r.Logger.Warn("⚠️ Failed to parse peer info from %s: %v", addrStr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Track peer discovery with BACKBEAT if available
|
||||
if r.BackbeatIntegration != nil {
|
||||
operationID := fmt.Sprintf("peer-discovery-%d", time.Now().Unix())
|
||||
@@ -403,7 +468,7 @@ func (r *SharedRuntime) initializeDHTStorage() error {
|
||||
"peer_addr": addrStr,
|
||||
}); err == nil {
|
||||
r.BackbeatIntegration.UpdateP2POperationPhase(operationID, backbeat.PhaseConnecting, 0)
|
||||
|
||||
|
||||
if err := r.Node.Host().Connect(r.Context, *info); err != nil {
|
||||
r.Logger.Warn("⚠️ Failed to connect to bootstrap peer %s: %v", addrStr, err)
|
||||
r.BackbeatIntegration.FailP2POperation(operationID, err.Error())
|
||||
@@ -420,20 +485,20 @@ func (r *SharedRuntime) initializeDHTStorage() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Initialize encrypted storage
|
||||
encryptedStorage = dht.NewEncryptedDHTStorage(
|
||||
r.Context,
|
||||
r.Node.Host(),
|
||||
r.Node.Host(),
|
||||
dhtNode,
|
||||
r.Config,
|
||||
r.Node.ID().ShortString(),
|
||||
)
|
||||
|
||||
|
||||
// Start cache cleanup
|
||||
encryptedStorage.StartCacheCleanup(5 * time.Minute)
|
||||
r.Logger.Info("🔐 Encrypted DHT storage initialized")
|
||||
|
||||
|
||||
// Initialize decision publisher
|
||||
decisionPublisher = ucxl.NewDecisionPublisher(
|
||||
r.Context,
|
||||
@@ -451,11 +516,24 @@ func (r *SharedRuntime) initializeDHTStorage() error {
|
||||
r.DHTNode = dhtNode
|
||||
r.EncryptedStorage = encryptedStorage
|
||||
r.DecisionPublisher = decisionPublisher
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SharedRuntime) initializeServices() error {
|
||||
// Create simple task tracker ahead of coordinator so broadcasts stay accurate
|
||||
taskTracker := &SimpleTaskTracker{
|
||||
maxTasks: r.Config.Agent.MaxTasks,
|
||||
activeTasks: make(map[string]bool),
|
||||
}
|
||||
|
||||
// Connect decision publisher to task tracker if available
|
||||
if r.DecisionPublisher != nil {
|
||||
taskTracker.decisionPublisher = r.DecisionPublisher
|
||||
r.Logger.Info("📤 Task completion decisions will be published to DHT")
|
||||
}
|
||||
r.TaskTracker = taskTracker
|
||||
|
||||
// === Task Coordination Integration ===
|
||||
taskCoordinator := coordinator.NewTaskCoordinator(
|
||||
r.Context,
|
||||
@@ -464,8 +542,9 @@ func (r *SharedRuntime) initializeServices() error {
|
||||
r.Config,
|
||||
r.Node.ID().ShortString(),
|
||||
nil, // HMMM router placeholder
|
||||
taskTracker,
|
||||
)
|
||||
|
||||
|
||||
taskCoordinator.Start()
|
||||
r.TaskCoordinator = taskCoordinator
|
||||
r.Logger.Info("✅ Task coordination system active")
|
||||
@@ -487,14 +566,14 @@ func (r *SharedRuntime) initializeServices() error {
|
||||
if storageDir == "" {
|
||||
storageDir = filepath.Join(os.TempDir(), "chorus-ucxi-storage")
|
||||
}
|
||||
|
||||
|
||||
storage, err := ucxi.NewBasicContentStorage(storageDir)
|
||||
if err != nil {
|
||||
r.Logger.Warn("⚠️ Failed to create UCXI storage: %v", err)
|
||||
} else {
|
||||
resolver := ucxi.NewBasicAddressResolver(r.Node.ID().ShortString())
|
||||
resolver.SetDefaultTTL(r.Config.UCXL.Resolution.CacheTTL)
|
||||
|
||||
|
||||
ucxiConfig := ucxi.ServerConfig{
|
||||
Port: r.Config.UCXL.Server.Port,
|
||||
BasePath: r.Config.UCXL.Server.BasePath,
|
||||
@@ -502,7 +581,7 @@ func (r *SharedRuntime) initializeServices() error {
|
||||
Storage: storage,
|
||||
Logger: ucxi.SimpleLogger{},
|
||||
}
|
||||
|
||||
|
||||
ucxiServer = ucxi.NewServer(ucxiConfig)
|
||||
go func() {
|
||||
r.Logger.Info("🔗 UCXI server starting on :%d", r.Config.UCXL.Server.Port)
|
||||
@@ -515,35 +594,41 @@ func (r *SharedRuntime) initializeServices() error {
|
||||
r.Logger.Info("⚪ UCXI server disabled")
|
||||
}
|
||||
r.UCXIServer = ucxiServer
|
||||
|
||||
// Create simple task tracker
|
||||
taskTracker := &SimpleTaskTracker{
|
||||
maxTasks: r.Config.Agent.MaxTasks,
|
||||
activeTasks: make(map[string]bool),
|
||||
}
|
||||
|
||||
// Connect decision publisher to task tracker if available
|
||||
if r.DecisionPublisher != nil {
|
||||
taskTracker.decisionPublisher = r.DecisionPublisher
|
||||
r.Logger.Info("📤 Task completion decisions will be published to DHT")
|
||||
}
|
||||
r.TaskTracker = taskTracker
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SharedRuntime) handleShhhFindings(ctx context.Context, findings []shhh.Finding) {
|
||||
if r == nil || r.Metrics == nil {
|
||||
return
|
||||
}
|
||||
for _, finding := range findings {
|
||||
r.Metrics.IncrementSHHHFindings(finding.Rule, string(finding.Severity), finding.Count)
|
||||
}
|
||||
}
|
||||
|
||||
type shhhAuditSink struct {
|
||||
logger *SimpleLogger
|
||||
}
|
||||
|
||||
func (s *shhhAuditSink) RecordRedaction(_ context.Context, event shhh.AuditEvent) {
|
||||
if s == nil || s.logger == nil {
|
||||
return
|
||||
}
|
||||
s.logger.Warn("🔒 SHHH redaction applied (rule=%s severity=%s path=%s)", event.Rule, event.Severity, event.Path)
|
||||
}
|
||||
|
||||
// initializeAIProvider configures the reasoning engine with the appropriate AI provider
|
||||
func initializeAIProvider(cfg *config.Config, logger *SimpleLogger) error {
|
||||
// Set the AI provider
|
||||
reasoning.SetAIProvider(cfg.AI.Provider)
|
||||
|
||||
|
||||
// Configure the selected provider
|
||||
switch cfg.AI.Provider {
|
||||
case "resetdata":
|
||||
if cfg.AI.ResetData.APIKey == "" {
|
||||
return fmt.Errorf("RESETDATA_API_KEY environment variable is required for resetdata provider")
|
||||
}
|
||||
|
||||
|
||||
resetdataConfig := reasoning.ResetDataConfig{
|
||||
BaseURL: cfg.AI.ResetData.BaseURL,
|
||||
APIKey: cfg.AI.ResetData.APIKey,
|
||||
@@ -551,19 +636,19 @@ func initializeAIProvider(cfg *config.Config, logger *SimpleLogger) error {
|
||||
Timeout: cfg.AI.ResetData.Timeout,
|
||||
}
|
||||
reasoning.SetResetDataConfig(resetdataConfig)
|
||||
logger.Info("🌐 ResetData AI provider configured - Endpoint: %s, Model: %s",
|
||||
logger.Info("🌐 ResetData AI provider configured - Endpoint: %s, Model: %s",
|
||||
cfg.AI.ResetData.BaseURL, cfg.AI.ResetData.Model)
|
||||
|
||||
|
||||
case "ollama":
|
||||
reasoning.SetOllamaEndpoint(cfg.AI.Ollama.Endpoint)
|
||||
logger.Info("🦙 Ollama AI provider configured - Endpoint: %s", cfg.AI.Ollama.Endpoint)
|
||||
|
||||
|
||||
default:
|
||||
logger.Warn("⚠️ Unknown AI provider '%s', defaulting to resetdata", cfg.AI.Provider)
|
||||
if cfg.AI.ResetData.APIKey == "" {
|
||||
return fmt.Errorf("RESETDATA_API_KEY environment variable is required for default resetdata provider")
|
||||
}
|
||||
|
||||
|
||||
resetdataConfig := reasoning.ResetDataConfig{
|
||||
BaseURL: cfg.AI.ResetData.BaseURL,
|
||||
APIKey: cfg.AI.ResetData.APIKey,
|
||||
@@ -573,7 +658,7 @@ func initializeAIProvider(cfg *config.Config, logger *SimpleLogger) error {
|
||||
reasoning.SetResetDataConfig(resetdataConfig)
|
||||
reasoning.SetAIProvider("resetdata")
|
||||
}
|
||||
|
||||
|
||||
// Configure model selection
|
||||
reasoning.SetModelConfig(
|
||||
cfg.Agent.Models,
|
||||
|
||||
@@ -9,25 +9,31 @@ type Config struct {
|
||||
// Network configuration
|
||||
ListenAddresses []string
|
||||
NetworkID string
|
||||
|
||||
|
||||
// Discovery configuration
|
||||
EnableMDNS bool
|
||||
MDNSServiceTag string
|
||||
|
||||
|
||||
// DHT configuration
|
||||
EnableDHT bool
|
||||
DHTBootstrapPeers []string
|
||||
DHTMode string // "client", "server", "auto"
|
||||
DHTProtocolPrefix string
|
||||
|
||||
// Connection limits
|
||||
MaxConnections int
|
||||
MaxPeersPerIP int
|
||||
ConnectionTimeout time.Duration
|
||||
|
||||
|
||||
// Connection limits and rate limiting
|
||||
MaxConnections int
|
||||
MaxPeersPerIP int
|
||||
ConnectionTimeout time.Duration
|
||||
LowWatermark int // Connection manager low watermark
|
||||
HighWatermark int // Connection manager high watermark
|
||||
DialsPerSecond int // Dial rate limiting
|
||||
MaxConcurrentDials int // Maximum concurrent outbound dials
|
||||
MaxConcurrentDHT int // Maximum concurrent DHT queries
|
||||
JoinStaggerMS int // Join stagger delay in milliseconds
|
||||
|
||||
// Security configuration
|
||||
EnableSecurity bool
|
||||
|
||||
|
||||
// Pubsub configuration
|
||||
EnablePubsub bool
|
||||
BzzzTopic string // Task coordination topic
|
||||
@@ -47,25 +53,31 @@ func DefaultConfig() *Config {
|
||||
"/ip6/::/tcp/3333",
|
||||
},
|
||||
NetworkID: "CHORUS-network",
|
||||
|
||||
// Discovery settings
|
||||
EnableMDNS: true,
|
||||
|
||||
// Discovery settings - mDNS disabled for Swarm by default
|
||||
EnableMDNS: false, // Disabled for container environments
|
||||
MDNSServiceTag: "CHORUS-peer-discovery",
|
||||
|
||||
|
||||
// DHT settings (disabled by default for local development)
|
||||
EnableDHT: false,
|
||||
DHTBootstrapPeers: []string{},
|
||||
DHTMode: "auto",
|
||||
DHTProtocolPrefix: "/CHORUS",
|
||||
|
||||
// Connection limits for local network
|
||||
MaxConnections: 50,
|
||||
MaxPeersPerIP: 3,
|
||||
ConnectionTimeout: 30 * time.Second,
|
||||
|
||||
|
||||
// Connection limits and rate limiting for scaling
|
||||
MaxConnections: 50,
|
||||
MaxPeersPerIP: 3,
|
||||
ConnectionTimeout: 30 * time.Second,
|
||||
LowWatermark: 32, // Keep at least 32 connections
|
||||
HighWatermark: 128, // Trim above 128 connections
|
||||
DialsPerSecond: 5, // Limit outbound dials to prevent storms
|
||||
MaxConcurrentDials: 10, // Maximum concurrent outbound dials
|
||||
MaxConcurrentDHT: 16, // Maximum concurrent DHT queries
|
||||
JoinStaggerMS: 0, // No stagger by default (set by assignment)
|
||||
|
||||
// Security enabled by default
|
||||
EnableSecurity: true,
|
||||
|
||||
|
||||
// Pubsub for coordination and meta-discussion
|
||||
EnablePubsub: true,
|
||||
BzzzTopic: "CHORUS/coordination/v1",
|
||||
@@ -164,4 +176,34 @@ func WithDHTProtocolPrefix(prefix string) Option {
|
||||
return func(c *Config) {
|
||||
c.DHTProtocolPrefix = prefix
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectionManager sets connection manager watermarks
|
||||
func WithConnectionManager(low, high int) Option {
|
||||
return func(c *Config) {
|
||||
c.LowWatermark = low
|
||||
c.HighWatermark = high
|
||||
}
|
||||
}
|
||||
|
||||
// WithDialRateLimit sets the dial rate limiting
|
||||
func WithDialRateLimit(dialsPerSecond, maxConcurrent int) Option {
|
||||
return func(c *Config) {
|
||||
c.DialsPerSecond = dialsPerSecond
|
||||
c.MaxConcurrentDials = maxConcurrent
|
||||
}
|
||||
}
|
||||
|
||||
// WithDHTRateLimit sets the DHT query rate limiting
|
||||
func WithDHTRateLimit(maxConcurrentDHT int) Option {
|
||||
return func(c *Config) {
|
||||
c.MaxConcurrentDHT = maxConcurrentDHT
|
||||
}
|
||||
}
|
||||
|
||||
// WithJoinStagger sets the join stagger delay in milliseconds
|
||||
func WithJoinStagger(delayMS int) Option {
|
||||
return func(c *Config) {
|
||||
c.JoinStaggerMS = delayMS
|
||||
}
|
||||
}
|
||||
11
p2p/node.go
11
p2p/node.go
@@ -6,16 +6,17 @@ import (
|
||||
"time"
|
||||
|
||||
"chorus/pkg/dht"
|
||||
|
||||
"github.com/libp2p/go-libp2p"
|
||||
kaddht "github.com/libp2p/go-libp2p-kad-dht"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
||||
kaddht "github.com/libp2p/go-libp2p-kad-dht"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
// Node represents a Bzzz P2P node
|
||||
// Node represents a CHORUS P2P node
|
||||
type Node struct {
|
||||
host host.Host
|
||||
ctx context.Context
|
||||
@@ -157,9 +158,9 @@ func (n *Node) startBackgroundTasks() {
|
||||
// logConnectionStatus logs the current connection status
|
||||
func (n *Node) logConnectionStatus() {
|
||||
peers := n.Peers()
|
||||
fmt.Printf("🐝 Bzzz Node Status - ID: %s, Connected Peers: %d\n",
|
||||
fmt.Printf("🐝 Bzzz Node Status - ID: %s, Connected Peers: %d\n",
|
||||
n.ID().ShortString(), len(peers))
|
||||
|
||||
|
||||
if len(peers) > 0 {
|
||||
fmt.Printf(" Connected to: ")
|
||||
for i, p := range peers {
|
||||
@@ -197,4 +198,4 @@ func (n *Node) Close() error {
|
||||
}
|
||||
n.cancel()
|
||||
return n.host.Close()
|
||||
}
|
||||
}
|
||||
|
||||
329
pkg/ai/config.go
Normal file
329
pkg/ai/config.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ModelConfig represents the complete model configuration loaded from YAML
|
||||
type ModelConfig struct {
|
||||
Providers map[string]ProviderConfig `yaml:"providers" json:"providers"`
|
||||
DefaultProvider string `yaml:"default_provider" json:"default_provider"`
|
||||
FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"`
|
||||
Roles map[string]RoleConfig `yaml:"roles" json:"roles"`
|
||||
Environments map[string]EnvConfig `yaml:"environments" json:"environments"`
|
||||
ModelPreferences map[string]TaskPreference `yaml:"model_preferences" json:"model_preferences"`
|
||||
}
|
||||
|
||||
// EnvConfig represents environment-specific configuration overrides
|
||||
type EnvConfig struct {
|
||||
DefaultProvider string `yaml:"default_provider" json:"default_provider"`
|
||||
FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"`
|
||||
}
|
||||
|
||||
// TaskPreference represents preferred models for specific task types
|
||||
type TaskPreference struct {
|
||||
PreferredModels []string `yaml:"preferred_models" json:"preferred_models"`
|
||||
MinContextTokens int `yaml:"min_context_tokens" json:"min_context_tokens"`
|
||||
}
|
||||
|
||||
// ConfigLoader loads and manages AI provider configurations
|
||||
type ConfigLoader struct {
|
||||
configPath string
|
||||
environment string
|
||||
}
|
||||
|
||||
// NewConfigLoader creates a new configuration loader
|
||||
func NewConfigLoader(configPath, environment string) *ConfigLoader {
|
||||
return &ConfigLoader{
|
||||
configPath: configPath,
|
||||
environment: environment,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadConfig loads the complete configuration from the YAML file
|
||||
func (c *ConfigLoader) LoadConfig() (*ModelConfig, error) {
|
||||
data, err := os.ReadFile(c.configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %w", c.configPath, err)
|
||||
}
|
||||
|
||||
// Expand environment variables in the config
|
||||
configData := c.expandEnvVars(string(data))
|
||||
|
||||
var config ModelConfig
|
||||
if err := yaml.Unmarshal([]byte(configData), &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file %s: %w", c.configPath, err)
|
||||
}
|
||||
|
||||
// Apply environment-specific overrides
|
||||
if c.environment != "" {
|
||||
c.applyEnvironmentOverrides(&config)
|
||||
}
|
||||
|
||||
// Validate the configuration
|
||||
if err := c.validateConfig(&config); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// LoadProviderFactory creates a provider factory from the configuration
|
||||
func (c *ConfigLoader) LoadProviderFactory() (*ProviderFactory, error) {
|
||||
config, err := c.LoadConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Register all providers
|
||||
for name, providerConfig := range config.Providers {
|
||||
if err := factory.RegisterProvider(name, providerConfig); err != nil {
|
||||
// Log warning but continue with other providers
|
||||
fmt.Printf("Warning: Failed to register provider %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Set up role mapping
|
||||
roleMapping := RoleModelMapping{
|
||||
DefaultProvider: config.DefaultProvider,
|
||||
FallbackProvider: config.FallbackProvider,
|
||||
Roles: config.Roles,
|
||||
}
|
||||
factory.SetRoleMapping(roleMapping)
|
||||
|
||||
return factory, nil
|
||||
}
|
||||
|
||||
// expandEnvVars expands environment variables in the configuration
|
||||
func (c *ConfigLoader) expandEnvVars(config string) string {
|
||||
// Replace ${VAR} and $VAR patterns with environment variable values
|
||||
expanded := config
|
||||
|
||||
// Handle ${VAR} pattern
|
||||
for {
|
||||
start := strings.Index(expanded, "${")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(expanded[start:], "}")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
end += start
|
||||
|
||||
varName := expanded[start+2 : end]
|
||||
varValue := os.Getenv(varName)
|
||||
expanded = expanded[:start] + varValue + expanded[end+1:]
|
||||
}
|
||||
|
||||
return expanded
|
||||
}
|
||||
|
||||
// applyEnvironmentOverrides applies environment-specific configuration overrides
|
||||
func (c *ConfigLoader) applyEnvironmentOverrides(config *ModelConfig) {
|
||||
envConfig, exists := config.Environments[c.environment]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Override default and fallback providers if specified
|
||||
if envConfig.DefaultProvider != "" {
|
||||
config.DefaultProvider = envConfig.DefaultProvider
|
||||
}
|
||||
if envConfig.FallbackProvider != "" {
|
||||
config.FallbackProvider = envConfig.FallbackProvider
|
||||
}
|
||||
}
|
||||
|
||||
// validateConfig validates the loaded configuration
|
||||
func (c *ConfigLoader) validateConfig(config *ModelConfig) error {
|
||||
// Check that default provider exists
|
||||
if config.DefaultProvider != "" {
|
||||
if _, exists := config.Providers[config.DefaultProvider]; !exists {
|
||||
return fmt.Errorf("default_provider '%s' not found in providers", config.DefaultProvider)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that fallback provider exists
|
||||
if config.FallbackProvider != "" {
|
||||
if _, exists := config.Providers[config.FallbackProvider]; !exists {
|
||||
return fmt.Errorf("fallback_provider '%s' not found in providers", config.FallbackProvider)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate each provider configuration
|
||||
for name, providerConfig := range config.Providers {
|
||||
if err := c.validateProviderConfig(name, providerConfig); err != nil {
|
||||
return fmt.Errorf("invalid provider config '%s': %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate role configurations
|
||||
for roleName, roleConfig := range config.Roles {
|
||||
if err := c.validateRoleConfig(roleName, roleConfig, config.Providers); err != nil {
|
||||
return fmt.Errorf("invalid role config '%s': %w", roleName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateProviderConfig validates a single provider configuration
|
||||
func (c *ConfigLoader) validateProviderConfig(name string, config ProviderConfig) error {
|
||||
// Check required fields
|
||||
if config.Type == "" {
|
||||
return fmt.Errorf("type is required")
|
||||
}
|
||||
|
||||
// Validate provider type
|
||||
validTypes := []string{"ollama", "openai", "resetdata"}
|
||||
typeValid := false
|
||||
for _, validType := range validTypes {
|
||||
if config.Type == validType {
|
||||
typeValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !typeValid {
|
||||
return fmt.Errorf("invalid provider type '%s', must be one of: %s",
|
||||
config.Type, strings.Join(validTypes, ", "))
|
||||
}
|
||||
|
||||
// Check endpoint for all types
|
||||
if config.Endpoint == "" {
|
||||
return fmt.Errorf("endpoint is required")
|
||||
}
|
||||
|
||||
// Check API key for providers that require it
|
||||
if (config.Type == "openai" || config.Type == "resetdata") && config.APIKey == "" {
|
||||
return fmt.Errorf("api_key is required for %s provider", config.Type)
|
||||
}
|
||||
|
||||
// Check default model
|
||||
if config.DefaultModel == "" {
|
||||
return fmt.Errorf("default_model is required")
|
||||
}
|
||||
|
||||
// Validate timeout
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 300 * time.Second // Set default
|
||||
}
|
||||
|
||||
// Validate temperature range
|
||||
if config.Temperature < 0 || config.Temperature > 2.0 {
|
||||
return fmt.Errorf("temperature must be between 0 and 2.0")
|
||||
}
|
||||
|
||||
// Validate max tokens
|
||||
if config.MaxTokens <= 0 {
|
||||
config.MaxTokens = 4096 // Set default
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRoleConfig validates a role configuration
|
||||
func (c *ConfigLoader) validateRoleConfig(roleName string, config RoleConfig, providers map[string]ProviderConfig) error {
|
||||
// Check that provider exists
|
||||
if config.Provider != "" {
|
||||
if _, exists := providers[config.Provider]; !exists {
|
||||
return fmt.Errorf("provider '%s' not found", config.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Check fallback provider exists if specified
|
||||
if config.FallbackProvider != "" {
|
||||
if _, exists := providers[config.FallbackProvider]; !exists {
|
||||
return fmt.Errorf("fallback_provider '%s' not found", config.FallbackProvider)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate temperature range
|
||||
if config.Temperature < 0 || config.Temperature > 2.0 {
|
||||
return fmt.Errorf("temperature must be between 0 and 2.0")
|
||||
}
|
||||
|
||||
// Validate max tokens
|
||||
if config.MaxTokens < 0 {
|
||||
return fmt.Errorf("max_tokens cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProviderForTaskType returns the best provider for a specific task type
|
||||
func (c *ConfigLoader) GetProviderForTaskType(config *ModelConfig, factory *ProviderFactory, taskType string) (ModelProvider, ProviderConfig, error) {
|
||||
// Check if we have preferences for this task type
|
||||
if preference, exists := config.ModelPreferences[taskType]; exists {
|
||||
// Try each preferred model in order
|
||||
for _, modelName := range preference.PreferredModels {
|
||||
for providerName, provider := range factory.providers {
|
||||
capabilities := provider.GetCapabilities()
|
||||
for _, supportedModel := range capabilities.SupportedModels {
|
||||
if supportedModel == modelName && factory.isProviderHealthy(providerName) {
|
||||
providerConfig := factory.configs[providerName]
|
||||
providerConfig.DefaultModel = modelName
|
||||
|
||||
// Ensure minimum context if specified
|
||||
if preference.MinContextTokens > providerConfig.MaxTokens {
|
||||
providerConfig.MaxTokens = preference.MinContextTokens
|
||||
}
|
||||
|
||||
return provider, providerConfig, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default provider selection
|
||||
if config.DefaultProvider != "" {
|
||||
provider, err := factory.GetProvider(config.DefaultProvider)
|
||||
if err != nil {
|
||||
return nil, ProviderConfig{}, err
|
||||
}
|
||||
return provider, factory.configs[config.DefaultProvider], nil
|
||||
}
|
||||
|
||||
return nil, ProviderConfig{}, NewProviderError(ErrProviderNotFound, "no suitable provider found for task type "+taskType)
|
||||
}
|
||||
|
||||
// DefaultConfigPath returns the default path for the model configuration file
|
||||
func DefaultConfigPath() string {
|
||||
// Try environment variable first
|
||||
if path := os.Getenv("CHORUS_MODEL_CONFIG"); path != "" {
|
||||
return path
|
||||
}
|
||||
|
||||
// Try relative to current working directory
|
||||
if _, err := os.Stat("configs/models.yaml"); err == nil {
|
||||
return "configs/models.yaml"
|
||||
}
|
||||
|
||||
// Try relative to executable
|
||||
if _, err := os.Stat("./configs/models.yaml"); err == nil {
|
||||
return "./configs/models.yaml"
|
||||
}
|
||||
|
||||
// Default fallback
|
||||
return "configs/models.yaml"
|
||||
}
|
||||
|
||||
// GetEnvironment returns the current environment (from env var or default)
|
||||
func GetEnvironment() string {
|
||||
if env := os.Getenv("CHORUS_ENVIRONMENT"); env != "" {
|
||||
return env
|
||||
}
|
||||
if env := os.Getenv("NODE_ENV"); env != "" {
|
||||
return env
|
||||
}
|
||||
return "development" // default
|
||||
}
|
||||
596
pkg/ai/config_test.go
Normal file
596
pkg/ai/config_test.go
Normal file
@@ -0,0 +1,596 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewConfigLoader(t *testing.T) {
|
||||
loader := NewConfigLoader("test.yaml", "development")
|
||||
|
||||
assert.Equal(t, "test.yaml", loader.configPath)
|
||||
assert.Equal(t, "development", loader.environment)
|
||||
}
|
||||
|
||||
func TestConfigLoaderExpandEnvVars(t *testing.T) {
|
||||
loader := NewConfigLoader("", "")
|
||||
|
||||
// Set test environment variables
|
||||
os.Setenv("TEST_VAR", "test_value")
|
||||
os.Setenv("ANOTHER_VAR", "another_value")
|
||||
defer func() {
|
||||
os.Unsetenv("TEST_VAR")
|
||||
os.Unsetenv("ANOTHER_VAR")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single variable",
|
||||
input: "endpoint: ${TEST_VAR}",
|
||||
expected: "endpoint: test_value",
|
||||
},
|
||||
{
|
||||
name: "multiple variables",
|
||||
input: "endpoint: ${TEST_VAR}/api\nkey: ${ANOTHER_VAR}",
|
||||
expected: "endpoint: test_value/api\nkey: another_value",
|
||||
},
|
||||
{
|
||||
name: "no variables",
|
||||
input: "endpoint: http://localhost",
|
||||
expected: "endpoint: http://localhost",
|
||||
},
|
||||
{
|
||||
name: "undefined variable",
|
||||
input: "endpoint: ${UNDEFINED_VAR}",
|
||||
expected: "endpoint: ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := loader.expandEnvVars(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigLoaderApplyEnvironmentOverrides(t *testing.T) {
|
||||
loader := NewConfigLoader("", "production")
|
||||
|
||||
config := &ModelConfig{
|
||||
DefaultProvider: "ollama",
|
||||
FallbackProvider: "resetdata",
|
||||
Environments: map[string]EnvConfig{
|
||||
"production": {
|
||||
DefaultProvider: "openai",
|
||||
FallbackProvider: "ollama",
|
||||
},
|
||||
"development": {
|
||||
DefaultProvider: "ollama",
|
||||
FallbackProvider: "mock",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loader.applyEnvironmentOverrides(config)
|
||||
|
||||
assert.Equal(t, "openai", config.DefaultProvider)
|
||||
assert.Equal(t, "ollama", config.FallbackProvider)
|
||||
}
|
||||
|
||||
func TestConfigLoaderApplyEnvironmentOverridesNoMatch(t *testing.T) {
|
||||
loader := NewConfigLoader("", "testing")
|
||||
|
||||
config := &ModelConfig{
|
||||
DefaultProvider: "ollama",
|
||||
FallbackProvider: "resetdata",
|
||||
Environments: map[string]EnvConfig{
|
||||
"production": {
|
||||
DefaultProvider: "openai",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
original := *config
|
||||
loader.applyEnvironmentOverrides(config)
|
||||
|
||||
// Should remain unchanged
|
||||
assert.Equal(t, original.DefaultProvider, config.DefaultProvider)
|
||||
assert.Equal(t, original.FallbackProvider, config.FallbackProvider)
|
||||
}
|
||||
|
||||
func TestConfigLoaderValidateConfig(t *testing.T) {
|
||||
loader := NewConfigLoader("", "")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *ModelConfig
|
||||
expectErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &ModelConfig{
|
||||
DefaultProvider: "test",
|
||||
FallbackProvider: "backup",
|
||||
Providers: map[string]ProviderConfig{
|
||||
"test": {
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
DefaultModel: "llama2",
|
||||
},
|
||||
"backup": {
|
||||
Type: "resetdata",
|
||||
Endpoint: "https://api.resetdata.ai",
|
||||
APIKey: "key",
|
||||
DefaultModel: "llama2",
|
||||
},
|
||||
},
|
||||
Roles: map[string]RoleConfig{
|
||||
"developer": {
|
||||
Provider: "test",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "default provider not found",
|
||||
config: &ModelConfig{
|
||||
DefaultProvider: "nonexistent",
|
||||
Providers: map[string]ProviderConfig{
|
||||
"test": {
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
DefaultModel: "llama2",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "default_provider 'nonexistent' not found",
|
||||
},
|
||||
{
|
||||
name: "fallback provider not found",
|
||||
config: &ModelConfig{
|
||||
FallbackProvider: "nonexistent",
|
||||
Providers: map[string]ProviderConfig{
|
||||
"test": {
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
DefaultModel: "llama2",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "fallback_provider 'nonexistent' not found",
|
||||
},
|
||||
{
|
||||
name: "invalid provider config",
|
||||
config: &ModelConfig{
|
||||
Providers: map[string]ProviderConfig{
|
||||
"invalid": {
|
||||
Type: "invalid_type",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "invalid provider config 'invalid'",
|
||||
},
|
||||
{
|
||||
name: "invalid role config",
|
||||
config: &ModelConfig{
|
||||
Providers: map[string]ProviderConfig{
|
||||
"test": {
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
DefaultModel: "llama2",
|
||||
},
|
||||
},
|
||||
Roles: map[string]RoleConfig{
|
||||
"developer": {
|
||||
Provider: "nonexistent",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "invalid role config 'developer'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := loader.validateConfig(tt.config)
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigLoaderValidateProviderConfig(t *testing.T) {
|
||||
loader := NewConfigLoader("", "")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config ProviderConfig
|
||||
expectErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid ollama config",
|
||||
config: ProviderConfig{
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
DefaultModel: "llama2",
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 4096,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid openai config",
|
||||
config: ProviderConfig{
|
||||
Type: "openai",
|
||||
Endpoint: "https://api.openai.com/v1",
|
||||
APIKey: "test-key",
|
||||
DefaultModel: "gpt-4",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing type",
|
||||
config: ProviderConfig{
|
||||
Endpoint: "http://localhost",
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "type is required",
|
||||
},
|
||||
{
|
||||
name: "invalid type",
|
||||
config: ProviderConfig{
|
||||
Type: "invalid",
|
||||
Endpoint: "http://localhost",
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "invalid provider type 'invalid'",
|
||||
},
|
||||
{
|
||||
name: "missing endpoint",
|
||||
config: ProviderConfig{
|
||||
Type: "ollama",
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "endpoint is required",
|
||||
},
|
||||
{
|
||||
name: "openai missing api key",
|
||||
config: ProviderConfig{
|
||||
Type: "openai",
|
||||
Endpoint: "https://api.openai.com/v1",
|
||||
DefaultModel: "gpt-4",
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "api_key is required for openai provider",
|
||||
},
|
||||
{
|
||||
name: "missing default model",
|
||||
config: ProviderConfig{
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "default_model is required",
|
||||
},
|
||||
{
|
||||
name: "invalid temperature",
|
||||
config: ProviderConfig{
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
DefaultModel: "llama2",
|
||||
Temperature: 3.0, // Too high
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "temperature must be between 0 and 2.0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := loader.validateProviderConfig("test", tt.config)
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigLoaderValidateRoleConfig(t *testing.T) {
|
||||
loader := NewConfigLoader("", "")
|
||||
|
||||
providers := map[string]ProviderConfig{
|
||||
"test": {
|
||||
Type: "ollama",
|
||||
},
|
||||
"backup": {
|
||||
Type: "resetdata",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config RoleConfig
|
||||
expectErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid role config",
|
||||
config: RoleConfig{
|
||||
Provider: "test",
|
||||
Model: "llama2",
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 4096,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "provider not found",
|
||||
config: RoleConfig{
|
||||
Provider: "nonexistent",
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "provider 'nonexistent' not found",
|
||||
},
|
||||
{
|
||||
name: "fallback provider not found",
|
||||
config: RoleConfig{
|
||||
Provider: "test",
|
||||
FallbackProvider: "nonexistent",
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "fallback_provider 'nonexistent' not found",
|
||||
},
|
||||
{
|
||||
name: "invalid temperature",
|
||||
config: RoleConfig{
|
||||
Provider: "test",
|
||||
Temperature: -1.0,
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "temperature must be between 0 and 2.0",
|
||||
},
|
||||
{
|
||||
name: "invalid max tokens",
|
||||
config: RoleConfig{
|
||||
Provider: "test",
|
||||
MaxTokens: -100,
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "max_tokens cannot be negative",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := loader.validateRoleConfig("test-role", tt.config, providers)
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigLoaderLoadConfig(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
configContent := `
|
||||
providers:
|
||||
test:
|
||||
type: ollama
|
||||
endpoint: http://localhost:11434
|
||||
default_model: llama2
|
||||
temperature: 0.7
|
||||
|
||||
default_provider: test
|
||||
fallback_provider: test
|
||||
|
||||
roles:
|
||||
developer:
|
||||
provider: test
|
||||
model: codellama
|
||||
`
|
||||
|
||||
tmpFile, err := ioutil.TempFile("", "test-config-*.yaml")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
_, err = tmpFile.WriteString(configContent)
|
||||
require.NoError(t, err)
|
||||
tmpFile.Close()
|
||||
|
||||
loader := NewConfigLoader(tmpFile.Name(), "")
|
||||
config, err := loader.LoadConfig()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test", config.DefaultProvider)
|
||||
assert.Equal(t, "test", config.FallbackProvider)
|
||||
assert.Len(t, config.Providers, 1)
|
||||
assert.Contains(t, config.Providers, "test")
|
||||
assert.Equal(t, "ollama", config.Providers["test"].Type)
|
||||
assert.Len(t, config.Roles, 1)
|
||||
assert.Contains(t, config.Roles, "developer")
|
||||
assert.Equal(t, "codellama", config.Roles["developer"].Model)
|
||||
}
|
||||
|
||||
func TestConfigLoaderLoadConfigWithEnvVars(t *testing.T) {
|
||||
// Set environment variables
|
||||
os.Setenv("TEST_ENDPOINT", "http://test.example.com")
|
||||
os.Setenv("TEST_MODEL", "test-model")
|
||||
defer func() {
|
||||
os.Unsetenv("TEST_ENDPOINT")
|
||||
os.Unsetenv("TEST_MODEL")
|
||||
}()
|
||||
|
||||
configContent := `
|
||||
providers:
|
||||
test:
|
||||
type: ollama
|
||||
endpoint: ${TEST_ENDPOINT}
|
||||
default_model: ${TEST_MODEL}
|
||||
|
||||
default_provider: test
|
||||
`
|
||||
|
||||
tmpFile, err := ioutil.TempFile("", "test-config-*.yaml")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
_, err = tmpFile.WriteString(configContent)
|
||||
require.NoError(t, err)
|
||||
tmpFile.Close()
|
||||
|
||||
loader := NewConfigLoader(tmpFile.Name(), "")
|
||||
config, err := loader.LoadConfig()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://test.example.com", config.Providers["test"].Endpoint)
|
||||
assert.Equal(t, "test-model", config.Providers["test"].DefaultModel)
|
||||
}
|
||||
|
||||
func TestConfigLoaderLoadConfigFileNotFound(t *testing.T) {
|
||||
loader := NewConfigLoader("nonexistent.yaml", "")
|
||||
_, err := loader.LoadConfig()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to read config file")
|
||||
}
|
||||
|
||||
func TestConfigLoaderLoadConfigInvalidYAML(t *testing.T) {
|
||||
// Create a file with invalid YAML
|
||||
tmpFile, err := ioutil.TempFile("", "invalid-config-*.yaml")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
_, err = tmpFile.WriteString("invalid: yaml: content: [")
|
||||
require.NoError(t, err)
|
||||
tmpFile.Close()
|
||||
|
||||
loader := NewConfigLoader(tmpFile.Name(), "")
|
||||
_, err = loader.LoadConfig()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse config file")
|
||||
}
|
||||
|
||||
func TestDefaultConfigPath(t *testing.T) {
|
||||
// Test with environment variable
|
||||
os.Setenv("CHORUS_MODEL_CONFIG", "/custom/path/models.yaml")
|
||||
defer os.Unsetenv("CHORUS_MODEL_CONFIG")
|
||||
|
||||
path := DefaultConfigPath()
|
||||
assert.Equal(t, "/custom/path/models.yaml", path)
|
||||
|
||||
// Test without environment variable
|
||||
os.Unsetenv("CHORUS_MODEL_CONFIG")
|
||||
path = DefaultConfigPath()
|
||||
assert.Equal(t, "configs/models.yaml", path)
|
||||
}
|
||||
|
||||
func TestGetEnvironment(t *testing.T) {
|
||||
// Test with CHORUS_ENVIRONMENT
|
||||
os.Setenv("CHORUS_ENVIRONMENT", "production")
|
||||
defer os.Unsetenv("CHORUS_ENVIRONMENT")
|
||||
|
||||
env := GetEnvironment()
|
||||
assert.Equal(t, "production", env)
|
||||
|
||||
// Test with NODE_ENV fallback
|
||||
os.Unsetenv("CHORUS_ENVIRONMENT")
|
||||
os.Setenv("NODE_ENV", "staging")
|
||||
defer os.Unsetenv("NODE_ENV")
|
||||
|
||||
env = GetEnvironment()
|
||||
assert.Equal(t, "staging", env)
|
||||
|
||||
// Test default
|
||||
os.Unsetenv("CHORUS_ENVIRONMENT")
|
||||
os.Unsetenv("NODE_ENV")
|
||||
|
||||
env = GetEnvironment()
|
||||
assert.Equal(t, "development", env)
|
||||
}
|
||||
|
||||
func TestModelConfig(t *testing.T) {
|
||||
config := ModelConfig{
|
||||
Providers: map[string]ProviderConfig{
|
||||
"test": {
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
DefaultModel: "llama2",
|
||||
},
|
||||
},
|
||||
DefaultProvider: "test",
|
||||
FallbackProvider: "test",
|
||||
Roles: map[string]RoleConfig{
|
||||
"developer": {
|
||||
Provider: "test",
|
||||
Model: "codellama",
|
||||
},
|
||||
},
|
||||
Environments: map[string]EnvConfig{
|
||||
"production": {
|
||||
DefaultProvider: "openai",
|
||||
},
|
||||
},
|
||||
ModelPreferences: map[string]TaskPreference{
|
||||
"code_generation": {
|
||||
PreferredModels: []string{"codellama", "gpt-4"},
|
||||
MinContextTokens: 8192,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Len(t, config.Providers, 1)
|
||||
assert.Len(t, config.Roles, 1)
|
||||
assert.Len(t, config.Environments, 1)
|
||||
assert.Len(t, config.ModelPreferences, 1)
|
||||
}
|
||||
|
||||
func TestEnvConfig(t *testing.T) {
|
||||
envConfig := EnvConfig{
|
||||
DefaultProvider: "openai",
|
||||
FallbackProvider: "ollama",
|
||||
}
|
||||
|
||||
assert.Equal(t, "openai", envConfig.DefaultProvider)
|
||||
assert.Equal(t, "ollama", envConfig.FallbackProvider)
|
||||
}
|
||||
|
||||
func TestTaskPreference(t *testing.T) {
|
||||
pref := TaskPreference{
|
||||
PreferredModels: []string{"gpt-4", "codellama:13b"},
|
||||
MinContextTokens: 8192,
|
||||
}
|
||||
|
||||
assert.Len(t, pref.PreferredModels, 2)
|
||||
assert.Equal(t, 8192, pref.MinContextTokens)
|
||||
assert.Contains(t, pref.PreferredModels, "gpt-4")
|
||||
}
|
||||
392
pkg/ai/factory.go
Normal file
392
pkg/ai/factory.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProviderFactory creates and manages AI model providers
|
||||
type ProviderFactory struct {
|
||||
configs map[string]ProviderConfig // provider name -> config
|
||||
providers map[string]ModelProvider // provider name -> instance
|
||||
roleMapping RoleModelMapping // role-based model selection
|
||||
healthChecks map[string]bool // provider name -> health status
|
||||
lastHealthCheck map[string]time.Time // provider name -> last check time
|
||||
CreateProvider func(config ProviderConfig) (ModelProvider, error) // provider creation function
|
||||
}
|
||||
|
||||
// NewProviderFactory creates a new provider factory
|
||||
func NewProviderFactory() *ProviderFactory {
|
||||
factory := &ProviderFactory{
|
||||
configs: make(map[string]ProviderConfig),
|
||||
providers: make(map[string]ModelProvider),
|
||||
healthChecks: make(map[string]bool),
|
||||
lastHealthCheck: make(map[string]time.Time),
|
||||
}
|
||||
factory.CreateProvider = factory.defaultCreateProvider
|
||||
return factory
|
||||
}
|
||||
|
||||
// RegisterProvider registers a provider configuration
|
||||
func (f *ProviderFactory) RegisterProvider(name string, config ProviderConfig) error {
|
||||
// Validate the configuration
|
||||
provider, err := f.CreateProvider(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create provider %s: %w", name, err)
|
||||
}
|
||||
|
||||
if err := provider.ValidateConfig(); err != nil {
|
||||
return fmt.Errorf("invalid configuration for provider %s: %w", name, err)
|
||||
}
|
||||
|
||||
f.configs[name] = config
|
||||
f.providers[name] = provider
|
||||
f.healthChecks[name] = true
|
||||
f.lastHealthCheck[name] = time.Now()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRoleMapping sets the role-to-model mapping configuration
|
||||
func (f *ProviderFactory) SetRoleMapping(mapping RoleModelMapping) {
|
||||
f.roleMapping = mapping
|
||||
}
|
||||
|
||||
// GetProvider returns a provider by name
|
||||
func (f *ProviderFactory) GetProvider(name string) (ModelProvider, error) {
|
||||
provider, exists := f.providers[name]
|
||||
if !exists {
|
||||
return nil, NewProviderError(ErrProviderNotFound, fmt.Sprintf("provider %s not found", name))
|
||||
}
|
||||
|
||||
// Check health status
|
||||
if !f.isProviderHealthy(name) {
|
||||
return nil, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("provider %s is unhealthy", name))
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// GetProviderForRole returns the best provider for a specific agent role
|
||||
func (f *ProviderFactory) GetProviderForRole(role string) (ModelProvider, ProviderConfig, error) {
|
||||
// Get role configuration
|
||||
roleConfig, exists := f.roleMapping.Roles[role]
|
||||
if !exists {
|
||||
// Fall back to default provider
|
||||
if f.roleMapping.DefaultProvider != "" {
|
||||
return f.getProviderWithFallback(f.roleMapping.DefaultProvider, f.roleMapping.FallbackProvider)
|
||||
}
|
||||
return nil, ProviderConfig{}, NewProviderError(ErrProviderNotFound, fmt.Sprintf("no provider configured for role %s", role))
|
||||
}
|
||||
|
||||
// Try primary provider first
|
||||
provider, config, err := f.getProviderWithFallback(roleConfig.Provider, roleConfig.FallbackProvider)
|
||||
if err != nil {
|
||||
// Try role fallback
|
||||
if roleConfig.FallbackProvider != "" {
|
||||
return f.getProviderWithFallback(roleConfig.FallbackProvider, f.roleMapping.FallbackProvider)
|
||||
}
|
||||
// Try global fallback
|
||||
if f.roleMapping.FallbackProvider != "" {
|
||||
return f.getProviderWithFallback(f.roleMapping.FallbackProvider, "")
|
||||
}
|
||||
return nil, ProviderConfig{}, err
|
||||
}
|
||||
|
||||
// Merge role-specific configuration
|
||||
mergedConfig := f.mergeRoleConfig(config, roleConfig)
|
||||
return provider, mergedConfig, nil
|
||||
}
|
||||
|
||||
// GetProviderForTask returns the best provider for a specific task
|
||||
func (f *ProviderFactory) GetProviderForTask(request *TaskRequest) (ModelProvider, ProviderConfig, error) {
|
||||
// Check if a specific model is requested
|
||||
if request.ModelName != "" {
|
||||
// Find provider that supports the requested model
|
||||
for name, provider := range f.providers {
|
||||
capabilities := provider.GetCapabilities()
|
||||
for _, supportedModel := range capabilities.SupportedModels {
|
||||
if supportedModel == request.ModelName {
|
||||
if f.isProviderHealthy(name) {
|
||||
config := f.configs[name]
|
||||
config.DefaultModel = request.ModelName // Override default model
|
||||
return provider, config, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, ProviderConfig{}, NewProviderError(ErrModelNotSupported, fmt.Sprintf("model %s not available", request.ModelName))
|
||||
}
|
||||
|
||||
// Use role-based selection
|
||||
return f.GetProviderForRole(request.AgentRole)
|
||||
}
|
||||
|
||||
// ListProviders returns all registered provider names
|
||||
func (f *ProviderFactory) ListProviders() []string {
|
||||
var names []string
|
||||
for name := range f.providers {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ListHealthyProviders returns only healthy provider names
|
||||
func (f *ProviderFactory) ListHealthyProviders() []string {
|
||||
var names []string
|
||||
for name := range f.providers {
|
||||
if f.isProviderHealthy(name) {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// GetProviderInfo returns information about all registered providers
|
||||
func (f *ProviderFactory) GetProviderInfo() map[string]ProviderInfo {
|
||||
info := make(map[string]ProviderInfo)
|
||||
for name, provider := range f.providers {
|
||||
providerInfo := provider.GetProviderInfo()
|
||||
providerInfo.Name = name // Override with registered name
|
||||
info[name] = providerInfo
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// HealthCheck performs health checks on all providers
|
||||
func (f *ProviderFactory) HealthCheck(ctx context.Context) map[string]error {
|
||||
results := make(map[string]error)
|
||||
|
||||
for name, provider := range f.providers {
|
||||
err := f.checkProviderHealth(ctx, name, provider)
|
||||
results[name] = err
|
||||
f.healthChecks[name] = (err == nil)
|
||||
f.lastHealthCheck[name] = time.Now()
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// GetHealthStatus returns the current health status of all providers
|
||||
func (f *ProviderFactory) GetHealthStatus() map[string]ProviderHealth {
|
||||
status := make(map[string]ProviderHealth)
|
||||
|
||||
for name, provider := range f.providers {
|
||||
status[name] = ProviderHealth{
|
||||
Name: name,
|
||||
Healthy: f.healthChecks[name],
|
||||
LastCheck: f.lastHealthCheck[name],
|
||||
ProviderInfo: provider.GetProviderInfo(),
|
||||
Capabilities: provider.GetCapabilities(),
|
||||
}
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// StartHealthCheckRoutine starts a background health check routine
|
||||
func (f *ProviderFactory) StartHealthCheckRoutine(ctx context.Context, interval time.Duration) {
|
||||
if interval == 0 {
|
||||
interval = 5 * time.Minute // Default to 5 minutes
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
healthCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
f.HealthCheck(healthCtx)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// defaultCreateProvider creates a provider instance based on configuration
|
||||
func (f *ProviderFactory) defaultCreateProvider(config ProviderConfig) (ModelProvider, error) {
|
||||
switch config.Type {
|
||||
case "ollama":
|
||||
return NewOllamaProvider(config), nil
|
||||
case "openai":
|
||||
return NewOpenAIProvider(config), nil
|
||||
case "resetdata":
|
||||
return NewResetDataProvider(config), nil
|
||||
default:
|
||||
return nil, NewProviderError(ErrProviderNotFound, fmt.Sprintf("unknown provider type: %s", config.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// getProviderWithFallback attempts to get a provider with fallback support
|
||||
func (f *ProviderFactory) getProviderWithFallback(primaryName, fallbackName string) (ModelProvider, ProviderConfig, error) {
|
||||
// Try primary provider
|
||||
if primaryName != "" {
|
||||
if provider, exists := f.providers[primaryName]; exists && f.isProviderHealthy(primaryName) {
|
||||
return provider, f.configs[primaryName], nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try fallback provider
|
||||
if fallbackName != "" {
|
||||
if provider, exists := f.providers[fallbackName]; exists && f.isProviderHealthy(fallbackName) {
|
||||
return provider, f.configs[fallbackName], nil
|
||||
}
|
||||
}
|
||||
|
||||
if primaryName != "" {
|
||||
return nil, ProviderConfig{}, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("provider %s and fallback %s are unavailable", primaryName, fallbackName))
|
||||
}
|
||||
|
||||
return nil, ProviderConfig{}, NewProviderError(ErrProviderNotFound, "no provider specified")
|
||||
}
|
||||
|
||||
// mergeRoleConfig merges role-specific configuration with provider configuration
|
||||
func (f *ProviderFactory) mergeRoleConfig(baseConfig ProviderConfig, roleConfig RoleConfig) ProviderConfig {
|
||||
merged := baseConfig
|
||||
|
||||
// Override model if specified in role config
|
||||
if roleConfig.Model != "" {
|
||||
merged.DefaultModel = roleConfig.Model
|
||||
}
|
||||
|
||||
// Override temperature if specified
|
||||
if roleConfig.Temperature > 0 {
|
||||
merged.Temperature = roleConfig.Temperature
|
||||
}
|
||||
|
||||
// Override max tokens if specified
|
||||
if roleConfig.MaxTokens > 0 {
|
||||
merged.MaxTokens = roleConfig.MaxTokens
|
||||
}
|
||||
|
||||
// Override tool settings
|
||||
if roleConfig.EnableTools {
|
||||
merged.EnableTools = roleConfig.EnableTools
|
||||
}
|
||||
if roleConfig.EnableMCP {
|
||||
merged.EnableMCP = roleConfig.EnableMCP
|
||||
}
|
||||
|
||||
// Merge MCP servers
|
||||
if len(roleConfig.MCPServers) > 0 {
|
||||
merged.MCPServers = append(merged.MCPServers, roleConfig.MCPServers...)
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// isProviderHealthy checks if a provider is currently healthy
|
||||
func (f *ProviderFactory) isProviderHealthy(name string) bool {
|
||||
healthy, exists := f.healthChecks[name]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if health check is too old (consider unhealthy if >10 minutes old)
|
||||
lastCheck, exists := f.lastHealthCheck[name]
|
||||
if !exists || time.Since(lastCheck) > 10*time.Minute {
|
||||
return false
|
||||
}
|
||||
|
||||
return healthy
|
||||
}
|
||||
|
||||
// checkProviderHealth performs a health check on a specific provider
|
||||
func (f *ProviderFactory) checkProviderHealth(ctx context.Context, name string, provider ModelProvider) error {
|
||||
// Create a minimal health check request
|
||||
healthRequest := &TaskRequest{
|
||||
TaskID: "health-check",
|
||||
AgentID: "health-checker",
|
||||
AgentRole: "system",
|
||||
Repository: "health-check",
|
||||
TaskTitle: "Health Check",
|
||||
TaskDescription: "Simple health check task",
|
||||
ModelName: "", // Use default
|
||||
MaxTokens: 50, // Minimal response
|
||||
EnableTools: false,
|
||||
}
|
||||
|
||||
// Set a short timeout for health checks
|
||||
healthCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := provider.ExecuteTask(healthCtx, healthRequest)
|
||||
return err
|
||||
}
|
||||
|
||||
// ProviderHealth represents the health status of a provider
|
||||
type ProviderHealth struct {
|
||||
Name string `json:"name"`
|
||||
Healthy bool `json:"healthy"`
|
||||
LastCheck time.Time `json:"last_check"`
|
||||
ProviderInfo ProviderInfo `json:"provider_info"`
|
||||
Capabilities ProviderCapabilities `json:"capabilities"`
|
||||
}
|
||||
|
||||
// DefaultProviderFactory creates a factory with common provider configurations
|
||||
func DefaultProviderFactory() *ProviderFactory {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Register default Ollama provider
|
||||
ollamaConfig := ProviderConfig{
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
DefaultModel: "llama3.1:8b",
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 4096,
|
||||
Timeout: 300 * time.Second,
|
||||
RetryAttempts: 3,
|
||||
RetryDelay: 2 * time.Second,
|
||||
EnableTools: true,
|
||||
EnableMCP: true,
|
||||
}
|
||||
factory.RegisterProvider("ollama", ollamaConfig)
|
||||
|
||||
// Set default role mapping
|
||||
defaultMapping := RoleModelMapping{
|
||||
DefaultProvider: "ollama",
|
||||
FallbackProvider: "ollama",
|
||||
Roles: map[string]RoleConfig{
|
||||
"developer": {
|
||||
Provider: "ollama",
|
||||
Model: "codellama:13b",
|
||||
Temperature: 0.3,
|
||||
MaxTokens: 8192,
|
||||
EnableTools: true,
|
||||
EnableMCP: true,
|
||||
SystemPrompt: "You are an expert software developer focused on writing clean, maintainable, and well-tested code.",
|
||||
},
|
||||
"reviewer": {
|
||||
Provider: "ollama",
|
||||
Model: "llama3.1:8b",
|
||||
Temperature: 0.2,
|
||||
MaxTokens: 6144,
|
||||
EnableTools: true,
|
||||
SystemPrompt: "You are a thorough code reviewer focused on quality, security, and best practices.",
|
||||
},
|
||||
"architect": {
|
||||
Provider: "ollama",
|
||||
Model: "llama3.1:13b",
|
||||
Temperature: 0.5,
|
||||
MaxTokens: 8192,
|
||||
EnableTools: true,
|
||||
SystemPrompt: "You are a senior software architect focused on system design and technical decision making.",
|
||||
},
|
||||
"tester": {
|
||||
Provider: "ollama",
|
||||
Model: "codellama:7b",
|
||||
Temperature: 0.3,
|
||||
MaxTokens: 6144,
|
||||
EnableTools: true,
|
||||
SystemPrompt: "You are a QA engineer focused on comprehensive testing and quality assurance.",
|
||||
},
|
||||
},
|
||||
}
|
||||
factory.SetRoleMapping(defaultMapping)
|
||||
|
||||
return factory
|
||||
}
|
||||
516
pkg/ai/factory_test.go
Normal file
516
pkg/ai/factory_test.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewProviderFactory(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
assert.NotNil(t, factory)
|
||||
assert.Empty(t, factory.configs)
|
||||
assert.Empty(t, factory.providers)
|
||||
assert.Empty(t, factory.healthChecks)
|
||||
assert.Empty(t, factory.lastHealthCheck)
|
||||
}
|
||||
|
||||
func TestProviderFactoryRegisterProvider(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Create a valid mock provider config (since validation will be called)
|
||||
config := ProviderConfig{
|
||||
Type: "mock",
|
||||
Endpoint: "mock://localhost",
|
||||
DefaultModel: "test-model",
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 4096,
|
||||
Timeout: 300 * time.Second,
|
||||
}
|
||||
|
||||
// Override CreateProvider to return our mock
|
||||
originalCreate := factory.CreateProvider
|
||||
factory.CreateProvider = func(config ProviderConfig) (ModelProvider, error) {
|
||||
return NewMockProvider("test-provider"), nil
|
||||
}
|
||||
defer func() { factory.CreateProvider = originalCreate }()
|
||||
|
||||
err := factory.RegisterProvider("test", config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify provider was registered
|
||||
assert.Len(t, factory.providers, 1)
|
||||
assert.Contains(t, factory.providers, "test")
|
||||
assert.True(t, factory.healthChecks["test"])
|
||||
}
|
||||
|
||||
func TestProviderFactoryRegisterProviderValidationFailure(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Create a mock provider that will fail validation
|
||||
config := ProviderConfig{
|
||||
Type: "mock",
|
||||
Endpoint: "mock://localhost",
|
||||
DefaultModel: "test-model",
|
||||
}
|
||||
|
||||
// Override CreateProvider to return a failing mock
|
||||
factory.CreateProvider = func(config ProviderConfig) (ModelProvider, error) {
|
||||
mock := NewMockProvider("failing-provider")
|
||||
mock.shouldFail = true // This will make ValidateConfig fail
|
||||
return mock, nil
|
||||
}
|
||||
|
||||
err := factory.RegisterProvider("failing", config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid configuration")
|
||||
|
||||
// Verify provider was not registered
|
||||
assert.Empty(t, factory.providers)
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetProvider(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
mockProvider := NewMockProvider("test-provider")
|
||||
|
||||
// Manually add provider and mark as healthy
|
||||
factory.providers["test"] = mockProvider
|
||||
factory.healthChecks["test"] = true
|
||||
factory.lastHealthCheck["test"] = time.Now()
|
||||
|
||||
provider, err := factory.GetProvider("test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, mockProvider, provider)
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetProviderNotFound(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
_, err := factory.GetProvider("nonexistent")
|
||||
require.Error(t, err)
|
||||
assert.IsType(t, &ProviderError{}, err)
|
||||
|
||||
providerErr := err.(*ProviderError)
|
||||
assert.Equal(t, "PROVIDER_NOT_FOUND", providerErr.Code)
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetProviderUnhealthy(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
mockProvider := NewMockProvider("test-provider")
|
||||
|
||||
// Add provider but mark as unhealthy
|
||||
factory.providers["test"] = mockProvider
|
||||
factory.healthChecks["test"] = false
|
||||
factory.lastHealthCheck["test"] = time.Now()
|
||||
|
||||
_, err := factory.GetProvider("test")
|
||||
require.Error(t, err)
|
||||
assert.IsType(t, &ProviderError{}, err)
|
||||
|
||||
providerErr := err.(*ProviderError)
|
||||
assert.Equal(t, "PROVIDER_UNAVAILABLE", providerErr.Code)
|
||||
}
|
||||
|
||||
func TestProviderFactorySetRoleMapping(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
mapping := RoleModelMapping{
|
||||
DefaultProvider: "test",
|
||||
FallbackProvider: "backup",
|
||||
Roles: map[string]RoleConfig{
|
||||
"developer": {
|
||||
Provider: "test",
|
||||
Model: "dev-model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
factory.SetRoleMapping(mapping)
|
||||
|
||||
assert.Equal(t, mapping, factory.roleMapping)
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetProviderForRole(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Set up providers
|
||||
devProvider := NewMockProvider("dev-provider")
|
||||
backupProvider := NewMockProvider("backup-provider")
|
||||
|
||||
factory.providers["dev"] = devProvider
|
||||
factory.providers["backup"] = backupProvider
|
||||
factory.healthChecks["dev"] = true
|
||||
factory.healthChecks["backup"] = true
|
||||
factory.lastHealthCheck["dev"] = time.Now()
|
||||
factory.lastHealthCheck["backup"] = time.Now()
|
||||
|
||||
factory.configs["dev"] = ProviderConfig{
|
||||
Type: "mock",
|
||||
DefaultModel: "dev-model",
|
||||
Temperature: 0.7,
|
||||
}
|
||||
|
||||
factory.configs["backup"] = ProviderConfig{
|
||||
Type: "mock",
|
||||
DefaultModel: "backup-model",
|
||||
Temperature: 0.8,
|
||||
}
|
||||
|
||||
// Set up role mapping
|
||||
mapping := RoleModelMapping{
|
||||
DefaultProvider: "backup",
|
||||
FallbackProvider: "backup",
|
||||
Roles: map[string]RoleConfig{
|
||||
"developer": {
|
||||
Provider: "dev",
|
||||
Model: "custom-dev-model",
|
||||
Temperature: 0.3,
|
||||
},
|
||||
},
|
||||
}
|
||||
factory.SetRoleMapping(mapping)
|
||||
|
||||
provider, config, err := factory.GetProviderForRole("developer")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, devProvider, provider)
|
||||
assert.Equal(t, "custom-dev-model", config.DefaultModel)
|
||||
assert.Equal(t, float32(0.3), config.Temperature)
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetProviderForRoleWithFallback(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Set up only backup provider (primary is missing)
|
||||
backupProvider := NewMockProvider("backup-provider")
|
||||
factory.providers["backup"] = backupProvider
|
||||
factory.healthChecks["backup"] = true
|
||||
factory.lastHealthCheck["backup"] = time.Now()
|
||||
factory.configs["backup"] = ProviderConfig{Type: "mock", DefaultModel: "backup-model"}
|
||||
|
||||
// Set up role mapping with primary provider that doesn't exist
|
||||
mapping := RoleModelMapping{
|
||||
DefaultProvider: "backup",
|
||||
FallbackProvider: "backup",
|
||||
Roles: map[string]RoleConfig{
|
||||
"developer": {
|
||||
Provider: "nonexistent",
|
||||
FallbackProvider: "backup",
|
||||
},
|
||||
},
|
||||
}
|
||||
factory.SetRoleMapping(mapping)
|
||||
|
||||
provider, config, err := factory.GetProviderForRole("developer")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, backupProvider, provider)
|
||||
assert.Equal(t, "backup-model", config.DefaultModel)
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetProviderForRoleNotFound(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// No providers registered and no default
|
||||
mapping := RoleModelMapping{
|
||||
Roles: make(map[string]RoleConfig),
|
||||
}
|
||||
factory.SetRoleMapping(mapping)
|
||||
|
||||
_, _, err := factory.GetProviderForRole("nonexistent")
|
||||
require.Error(t, err)
|
||||
assert.IsType(t, &ProviderError{}, err)
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetProviderForTask(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Set up a provider that supports a specific model
|
||||
mockProvider := NewMockProvider("test-provider")
|
||||
mockProvider.capabilities.SupportedModels = []string{"specific-model", "another-model"}
|
||||
|
||||
factory.providers["test"] = mockProvider
|
||||
factory.healthChecks["test"] = true
|
||||
factory.lastHealthCheck["test"] = time.Now()
|
||||
factory.configs["test"] = ProviderConfig{Type: "mock", DefaultModel: "default-model"}
|
||||
|
||||
request := &TaskRequest{
|
||||
TaskID: "test-123",
|
||||
AgentRole: "developer",
|
||||
ModelName: "specific-model", // Request specific model
|
||||
}
|
||||
|
||||
provider, config, err := factory.GetProviderForTask(request)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, mockProvider, provider)
|
||||
assert.Equal(t, "specific-model", config.DefaultModel) // Should override default
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetProviderForTaskModelNotSupported(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
mockProvider := NewMockProvider("test-provider")
|
||||
mockProvider.capabilities.SupportedModels = []string{"model-1", "model-2"}
|
||||
|
||||
factory.providers["test"] = mockProvider
|
||||
factory.healthChecks["test"] = true
|
||||
factory.lastHealthCheck["test"] = time.Now()
|
||||
|
||||
request := &TaskRequest{
|
||||
TaskID: "test-123",
|
||||
AgentRole: "developer",
|
||||
ModelName: "unsupported-model",
|
||||
}
|
||||
|
||||
_, _, err := factory.GetProviderForTask(request)
|
||||
require.Error(t, err)
|
||||
assert.IsType(t, &ProviderError{}, err)
|
||||
|
||||
providerErr := err.(*ProviderError)
|
||||
assert.Equal(t, "MODEL_NOT_SUPPORTED", providerErr.Code)
|
||||
}
|
||||
|
||||
func TestProviderFactoryListProviders(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Add some mock providers
|
||||
factory.providers["provider1"] = NewMockProvider("provider1")
|
||||
factory.providers["provider2"] = NewMockProvider("provider2")
|
||||
factory.providers["provider3"] = NewMockProvider("provider3")
|
||||
|
||||
providers := factory.ListProviders()
|
||||
|
||||
assert.Len(t, providers, 3)
|
||||
assert.Contains(t, providers, "provider1")
|
||||
assert.Contains(t, providers, "provider2")
|
||||
assert.Contains(t, providers, "provider3")
|
||||
}
|
||||
|
||||
func TestProviderFactoryListHealthyProviders(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Add providers with different health states
|
||||
factory.providers["healthy1"] = NewMockProvider("healthy1")
|
||||
factory.providers["healthy2"] = NewMockProvider("healthy2")
|
||||
factory.providers["unhealthy"] = NewMockProvider("unhealthy")
|
||||
|
||||
factory.healthChecks["healthy1"] = true
|
||||
factory.healthChecks["healthy2"] = true
|
||||
factory.healthChecks["unhealthy"] = false
|
||||
|
||||
factory.lastHealthCheck["healthy1"] = time.Now()
|
||||
factory.lastHealthCheck["healthy2"] = time.Now()
|
||||
factory.lastHealthCheck["unhealthy"] = time.Now()
|
||||
|
||||
healthyProviders := factory.ListHealthyProviders()
|
||||
|
||||
assert.Len(t, healthyProviders, 2)
|
||||
assert.Contains(t, healthyProviders, "healthy1")
|
||||
assert.Contains(t, healthyProviders, "healthy2")
|
||||
assert.NotContains(t, healthyProviders, "unhealthy")
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetProviderInfo(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
mock1 := NewMockProvider("mock1")
|
||||
mock2 := NewMockProvider("mock2")
|
||||
|
||||
factory.providers["provider1"] = mock1
|
||||
factory.providers["provider2"] = mock2
|
||||
|
||||
info := factory.GetProviderInfo()
|
||||
|
||||
assert.Len(t, info, 2)
|
||||
assert.Contains(t, info, "provider1")
|
||||
assert.Contains(t, info, "provider2")
|
||||
|
||||
// Verify that the name is overridden with the registered name
|
||||
assert.Equal(t, "provider1", info["provider1"].Name)
|
||||
assert.Equal(t, "provider2", info["provider2"].Name)
|
||||
}
|
||||
|
||||
func TestProviderFactoryHealthCheck(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Add a healthy and an unhealthy provider
|
||||
healthyProvider := NewMockProvider("healthy")
|
||||
unhealthyProvider := NewMockProvider("unhealthy")
|
||||
unhealthyProvider.shouldFail = true
|
||||
|
||||
factory.providers["healthy"] = healthyProvider
|
||||
factory.providers["unhealthy"] = unhealthyProvider
|
||||
|
||||
ctx := context.Background()
|
||||
results := factory.HealthCheck(ctx)
|
||||
|
||||
assert.Len(t, results, 2)
|
||||
assert.NoError(t, results["healthy"])
|
||||
assert.Error(t, results["unhealthy"])
|
||||
|
||||
// Verify health states were updated
|
||||
assert.True(t, factory.healthChecks["healthy"])
|
||||
assert.False(t, factory.healthChecks["unhealthy"])
|
||||
}
|
||||
|
||||
func TestProviderFactoryGetHealthStatus(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
mockProvider := NewMockProvider("test")
|
||||
factory.providers["test"] = mockProvider
|
||||
|
||||
now := time.Now()
|
||||
factory.healthChecks["test"] = true
|
||||
factory.lastHealthCheck["test"] = now
|
||||
|
||||
status := factory.GetHealthStatus()
|
||||
|
||||
assert.Len(t, status, 1)
|
||||
assert.Contains(t, status, "test")
|
||||
|
||||
testStatus := status["test"]
|
||||
assert.Equal(t, "test", testStatus.Name)
|
||||
assert.True(t, testStatus.Healthy)
|
||||
assert.Equal(t, now, testStatus.LastCheck)
|
||||
}
|
||||
|
||||
func TestProviderFactoryIsProviderHealthy(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Test healthy provider
|
||||
factory.healthChecks["healthy"] = true
|
||||
factory.lastHealthCheck["healthy"] = time.Now()
|
||||
assert.True(t, factory.isProviderHealthy("healthy"))
|
||||
|
||||
// Test unhealthy provider
|
||||
factory.healthChecks["unhealthy"] = false
|
||||
factory.lastHealthCheck["unhealthy"] = time.Now()
|
||||
assert.False(t, factory.isProviderHealthy("unhealthy"))
|
||||
|
||||
// Test provider with old health check (should be considered unhealthy)
|
||||
factory.healthChecks["stale"] = true
|
||||
factory.lastHealthCheck["stale"] = time.Now().Add(-15 * time.Minute)
|
||||
assert.False(t, factory.isProviderHealthy("stale"))
|
||||
|
||||
// Test non-existent provider
|
||||
assert.False(t, factory.isProviderHealthy("nonexistent"))
|
||||
}
|
||||
|
||||
func TestProviderFactoryMergeRoleConfig(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
baseConfig := ProviderConfig{
|
||||
Type: "test",
|
||||
DefaultModel: "base-model",
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 4096,
|
||||
EnableTools: false,
|
||||
EnableMCP: false,
|
||||
MCPServers: []string{"base-server"},
|
||||
}
|
||||
|
||||
roleConfig := RoleConfig{
|
||||
Model: "role-model",
|
||||
Temperature: 0.3,
|
||||
MaxTokens: 8192,
|
||||
EnableTools: true,
|
||||
EnableMCP: true,
|
||||
MCPServers: []string{"role-server"},
|
||||
}
|
||||
|
||||
merged := factory.mergeRoleConfig(baseConfig, roleConfig)
|
||||
|
||||
assert.Equal(t, "role-model", merged.DefaultModel)
|
||||
assert.Equal(t, float32(0.3), merged.Temperature)
|
||||
assert.Equal(t, 8192, merged.MaxTokens)
|
||||
assert.True(t, merged.EnableTools)
|
||||
assert.True(t, merged.EnableMCP)
|
||||
assert.Len(t, merged.MCPServers, 2) // Should be merged
|
||||
assert.Contains(t, merged.MCPServers, "base-server")
|
||||
assert.Contains(t, merged.MCPServers, "role-server")
|
||||
}
|
||||
|
||||
func TestDefaultProviderFactory(t *testing.T) {
|
||||
factory := DefaultProviderFactory()
|
||||
|
||||
// Should have at least the default ollama provider
|
||||
providers := factory.ListProviders()
|
||||
assert.Contains(t, providers, "ollama")
|
||||
|
||||
// Should have role mappings configured
|
||||
assert.NotEmpty(t, factory.roleMapping.Roles)
|
||||
assert.Contains(t, factory.roleMapping.Roles, "developer")
|
||||
assert.Contains(t, factory.roleMapping.Roles, "reviewer")
|
||||
|
||||
// Test getting provider for developer role
|
||||
_, config, err := factory.GetProviderForRole("developer")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "codellama:13b", config.DefaultModel)
|
||||
assert.Equal(t, float32(0.3), config.Temperature)
|
||||
}
|
||||
|
||||
func TestProviderFactoryCreateProvider(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config ProviderConfig
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "ollama provider",
|
||||
config: ProviderConfig{
|
||||
Type: "ollama",
|
||||
Endpoint: "http://localhost:11434",
|
||||
DefaultModel: "llama2",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "openai provider",
|
||||
config: ProviderConfig{
|
||||
Type: "openai",
|
||||
Endpoint: "https://api.openai.com/v1",
|
||||
APIKey: "test-key",
|
||||
DefaultModel: "gpt-4",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "resetdata provider",
|
||||
config: ProviderConfig{
|
||||
Type: "resetdata",
|
||||
Endpoint: "https://api.resetdata.ai",
|
||||
APIKey: "test-key",
|
||||
DefaultModel: "llama2",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "unknown provider",
|
||||
config: ProviderConfig{
|
||||
Type: "unknown",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := factory.CreateProvider(tt.config)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
433
pkg/ai/ollama.go
Normal file
433
pkg/ai/ollama.go
Normal file
@@ -0,0 +1,433 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OllamaProvider implements ModelProvider for local Ollama instances
|
||||
type OllamaProvider struct {
|
||||
config ProviderConfig
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// OllamaRequest represents a request to Ollama API
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Messages []OllamaMessage `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
Format string `json:"format,omitempty"`
|
||||
Options map[string]interface{} `json:"options,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaMessage represents a message in the Ollama chat format
|
||||
type OllamaMessage struct {
|
||||
Role string `json:"role"` // system, user, assistant
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// OllamaResponse represents a response from Ollama API
|
||||
type OllamaResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message OllamaMessage `json:"message,omitempty"`
|
||||
Response string `json:"response,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
TotalDuration int64 `json:"total_duration,omitempty"`
|
||||
LoadDuration int64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaModelsResponse represents the response from /api/tags endpoint
|
||||
type OllamaModelsResponse struct {
|
||||
Models []OllamaModel `json:"models"`
|
||||
}
|
||||
|
||||
// OllamaModel represents a model in Ollama
|
||||
type OllamaModel struct {
|
||||
Name string `json:"name"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details OllamaModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaModelDetails provides detailed model information
|
||||
type OllamaModelDetails struct {
|
||||
Format string `json:"format"`
|
||||
Family string `json:"family"`
|
||||
Families []string `json:"families,omitempty"`
|
||||
ParameterSize string `json:"parameter_size"`
|
||||
QuantizationLevel string `json:"quantization_level"`
|
||||
}
|
||||
|
||||
// NewOllamaProvider creates a new Ollama provider instance
|
||||
func NewOllamaProvider(config ProviderConfig) *OllamaProvider {
|
||||
timeout := config.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 300 * time.Second // 5 minutes default for task execution
|
||||
}
|
||||
|
||||
return &OllamaProvider{
|
||||
config: config,
|
||||
httpClient: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteTask implements the ModelProvider interface for Ollama
|
||||
func (p *OllamaProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Build the prompt from task context
|
||||
prompt, err := p.buildTaskPrompt(request)
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to build prompt: %v", err))
|
||||
}
|
||||
|
||||
// Prepare Ollama request
|
||||
ollamaReq := OllamaRequest{
|
||||
Model: p.selectModel(request.ModelName),
|
||||
Stream: false,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": p.getTemperature(request.Temperature),
|
||||
"num_predict": p.getMaxTokens(request.MaxTokens),
|
||||
},
|
||||
}
|
||||
|
||||
// Use chat format for better conversation handling
|
||||
ollamaReq.Messages = []OllamaMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: p.getSystemPrompt(request),
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
// Execute the request
|
||||
response, err := p.makeRequest(ctx, "/api/chat", ollamaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
|
||||
// Parse response and extract actions
|
||||
actions, artifacts := p.parseResponseForActions(response.Message.Content, request)
|
||||
|
||||
return &TaskResponse{
|
||||
Success: true,
|
||||
TaskID: request.TaskID,
|
||||
AgentID: request.AgentID,
|
||||
ModelUsed: response.Model,
|
||||
Provider: "ollama",
|
||||
Response: response.Message.Content,
|
||||
Actions: actions,
|
||||
Artifacts: artifacts,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
Duration: endTime.Sub(startTime),
|
||||
TokensUsed: TokenUsage{
|
||||
PromptTokens: response.PromptEvalCount,
|
||||
CompletionTokens: response.EvalCount,
|
||||
TotalTokens: response.PromptEvalCount + response.EvalCount,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetCapabilities returns Ollama provider capabilities
|
||||
func (p *OllamaProvider) GetCapabilities() ProviderCapabilities {
|
||||
return ProviderCapabilities{
|
||||
SupportsMCP: p.config.EnableMCP,
|
||||
SupportsTools: p.config.EnableTools,
|
||||
SupportsStreaming: true,
|
||||
SupportsFunctions: false, // Ollama doesn't support function calling natively
|
||||
MaxTokens: p.config.MaxTokens,
|
||||
SupportedModels: p.getSupportedModels(),
|
||||
SupportsImages: true, // Many Ollama models support images
|
||||
SupportsFiles: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig validates the Ollama provider configuration
|
||||
func (p *OllamaProvider) ValidateConfig() error {
|
||||
if p.config.Endpoint == "" {
|
||||
return NewProviderError(ErrInvalidConfiguration, "endpoint is required for Ollama provider")
|
||||
}
|
||||
|
||||
if p.config.DefaultModel == "" {
|
||||
return NewProviderError(ErrInvalidConfiguration, "default_model is required for Ollama provider")
|
||||
}
|
||||
|
||||
// Test connection to Ollama
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := p.testConnection(ctx); err != nil {
|
||||
return NewProviderError(ErrProviderUnavailable, fmt.Sprintf("failed to connect to Ollama: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProviderInfo returns information about the Ollama provider
|
||||
func (p *OllamaProvider) GetProviderInfo() ProviderInfo {
|
||||
return ProviderInfo{
|
||||
Name: "Ollama",
|
||||
Type: "ollama",
|
||||
Version: "1.0.0",
|
||||
Endpoint: p.config.Endpoint,
|
||||
DefaultModel: p.config.DefaultModel,
|
||||
RequiresAPIKey: false,
|
||||
RateLimit: 0, // No rate limit for local Ollama
|
||||
}
|
||||
}
|
||||
|
||||
// buildTaskPrompt constructs a comprehensive prompt for task execution
|
||||
func (p *OllamaProvider) buildTaskPrompt(request *TaskRequest) (string, error) {
|
||||
var prompt strings.Builder
|
||||
|
||||
prompt.WriteString(fmt.Sprintf("You are a %s agent working on a task in the repository: %s\n\n",
|
||||
request.AgentRole, request.Repository))
|
||||
|
||||
prompt.WriteString(fmt.Sprintf("**Task Title:** %s\n", request.TaskTitle))
|
||||
prompt.WriteString(fmt.Sprintf("**Task Description:**\n%s\n\n", request.TaskDescription))
|
||||
|
||||
if len(request.TaskLabels) > 0 {
|
||||
prompt.WriteString(fmt.Sprintf("**Labels:** %s\n", strings.Join(request.TaskLabels, ", ")))
|
||||
}
|
||||
|
||||
prompt.WriteString(fmt.Sprintf("**Priority:** %d/10\n", request.Priority))
|
||||
prompt.WriteString(fmt.Sprintf("**Complexity:** %d/10\n\n", request.Complexity))
|
||||
|
||||
if request.WorkingDirectory != "" {
|
||||
prompt.WriteString(fmt.Sprintf("**Working Directory:** %s\n", request.WorkingDirectory))
|
||||
}
|
||||
|
||||
if len(request.RepositoryFiles) > 0 {
|
||||
prompt.WriteString("**Relevant Files:**\n")
|
||||
for _, file := range request.RepositoryFiles {
|
||||
prompt.WriteString(fmt.Sprintf("- %s\n", file))
|
||||
}
|
||||
prompt.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add role-specific instructions
|
||||
prompt.WriteString(p.getRoleSpecificInstructions(request.AgentRole))
|
||||
|
||||
prompt.WriteString("\nPlease analyze the task and provide a detailed plan for implementation. ")
|
||||
prompt.WriteString("If you need to make changes to files, describe the specific changes needed. ")
|
||||
prompt.WriteString("If you need to run commands, specify the exact commands to execute.")
|
||||
|
||||
return prompt.String(), nil
|
||||
}
|
||||
|
||||
// getRoleSpecificInstructions returns instructions specific to the agent role
|
||||
func (p *OllamaProvider) getRoleSpecificInstructions(role string) string {
|
||||
switch strings.ToLower(role) {
|
||||
case "developer":
|
||||
return `As a developer agent, focus on:
|
||||
- Implementing code changes to address the task requirements
|
||||
- Following best practices for the programming language
|
||||
- Writing clean, maintainable, and well-documented code
|
||||
- Ensuring proper error handling and edge case coverage
|
||||
- Running appropriate tests to validate your changes`
|
||||
|
||||
case "reviewer":
|
||||
return `As a reviewer agent, focus on:
|
||||
- Analyzing code quality and adherence to best practices
|
||||
- Identifying potential bugs, security issues, or performance problems
|
||||
- Suggesting improvements for maintainability and readability
|
||||
- Validating test coverage and test quality
|
||||
- Ensuring documentation is accurate and complete`
|
||||
|
||||
case "architect":
|
||||
return `As an architect agent, focus on:
|
||||
- Designing system architecture and component interactions
|
||||
- Making technology stack and framework decisions
|
||||
- Defining interfaces and API contracts
|
||||
- Considering scalability, performance, and security implications
|
||||
- Creating architectural documentation and diagrams`
|
||||
|
||||
case "tester":
|
||||
return `As a tester agent, focus on:
|
||||
- Creating comprehensive test cases and test plans
|
||||
- Implementing unit, integration, and end-to-end tests
|
||||
- Identifying edge cases and potential failure scenarios
|
||||
- Setting up test automation and CI/CD integration
|
||||
- Validating functionality against requirements`
|
||||
|
||||
default:
|
||||
return `As an AI agent, focus on:
|
||||
- Understanding the task requirements thoroughly
|
||||
- Providing a clear and actionable implementation plan
|
||||
- Following software development best practices
|
||||
- Ensuring your work is well-documented and maintainable`
|
||||
}
|
||||
}
|
||||
|
||||
// selectModel chooses the appropriate model for the request
|
||||
func (p *OllamaProvider) selectModel(requestedModel string) string {
|
||||
if requestedModel != "" {
|
||||
return requestedModel
|
||||
}
|
||||
return p.config.DefaultModel
|
||||
}
|
||||
|
||||
// getTemperature returns the temperature setting for the request
|
||||
func (p *OllamaProvider) getTemperature(requestTemp float32) float32 {
|
||||
if requestTemp > 0 {
|
||||
return requestTemp
|
||||
}
|
||||
if p.config.Temperature > 0 {
|
||||
return p.config.Temperature
|
||||
}
|
||||
return 0.7 // Default temperature
|
||||
}
|
||||
|
||||
// getMaxTokens returns the max tokens setting for the request
|
||||
func (p *OllamaProvider) getMaxTokens(requestTokens int) int {
|
||||
if requestTokens > 0 {
|
||||
return requestTokens
|
||||
}
|
||||
if p.config.MaxTokens > 0 {
|
||||
return p.config.MaxTokens
|
||||
}
|
||||
return 4096 // Default max tokens
|
||||
}
|
||||
|
||||
// getSystemPrompt constructs the system prompt
|
||||
func (p *OllamaProvider) getSystemPrompt(request *TaskRequest) string {
|
||||
if request.SystemPrompt != "" {
|
||||
return request.SystemPrompt
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`You are an AI assistant specializing in software development tasks.
|
||||
You are currently working as a %s agent in the CHORUS autonomous agent system.
|
||||
|
||||
Your capabilities include:
|
||||
- Analyzing code and repository structures
|
||||
- Implementing features and fixing bugs
|
||||
- Writing and reviewing code in multiple programming languages
|
||||
- Creating tests and documentation
|
||||
- Following software development best practices
|
||||
|
||||
Always provide detailed, actionable responses with specific implementation steps.`, request.AgentRole)
|
||||
}
|
||||
|
||||
// makeRequest makes an HTTP request to the Ollama API
|
||||
func (p *OllamaProvider) makeRequest(ctx context.Context, endpoint string, request interface{}) (*OllamaResponse, error) {
|
||||
requestJSON, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to marshal request: %v", err))
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(p.config.Endpoint, "/") + endpoint
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(requestJSON))
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to create request: %v", err))
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add custom headers if configured
|
||||
for key, value := range p.config.CustomHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed,
|
||||
fmt.Sprintf("API request failed with status %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
var ollamaResp OllamaResponse
|
||||
if err := json.Unmarshal(body, &ollamaResp); err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to parse response: %v", err))
|
||||
}
|
||||
|
||||
return &ollamaResp, nil
|
||||
}
|
||||
|
||||
// testConnection tests the connection to Ollama
|
||||
func (p *OllamaProvider) testConnection(ctx context.Context) error {
|
||||
url := strings.TrimSuffix(p.config.Endpoint, "/") + "/api/tags"
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSupportedModels returns a list of supported models (would normally query Ollama)
|
||||
func (p *OllamaProvider) getSupportedModels() []string {
|
||||
// In a real implementation, this would query the /api/tags endpoint
|
||||
return []string{
|
||||
"llama3.1:8b", "llama3.1:13b", "llama3.1:70b",
|
||||
"codellama:7b", "codellama:13b", "codellama:34b",
|
||||
"mistral:7b", "mixtral:8x7b",
|
||||
"qwen2:7b", "gemma:7b",
|
||||
}
|
||||
}
|
||||
|
||||
// parseResponseForActions extracts actions and artifacts from the response
|
||||
func (p *OllamaProvider) parseResponseForActions(response string, request *TaskRequest) ([]TaskAction, []Artifact) {
|
||||
var actions []TaskAction
|
||||
var artifacts []Artifact
|
||||
|
||||
// This is a simplified implementation - in reality, you'd parse the response
|
||||
// to extract specific actions like file changes, commands to run, etc.
|
||||
|
||||
// For now, just create a basic action indicating task analysis
|
||||
action := TaskAction{
|
||||
Type: "task_analysis",
|
||||
Target: request.TaskTitle,
|
||||
Content: response,
|
||||
Result: "Task analyzed successfully",
|
||||
Success: true,
|
||||
Timestamp: time.Now(),
|
||||
Metadata: map[string]interface{}{
|
||||
"agent_role": request.AgentRole,
|
||||
"repository": request.Repository,
|
||||
},
|
||||
}
|
||||
actions = append(actions, action)
|
||||
|
||||
return actions, artifacts
|
||||
}
|
||||
518
pkg/ai/openai.go
Normal file
518
pkg/ai/openai.go
Normal file
@@ -0,0 +1,518 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// OpenAIProvider implements ModelProvider for OpenAI API
|
||||
type OpenAIProvider struct {
|
||||
config ProviderConfig
|
||||
client *openai.Client
|
||||
}
|
||||
|
||||
// NewOpenAIProvider creates a new OpenAI provider instance
|
||||
func NewOpenAIProvider(config ProviderConfig) *OpenAIProvider {
|
||||
client := openai.NewClient(config.APIKey)
|
||||
|
||||
// Use custom endpoint if specified
|
||||
if config.Endpoint != "" && config.Endpoint != "https://api.openai.com/v1" {
|
||||
clientConfig := openai.DefaultConfig(config.APIKey)
|
||||
clientConfig.BaseURL = config.Endpoint
|
||||
client = openai.NewClientWithConfig(clientConfig)
|
||||
}
|
||||
|
||||
return &OpenAIProvider{
|
||||
config: config,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteTask implements the ModelProvider interface for OpenAI
|
||||
func (p *OpenAIProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Build messages for the chat completion
|
||||
messages, err := p.buildChatMessages(request)
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to build messages: %v", err))
|
||||
}
|
||||
|
||||
// Prepare the chat completion request
|
||||
chatReq := openai.ChatCompletionRequest{
|
||||
Model: p.selectModel(request.ModelName),
|
||||
Messages: messages,
|
||||
Temperature: p.getTemperature(request.Temperature),
|
||||
MaxTokens: p.getMaxTokens(request.MaxTokens),
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
// Add tools if enabled and supported
|
||||
if p.config.EnableTools && request.EnableTools {
|
||||
chatReq.Tools = p.getToolDefinitions(request)
|
||||
chatReq.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
// Execute the chat completion
|
||||
resp, err := p.client.CreateChatCompletion(ctx, chatReq)
|
||||
if err != nil {
|
||||
return nil, p.handleOpenAIError(err)
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
|
||||
// Process the response
|
||||
if len(resp.Choices) == 0 {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, "no response choices returned from OpenAI")
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
responseText := choice.Message.Content
|
||||
|
||||
// Process tool calls if present
|
||||
var actions []TaskAction
|
||||
var artifacts []Artifact
|
||||
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
toolActions, toolArtifacts := p.processToolCalls(choice.Message.ToolCalls, request)
|
||||
actions = append(actions, toolActions...)
|
||||
artifacts = append(artifacts, toolArtifacts...)
|
||||
}
|
||||
|
||||
// Parse response for additional actions
|
||||
responseActions, responseArtifacts := p.parseResponseForActions(responseText, request)
|
||||
actions = append(actions, responseActions...)
|
||||
artifacts = append(artifacts, responseArtifacts...)
|
||||
|
||||
return &TaskResponse{
|
||||
Success: true,
|
||||
TaskID: request.TaskID,
|
||||
AgentID: request.AgentID,
|
||||
ModelUsed: resp.Model,
|
||||
Provider: "openai",
|
||||
Response: responseText,
|
||||
Actions: actions,
|
||||
Artifacts: artifacts,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
Duration: endTime.Sub(startTime),
|
||||
TokensUsed: TokenUsage{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetCapabilities returns OpenAI provider capabilities
|
||||
func (p *OpenAIProvider) GetCapabilities() ProviderCapabilities {
|
||||
return ProviderCapabilities{
|
||||
SupportsMCP: p.config.EnableMCP,
|
||||
SupportsTools: true, // OpenAI supports function calling
|
||||
SupportsStreaming: true,
|
||||
SupportsFunctions: true,
|
||||
MaxTokens: p.getModelMaxTokens(p.config.DefaultModel),
|
||||
SupportedModels: p.getSupportedModels(),
|
||||
SupportsImages: p.modelSupportsImages(p.config.DefaultModel),
|
||||
SupportsFiles: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig validates the OpenAI provider configuration
|
||||
func (p *OpenAIProvider) ValidateConfig() error {
|
||||
if p.config.APIKey == "" {
|
||||
return NewProviderError(ErrAPIKeyRequired, "API key is required for OpenAI provider")
|
||||
}
|
||||
|
||||
if p.config.DefaultModel == "" {
|
||||
return NewProviderError(ErrInvalidConfiguration, "default_model is required for OpenAI provider")
|
||||
}
|
||||
|
||||
// Test the API connection with a minimal request
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := p.testConnection(ctx); err != nil {
|
||||
return NewProviderError(ErrProviderUnavailable, fmt.Sprintf("failed to connect to OpenAI: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProviderInfo returns information about the OpenAI provider
|
||||
func (p *OpenAIProvider) GetProviderInfo() ProviderInfo {
|
||||
endpoint := p.config.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
return ProviderInfo{
|
||||
Name: "OpenAI",
|
||||
Type: "openai",
|
||||
Version: "1.0.0",
|
||||
Endpoint: endpoint,
|
||||
DefaultModel: p.config.DefaultModel,
|
||||
RequiresAPIKey: true,
|
||||
RateLimit: 10000, // Approximate RPM for paid accounts
|
||||
}
|
||||
}
|
||||
|
||||
// buildChatMessages constructs messages for the OpenAI chat completion
|
||||
func (p *OpenAIProvider) buildChatMessages(request *TaskRequest) ([]openai.ChatCompletionMessage, error) {
|
||||
var messages []openai.ChatCompletionMessage
|
||||
|
||||
// System message
|
||||
systemPrompt := p.getSystemPrompt(request)
|
||||
if systemPrompt != "" {
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// User message with task details
|
||||
userPrompt, err := p.buildTaskPrompt(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: userPrompt,
|
||||
})
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// buildTaskPrompt constructs a comprehensive prompt for task execution
|
||||
func (p *OpenAIProvider) buildTaskPrompt(request *TaskRequest) (string, error) {
|
||||
var prompt strings.Builder
|
||||
|
||||
prompt.WriteString(fmt.Sprintf("You are working as a %s agent on the following task:\n\n",
|
||||
request.AgentRole))
|
||||
|
||||
prompt.WriteString(fmt.Sprintf("**Repository:** %s\n", request.Repository))
|
||||
prompt.WriteString(fmt.Sprintf("**Task:** %s\n", request.TaskTitle))
|
||||
prompt.WriteString(fmt.Sprintf("**Description:**\n%s\n\n", request.TaskDescription))
|
||||
|
||||
if len(request.TaskLabels) > 0 {
|
||||
prompt.WriteString(fmt.Sprintf("**Labels:** %s\n", strings.Join(request.TaskLabels, ", ")))
|
||||
}
|
||||
|
||||
prompt.WriteString(fmt.Sprintf("**Priority:** %d/10 | **Complexity:** %d/10\n\n",
|
||||
request.Priority, request.Complexity))
|
||||
|
||||
if request.WorkingDirectory != "" {
|
||||
prompt.WriteString(fmt.Sprintf("**Working Directory:** %s\n", request.WorkingDirectory))
|
||||
}
|
||||
|
||||
if len(request.RepositoryFiles) > 0 {
|
||||
prompt.WriteString("**Relevant Files:**\n")
|
||||
for _, file := range request.RepositoryFiles {
|
||||
prompt.WriteString(fmt.Sprintf("- %s\n", file))
|
||||
}
|
||||
prompt.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add role-specific guidance
|
||||
prompt.WriteString(p.getRoleSpecificGuidance(request.AgentRole))
|
||||
|
||||
prompt.WriteString("\nAnalyze this task and provide a detailed implementation plan. ")
|
||||
if request.EnableTools {
|
||||
prompt.WriteString("Use the available tools to make concrete changes or gather information as needed. ")
|
||||
}
|
||||
prompt.WriteString("Be specific about what needs to be done and how to accomplish it.")
|
||||
|
||||
return prompt.String(), nil
|
||||
}
|
||||
|
||||
// getRoleSpecificGuidance returns guidance specific to the agent role
|
||||
func (p *OpenAIProvider) getRoleSpecificGuidance(role string) string {
|
||||
switch strings.ToLower(role) {
|
||||
case "developer":
|
||||
return `**Developer Guidelines:**
|
||||
- Write clean, maintainable, and well-documented code
|
||||
- Follow language-specific best practices and conventions
|
||||
- Implement proper error handling and validation
|
||||
- Create or update tests to cover your changes
|
||||
- Consider performance and security implications`
|
||||
|
||||
case "reviewer":
|
||||
return `**Code Review Guidelines:**
|
||||
- Analyze code quality, readability, and maintainability
|
||||
- Check for bugs, security vulnerabilities, and performance issues
|
||||
- Verify test coverage and quality
|
||||
- Ensure documentation is accurate and complete
|
||||
- Suggest improvements and alternatives`
|
||||
|
||||
case "architect":
|
||||
return `**Architecture Guidelines:**
|
||||
- Design scalable and maintainable system architecture
|
||||
- Make informed technology and framework decisions
|
||||
- Define clear interfaces and API contracts
|
||||
- Consider security, performance, and scalability requirements
|
||||
- Document architectural decisions and rationale`
|
||||
|
||||
case "tester":
|
||||
return `**Testing Guidelines:**
|
||||
- Create comprehensive test plans and test cases
|
||||
- Implement unit, integration, and end-to-end tests
|
||||
- Identify edge cases and potential failure scenarios
|
||||
- Set up test automation and continuous integration
|
||||
- Validate functionality against requirements`
|
||||
|
||||
default:
|
||||
return `**General Guidelines:**
|
||||
- Understand requirements thoroughly before implementation
|
||||
- Follow software development best practices
|
||||
- Provide clear documentation and explanations
|
||||
- Consider maintainability and future extensibility`
|
||||
}
|
||||
}
|
||||
|
||||
// getToolDefinitions returns tool definitions for OpenAI function calling
|
||||
func (p *OpenAIProvider) getToolDefinitions(request *TaskRequest) []openai.Tool {
|
||||
var tools []openai.Tool
|
||||
|
||||
// File operations tool
|
||||
tools = append(tools, openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: "file_operation",
|
||||
Description: "Create, read, update, or delete files in the repository",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"operation": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"create", "read", "update", "delete"},
|
||||
"description": "The file operation to perform",
|
||||
},
|
||||
"path": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The file path relative to the repository root",
|
||||
},
|
||||
"content": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The file content (for create/update operations)",
|
||||
},
|
||||
},
|
||||
"required": []string{"operation", "path"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Command execution tool
|
||||
tools = append(tools, openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: "execute_command",
|
||||
Description: "Execute shell commands in the repository working directory",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"command": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The shell command to execute",
|
||||
},
|
||||
"working_dir": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Working directory for command execution (optional)",
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
// processToolCalls handles OpenAI function calls
|
||||
func (p *OpenAIProvider) processToolCalls(toolCalls []openai.ToolCall, request *TaskRequest) ([]TaskAction, []Artifact) {
|
||||
var actions []TaskAction
|
||||
var artifacts []Artifact
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
action := TaskAction{
|
||||
Type: "function_call",
|
||||
Target: toolCall.Function.Name,
|
||||
Content: toolCall.Function.Arguments,
|
||||
Timestamp: time.Now(),
|
||||
Metadata: map[string]interface{}{
|
||||
"tool_call_id": toolCall.ID,
|
||||
"function": toolCall.Function.Name,
|
||||
},
|
||||
}
|
||||
|
||||
// In a real implementation, you would actually execute these tool calls
|
||||
// For now, just mark them as successful
|
||||
action.Result = fmt.Sprintf("Function call %s processed", toolCall.Function.Name)
|
||||
action.Success = true
|
||||
|
||||
actions = append(actions, action)
|
||||
}
|
||||
|
||||
return actions, artifacts
|
||||
}
|
||||
|
||||
// selectModel chooses the appropriate OpenAI model
|
||||
func (p *OpenAIProvider) selectModel(requestedModel string) string {
|
||||
if requestedModel != "" {
|
||||
return requestedModel
|
||||
}
|
||||
return p.config.DefaultModel
|
||||
}
|
||||
|
||||
// getTemperature returns the temperature setting
|
||||
func (p *OpenAIProvider) getTemperature(requestTemp float32) float32 {
|
||||
if requestTemp > 0 {
|
||||
return requestTemp
|
||||
}
|
||||
if p.config.Temperature > 0 {
|
||||
return p.config.Temperature
|
||||
}
|
||||
return 0.7 // Default temperature
|
||||
}
|
||||
|
||||
// getMaxTokens returns the max tokens setting
|
||||
func (p *OpenAIProvider) getMaxTokens(requestTokens int) int {
|
||||
if requestTokens > 0 {
|
||||
return requestTokens
|
||||
}
|
||||
if p.config.MaxTokens > 0 {
|
||||
return p.config.MaxTokens
|
||||
}
|
||||
return 4096 // Default max tokens
|
||||
}
|
||||
|
||||
// getSystemPrompt constructs the system prompt
|
||||
func (p *OpenAIProvider) getSystemPrompt(request *TaskRequest) string {
|
||||
if request.SystemPrompt != "" {
|
||||
return request.SystemPrompt
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`You are an expert AI assistant specializing in software development.
|
||||
You are currently operating as a %s agent in the CHORUS autonomous development system.
|
||||
|
||||
Your capabilities:
|
||||
- Code analysis, implementation, and optimization
|
||||
- Software architecture and design patterns
|
||||
- Testing strategies and implementation
|
||||
- Documentation and technical writing
|
||||
- DevOps and deployment practices
|
||||
|
||||
Always provide thorough, actionable responses with specific implementation details.
|
||||
When using tools, explain your reasoning and the expected outcomes.`, request.AgentRole)
|
||||
}
|
||||
|
||||
// getModelMaxTokens returns the maximum tokens for a specific model
|
||||
func (p *OpenAIProvider) getModelMaxTokens(model string) int {
|
||||
switch model {
|
||||
case "gpt-4o", "gpt-4o-2024-05-13":
|
||||
return 128000
|
||||
case "gpt-4-turbo", "gpt-4-turbo-2024-04-09":
|
||||
return 128000
|
||||
case "gpt-4", "gpt-4-0613":
|
||||
return 8192
|
||||
case "gpt-3.5-turbo", "gpt-3.5-turbo-0125":
|
||||
return 16385
|
||||
default:
|
||||
return 4096 // Conservative default
|
||||
}
|
||||
}
|
||||
|
||||
// modelSupportsImages checks if a model supports image inputs
|
||||
func (p *OpenAIProvider) modelSupportsImages(model string) bool {
|
||||
visionModels := []string{"gpt-4o", "gpt-4o-2024-05-13", "gpt-4-turbo", "gpt-4-vision-preview"}
|
||||
for _, visionModel := range visionModels {
|
||||
if strings.Contains(model, visionModel) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getSupportedModels returns a list of supported OpenAI models
|
||||
func (p *OpenAIProvider) getSupportedModels() []string {
|
||||
return []string{
|
||||
"gpt-4o", "gpt-4o-2024-05-13",
|
||||
"gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4", "gpt-4-0613",
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0125",
|
||||
}
|
||||
}
|
||||
|
||||
// testConnection tests the OpenAI API connection
|
||||
func (p *OpenAIProvider) testConnection(ctx context.Context) error {
|
||||
// Simple test request to verify API key and connection
|
||||
_, err := p.client.ListModels(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// handleOpenAIError converts OpenAI errors to provider errors
|
||||
func (p *OpenAIProvider) handleOpenAIError(err error) *ProviderError {
|
||||
errStr := err.Error()
|
||||
|
||||
if strings.Contains(errStr, "rate limit") {
|
||||
return &ProviderError{
|
||||
Code: "RATE_LIMIT_EXCEEDED",
|
||||
Message: "OpenAI API rate limit exceeded",
|
||||
Details: errStr,
|
||||
Retryable: true,
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(errStr, "quota") {
|
||||
return &ProviderError{
|
||||
Code: "QUOTA_EXCEEDED",
|
||||
Message: "OpenAI API quota exceeded",
|
||||
Details: errStr,
|
||||
Retryable: false,
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(errStr, "invalid_api_key") {
|
||||
return &ProviderError{
|
||||
Code: "INVALID_API_KEY",
|
||||
Message: "Invalid OpenAI API key",
|
||||
Details: errStr,
|
||||
Retryable: false,
|
||||
}
|
||||
}
|
||||
|
||||
return &ProviderError{
|
||||
Code: "API_ERROR",
|
||||
Message: "OpenAI API error",
|
||||
Details: errStr,
|
||||
Retryable: true,
|
||||
}
|
||||
}
|
||||
|
||||
// parseResponseForActions extracts actions from the response text
|
||||
func (p *OpenAIProvider) parseResponseForActions(response string, request *TaskRequest) ([]TaskAction, []Artifact) {
|
||||
var actions []TaskAction
|
||||
var artifacts []Artifact
|
||||
|
||||
// Create a basic task analysis action
|
||||
action := TaskAction{
|
||||
Type: "task_analysis",
|
||||
Target: request.TaskTitle,
|
||||
Content: response,
|
||||
Result: "Task analyzed by OpenAI model",
|
||||
Success: true,
|
||||
Timestamp: time.Now(),
|
||||
Metadata: map[string]interface{}{
|
||||
"agent_role": request.AgentRole,
|
||||
"repository": request.Repository,
|
||||
"model": p.config.DefaultModel,
|
||||
},
|
||||
}
|
||||
actions = append(actions, action)
|
||||
|
||||
return actions, artifacts
|
||||
}
|
||||
211
pkg/ai/provider.go
Normal file
211
pkg/ai/provider.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ModelProvider defines the interface for AI model providers
|
||||
type ModelProvider interface {
|
||||
// ExecuteTask executes a task using the AI model
|
||||
ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error)
|
||||
|
||||
// GetCapabilities returns the capabilities supported by this provider
|
||||
GetCapabilities() ProviderCapabilities
|
||||
|
||||
// ValidateConfig validates the provider configuration
|
||||
ValidateConfig() error
|
||||
|
||||
// GetProviderInfo returns information about this provider
|
||||
GetProviderInfo() ProviderInfo
|
||||
}
|
||||
|
||||
// TaskRequest represents a request to execute a task
|
||||
type TaskRequest struct {
|
||||
// Task context and metadata
|
||||
TaskID string `json:"task_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
AgentRole string `json:"agent_role"`
|
||||
Repository string `json:"repository"`
|
||||
TaskTitle string `json:"task_title"`
|
||||
TaskDescription string `json:"task_description"`
|
||||
TaskLabels []string `json:"task_labels"`
|
||||
Priority int `json:"priority"`
|
||||
Complexity int `json:"complexity"`
|
||||
|
||||
// Model configuration
|
||||
ModelName string `json:"model_name"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
|
||||
// Execution context
|
||||
WorkingDirectory string `json:"working_directory"`
|
||||
RepositoryFiles []string `json:"repository_files,omitempty"`
|
||||
Context map[string]interface{} `json:"context,omitempty"`
|
||||
|
||||
// Tool and MCP configuration
|
||||
EnableTools bool `json:"enable_tools"`
|
||||
MCPServers []string `json:"mcp_servers,omitempty"`
|
||||
AllowedTools []string `json:"allowed_tools,omitempty"`
|
||||
}
|
||||
|
||||
// TaskResponse represents the response from task execution
|
||||
type TaskResponse struct {
|
||||
// Execution results
|
||||
Success bool `json:"success"`
|
||||
TaskID string `json:"task_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
ModelUsed string `json:"model_used"`
|
||||
Provider string `json:"provider"`
|
||||
|
||||
// Response content
|
||||
Response string `json:"response"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
Actions []TaskAction `json:"actions,omitempty"`
|
||||
Artifacts []Artifact `json:"artifacts,omitempty"`
|
||||
|
||||
// Metadata
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
TokensUsed TokenUsage `json:"tokens_used,omitempty"`
|
||||
|
||||
// Error information
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
Retryable bool `json:"retryable,omitempty"`
|
||||
}
|
||||
|
||||
// TaskAction represents an action taken during task execution
|
||||
type TaskAction struct {
|
||||
Type string `json:"type"` // file_create, file_edit, command_run, etc.
|
||||
Target string `json:"target"` // file path, command, etc.
|
||||
Content string `json:"content"` // file content, command args, etc.
|
||||
Result string `json:"result"` // execution result
|
||||
Success bool `json:"success"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Artifact represents a file or output artifact from task execution
|
||||
type Artifact struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // file, patch, log, etc.
|
||||
Path string `json:"path"` // relative path in repository
|
||||
Content string `json:"content"`
|
||||
Size int64 `json:"size"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Checksum string `json:"checksum"`
|
||||
}
|
||||
|
||||
// TokenUsage represents token consumption for the request
|
||||
type TokenUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// ProviderCapabilities defines what a provider supports
|
||||
type ProviderCapabilities struct {
|
||||
SupportsMCP bool `json:"supports_mcp"`
|
||||
SupportsTools bool `json:"supports_tools"`
|
||||
SupportsStreaming bool `json:"supports_streaming"`
|
||||
SupportsFunctions bool `json:"supports_functions"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
SupportedModels []string `json:"supported_models"`
|
||||
SupportsImages bool `json:"supports_images"`
|
||||
SupportsFiles bool `json:"supports_files"`
|
||||
}
|
||||
|
||||
// ProviderInfo contains metadata about the provider
|
||||
type ProviderInfo struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // ollama, openai, resetdata
|
||||
Version string `json:"version"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
DefaultModel string `json:"default_model"`
|
||||
RequiresAPIKey bool `json:"requires_api_key"`
|
||||
RateLimit int `json:"rate_limit"` // requests per minute
|
||||
}
|
||||
|
||||
// ProviderConfig contains configuration for a specific provider
|
||||
type ProviderConfig struct {
|
||||
Type string `yaml:"type" json:"type"` // ollama, openai, resetdata
|
||||
Endpoint string `yaml:"endpoint" json:"endpoint"`
|
||||
APIKey string `yaml:"api_key" json:"api_key,omitempty"`
|
||||
DefaultModel string `yaml:"default_model" json:"default_model"`
|
||||
Temperature float32 `yaml:"temperature" json:"temperature"`
|
||||
MaxTokens int `yaml:"max_tokens" json:"max_tokens"`
|
||||
Timeout time.Duration `yaml:"timeout" json:"timeout"`
|
||||
RetryAttempts int `yaml:"retry_attempts" json:"retry_attempts"`
|
||||
RetryDelay time.Duration `yaml:"retry_delay" json:"retry_delay"`
|
||||
EnableTools bool `yaml:"enable_tools" json:"enable_tools"`
|
||||
EnableMCP bool `yaml:"enable_mcp" json:"enable_mcp"`
|
||||
MCPServers []string `yaml:"mcp_servers" json:"mcp_servers,omitempty"`
|
||||
CustomHeaders map[string]string `yaml:"custom_headers" json:"custom_headers,omitempty"`
|
||||
ExtraParams map[string]interface{} `yaml:"extra_params" json:"extra_params,omitempty"`
|
||||
}
|
||||
|
||||
// RoleModelMapping defines model selection based on agent role
|
||||
type RoleModelMapping struct {
|
||||
DefaultProvider string `yaml:"default_provider" json:"default_provider"`
|
||||
FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"`
|
||||
Roles map[string]RoleConfig `yaml:"roles" json:"roles"`
|
||||
}
|
||||
|
||||
// RoleConfig defines model configuration for a specific role
|
||||
type RoleConfig struct {
|
||||
Provider string `yaml:"provider" json:"provider"`
|
||||
Model string `yaml:"model" json:"model"`
|
||||
Temperature float32 `yaml:"temperature" json:"temperature"`
|
||||
MaxTokens int `yaml:"max_tokens" json:"max_tokens"`
|
||||
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
|
||||
FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"`
|
||||
FallbackModel string `yaml:"fallback_model" json:"fallback_model"`
|
||||
EnableTools bool `yaml:"enable_tools" json:"enable_tools"`
|
||||
EnableMCP bool `yaml:"enable_mcp" json:"enable_mcp"`
|
||||
AllowedTools []string `yaml:"allowed_tools" json:"allowed_tools,omitempty"`
|
||||
MCPServers []string `yaml:"mcp_servers" json:"mcp_servers,omitempty"`
|
||||
}
|
||||
|
||||
// Common error types
|
||||
var (
|
||||
ErrProviderNotFound = &ProviderError{Code: "PROVIDER_NOT_FOUND", Message: "Provider not found"}
|
||||
ErrModelNotSupported = &ProviderError{Code: "MODEL_NOT_SUPPORTED", Message: "Model not supported by provider"}
|
||||
ErrAPIKeyRequired = &ProviderError{Code: "API_KEY_REQUIRED", Message: "API key required for provider"}
|
||||
ErrRateLimitExceeded = &ProviderError{Code: "RATE_LIMIT_EXCEEDED", Message: "Rate limit exceeded"}
|
||||
ErrProviderUnavailable = &ProviderError{Code: "PROVIDER_UNAVAILABLE", Message: "Provider temporarily unavailable"}
|
||||
ErrInvalidConfiguration = &ProviderError{Code: "INVALID_CONFIGURATION", Message: "Invalid provider configuration"}
|
||||
ErrTaskExecutionFailed = &ProviderError{Code: "TASK_EXECUTION_FAILED", Message: "Task execution failed"}
|
||||
)
|
||||
|
||||
// ProviderError represents provider-specific errors
|
||||
type ProviderError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
Retryable bool `json:"retryable"`
|
||||
}
|
||||
|
||||
func (e *ProviderError) Error() string {
|
||||
if e.Details != "" {
|
||||
return e.Message + ": " + e.Details
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// IsRetryable returns whether the error is retryable
|
||||
func (e *ProviderError) IsRetryable() bool {
|
||||
return e.Retryable
|
||||
}
|
||||
|
||||
// NewProviderError creates a new provider error with details
|
||||
func NewProviderError(base *ProviderError, details string) *ProviderError {
|
||||
return &ProviderError{
|
||||
Code: base.Code,
|
||||
Message: base.Message,
|
||||
Details: details,
|
||||
Retryable: base.Retryable,
|
||||
}
|
||||
}
|
||||
446
pkg/ai/provider_test.go
Normal file
446
pkg/ai/provider_test.go
Normal file
@@ -0,0 +1,446 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockProvider implements ModelProvider for testing
|
||||
type MockProvider struct {
|
||||
name string
|
||||
capabilities ProviderCapabilities
|
||||
shouldFail bool
|
||||
response *TaskResponse
|
||||
executeFunc func(ctx context.Context, request *TaskRequest) (*TaskResponse, error)
|
||||
}
|
||||
|
||||
func NewMockProvider(name string) *MockProvider {
|
||||
return &MockProvider{
|
||||
name: name,
|
||||
capabilities: ProviderCapabilities{
|
||||
SupportsMCP: true,
|
||||
SupportsTools: true,
|
||||
SupportsStreaming: true,
|
||||
SupportsFunctions: false,
|
||||
MaxTokens: 4096,
|
||||
SupportedModels: []string{"test-model", "test-model-2"},
|
||||
SupportsImages: false,
|
||||
SupportsFiles: true,
|
||||
},
|
||||
response: &TaskResponse{
|
||||
Success: true,
|
||||
Response: "Mock response",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||
if m.executeFunc != nil {
|
||||
return m.executeFunc(ctx, request)
|
||||
}
|
||||
|
||||
if m.shouldFail {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, "mock execution failed")
|
||||
}
|
||||
|
||||
response := *m.response // Copy the response
|
||||
response.TaskID = request.TaskID
|
||||
response.AgentID = request.AgentID
|
||||
response.Provider = m.name
|
||||
response.StartTime = time.Now()
|
||||
response.EndTime = time.Now().Add(100 * time.Millisecond)
|
||||
response.Duration = response.EndTime.Sub(response.StartTime)
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetCapabilities() ProviderCapabilities {
|
||||
return m.capabilities
|
||||
}
|
||||
|
||||
func (m *MockProvider) ValidateConfig() error {
|
||||
if m.shouldFail {
|
||||
return NewProviderError(ErrInvalidConfiguration, "mock config validation failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetProviderInfo() ProviderInfo {
|
||||
return ProviderInfo{
|
||||
Name: m.name,
|
||||
Type: "mock",
|
||||
Version: "1.0.0",
|
||||
Endpoint: "mock://localhost",
|
||||
DefaultModel: "test-model",
|
||||
RequiresAPIKey: false,
|
||||
RateLimit: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *ProviderError
|
||||
expected string
|
||||
retryable bool
|
||||
}{
|
||||
{
|
||||
name: "simple error",
|
||||
err: ErrProviderNotFound,
|
||||
expected: "Provider not found",
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "error with details",
|
||||
err: NewProviderError(ErrRateLimitExceeded, "API rate limit of 1000/hour exceeded"),
|
||||
expected: "Rate limit exceeded: API rate limit of 1000/hour exceeded",
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "retryable error",
|
||||
err: &ProviderError{
|
||||
Code: "TEMPORARY_ERROR",
|
||||
Message: "Temporary failure",
|
||||
Retryable: true,
|
||||
},
|
||||
expected: "Temporary failure",
|
||||
retryable: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.err.Error())
|
||||
assert.Equal(t, tt.retryable, tt.err.IsRetryable())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRequest(t *testing.T) {
|
||||
request := &TaskRequest{
|
||||
TaskID: "test-task-123",
|
||||
AgentID: "agent-456",
|
||||
AgentRole: "developer",
|
||||
Repository: "test/repo",
|
||||
TaskTitle: "Test Task",
|
||||
TaskDescription: "A test task for unit testing",
|
||||
TaskLabels: []string{"bug", "urgent"},
|
||||
Priority: 8,
|
||||
Complexity: 6,
|
||||
ModelName: "test-model",
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 4096,
|
||||
EnableTools: true,
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
assert.NotEmpty(t, request.TaskID)
|
||||
assert.NotEmpty(t, request.AgentID)
|
||||
assert.NotEmpty(t, request.AgentRole)
|
||||
assert.NotEmpty(t, request.Repository)
|
||||
assert.NotEmpty(t, request.TaskTitle)
|
||||
assert.Greater(t, request.Priority, 0)
|
||||
assert.Greater(t, request.Complexity, 0)
|
||||
}
|
||||
|
||||
func TestTaskResponse(t *testing.T) {
|
||||
startTime := time.Now()
|
||||
endTime := startTime.Add(2 * time.Second)
|
||||
|
||||
response := &TaskResponse{
|
||||
Success: true,
|
||||
TaskID: "test-task-123",
|
||||
AgentID: "agent-456",
|
||||
ModelUsed: "test-model",
|
||||
Provider: "mock",
|
||||
Response: "Task completed successfully",
|
||||
Actions: []TaskAction{
|
||||
{
|
||||
Type: "file_create",
|
||||
Target: "test.go",
|
||||
Content: "package main",
|
||||
Result: "File created",
|
||||
Success: true,
|
||||
Timestamp: time.Now(),
|
||||
},
|
||||
},
|
||||
Artifacts: []Artifact{
|
||||
{
|
||||
Name: "test.go",
|
||||
Type: "file",
|
||||
Path: "./test.go",
|
||||
Content: "package main",
|
||||
Size: 12,
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
Duration: endTime.Sub(startTime),
|
||||
TokensUsed: TokenUsage{
|
||||
PromptTokens: 50,
|
||||
CompletionTokens: 100,
|
||||
TotalTokens: 150,
|
||||
},
|
||||
}
|
||||
|
||||
// Validate response structure
|
||||
assert.True(t, response.Success)
|
||||
assert.NotEmpty(t, response.TaskID)
|
||||
assert.NotEmpty(t, response.Provider)
|
||||
assert.Len(t, response.Actions, 1)
|
||||
assert.Len(t, response.Artifacts, 1)
|
||||
assert.Equal(t, 2*time.Second, response.Duration)
|
||||
assert.Equal(t, 150, response.TokensUsed.TotalTokens)
|
||||
}
|
||||
|
||||
func TestTaskAction(t *testing.T) {
|
||||
action := TaskAction{
|
||||
Type: "file_edit",
|
||||
Target: "main.go",
|
||||
Content: "updated content",
|
||||
Result: "File updated successfully",
|
||||
Success: true,
|
||||
Timestamp: time.Now(),
|
||||
Metadata: map[string]interface{}{
|
||||
"line_count": 42,
|
||||
"backup": true,
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "file_edit", action.Type)
|
||||
assert.True(t, action.Success)
|
||||
assert.NotNil(t, action.Metadata)
|
||||
assert.Equal(t, 42, action.Metadata["line_count"])
|
||||
}
|
||||
|
||||
func TestArtifact(t *testing.T) {
|
||||
artifact := Artifact{
|
||||
Name: "output.log",
|
||||
Type: "log",
|
||||
Path: "/tmp/output.log",
|
||||
Content: "Log content here",
|
||||
Size: 16,
|
||||
CreatedAt: time.Now(),
|
||||
Checksum: "abc123",
|
||||
}
|
||||
|
||||
assert.Equal(t, "output.log", artifact.Name)
|
||||
assert.Equal(t, "log", artifact.Type)
|
||||
assert.Equal(t, int64(16), artifact.Size)
|
||||
assert.NotEmpty(t, artifact.Checksum)
|
||||
}
|
||||
|
||||
func TestProviderCapabilities(t *testing.T) {
|
||||
capabilities := ProviderCapabilities{
|
||||
SupportsMCP: true,
|
||||
SupportsTools: true,
|
||||
SupportsStreaming: false,
|
||||
SupportsFunctions: true,
|
||||
MaxTokens: 8192,
|
||||
SupportedModels: []string{"gpt-4", "gpt-3.5-turbo"},
|
||||
SupportsImages: true,
|
||||
SupportsFiles: true,
|
||||
}
|
||||
|
||||
assert.True(t, capabilities.SupportsMCP)
|
||||
assert.True(t, capabilities.SupportsTools)
|
||||
assert.False(t, capabilities.SupportsStreaming)
|
||||
assert.Equal(t, 8192, capabilities.MaxTokens)
|
||||
assert.Len(t, capabilities.SupportedModels, 2)
|
||||
}
|
||||
|
||||
func TestProviderInfo(t *testing.T) {
|
||||
info := ProviderInfo{
|
||||
Name: "Test Provider",
|
||||
Type: "test",
|
||||
Version: "1.0.0",
|
||||
Endpoint: "https://api.test.com",
|
||||
DefaultModel: "test-model",
|
||||
RequiresAPIKey: true,
|
||||
RateLimit: 1000,
|
||||
}
|
||||
|
||||
assert.Equal(t, "Test Provider", info.Name)
|
||||
assert.True(t, info.RequiresAPIKey)
|
||||
assert.Equal(t, 1000, info.RateLimit)
|
||||
}
|
||||
|
||||
func TestProviderConfig(t *testing.T) {
|
||||
config := ProviderConfig{
|
||||
Type: "test",
|
||||
Endpoint: "https://api.test.com",
|
||||
APIKey: "test-key",
|
||||
DefaultModel: "test-model",
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 4096,
|
||||
Timeout: 300 * time.Second,
|
||||
RetryAttempts: 3,
|
||||
RetryDelay: 2 * time.Second,
|
||||
EnableTools: true,
|
||||
EnableMCP: true,
|
||||
}
|
||||
|
||||
assert.Equal(t, "test", config.Type)
|
||||
assert.Equal(t, float32(0.7), config.Temperature)
|
||||
assert.Equal(t, 4096, config.MaxTokens)
|
||||
assert.Equal(t, 300*time.Second, config.Timeout)
|
||||
assert.True(t, config.EnableTools)
|
||||
}
|
||||
|
||||
func TestRoleConfig(t *testing.T) {
|
||||
roleConfig := RoleConfig{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
Temperature: 0.3,
|
||||
MaxTokens: 8192,
|
||||
SystemPrompt: "You are a helpful assistant",
|
||||
FallbackProvider: "ollama",
|
||||
FallbackModel: "llama2",
|
||||
EnableTools: true,
|
||||
EnableMCP: false,
|
||||
AllowedTools: []string{"file_ops", "code_analysis"},
|
||||
MCPServers: []string{"file-server"},
|
||||
}
|
||||
|
||||
assert.Equal(t, "openai", roleConfig.Provider)
|
||||
assert.Equal(t, "gpt-4", roleConfig.Model)
|
||||
assert.Equal(t, float32(0.3), roleConfig.Temperature)
|
||||
assert.Len(t, roleConfig.AllowedTools, 2)
|
||||
assert.True(t, roleConfig.EnableTools)
|
||||
assert.False(t, roleConfig.EnableMCP)
|
||||
}
|
||||
|
||||
func TestRoleModelMapping(t *testing.T) {
|
||||
mapping := RoleModelMapping{
|
||||
DefaultProvider: "ollama",
|
||||
FallbackProvider: "openai",
|
||||
Roles: map[string]RoleConfig{
|
||||
"developer": {
|
||||
Provider: "ollama",
|
||||
Model: "codellama",
|
||||
Temperature: 0.3,
|
||||
},
|
||||
"reviewer": {
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
Temperature: 0.2,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "ollama", mapping.DefaultProvider)
|
||||
assert.Len(t, mapping.Roles, 2)
|
||||
|
||||
devConfig, exists := mapping.Roles["developer"]
|
||||
require.True(t, exists)
|
||||
assert.Equal(t, "codellama", devConfig.Model)
|
||||
assert.Equal(t, float32(0.3), devConfig.Temperature)
|
||||
}
|
||||
|
||||
func TestTokenUsage(t *testing.T) {
|
||||
usage := TokenUsage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 200,
|
||||
TotalTokens: 300,
|
||||
}
|
||||
|
||||
assert.Equal(t, 100, usage.PromptTokens)
|
||||
assert.Equal(t, 200, usage.CompletionTokens)
|
||||
assert.Equal(t, 300, usage.TotalTokens)
|
||||
assert.Equal(t, usage.PromptTokens+usage.CompletionTokens, usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestMockProviderExecuteTask(t *testing.T) {
|
||||
provider := NewMockProvider("test-provider")
|
||||
|
||||
request := &TaskRequest{
|
||||
TaskID: "test-123",
|
||||
AgentID: "agent-456",
|
||||
AgentRole: "developer",
|
||||
Repository: "test/repo",
|
||||
TaskTitle: "Test Task",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
response, err := provider.ExecuteTask(ctx, request)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, response.Success)
|
||||
assert.Equal(t, "test-123", response.TaskID)
|
||||
assert.Equal(t, "agent-456", response.AgentID)
|
||||
assert.Equal(t, "test-provider", response.Provider)
|
||||
assert.NotEmpty(t, response.Response)
|
||||
}
|
||||
|
||||
func TestMockProviderFailure(t *testing.T) {
|
||||
provider := NewMockProvider("failing-provider")
|
||||
provider.shouldFail = true
|
||||
|
||||
request := &TaskRequest{
|
||||
TaskID: "test-123",
|
||||
AgentID: "agent-456",
|
||||
AgentRole: "developer",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := provider.ExecuteTask(ctx, request)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.IsType(t, &ProviderError{}, err)
|
||||
|
||||
providerErr := err.(*ProviderError)
|
||||
assert.Equal(t, "TASK_EXECUTION_FAILED", providerErr.Code)
|
||||
}
|
||||
|
||||
func TestMockProviderCustomExecuteFunc(t *testing.T) {
|
||||
provider := NewMockProvider("custom-provider")
|
||||
|
||||
// Set custom execution function
|
||||
provider.executeFunc = func(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||
return &TaskResponse{
|
||||
Success: true,
|
||||
TaskID: request.TaskID,
|
||||
Response: "Custom response: " + request.TaskTitle,
|
||||
Provider: "custom-provider",
|
||||
}, nil
|
||||
}
|
||||
|
||||
request := &TaskRequest{
|
||||
TaskID: "test-123",
|
||||
TaskTitle: "Custom Task",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
response, err := provider.ExecuteTask(ctx, request)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Custom response: Custom Task", response.Response)
|
||||
}
|
||||
|
||||
func TestMockProviderCapabilities(t *testing.T) {
|
||||
provider := NewMockProvider("test-provider")
|
||||
|
||||
capabilities := provider.GetCapabilities()
|
||||
|
||||
assert.True(t, capabilities.SupportsMCP)
|
||||
assert.True(t, capabilities.SupportsTools)
|
||||
assert.Equal(t, 4096, capabilities.MaxTokens)
|
||||
assert.Len(t, capabilities.SupportedModels, 2)
|
||||
assert.Contains(t, capabilities.SupportedModels, "test-model")
|
||||
}
|
||||
|
||||
func TestMockProviderInfo(t *testing.T) {
|
||||
provider := NewMockProvider("test-provider")
|
||||
|
||||
info := provider.GetProviderInfo()
|
||||
|
||||
assert.Equal(t, "test-provider", info.Name)
|
||||
assert.Equal(t, "mock", info.Type)
|
||||
assert.Equal(t, "test-model", info.DefaultModel)
|
||||
assert.False(t, info.RequiresAPIKey)
|
||||
}
|
||||
500
pkg/ai/resetdata.go
Normal file
500
pkg/ai/resetdata.go
Normal file
@@ -0,0 +1,500 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ResetDataProvider implements ModelProvider for ResetData LaaS API
|
||||
type ResetDataProvider struct {
|
||||
config ProviderConfig
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// ResetDataRequest represents a request to ResetData LaaS API
|
||||
type ResetDataRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ResetDataMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
}
|
||||
|
||||
// ResetDataMessage represents a message in the ResetData format
|
||||
type ResetDataMessage struct {
|
||||
Role string `json:"role"` // system, user, assistant
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ResetDataResponse represents a response from ResetData LaaS API
|
||||
type ResetDataResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ResetDataChoice `json:"choices"`
|
||||
Usage ResetDataUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// ResetDataChoice represents a choice in the response
|
||||
type ResetDataChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ResetDataMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// ResetDataUsage represents token usage information
|
||||
type ResetDataUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// ResetDataModelsResponse represents available models response
|
||||
type ResetDataModelsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ResetDataModel `json:"data"`
|
||||
}
|
||||
|
||||
// ResetDataModel represents a model in ResetData
|
||||
type ResetDataModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
// NewResetDataProvider creates a new ResetData provider instance
|
||||
func NewResetDataProvider(config ProviderConfig) *ResetDataProvider {
|
||||
timeout := config.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 300 * time.Second // 5 minutes default for task execution
|
||||
}
|
||||
|
||||
return &ResetDataProvider{
|
||||
config: config,
|
||||
httpClient: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteTask implements the ModelProvider interface for ResetData
|
||||
func (p *ResetDataProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Build messages for the chat completion
|
||||
messages, err := p.buildChatMessages(request)
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to build messages: %v", err))
|
||||
}
|
||||
|
||||
// Prepare the ResetData request
|
||||
resetDataReq := ResetDataRequest{
|
||||
Model: p.selectModel(request.ModelName),
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
Temperature: p.getTemperature(request.Temperature),
|
||||
MaxTokens: p.getMaxTokens(request.MaxTokens),
|
||||
}
|
||||
|
||||
// Execute the request
|
||||
response, err := p.makeRequest(ctx, "/v1/chat/completions", resetDataReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
|
||||
// Process the response
|
||||
if len(response.Choices) == 0 {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, "no response choices returned from ResetData")
|
||||
}
|
||||
|
||||
choice := response.Choices[0]
|
||||
responseText := choice.Message.Content
|
||||
|
||||
// Parse response for actions and artifacts
|
||||
actions, artifacts := p.parseResponseForActions(responseText, request)
|
||||
|
||||
return &TaskResponse{
|
||||
Success: true,
|
||||
TaskID: request.TaskID,
|
||||
AgentID: request.AgentID,
|
||||
ModelUsed: response.Model,
|
||||
Provider: "resetdata",
|
||||
Response: responseText,
|
||||
Actions: actions,
|
||||
Artifacts: artifacts,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
Duration: endTime.Sub(startTime),
|
||||
TokensUsed: TokenUsage{
|
||||
PromptTokens: response.Usage.PromptTokens,
|
||||
CompletionTokens: response.Usage.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetCapabilities returns ResetData provider capabilities
|
||||
func (p *ResetDataProvider) GetCapabilities() ProviderCapabilities {
|
||||
return ProviderCapabilities{
|
||||
SupportsMCP: p.config.EnableMCP,
|
||||
SupportsTools: p.config.EnableTools,
|
||||
SupportsStreaming: true,
|
||||
SupportsFunctions: false, // ResetData LaaS doesn't support function calling
|
||||
MaxTokens: p.config.MaxTokens,
|
||||
SupportedModels: p.getSupportedModels(),
|
||||
SupportsImages: false, // Most ResetData models don't support images
|
||||
SupportsFiles: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig validates the ResetData provider configuration
|
||||
func (p *ResetDataProvider) ValidateConfig() error {
|
||||
if p.config.APIKey == "" {
|
||||
return NewProviderError(ErrAPIKeyRequired, "API key is required for ResetData provider")
|
||||
}
|
||||
|
||||
if p.config.Endpoint == "" {
|
||||
return NewProviderError(ErrInvalidConfiguration, "endpoint is required for ResetData provider")
|
||||
}
|
||||
|
||||
if p.config.DefaultModel == "" {
|
||||
return NewProviderError(ErrInvalidConfiguration, "default_model is required for ResetData provider")
|
||||
}
|
||||
|
||||
// Test the API connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := p.testConnection(ctx); err != nil {
|
||||
return NewProviderError(ErrProviderUnavailable, fmt.Sprintf("failed to connect to ResetData: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProviderInfo returns information about the ResetData provider
|
||||
func (p *ResetDataProvider) GetProviderInfo() ProviderInfo {
|
||||
return ProviderInfo{
|
||||
Name: "ResetData",
|
||||
Type: "resetdata",
|
||||
Version: "1.0.0",
|
||||
Endpoint: p.config.Endpoint,
|
||||
DefaultModel: p.config.DefaultModel,
|
||||
RequiresAPIKey: true,
|
||||
RateLimit: 600, // 10 requests per second typical limit
|
||||
}
|
||||
}
|
||||
|
||||
// buildChatMessages constructs messages for the ResetData chat completion
|
||||
func (p *ResetDataProvider) buildChatMessages(request *TaskRequest) ([]ResetDataMessage, error) {
|
||||
var messages []ResetDataMessage
|
||||
|
||||
// System message
|
||||
systemPrompt := p.getSystemPrompt(request)
|
||||
if systemPrompt != "" {
|
||||
messages = append(messages, ResetDataMessage{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// User message with task details
|
||||
userPrompt, err := p.buildTaskPrompt(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages = append(messages, ResetDataMessage{
|
||||
Role: "user",
|
||||
Content: userPrompt,
|
||||
})
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// buildTaskPrompt constructs a comprehensive prompt for task execution
|
||||
func (p *ResetDataProvider) buildTaskPrompt(request *TaskRequest) (string, error) {
|
||||
var prompt strings.Builder
|
||||
|
||||
prompt.WriteString(fmt.Sprintf("Acting as a %s agent, analyze and work on this task:\n\n",
|
||||
request.AgentRole))
|
||||
|
||||
prompt.WriteString(fmt.Sprintf("**Repository:** %s\n", request.Repository))
|
||||
prompt.WriteString(fmt.Sprintf("**Task Title:** %s\n", request.TaskTitle))
|
||||
prompt.WriteString(fmt.Sprintf("**Description:**\n%s\n\n", request.TaskDescription))
|
||||
|
||||
if len(request.TaskLabels) > 0 {
|
||||
prompt.WriteString(fmt.Sprintf("**Labels:** %s\n", strings.Join(request.TaskLabels, ", ")))
|
||||
}
|
||||
|
||||
prompt.WriteString(fmt.Sprintf("**Priority:** %d/10 | **Complexity:** %d/10\n\n",
|
||||
request.Priority, request.Complexity))
|
||||
|
||||
if request.WorkingDirectory != "" {
|
||||
prompt.WriteString(fmt.Sprintf("**Working Directory:** %s\n", request.WorkingDirectory))
|
||||
}
|
||||
|
||||
if len(request.RepositoryFiles) > 0 {
|
||||
prompt.WriteString("**Relevant Files:**\n")
|
||||
for _, file := range request.RepositoryFiles {
|
||||
prompt.WriteString(fmt.Sprintf("- %s\n", file))
|
||||
}
|
||||
prompt.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add role-specific instructions
|
||||
prompt.WriteString(p.getRoleSpecificInstructions(request.AgentRole))
|
||||
|
||||
prompt.WriteString("\nProvide a detailed analysis and implementation plan. ")
|
||||
prompt.WriteString("Include specific steps, code changes, and any commands that need to be executed. ")
|
||||
prompt.WriteString("Focus on delivering actionable results that address the task requirements completely.")
|
||||
|
||||
return prompt.String(), nil
|
||||
}
|
||||
|
||||
// getRoleSpecificInstructions returns instructions specific to the agent role
|
||||
func (p *ResetDataProvider) getRoleSpecificInstructions(role string) string {
|
||||
switch strings.ToLower(role) {
|
||||
case "developer":
|
||||
return `**Developer Focus Areas:**
|
||||
- Implement robust, well-tested code solutions
|
||||
- Follow coding standards and best practices
|
||||
- Ensure proper error handling and edge case coverage
|
||||
- Write clear documentation and comments
|
||||
- Consider performance, security, and maintainability`
|
||||
|
||||
case "reviewer":
|
||||
return `**Code Review Focus Areas:**
|
||||
- Evaluate code quality, style, and best practices
|
||||
- Identify potential bugs, security issues, and performance bottlenecks
|
||||
- Check test coverage and test quality
|
||||
- Verify documentation completeness and accuracy
|
||||
- Suggest refactoring and improvement opportunities`
|
||||
|
||||
case "architect":
|
||||
return `**Architecture Focus Areas:**
|
||||
- Design scalable and maintainable system components
|
||||
- Make informed decisions about technologies and patterns
|
||||
- Define clear interfaces and integration points
|
||||
- Consider scalability, security, and performance requirements
|
||||
- Document architectural decisions and trade-offs`
|
||||
|
||||
case "tester":
|
||||
return `**Testing Focus Areas:**
|
||||
- Design comprehensive test strategies and test cases
|
||||
- Implement automated tests at multiple levels
|
||||
- Identify edge cases and failure scenarios
|
||||
- Set up continuous testing and quality assurance
|
||||
- Validate requirements and acceptance criteria`
|
||||
|
||||
default:
|
||||
return `**General Focus Areas:**
|
||||
- Understand requirements and constraints thoroughly
|
||||
- Apply software engineering best practices
|
||||
- Provide clear, actionable recommendations
|
||||
- Consider long-term maintainability and extensibility`
|
||||
}
|
||||
}
|
||||
|
||||
// selectModel chooses the appropriate ResetData model
|
||||
func (p *ResetDataProvider) selectModel(requestedModel string) string {
|
||||
if requestedModel != "" {
|
||||
return requestedModel
|
||||
}
|
||||
return p.config.DefaultModel
|
||||
}
|
||||
|
||||
// getTemperature returns the temperature setting
|
||||
func (p *ResetDataProvider) getTemperature(requestTemp float32) float32 {
|
||||
if requestTemp > 0 {
|
||||
return requestTemp
|
||||
}
|
||||
if p.config.Temperature > 0 {
|
||||
return p.config.Temperature
|
||||
}
|
||||
return 0.7 // Default temperature
|
||||
}
|
||||
|
||||
// getMaxTokens returns the max tokens setting
|
||||
func (p *ResetDataProvider) getMaxTokens(requestTokens int) int {
|
||||
if requestTokens > 0 {
|
||||
return requestTokens
|
||||
}
|
||||
if p.config.MaxTokens > 0 {
|
||||
return p.config.MaxTokens
|
||||
}
|
||||
return 4096 // Default max tokens
|
||||
}
|
||||
|
||||
// getSystemPrompt constructs the system prompt
|
||||
func (p *ResetDataProvider) getSystemPrompt(request *TaskRequest) string {
|
||||
if request.SystemPrompt != "" {
|
||||
return request.SystemPrompt
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`You are an expert software development AI assistant working as a %s agent
|
||||
in the CHORUS autonomous development system.
|
||||
|
||||
Your expertise includes:
|
||||
- Software architecture and design patterns
|
||||
- Code implementation across multiple programming languages
|
||||
- Testing strategies and quality assurance
|
||||
- DevOps and deployment practices
|
||||
- Security and performance optimization
|
||||
|
||||
Provide detailed, practical solutions with specific implementation steps.
|
||||
Focus on delivering high-quality, production-ready results.`, request.AgentRole)
|
||||
}
|
||||
|
||||
// makeRequest makes an HTTP request to the ResetData API
|
||||
func (p *ResetDataProvider) makeRequest(ctx context.Context, endpoint string, request interface{}) (*ResetDataResponse, error) {
|
||||
requestJSON, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to marshal request: %v", err))
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(p.config.Endpoint, "/") + endpoint
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(requestJSON))
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to create request: %v", err))
|
||||
}
|
||||
|
||||
// Set required headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||
|
||||
// Add custom headers if configured
|
||||
for key, value := range p.config.CustomHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, p.handleHTTPError(resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var resetDataResp ResetDataResponse
|
||||
if err := json.Unmarshal(body, &resetDataResp); err != nil {
|
||||
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to parse response: %v", err))
|
||||
}
|
||||
|
||||
return &resetDataResp, nil
|
||||
}
|
||||
|
||||
// testConnection tests the connection to ResetData API
|
||||
func (p *ResetDataProvider) testConnection(ctx context.Context) error {
|
||||
url := strings.TrimSuffix(p.config.Endpoint, "/") + "/v1/models"
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API test failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSupportedModels returns a list of supported ResetData models
|
||||
func (p *ResetDataProvider) getSupportedModels() []string {
|
||||
// Common models available through ResetData LaaS
|
||||
return []string{
|
||||
"llama3.1:8b", "llama3.1:70b",
|
||||
"mistral:7b", "mixtral:8x7b",
|
||||
"qwen2:7b", "qwen2:72b",
|
||||
"gemma:7b", "gemma2:9b",
|
||||
"codellama:7b", "codellama:13b",
|
||||
}
|
||||
}
|
||||
|
||||
// handleHTTPError converts HTTP errors to provider errors
|
||||
func (p *ResetDataProvider) handleHTTPError(statusCode int, body []byte) *ProviderError {
|
||||
bodyStr := string(body)
|
||||
|
||||
switch statusCode {
|
||||
case http.StatusUnauthorized:
|
||||
return &ProviderError{
|
||||
Code: "UNAUTHORIZED",
|
||||
Message: "Invalid ResetData API key",
|
||||
Details: bodyStr,
|
||||
Retryable: false,
|
||||
}
|
||||
case http.StatusTooManyRequests:
|
||||
return &ProviderError{
|
||||
Code: "RATE_LIMIT_EXCEEDED",
|
||||
Message: "ResetData API rate limit exceeded",
|
||||
Details: bodyStr,
|
||||
Retryable: true,
|
||||
}
|
||||
case http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable:
|
||||
return &ProviderError{
|
||||
Code: "SERVICE_UNAVAILABLE",
|
||||
Message: "ResetData API service unavailable",
|
||||
Details: bodyStr,
|
||||
Retryable: true,
|
||||
}
|
||||
default:
|
||||
return &ProviderError{
|
||||
Code: "API_ERROR",
|
||||
Message: fmt.Sprintf("ResetData API error (status %d)", statusCode),
|
||||
Details: bodyStr,
|
||||
Retryable: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseResponseForActions extracts actions from the response text
|
||||
func (p *ResetDataProvider) parseResponseForActions(response string, request *TaskRequest) ([]TaskAction, []Artifact) {
|
||||
var actions []TaskAction
|
||||
var artifacts []Artifact
|
||||
|
||||
// Create a basic task analysis action
|
||||
action := TaskAction{
|
||||
Type: "task_analysis",
|
||||
Target: request.TaskTitle,
|
||||
Content: response,
|
||||
Result: "Task analyzed by ResetData model",
|
||||
Success: true,
|
||||
Timestamp: time.Now(),
|
||||
Metadata: map[string]interface{}{
|
||||
"agent_role": request.AgentRole,
|
||||
"repository": request.Repository,
|
||||
"model": p.config.DefaultModel,
|
||||
},
|
||||
}
|
||||
actions = append(actions, action)
|
||||
|
||||
return actions, artifacts
|
||||
}
|
||||
353
pkg/bootstrap/pool_manager.go
Normal file
353
pkg/bootstrap/pool_manager.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
// BootstrapPool manages a pool of bootstrap peers for DHT joining
|
||||
type BootstrapPool struct {
|
||||
peers []peer.AddrInfo
|
||||
dialsPerSecond int
|
||||
maxConcurrent int
|
||||
staggerDelay time.Duration
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// BootstrapConfig represents the JSON configuration for bootstrap peers
|
||||
type BootstrapConfig struct {
|
||||
Peers []BootstrapPeer `json:"peers"`
|
||||
Meta BootstrapMeta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// BootstrapPeer represents a single bootstrap peer
|
||||
type BootstrapPeer struct {
|
||||
ID string `json:"id"` // Peer ID
|
||||
Addresses []string `json:"addresses"` // Multiaddresses
|
||||
Priority int `json:"priority"` // Priority (higher = more likely to be selected)
|
||||
Healthy bool `json:"healthy"` // Health status
|
||||
LastSeen string `json:"last_seen"` // Last seen timestamp
|
||||
}
|
||||
|
||||
// BootstrapMeta contains metadata about the bootstrap configuration
|
||||
type BootstrapMeta struct {
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Version int `json:"version"`
|
||||
ClusterID string `json:"cluster_id"`
|
||||
TotalPeers int `json:"total_peers"`
|
||||
HealthyPeers int `json:"healthy_peers"`
|
||||
}
|
||||
|
||||
// BootstrapSubset represents a subset of peers assigned to a replica
|
||||
type BootstrapSubset struct {
|
||||
Peers []peer.AddrInfo `json:"peers"`
|
||||
StaggerDelayMS int `json:"stagger_delay_ms"`
|
||||
AssignedAt time.Time `json:"assigned_at"`
|
||||
}
|
||||
|
||||
// NewBootstrapPool creates a new bootstrap pool manager
|
||||
func NewBootstrapPool(dialsPerSecond, maxConcurrent int, staggerMS int) *BootstrapPool {
|
||||
return &BootstrapPool{
|
||||
peers: []peer.AddrInfo{},
|
||||
dialsPerSecond: dialsPerSecond,
|
||||
maxConcurrent: maxConcurrent,
|
||||
staggerDelay: time.Duration(staggerMS) * time.Millisecond,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromFile loads bootstrap configuration from a JSON file
|
||||
func (bp *BootstrapPool) LoadFromFile(filePath string) error {
|
||||
if filePath == "" {
|
||||
return nil // No file configured
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read bootstrap file %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
return bp.loadFromJSON(data)
|
||||
}
|
||||
|
||||
// LoadFromURL loads bootstrap configuration from a URL (WHOOSH endpoint)
|
||||
func (bp *BootstrapPool) LoadFromURL(ctx context.Context, url string) error {
|
||||
if url == "" {
|
||||
return nil // No URL configured
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create bootstrap request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := bp.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bootstrap request failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read bootstrap response: %w", err)
|
||||
}
|
||||
|
||||
return bp.loadFromJSON(data)
|
||||
}
|
||||
|
||||
// loadFromJSON parses JSON bootstrap configuration
|
||||
func (bp *BootstrapPool) loadFromJSON(data []byte) error {
|
||||
var config BootstrapConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse bootstrap JSON: %w", err)
|
||||
}
|
||||
|
||||
// Convert bootstrap peers to AddrInfo
|
||||
var peers []peer.AddrInfo
|
||||
for _, bsPeer := range config.Peers {
|
||||
// Only include healthy peers
|
||||
if !bsPeer.Healthy {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse peer ID
|
||||
peerID, err := peer.Decode(bsPeer.ID)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Invalid peer ID %s: %v\n", bsPeer.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse multiaddresses
|
||||
var addrs []multiaddr.Multiaddr
|
||||
for _, addrStr := range bsPeer.Addresses {
|
||||
addr, err := multiaddr.NewMultiaddr(addrStr)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Invalid multiaddress %s: %v\n", addrStr, err)
|
||||
continue
|
||||
}
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
if len(addrs) > 0 {
|
||||
peers = append(peers, peer.AddrInfo{
|
||||
ID: peerID,
|
||||
Addrs: addrs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
bp.peers = peers
|
||||
fmt.Printf("📋 Loaded %d healthy bootstrap peers from configuration\n", len(peers))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromEnvironment loads bootstrap configuration from environment variables
|
||||
func (bp *BootstrapPool) LoadFromEnvironment() error {
|
||||
// Try loading from file first
|
||||
if bootstrapFile := os.Getenv("BOOTSTRAP_JSON"); bootstrapFile != "" {
|
||||
if err := bp.LoadFromFile(bootstrapFile); err != nil {
|
||||
fmt.Printf("⚠️ Failed to load bootstrap from file: %v\n", err)
|
||||
} else {
|
||||
return nil // Successfully loaded from file
|
||||
}
|
||||
}
|
||||
|
||||
// Try loading from URL
|
||||
if bootstrapURL := os.Getenv("BOOTSTRAP_URL"); bootstrapURL != "" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := bp.LoadFromURL(ctx, bootstrapURL); err != nil {
|
||||
fmt.Printf("⚠️ Failed to load bootstrap from URL: %v\n", err)
|
||||
} else {
|
||||
return nil // Successfully loaded from URL
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to legacy environment variable
|
||||
if bootstrapPeersEnv := os.Getenv("CHORUS_BOOTSTRAP_PEERS"); bootstrapPeersEnv != "" {
|
||||
return bp.loadFromLegacyEnv(bootstrapPeersEnv)
|
||||
}
|
||||
|
||||
return nil // No bootstrap configuration found
|
||||
}
|
||||
|
||||
// loadFromLegacyEnv loads from comma-separated multiaddress list
|
||||
func (bp *BootstrapPool) loadFromLegacyEnv(peersEnv string) error {
|
||||
peerStrs := strings.Split(peersEnv, ",")
|
||||
var peers []peer.AddrInfo
|
||||
|
||||
for _, peerStr := range peerStrs {
|
||||
peerStr = strings.TrimSpace(peerStr)
|
||||
if peerStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse multiaddress
|
||||
addr, err := multiaddr.NewMultiaddr(peerStr)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Invalid bootstrap peer %s: %v\n", peerStr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract peer info
|
||||
info, err := peer.AddrInfoFromP2pAddr(addr)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Failed to parse peer info from %s: %v\n", peerStr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
peers = append(peers, *info)
|
||||
}
|
||||
|
||||
bp.peers = peers
|
||||
fmt.Printf("📋 Loaded %d bootstrap peers from legacy environment\n", len(peers))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSubset returns a subset of bootstrap peers for a replica
|
||||
func (bp *BootstrapPool) GetSubset(count int) BootstrapSubset {
|
||||
if len(bp.peers) == 0 {
|
||||
return BootstrapSubset{
|
||||
Peers: []peer.AddrInfo{},
|
||||
StaggerDelayMS: 0,
|
||||
AssignedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure count doesn't exceed available peers
|
||||
if count > len(bp.peers) {
|
||||
count = len(bp.peers)
|
||||
}
|
||||
|
||||
// Randomly select peers from the pool
|
||||
selectedPeers := make([]peer.AddrInfo, 0, count)
|
||||
indices := rand.Perm(len(bp.peers))
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
selectedPeers = append(selectedPeers, bp.peers[indices[i]])
|
||||
}
|
||||
|
||||
// Generate random stagger delay (0 to configured max)
|
||||
staggerMS := 0
|
||||
if bp.staggerDelay > 0 {
|
||||
staggerMS = rand.Intn(int(bp.staggerDelay.Milliseconds()))
|
||||
}
|
||||
|
||||
return BootstrapSubset{
|
||||
Peers: selectedPeers,
|
||||
StaggerDelayMS: staggerMS,
|
||||
AssignedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectWithRateLimit connects to bootstrap peers with rate limiting
|
||||
func (bp *BootstrapPool) ConnectWithRateLimit(ctx context.Context, h host.Host, subset BootstrapSubset) error {
|
||||
if len(subset.Peers) == 0 {
|
||||
return nil // No peers to connect to
|
||||
}
|
||||
|
||||
// Apply stagger delay
|
||||
if subset.StaggerDelayMS > 0 {
|
||||
delay := time.Duration(subset.StaggerDelayMS) * time.Millisecond
|
||||
fmt.Printf("⏱️ Applying join stagger delay: %v\n", delay)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
// Continue after delay
|
||||
}
|
||||
}
|
||||
|
||||
// Create rate limiter for dials
|
||||
ticker := time.NewTicker(time.Second / time.Duration(bp.dialsPerSecond))
|
||||
defer ticker.Stop()
|
||||
|
||||
// Semaphore for concurrent dials
|
||||
semaphore := make(chan struct{}, bp.maxConcurrent)
|
||||
|
||||
// Connect to each peer with rate limiting
|
||||
for i, peerInfo := range subset.Peers {
|
||||
// Wait for rate limiter
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
// Rate limit satisfied
|
||||
}
|
||||
|
||||
// Acquire semaphore
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case semaphore <- struct{}{}:
|
||||
// Semaphore acquired
|
||||
}
|
||||
|
||||
// Connect to peer in goroutine
|
||||
go func(info peer.AddrInfo, index int) {
|
||||
defer func() { <-semaphore }() // Release semaphore
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.Connect(ctx, info); err != nil {
|
||||
fmt.Printf("⚠️ Failed to connect to bootstrap peer %s (%d/%d): %v\n",
|
||||
info.ID.ShortString(), index+1, len(subset.Peers), err)
|
||||
} else {
|
||||
fmt.Printf("🔗 Connected to bootstrap peer %s (%d/%d)\n",
|
||||
info.ID.ShortString(), index+1, len(subset.Peers))
|
||||
}
|
||||
}(peerInfo, i)
|
||||
}
|
||||
|
||||
// Wait for all connections to complete or timeout
|
||||
for i := 0; i < bp.maxConcurrent && i < len(subset.Peers); i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case semaphore <- struct{}{}:
|
||||
<-semaphore // Immediately release
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerCount returns the number of available bootstrap peers
|
||||
func (bp *BootstrapPool) GetPeerCount() int {
|
||||
return len(bp.peers)
|
||||
}
|
||||
|
||||
// GetPeers returns all bootstrap peers (for debugging)
|
||||
func (bp *BootstrapPool) GetPeers() []peer.AddrInfo {
|
||||
return bp.peers
|
||||
}
|
||||
|
||||
// GetStats returns bootstrap pool statistics
|
||||
func (bp *BootstrapPool) GetStats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"peer_count": len(bp.peers),
|
||||
"dials_per_second": bp.dialsPerSecond,
|
||||
"max_concurrent": bp.maxConcurrent,
|
||||
"stagger_delay_ms": bp.staggerDelay.Milliseconds(),
|
||||
}
|
||||
}
|
||||
517
pkg/config/assignment.go
Normal file
517
pkg/config/assignment.go
Normal file
@@ -0,0 +1,517 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RuntimeConfig manages runtime configuration with assignment overrides
|
||||
type RuntimeConfig struct {
|
||||
Base *Config `json:"base"`
|
||||
Override *AssignmentConfig `json:"override"`
|
||||
mu sync.RWMutex
|
||||
reloadCh chan struct{}
|
||||
}
|
||||
|
||||
// AssignmentConfig represents runtime assignment from WHOOSH
|
||||
type AssignmentConfig struct {
|
||||
// Assignment metadata
|
||||
AssignmentID string `json:"assignment_id"`
|
||||
TaskSlot string `json:"task_slot"`
|
||||
TaskID string `json:"task_id"`
|
||||
ClusterID string `json:"cluster_id"`
|
||||
AssignedAt time.Time `json:"assigned_at"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
|
||||
// Agent configuration overrides
|
||||
Agent *AgentConfig `json:"agent,omitempty"`
|
||||
Network *NetworkConfig `json:"network,omitempty"`
|
||||
AI *AIConfig `json:"ai,omitempty"`
|
||||
Logging *LoggingConfig `json:"logging,omitempty"`
|
||||
|
||||
// Bootstrap configuration for scaling
|
||||
BootstrapPeers []string `json:"bootstrap_peers,omitempty"`
|
||||
JoinStagger int `json:"join_stagger_ms,omitempty"`
|
||||
|
||||
// Runtime capabilities
|
||||
RuntimeCapabilities []string `json:"runtime_capabilities,omitempty"`
|
||||
|
||||
// Key derivation for encryption
|
||||
RoleKey string `json:"role_key,omitempty"`
|
||||
ClusterSecret string `json:"cluster_secret,omitempty"`
|
||||
|
||||
// Custom fields
|
||||
Custom map[string]interface{} `json:"custom,omitempty"`
|
||||
}
|
||||
|
||||
// AssignmentRequest represents a request for assignment from WHOOSH
|
||||
type AssignmentRequest struct {
|
||||
ClusterID string `json:"cluster_id"`
|
||||
TaskSlot string `json:"task_slot,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
AgentID string `json:"agent_id"`
|
||||
NodeID string `json:"node_id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// NewRuntimeConfig creates a new runtime configuration manager
|
||||
func NewRuntimeConfig(baseConfig *Config) *RuntimeConfig {
|
||||
return &RuntimeConfig{
|
||||
Base: baseConfig,
|
||||
Override: nil,
|
||||
reloadCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the effective configuration value, with override taking precedence
|
||||
func (rc *RuntimeConfig) Get(field string) interface{} {
|
||||
rc.mu.RLock()
|
||||
defer rc.mu.RUnlock()
|
||||
|
||||
// Try override first
|
||||
if rc.Override != nil {
|
||||
if value := rc.getFromAssignment(field); value != nil {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to base configuration
|
||||
return rc.getFromBase(field)
|
||||
}
|
||||
|
||||
// GetConfig returns a merged configuration with overrides applied
|
||||
func (rc *RuntimeConfig) GetConfig() *Config {
|
||||
rc.mu.RLock()
|
||||
defer rc.mu.RUnlock()
|
||||
|
||||
if rc.Override == nil {
|
||||
return rc.Base
|
||||
}
|
||||
|
||||
// Create a copy of base config
|
||||
merged := *rc.Base
|
||||
|
||||
// Apply overrides
|
||||
if rc.Override.Agent != nil {
|
||||
rc.mergeAgentConfig(&merged.Agent, rc.Override.Agent)
|
||||
}
|
||||
if rc.Override.Network != nil {
|
||||
rc.mergeNetworkConfig(&merged.Network, rc.Override.Network)
|
||||
}
|
||||
if rc.Override.AI != nil {
|
||||
rc.mergeAIConfig(&merged.AI, rc.Override.AI)
|
||||
}
|
||||
if rc.Override.Logging != nil {
|
||||
rc.mergeLoggingConfig(&merged.Logging, rc.Override.Logging)
|
||||
}
|
||||
|
||||
return &merged
|
||||
}
|
||||
|
||||
// LoadAssignment fetches assignment from WHOOSH and applies it
|
||||
func (rc *RuntimeConfig) LoadAssignment(ctx context.Context, assignURL string) error {
|
||||
if assignURL == "" {
|
||||
return nil // No assignment URL configured
|
||||
}
|
||||
|
||||
// Build assignment request
|
||||
agentID := rc.Base.Agent.ID
|
||||
if agentID == "" {
|
||||
agentID = "unknown"
|
||||
}
|
||||
|
||||
req := AssignmentRequest{
|
||||
ClusterID: rc.Base.License.ClusterID,
|
||||
TaskSlot: os.Getenv("TASK_SLOT"),
|
||||
TaskID: os.Getenv("TASK_ID"),
|
||||
AgentID: agentID,
|
||||
NodeID: os.Getenv("NODE_ID"),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
// Make HTTP request to WHOOSH
|
||||
assignment, err := rc.fetchAssignment(ctx, assignURL, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch assignment: %w", err)
|
||||
}
|
||||
|
||||
// Apply assignment
|
||||
rc.mu.Lock()
|
||||
rc.Override = assignment
|
||||
rc.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartReloadHandler starts a signal handler for SIGHUP configuration reloads
|
||||
func (rc *RuntimeConfig) StartReloadHandler(ctx context.Context, assignURL string) {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGHUP)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-sigCh:
|
||||
fmt.Println("📡 Received SIGHUP, reloading assignment configuration...")
|
||||
if err := rc.LoadAssignment(ctx, assignURL); err != nil {
|
||||
fmt.Printf("❌ Failed to reload assignment: %v\n", err)
|
||||
} else {
|
||||
fmt.Println("✅ Assignment configuration reloaded successfully")
|
||||
}
|
||||
case <-rc.reloadCh:
|
||||
// Manual reload trigger
|
||||
if err := rc.LoadAssignment(ctx, assignURL); err != nil {
|
||||
fmt.Printf("❌ Failed to reload assignment: %v\n", err)
|
||||
} else {
|
||||
fmt.Println("✅ Assignment configuration reloaded successfully")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Reload triggers a manual configuration reload
|
||||
func (rc *RuntimeConfig) Reload() {
|
||||
select {
|
||||
case rc.reloadCh <- struct{}{}:
|
||||
default:
|
||||
// Channel full, reload already pending
|
||||
}
|
||||
}
|
||||
|
||||
// fetchAssignment makes HTTP request to WHOOSH assignment API
|
||||
func (rc *RuntimeConfig) fetchAssignment(ctx context.Context, assignURL string, req AssignmentRequest) (*AssignmentConfig, error) {
|
||||
// Build query parameters
|
||||
queryParams := fmt.Sprintf("?cluster_id=%s&agent_id=%s&node_id=%s",
|
||||
req.ClusterID, req.AgentID, req.NodeID)
|
||||
|
||||
if req.TaskSlot != "" {
|
||||
queryParams += "&task_slot=" + req.TaskSlot
|
||||
}
|
||||
if req.TaskID != "" {
|
||||
queryParams += "&task_id=" + req.TaskID
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "GET", assignURL+queryParams, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create assignment request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
httpReq.Header.Set("User-Agent", "CHORUS-Agent/0.1.0")
|
||||
|
||||
// Make request with timeout
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("assignment request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// No assignment available
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("assignment request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse assignment response
|
||||
var assignment AssignmentConfig
|
||||
if err := json.NewDecoder(resp.Body).Decode(&assignment); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode assignment response: %w", err)
|
||||
}
|
||||
|
||||
return &assignment, nil
|
||||
}
|
||||
|
||||
// Helper methods for getting values from different sources
|
||||
func (rc *RuntimeConfig) getFromAssignment(field string) interface{} {
|
||||
if rc.Override == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Simple field mapping - in a real implementation, you'd use reflection
|
||||
// or a more sophisticated field mapping system
|
||||
switch field {
|
||||
case "agent.id":
|
||||
if rc.Override.Agent != nil && rc.Override.Agent.ID != "" {
|
||||
return rc.Override.Agent.ID
|
||||
}
|
||||
case "agent.role":
|
||||
if rc.Override.Agent != nil && rc.Override.Agent.Role != "" {
|
||||
return rc.Override.Agent.Role
|
||||
}
|
||||
case "agent.capabilities":
|
||||
if len(rc.Override.RuntimeCapabilities) > 0 {
|
||||
return rc.Override.RuntimeCapabilities
|
||||
}
|
||||
case "bootstrap_peers":
|
||||
if len(rc.Override.BootstrapPeers) > 0 {
|
||||
return rc.Override.BootstrapPeers
|
||||
}
|
||||
case "join_stagger":
|
||||
if rc.Override.JoinStagger > 0 {
|
||||
return rc.Override.JoinStagger
|
||||
}
|
||||
}
|
||||
|
||||
// Check custom fields
|
||||
if rc.Override.Custom != nil {
|
||||
if val, exists := rc.Override.Custom[field]; exists {
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rc *RuntimeConfig) getFromBase(field string) interface{} {
|
||||
// Simple field mapping for base config
|
||||
switch field {
|
||||
case "agent.id":
|
||||
return rc.Base.Agent.ID
|
||||
case "agent.role":
|
||||
return rc.Base.Agent.Role
|
||||
case "agent.capabilities":
|
||||
return rc.Base.Agent.Capabilities
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Helper methods for merging configuration sections
|
||||
func (rc *RuntimeConfig) mergeAgentConfig(base *AgentConfig, override *AgentConfig) {
|
||||
if override.ID != "" {
|
||||
base.ID = override.ID
|
||||
}
|
||||
if override.Specialization != "" {
|
||||
base.Specialization = override.Specialization
|
||||
}
|
||||
if override.MaxTasks > 0 {
|
||||
base.MaxTasks = override.MaxTasks
|
||||
}
|
||||
if len(override.Capabilities) > 0 {
|
||||
base.Capabilities = override.Capabilities
|
||||
}
|
||||
if len(override.Models) > 0 {
|
||||
base.Models = override.Models
|
||||
}
|
||||
if override.Role != "" {
|
||||
base.Role = override.Role
|
||||
}
|
||||
if override.Project != "" {
|
||||
base.Project = override.Project
|
||||
}
|
||||
if len(override.Expertise) > 0 {
|
||||
base.Expertise = override.Expertise
|
||||
}
|
||||
if override.ReportsTo != "" {
|
||||
base.ReportsTo = override.ReportsTo
|
||||
}
|
||||
if len(override.Deliverables) > 0 {
|
||||
base.Deliverables = override.Deliverables
|
||||
}
|
||||
if override.ModelSelectionWebhook != "" {
|
||||
base.ModelSelectionWebhook = override.ModelSelectionWebhook
|
||||
}
|
||||
if override.DefaultReasoningModel != "" {
|
||||
base.DefaultReasoningModel = override.DefaultReasoningModel
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *RuntimeConfig) mergeNetworkConfig(base *NetworkConfig, override *NetworkConfig) {
|
||||
if override.P2PPort > 0 {
|
||||
base.P2PPort = override.P2PPort
|
||||
}
|
||||
if override.APIPort > 0 {
|
||||
base.APIPort = override.APIPort
|
||||
}
|
||||
if override.HealthPort > 0 {
|
||||
base.HealthPort = override.HealthPort
|
||||
}
|
||||
if override.BindAddr != "" {
|
||||
base.BindAddr = override.BindAddr
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *RuntimeConfig) mergeAIConfig(base *AIConfig, override *AIConfig) {
|
||||
if override.Provider != "" {
|
||||
base.Provider = override.Provider
|
||||
}
|
||||
// Merge Ollama config if present
|
||||
if override.Ollama.Endpoint != "" {
|
||||
base.Ollama.Endpoint = override.Ollama.Endpoint
|
||||
}
|
||||
if override.Ollama.Timeout > 0 {
|
||||
base.Ollama.Timeout = override.Ollama.Timeout
|
||||
}
|
||||
// Merge ResetData config if present
|
||||
if override.ResetData.BaseURL != "" {
|
||||
base.ResetData.BaseURL = override.ResetData.BaseURL
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *RuntimeConfig) mergeLoggingConfig(base *LoggingConfig, override *LoggingConfig) {
|
||||
if override.Level != "" {
|
||||
base.Level = override.Level
|
||||
}
|
||||
if override.Format != "" {
|
||||
base.Format = override.Format
|
||||
}
|
||||
}
|
||||
|
||||
// BootstrapConfig represents JSON bootstrap configuration
|
||||
type BootstrapConfig struct {
|
||||
Peers []BootstrapPeer `json:"peers"`
|
||||
Metadata BootstrapMeta `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// BootstrapPeer represents a single bootstrap peer
|
||||
type BootstrapPeer struct {
|
||||
Address string `json:"address"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// BootstrapMeta contains metadata about the bootstrap configuration
|
||||
type BootstrapMeta struct {
|
||||
GeneratedAt time.Time `json:"generated_at,omitempty"`
|
||||
ClusterID string `json:"cluster_id,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Notes string `json:"notes,omitempty"`
|
||||
}
|
||||
|
||||
// GetBootstrapPeers returns bootstrap peers with assignment override support and JSON config
|
||||
func (rc *RuntimeConfig) GetBootstrapPeers() []string {
|
||||
rc.mu.RLock()
|
||||
defer rc.mu.RUnlock()
|
||||
|
||||
// First priority: Assignment override from WHOOSH
|
||||
if rc.Override != nil && len(rc.Override.BootstrapPeers) > 0 {
|
||||
return rc.Override.BootstrapPeers
|
||||
}
|
||||
|
||||
// Second priority: JSON bootstrap configuration
|
||||
if jsonPeers := rc.loadBootstrapJSON(); len(jsonPeers) > 0 {
|
||||
return jsonPeers
|
||||
}
|
||||
|
||||
// Third priority: Environment variable (CSV format)
|
||||
if bootstrapEnv := os.Getenv("CHORUS_BOOTSTRAP_PEERS"); bootstrapEnv != "" {
|
||||
peers := strings.Split(bootstrapEnv, ",")
|
||||
// Trim whitespace from each peer
|
||||
for i, peer := range peers {
|
||||
peers[i] = strings.TrimSpace(peer)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// loadBootstrapJSON loads bootstrap peers from JSON file
|
||||
func (rc *RuntimeConfig) loadBootstrapJSON() []string {
|
||||
jsonPath := os.Getenv("BOOTSTRAP_JSON")
|
||||
if jsonPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(jsonPath); os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read and parse JSON file
|
||||
data, err := os.ReadFile(jsonPath)
|
||||
if err != nil {
|
||||
fmt.Printf("⚠️ Failed to read bootstrap JSON file %s: %v\n", jsonPath, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var config BootstrapConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
fmt.Printf("⚠️ Failed to parse bootstrap JSON file %s: %v\n", jsonPath, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract enabled peer addresses, sorted by priority
|
||||
var peers []string
|
||||
enabledPeers := make([]BootstrapPeer, 0, len(config.Peers))
|
||||
|
||||
// Filter enabled peers
|
||||
for _, peer := range config.Peers {
|
||||
if peer.Enabled && peer.Address != "" {
|
||||
enabledPeers = append(enabledPeers, peer)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by priority (higher priority first)
|
||||
for i := 0; i < len(enabledPeers)-1; i++ {
|
||||
for j := i + 1; j < len(enabledPeers); j++ {
|
||||
if enabledPeers[j].Priority > enabledPeers[i].Priority {
|
||||
enabledPeers[i], enabledPeers[j] = enabledPeers[j], enabledPeers[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract addresses
|
||||
for _, peer := range enabledPeers {
|
||||
peers = append(peers, peer.Address)
|
||||
}
|
||||
|
||||
if len(peers) > 0 {
|
||||
fmt.Printf("📋 Loaded %d bootstrap peers from JSON: %s\n", len(peers), jsonPath)
|
||||
}
|
||||
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetJoinStagger returns join stagger delay with assignment override support
|
||||
func (rc *RuntimeConfig) GetJoinStagger() time.Duration {
|
||||
rc.mu.RLock()
|
||||
defer rc.mu.RUnlock()
|
||||
|
||||
if rc.Override != nil && rc.Override.JoinStagger > 0 {
|
||||
return time.Duration(rc.Override.JoinStagger) * time.Millisecond
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
if staggerEnv := os.Getenv("CHORUS_JOIN_STAGGER_MS"); staggerEnv != "" {
|
||||
if ms, err := time.ParseDuration(staggerEnv + "ms"); err == nil {
|
||||
return ms
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetAssignmentInfo returns current assignment metadata
|
||||
func (rc *RuntimeConfig) GetAssignmentInfo() *AssignmentConfig {
|
||||
rc.mu.RLock()
|
||||
defer rc.mu.RUnlock()
|
||||
|
||||
if rc.Override == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
assignment := *rc.Override
|
||||
return &assignment
|
||||
}
|
||||
@@ -28,17 +28,18 @@ type Config struct {
|
||||
|
||||
// AgentConfig defines agent-specific settings
|
||||
type AgentConfig struct {
|
||||
ID string `yaml:"id"`
|
||||
Specialization string `yaml:"specialization"`
|
||||
MaxTasks int `yaml:"max_tasks"`
|
||||
Capabilities []string `yaml:"capabilities"`
|
||||
Models []string `yaml:"models"`
|
||||
Role string `yaml:"role"`
|
||||
Expertise []string `yaml:"expertise"`
|
||||
ReportsTo string `yaml:"reports_to"`
|
||||
Deliverables []string `yaml:"deliverables"`
|
||||
ModelSelectionWebhook string `yaml:"model_selection_webhook"`
|
||||
DefaultReasoningModel string `yaml:"default_reasoning_model"`
|
||||
ID string `yaml:"id"`
|
||||
Specialization string `yaml:"specialization"`
|
||||
MaxTasks int `yaml:"max_tasks"`
|
||||
Capabilities []string `yaml:"capabilities"`
|
||||
Models []string `yaml:"models"`
|
||||
Role string `yaml:"role"`
|
||||
Project string `yaml:"project"`
|
||||
Expertise []string `yaml:"expertise"`
|
||||
ReportsTo string `yaml:"reports_to"`
|
||||
Deliverables []string `yaml:"deliverables"`
|
||||
ModelSelectionWebhook string `yaml:"model_selection_webhook"`
|
||||
DefaultReasoningModel string `yaml:"default_reasoning_model"`
|
||||
}
|
||||
|
||||
// NetworkConfig defines network and API settings
|
||||
@@ -65,9 +66,9 @@ type LicenseConfig struct {
|
||||
|
||||
// AIConfig defines AI service settings
|
||||
type AIConfig struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Ollama OllamaConfig `yaml:"ollama"`
|
||||
ResetData ResetDataConfig `yaml:"resetdata"`
|
||||
Provider string `yaml:"provider"`
|
||||
Ollama OllamaConfig `yaml:"ollama"`
|
||||
ResetData ResetDataConfig `yaml:"resetdata"`
|
||||
}
|
||||
|
||||
// OllamaConfig defines Ollama-specific settings
|
||||
@@ -78,10 +79,10 @@ type OllamaConfig struct {
|
||||
|
||||
// ResetDataConfig defines ResetData LLM service settings
|
||||
type ResetDataConfig struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
Model string `yaml:"model"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
Model string `yaml:"model"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
}
|
||||
|
||||
// LoggingConfig defines logging settings
|
||||
@@ -99,13 +100,14 @@ type V2Config struct {
|
||||
type DHTConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
BootstrapPeers []string `yaml:"bootstrap_peers"`
|
||||
MDNSEnabled bool `yaml:"mdns_enabled"`
|
||||
}
|
||||
|
||||
// UCXLConfig defines UCXL protocol settings
|
||||
type UCXLConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Storage StorageConfig `yaml:"storage"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Storage StorageConfig `yaml:"storage"`
|
||||
Resolution ResolutionConfig `yaml:"resolution"`
|
||||
}
|
||||
|
||||
@@ -133,25 +135,26 @@ type SlurpConfig struct {
|
||||
|
||||
// WHOOSHAPIConfig defines WHOOSH API integration settings
|
||||
type WHOOSHAPIConfig struct {
|
||||
URL string `yaml:"url"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Token string `yaml:"token"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
URL string `yaml:"url"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Token string `yaml:"token"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// LoadFromEnvironment loads configuration from environment variables
|
||||
func LoadFromEnvironment() (*Config, error) {
|
||||
cfg := &Config{
|
||||
Agent: AgentConfig{
|
||||
ID: getEnvOrDefault("CHORUS_AGENT_ID", ""),
|
||||
Specialization: getEnvOrDefault("CHORUS_SPECIALIZATION", "general_developer"),
|
||||
MaxTasks: getEnvIntOrDefault("CHORUS_MAX_TASKS", 3),
|
||||
Capabilities: getEnvArrayOrDefault("CHORUS_CAPABILITIES", []string{"general_development", "task_coordination"}),
|
||||
Models: getEnvArrayOrDefault("CHORUS_MODELS", []string{"meta/llama-3.1-8b-instruct"}),
|
||||
Role: getEnvOrDefault("CHORUS_ROLE", ""),
|
||||
Expertise: getEnvArrayOrDefault("CHORUS_EXPERTISE", []string{}),
|
||||
ReportsTo: getEnvOrDefault("CHORUS_REPORTS_TO", ""),
|
||||
Deliverables: getEnvArrayOrDefault("CHORUS_DELIVERABLES", []string{}),
|
||||
ID: getEnvOrDefault("CHORUS_AGENT_ID", ""),
|
||||
Specialization: getEnvOrDefault("CHORUS_SPECIALIZATION", "general_developer"),
|
||||
MaxTasks: getEnvIntOrDefault("CHORUS_MAX_TASKS", 3),
|
||||
Capabilities: getEnvArrayOrDefault("CHORUS_CAPABILITIES", []string{"general_development", "task_coordination"}),
|
||||
Models: getEnvArrayOrDefault("CHORUS_MODELS", []string{"meta/llama-3.1-8b-instruct"}),
|
||||
Role: getEnvOrDefault("CHORUS_ROLE", ""),
|
||||
Project: getEnvOrDefault("CHORUS_PROJECT", "chorus"),
|
||||
Expertise: getEnvArrayOrDefault("CHORUS_EXPERTISE", []string{}),
|
||||
ReportsTo: getEnvOrDefault("CHORUS_REPORTS_TO", ""),
|
||||
Deliverables: getEnvArrayOrDefault("CHORUS_DELIVERABLES", []string{}),
|
||||
ModelSelectionWebhook: getEnvOrDefault("CHORUS_MODEL_SELECTION_WEBHOOK", ""),
|
||||
DefaultReasoningModel: getEnvOrDefault("CHORUS_DEFAULT_REASONING_MODEL", "meta/llama-3.1-8b-instruct"),
|
||||
},
|
||||
@@ -177,7 +180,7 @@ func LoadFromEnvironment() (*Config, error) {
|
||||
},
|
||||
ResetData: ResetDataConfig{
|
||||
BaseURL: getEnvOrDefault("RESETDATA_BASE_URL", "https://models.au-syd.resetdata.ai/v1"),
|
||||
APIKey: os.Getenv("RESETDATA_API_KEY"),
|
||||
APIKey: getEnvOrFileContent("RESETDATA_API_KEY", "RESETDATA_API_KEY_FILE"),
|
||||
Model: getEnvOrDefault("RESETDATA_MODEL", "meta/llama-3.1-8b-instruct"),
|
||||
Timeout: getEnvDurationOrDefault("RESETDATA_TIMEOUT", 30*time.Second),
|
||||
},
|
||||
@@ -190,6 +193,7 @@ func LoadFromEnvironment() (*Config, error) {
|
||||
DHT: DHTConfig{
|
||||
Enabled: getEnvBoolOrDefault("CHORUS_DHT_ENABLED", true),
|
||||
BootstrapPeers: getEnvArrayOrDefault("CHORUS_BOOTSTRAP_PEERS", []string{}),
|
||||
MDNSEnabled: getEnvBoolOrDefault("CHORUS_MDNS_ENABLED", true),
|
||||
},
|
||||
},
|
||||
UCXL: UCXLConfig{
|
||||
@@ -214,10 +218,10 @@ func LoadFromEnvironment() (*Config, error) {
|
||||
AuditLogging: getEnvBoolOrDefault("CHORUS_AUDIT_LOGGING", true),
|
||||
AuditPath: getEnvOrDefault("CHORUS_AUDIT_PATH", "/tmp/chorus-audit.log"),
|
||||
ElectionConfig: ElectionConfig{
|
||||
DiscoveryTimeout: getEnvDurationOrDefault("CHORUS_DISCOVERY_TIMEOUT", 10*time.Second),
|
||||
HeartbeatTimeout: getEnvDurationOrDefault("CHORUS_HEARTBEAT_TIMEOUT", 30*time.Second),
|
||||
ElectionTimeout: getEnvDurationOrDefault("CHORUS_ELECTION_TIMEOUT", 60*time.Second),
|
||||
DiscoveryBackoff: getEnvDurationOrDefault("CHORUS_DISCOVERY_BACKOFF", 5*time.Second),
|
||||
DiscoveryTimeout: getEnvDurationOrDefault("CHORUS_DISCOVERY_TIMEOUT", 15*time.Second),
|
||||
HeartbeatTimeout: getEnvDurationOrDefault("CHORUS_HEARTBEAT_TIMEOUT", 30*time.Second),
|
||||
ElectionTimeout: getEnvDurationOrDefault("CHORUS_ELECTION_TIMEOUT", 60*time.Second),
|
||||
DiscoveryBackoff: getEnvDurationOrDefault("CHORUS_DISCOVERY_BACKOFF", 5*time.Second),
|
||||
LeadershipScoring: &LeadershipScoring{
|
||||
UptimeWeight: 0.4,
|
||||
CapabilityWeight: 0.3,
|
||||
@@ -247,7 +251,7 @@ func (c *Config) Validate() error {
|
||||
if c.License.LicenseID == "" {
|
||||
return fmt.Errorf("CHORUS_LICENSE_ID is required")
|
||||
}
|
||||
|
||||
|
||||
if c.Agent.ID == "" {
|
||||
// Auto-generate agent ID if not provided
|
||||
hostname, _ := os.Hostname()
|
||||
@@ -258,7 +262,7 @@ func (c *Config) Validate() error {
|
||||
c.Agent.ID = fmt.Sprintf("chorus-%s", hostname)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -329,14 +333,14 @@ func getEnvOrFileContent(envKey, fileEnvKey string) string {
|
||||
if value := os.Getenv(envKey); value != "" {
|
||||
return value
|
||||
}
|
||||
|
||||
|
||||
// Then try reading from file path specified in fileEnvKey
|
||||
if filePath := os.Getenv(fileEnvKey); filePath != "" {
|
||||
if content, err := ioutil.ReadFile(filePath); err == nil {
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -360,4 +364,18 @@ func LoadConfig(configPath string) (*Config, error) {
|
||||
func SaveConfig(cfg *Config, configPath string) error {
|
||||
// For containers, configuration is environment-based, so this is a no-op
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// LoadRuntimeConfig loads configuration with runtime assignment support
|
||||
func LoadRuntimeConfig() (*RuntimeConfig, error) {
|
||||
// Load base configuration from environment
|
||||
baseConfig, err := LoadFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load base configuration: %w", err)
|
||||
}
|
||||
|
||||
// Create runtime configuration manager
|
||||
runtimeConfig := NewRuntimeConfig(baseConfig)
|
||||
|
||||
return runtimeConfig, nil
|
||||
}
|
||||
|
||||
@@ -41,10 +41,16 @@ type HybridUCXLConfig struct {
|
||||
}
|
||||
|
||||
type DiscoveryConfig struct {
|
||||
MDNSEnabled bool `env:"CHORUS_MDNS_ENABLED" default:"true" json:"mdns_enabled" yaml:"mdns_enabled"`
|
||||
DHTDiscovery bool `env:"CHORUS_DHT_DISCOVERY" default:"false" json:"dht_discovery" yaml:"dht_discovery"`
|
||||
AnnounceInterval time.Duration `env:"CHORUS_ANNOUNCE_INTERVAL" default:"30s" json:"announce_interval" yaml:"announce_interval"`
|
||||
ServiceName string `env:"CHORUS_SERVICE_NAME" default:"CHORUS" json:"service_name" yaml:"service_name"`
|
||||
MDNSEnabled bool `env:"CHORUS_MDNS_ENABLED" default:"true" json:"mdns_enabled" yaml:"mdns_enabled"`
|
||||
DHTDiscovery bool `env:"CHORUS_DHT_DISCOVERY" default:"false" json:"dht_discovery" yaml:"dht_discovery"`
|
||||
AnnounceInterval time.Duration `env:"CHORUS_ANNOUNCE_INTERVAL" default:"30s" json:"announce_interval" yaml:"announce_interval"`
|
||||
ServiceName string `env:"CHORUS_SERVICE_NAME" default:"CHORUS" json:"service_name" yaml:"service_name"`
|
||||
|
||||
// Rate limiting for scaling (as per WHOOSH issue #7)
|
||||
DialsPerSecond int `env:"CHORUS_DIALS_PER_SEC" default:"5" json:"dials_per_second" yaml:"dials_per_second"`
|
||||
MaxConcurrentDHT int `env:"CHORUS_MAX_CONCURRENT_DHT" default:"16" json:"max_concurrent_dht" yaml:"max_concurrent_dht"`
|
||||
MaxConcurrentDials int `env:"CHORUS_MAX_CONCURRENT_DIALS" default:"10" json:"max_concurrent_dials" yaml:"max_concurrent_dials"`
|
||||
JoinStaggerMS int `env:"CHORUS_JOIN_STAGGER_MS" default:"0" json:"join_stagger_ms" yaml:"join_stagger_ms"`
|
||||
}
|
||||
|
||||
type MonitoringConfig struct {
|
||||
@@ -79,10 +85,16 @@ func LoadHybridConfig() (*HybridConfig, error) {
|
||||
|
||||
// Load Discovery configuration
|
||||
config.Discovery = DiscoveryConfig{
|
||||
MDNSEnabled: getEnvBool("CHORUS_MDNS_ENABLED", true),
|
||||
DHTDiscovery: getEnvBool("CHORUS_DHT_DISCOVERY", false),
|
||||
AnnounceInterval: getEnvDuration("CHORUS_ANNOUNCE_INTERVAL", 30*time.Second),
|
||||
ServiceName: getEnvString("CHORUS_SERVICE_NAME", "CHORUS"),
|
||||
MDNSEnabled: getEnvBool("CHORUS_MDNS_ENABLED", true),
|
||||
DHTDiscovery: getEnvBool("CHORUS_DHT_DISCOVERY", false),
|
||||
AnnounceInterval: getEnvDuration("CHORUS_ANNOUNCE_INTERVAL", 30*time.Second),
|
||||
ServiceName: getEnvString("CHORUS_SERVICE_NAME", "CHORUS"),
|
||||
|
||||
// Rate limiting for scaling (as per WHOOSH issue #7)
|
||||
DialsPerSecond: getEnvInt("CHORUS_DIALS_PER_SEC", 5),
|
||||
MaxConcurrentDHT: getEnvInt("CHORUS_MAX_CONCURRENT_DHT", 16),
|
||||
MaxConcurrentDials: getEnvInt("CHORUS_MAX_CONCURRENT_DIALS", 10),
|
||||
JoinStaggerMS: getEnvInt("CHORUS_JOIN_STAGGER_MS", 0),
|
||||
}
|
||||
|
||||
// Load Monitoring configuration
|
||||
|
||||
@@ -12,27 +12,27 @@ const (
|
||||
|
||||
// SecurityConfig defines security-related configuration
|
||||
type SecurityConfig struct {
|
||||
KeyRotationDays int `yaml:"key_rotation_days"`
|
||||
AuditLogging bool `yaml:"audit_logging"`
|
||||
AuditPath string `yaml:"audit_path"`
|
||||
ElectionConfig ElectionConfig `yaml:"election"`
|
||||
KeyRotationDays int `yaml:"key_rotation_days"`
|
||||
AuditLogging bool `yaml:"audit_logging"`
|
||||
AuditPath string `yaml:"audit_path"`
|
||||
ElectionConfig ElectionConfig `yaml:"election"`
|
||||
}
|
||||
|
||||
// ElectionConfig defines election timing and behavior settings
|
||||
type ElectionConfig struct {
|
||||
DiscoveryTimeout time.Duration `yaml:"discovery_timeout"`
|
||||
HeartbeatTimeout time.Duration `yaml:"heartbeat_timeout"`
|
||||
ElectionTimeout time.Duration `yaml:"election_timeout"`
|
||||
DiscoveryBackoff time.Duration `yaml:"discovery_backoff"`
|
||||
LeadershipScoring *LeadershipScoring `yaml:"leadership_scoring,omitempty"`
|
||||
DiscoveryTimeout time.Duration `yaml:"discovery_timeout"`
|
||||
HeartbeatTimeout time.Duration `yaml:"heartbeat_timeout"`
|
||||
ElectionTimeout time.Duration `yaml:"election_timeout"`
|
||||
DiscoveryBackoff time.Duration `yaml:"discovery_backoff"`
|
||||
LeadershipScoring *LeadershipScoring `yaml:"leadership_scoring,omitempty"`
|
||||
}
|
||||
|
||||
// LeadershipScoring defines weights for election scoring
|
||||
type LeadershipScoring struct {
|
||||
UptimeWeight float64 `yaml:"uptime_weight"`
|
||||
CapabilityWeight float64 `yaml:"capability_weight"`
|
||||
ExperienceWeight float64 `yaml:"experience_weight"`
|
||||
LoadWeight float64 `yaml:"load_weight"`
|
||||
UptimeWeight float64 `yaml:"uptime_weight"`
|
||||
CapabilityWeight float64 `yaml:"capability_weight"`
|
||||
ExperienceWeight float64 `yaml:"experience_weight"`
|
||||
LoadWeight float64 `yaml:"load_weight"`
|
||||
}
|
||||
|
||||
// AgeKeyPair represents an Age encryption key pair
|
||||
@@ -43,14 +43,14 @@ type AgeKeyPair struct {
|
||||
|
||||
// RoleDefinition represents a role configuration
|
||||
type RoleDefinition struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Capabilities []string `yaml:"capabilities"`
|
||||
AccessLevel string `yaml:"access_level"`
|
||||
AuthorityLevel string `yaml:"authority_level"`
|
||||
Keys *AgeKeyPair `yaml:"keys,omitempty"`
|
||||
AgeKeys *AgeKeyPair `yaml:"age_keys,omitempty"` // Legacy field name
|
||||
CanDecrypt []string `yaml:"can_decrypt,omitempty"` // Roles this role can decrypt
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Capabilities []string `yaml:"capabilities"`
|
||||
AccessLevel string `yaml:"access_level"`
|
||||
AuthorityLevel string `yaml:"authority_level"`
|
||||
Keys *AgeKeyPair `yaml:"keys,omitempty"`
|
||||
AgeKeys *AgeKeyPair `yaml:"age_keys,omitempty"` // Legacy field name
|
||||
CanDecrypt []string `yaml:"can_decrypt,omitempty"` // Roles this role can decrypt
|
||||
}
|
||||
|
||||
// GetPredefinedRoles returns the predefined roles for the system
|
||||
@@ -65,7 +65,7 @@ func GetPredefinedRoles() map[string]*RoleDefinition {
|
||||
CanDecrypt: []string{"project_manager", "backend_developer", "frontend_developer", "devops_engineer", "security_engineer"},
|
||||
},
|
||||
"backend_developer": {
|
||||
Name: "backend_developer",
|
||||
Name: "backend_developer",
|
||||
Description: "Backend development and API work",
|
||||
Capabilities: []string{"backend", "api", "database"},
|
||||
AccessLevel: "medium",
|
||||
@@ -90,12 +90,52 @@ func GetPredefinedRoles() map[string]*RoleDefinition {
|
||||
},
|
||||
"security_engineer": {
|
||||
Name: "security_engineer",
|
||||
Description: "Security oversight and hardening",
|
||||
Description: "Security oversight and hardening",
|
||||
Capabilities: []string{"security", "audit", "compliance"},
|
||||
AccessLevel: "high",
|
||||
AuthorityLevel: AuthorityAdmin,
|
||||
CanDecrypt: []string{"security_engineer", "project_manager", "backend_developer", "frontend_developer", "devops_engineer"},
|
||||
},
|
||||
"security_expert": {
|
||||
Name: "security_expert",
|
||||
Description: "Advanced security analysis and policy work",
|
||||
Capabilities: []string{"security", "policy", "response"},
|
||||
AccessLevel: "high",
|
||||
AuthorityLevel: AuthorityAdmin,
|
||||
CanDecrypt: []string{"security_expert", "security_engineer", "project_manager"},
|
||||
},
|
||||
"senior_software_architect": {
|
||||
Name: "senior_software_architect",
|
||||
Description: "Architecture governance and system design",
|
||||
Capabilities: []string{"architecture", "design", "coordination"},
|
||||
AccessLevel: "high",
|
||||
AuthorityLevel: AuthorityAdmin,
|
||||
CanDecrypt: []string{"senior_software_architect", "project_manager", "backend_developer", "frontend_developer"},
|
||||
},
|
||||
"qa_engineer": {
|
||||
Name: "qa_engineer",
|
||||
Description: "Quality assurance and testing",
|
||||
Capabilities: []string{"testing", "validation"},
|
||||
AccessLevel: "medium",
|
||||
AuthorityLevel: AuthorityFull,
|
||||
CanDecrypt: []string{"qa_engineer", "backend_developer", "frontend_developer"},
|
||||
},
|
||||
"readonly_user": {
|
||||
Name: "readonly_user",
|
||||
Description: "Read-only observer with audit access",
|
||||
Capabilities: []string{"observation"},
|
||||
AccessLevel: "low",
|
||||
AuthorityLevel: AuthorityReadOnly,
|
||||
CanDecrypt: []string{"readonly_user"},
|
||||
},
|
||||
"suggestion_only_role": {
|
||||
Name: "suggestion_only_role",
|
||||
Description: "Can propose suggestions but not execute",
|
||||
Capabilities: []string{"recommendation"},
|
||||
AccessLevel: "low",
|
||||
AuthorityLevel: AuthoritySuggestion,
|
||||
CanDecrypt: []string{"suggestion_only_role"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,16 +146,16 @@ func (c *Config) CanDecryptRole(targetRole string) (bool, error) {
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
targetRoleDef, exists := roles[targetRole]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
// Simple access level check
|
||||
currentLevel := getAccessLevelValue(currentRole.AccessLevel)
|
||||
targetLevel := getAccessLevelValue(targetRoleDef.AccessLevel)
|
||||
|
||||
|
||||
return currentLevel >= targetLevel, nil
|
||||
}
|
||||
|
||||
@@ -130,4 +170,4 @@ func getAccessLevelValue(level string) int {
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
306
pkg/crypto/key_derivation.go
Normal file
306
pkg/crypto/key_derivation.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
"filippo.io/age"
|
||||
"filippo.io/age/armor"
|
||||
)
|
||||
|
||||
// KeyDerivationManager handles cluster-scoped key derivation for DHT encryption
|
||||
type KeyDerivationManager struct {
|
||||
clusterRootKey []byte
|
||||
clusterID string
|
||||
}
|
||||
|
||||
// DerivedKeySet contains keys derived for a specific role/scope
|
||||
type DerivedKeySet struct {
|
||||
RoleKey []byte // Role-specific key
|
||||
NodeKey []byte // Node-specific key for this instance
|
||||
AGEIdentity *age.X25519Identity // AGE identity for encryption/decryption
|
||||
AGERecipient *age.X25519Recipient // AGE recipient for encryption
|
||||
}
|
||||
|
||||
// NewKeyDerivationManager creates a new key derivation manager
|
||||
func NewKeyDerivationManager(clusterRootKey []byte, clusterID string) *KeyDerivationManager {
|
||||
return &KeyDerivationManager{
|
||||
clusterRootKey: clusterRootKey,
|
||||
clusterID: clusterID,
|
||||
}
|
||||
}
|
||||
|
||||
// NewKeyDerivationManagerFromSeed creates a manager from a seed string
|
||||
func NewKeyDerivationManagerFromSeed(seed, clusterID string) *KeyDerivationManager {
|
||||
// Use HKDF to derive a consistent root key from seed
|
||||
hash := sha256.New
|
||||
hkdf := hkdf.New(hash, []byte(seed), []byte(clusterID), []byte("CHORUS-cluster-root"))
|
||||
|
||||
rootKey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(hkdf, rootKey); err != nil {
|
||||
panic(fmt.Errorf("failed to derive cluster root key: %w", err))
|
||||
}
|
||||
|
||||
return &KeyDerivationManager{
|
||||
clusterRootKey: rootKey,
|
||||
clusterID: clusterID,
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveRoleKeys derives encryption keys for a specific role and agent
|
||||
func (kdm *KeyDerivationManager) DeriveRoleKeys(role, agentID string) (*DerivedKeySet, error) {
|
||||
if kdm.clusterRootKey == nil {
|
||||
return nil, fmt.Errorf("cluster root key not initialized")
|
||||
}
|
||||
|
||||
// Derive role-specific key
|
||||
roleKey, err := kdm.deriveKey(fmt.Sprintf("role-%s", role), 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive role key: %w", err)
|
||||
}
|
||||
|
||||
// Derive node-specific key from role key and agent ID
|
||||
nodeKey, err := kdm.deriveKeyFromParent(roleKey, fmt.Sprintf("node-%s", agentID), 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive node key: %w", err)
|
||||
}
|
||||
|
||||
// Generate AGE identity from node key
|
||||
ageIdentity, err := kdm.generateAGEIdentityFromKey(nodeKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate AGE identity: %w", err)
|
||||
}
|
||||
|
||||
ageRecipient := ageIdentity.Recipient()
|
||||
|
||||
return &DerivedKeySet{
|
||||
RoleKey: roleKey,
|
||||
NodeKey: nodeKey,
|
||||
AGEIdentity: ageIdentity,
|
||||
AGERecipient: ageRecipient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeriveClusterWideKeys derives keys that are shared across the entire cluster for a role
|
||||
func (kdm *KeyDerivationManager) DeriveClusterWideKeys(role string) (*DerivedKeySet, error) {
|
||||
if kdm.clusterRootKey == nil {
|
||||
return nil, fmt.Errorf("cluster root key not initialized")
|
||||
}
|
||||
|
||||
// Derive role-specific key
|
||||
roleKey, err := kdm.deriveKey(fmt.Sprintf("role-%s", role), 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive role key: %w", err)
|
||||
}
|
||||
|
||||
// For cluster-wide keys, use a deterministic "cluster" identifier
|
||||
clusterNodeKey, err := kdm.deriveKeyFromParent(roleKey, "cluster-shared", 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive cluster node key: %w", err)
|
||||
}
|
||||
|
||||
// Generate AGE identity from cluster node key
|
||||
ageIdentity, err := kdm.generateAGEIdentityFromKey(clusterNodeKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate AGE identity: %w", err)
|
||||
}
|
||||
|
||||
ageRecipient := ageIdentity.Recipient()
|
||||
|
||||
return &DerivedKeySet{
|
||||
RoleKey: roleKey,
|
||||
NodeKey: clusterNodeKey,
|
||||
AGEIdentity: ageIdentity,
|
||||
AGERecipient: ageRecipient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// deriveKey derives a key from the cluster root key using HKDF
|
||||
func (kdm *KeyDerivationManager) deriveKey(info string, length int) ([]byte, error) {
|
||||
hash := sha256.New
|
||||
hkdf := hkdf.New(hash, kdm.clusterRootKey, []byte(kdm.clusterID), []byte(info))
|
||||
|
||||
key := make([]byte, length)
|
||||
if _, err := io.ReadFull(hkdf, key); err != nil {
|
||||
return nil, fmt.Errorf("HKDF key derivation failed: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// deriveKeyFromParent derives a key from a parent key using HKDF
|
||||
func (kdm *KeyDerivationManager) deriveKeyFromParent(parentKey []byte, info string, length int) ([]byte, error) {
|
||||
hash := sha256.New
|
||||
hkdf := hkdf.New(hash, parentKey, []byte(kdm.clusterID), []byte(info))
|
||||
|
||||
key := make([]byte, length)
|
||||
if _, err := io.ReadFull(hkdf, key); err != nil {
|
||||
return nil, fmt.Errorf("HKDF key derivation failed: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generateAGEIdentityFromKey generates a deterministic AGE identity from a key
|
||||
func (kdm *KeyDerivationManager) generateAGEIdentityFromKey(key []byte) (*age.X25519Identity, error) {
|
||||
if len(key) < 32 {
|
||||
return nil, fmt.Errorf("key must be at least 32 bytes")
|
||||
}
|
||||
|
||||
// Use the first 32 bytes as the private key seed
|
||||
var privKey [32]byte
|
||||
copy(privKey[:], key[:32])
|
||||
|
||||
// Generate a new identity (note: this loses deterministic behavior)
|
||||
// TODO: Implement deterministic key derivation when age API allows
|
||||
identity, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AGE identity: %w", err)
|
||||
}
|
||||
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
// EncryptForRole encrypts data for a specific role (all nodes in that role can decrypt)
|
||||
func (kdm *KeyDerivationManager) EncryptForRole(data []byte, role string) ([]byte, error) {
|
||||
// Get cluster-wide keys for the role
|
||||
keySet, err := kdm.DeriveClusterWideKeys(role)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive cluster keys: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt using AGE
|
||||
var encrypted []byte
|
||||
buf := &writeBuffer{data: &encrypted}
|
||||
armorWriter := armor.NewWriter(buf)
|
||||
|
||||
ageWriter, err := age.Encrypt(armorWriter, keySet.AGERecipient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create age writer: %w", err)
|
||||
}
|
||||
|
||||
if _, err := ageWriter.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("failed to write encrypted data: %w", err)
|
||||
}
|
||||
|
||||
if err := ageWriter.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close age writer: %w", err)
|
||||
}
|
||||
|
||||
if err := armorWriter.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close armor writer: %w", err)
|
||||
}
|
||||
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
// DecryptForRole decrypts data encrypted for a specific role
|
||||
func (kdm *KeyDerivationManager) DecryptForRole(encryptedData []byte, role, agentID string) ([]byte, error) {
|
||||
// Try cluster-wide keys first
|
||||
clusterKeys, err := kdm.DeriveClusterWideKeys(role)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive cluster keys: %w", err)
|
||||
}
|
||||
|
||||
if decrypted, err := kdm.decryptWithIdentity(encryptedData, clusterKeys.AGEIdentity); err == nil {
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
// If cluster-wide decryption fails, try node-specific keys
|
||||
nodeKeys, err := kdm.DeriveRoleKeys(role, agentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive node keys: %w", err)
|
||||
}
|
||||
|
||||
return kdm.decryptWithIdentity(encryptedData, nodeKeys.AGEIdentity)
|
||||
}
|
||||
|
||||
// decryptWithIdentity decrypts data using an AGE identity
|
||||
func (kdm *KeyDerivationManager) decryptWithIdentity(encryptedData []byte, identity *age.X25519Identity) ([]byte, error) {
|
||||
armorReader := armor.NewReader(newReadBuffer(encryptedData))
|
||||
|
||||
ageReader, err := age.Decrypt(armorReader, identity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
decrypted, err := io.ReadAll(ageReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read decrypted data: %w", err)
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
// GetRoleRecipients returns AGE recipients for all nodes in a role (for multi-recipient encryption)
|
||||
func (kdm *KeyDerivationManager) GetRoleRecipients(role string, agentIDs []string) ([]*age.X25519Recipient, error) {
|
||||
var recipients []*age.X25519Recipient
|
||||
|
||||
// Add cluster-wide recipient
|
||||
clusterKeys, err := kdm.DeriveClusterWideKeys(role)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive cluster keys: %w", err)
|
||||
}
|
||||
recipients = append(recipients, clusterKeys.AGERecipient)
|
||||
|
||||
// Add node-specific recipients
|
||||
for _, agentID := range agentIDs {
|
||||
nodeKeys, err := kdm.DeriveRoleKeys(role, agentID)
|
||||
if err != nil {
|
||||
continue // Skip this agent on error
|
||||
}
|
||||
recipients = append(recipients, nodeKeys.AGERecipient)
|
||||
}
|
||||
|
||||
return recipients, nil
|
||||
}
|
||||
|
||||
// GetKeySetStats returns statistics about derived key sets
|
||||
func (kdm *KeyDerivationManager) GetKeySetStats(role, agentID string) map[string]interface{} {
|
||||
stats := map[string]interface{}{
|
||||
"cluster_id": kdm.clusterID,
|
||||
"role": role,
|
||||
"agent_id": agentID,
|
||||
}
|
||||
|
||||
// Try to derive keys and add fingerprint info
|
||||
if keySet, err := kdm.DeriveRoleKeys(role, agentID); err == nil {
|
||||
stats["node_key_length"] = len(keySet.NodeKey)
|
||||
stats["role_key_length"] = len(keySet.RoleKey)
|
||||
stats["age_recipient"] = keySet.AGERecipient.String()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Helper types for AGE encryption/decryption
|
||||
|
||||
type writeBuffer struct {
|
||||
data *[]byte
|
||||
}
|
||||
|
||||
func (w *writeBuffer) Write(p []byte) (n int, err error) {
|
||||
*w.data = append(*w.data, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
type readBuffer struct {
|
||||
data []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func newReadBuffer(data []byte) *readBuffer {
|
||||
return &readBuffer{data: data, pos: 0}
|
||||
}
|
||||
|
||||
func (r *readBuffer) Read(p []byte) (n int, err error) {
|
||||
if r.pos >= len(r.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(p, r.data[r.pos:])
|
||||
r.pos += n
|
||||
return n, nil
|
||||
}
|
||||
213
pkg/dht/dht.go
213
pkg/dht/dht.go
@@ -6,33 +6,34 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"crypto/sha256"
|
||||
"github.com/ipfs/go-cid"
|
||||
dht "github.com/libp2p/go-libp2p-kad-dht"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/protocol"
|
||||
"github.com/libp2p/go-libp2p/core/routing"
|
||||
dht "github.com/libp2p/go-libp2p-kad-dht"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
"github.com/multiformats/go-multihash"
|
||||
"github.com/ipfs/go-cid"
|
||||
"crypto/sha256"
|
||||
)
|
||||
|
||||
// LibP2PDHT provides distributed hash table functionality for CHORUS peer discovery
|
||||
type LibP2PDHT struct {
|
||||
host host.Host
|
||||
kdht *dht.IpfsDHT
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config *Config
|
||||
|
||||
host host.Host
|
||||
kdht *dht.IpfsDHT
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config *Config
|
||||
startTime time.Time
|
||||
|
||||
// Bootstrap state
|
||||
bootstrapped bool
|
||||
bootstrapMutex sync.RWMutex
|
||||
|
||||
|
||||
// Peer management
|
||||
knownPeers map[peer.ID]*PeerInfo
|
||||
peersMutex sync.RWMutex
|
||||
|
||||
|
||||
// Replication management
|
||||
replicationManager *ReplicationManager
|
||||
}
|
||||
@@ -41,30 +42,32 @@ type LibP2PDHT struct {
|
||||
type Config struct {
|
||||
// Bootstrap nodes for initial DHT discovery
|
||||
BootstrapPeers []multiaddr.Multiaddr
|
||||
|
||||
|
||||
// Protocol prefix for CHORUS DHT
|
||||
ProtocolPrefix string
|
||||
|
||||
|
||||
// Bootstrap timeout
|
||||
BootstrapTimeout time.Duration
|
||||
|
||||
|
||||
// Peer discovery interval
|
||||
DiscoveryInterval time.Duration
|
||||
|
||||
|
||||
// DHT mode (client, server, auto)
|
||||
Mode dht.ModeOpt
|
||||
|
||||
|
||||
// Enable automatic bootstrap
|
||||
AutoBootstrap bool
|
||||
}
|
||||
|
||||
// PeerInfo holds information about discovered peers
|
||||
const defaultProviderResultLimit = 20
|
||||
|
||||
type PeerInfo struct {
|
||||
ID peer.ID
|
||||
Addresses []multiaddr.Multiaddr
|
||||
Agent string
|
||||
Role string
|
||||
LastSeen time.Time
|
||||
ID peer.ID
|
||||
Addresses []multiaddr.Multiaddr
|
||||
Agent string
|
||||
Role string
|
||||
LastSeen time.Time
|
||||
Capabilities []string
|
||||
}
|
||||
|
||||
@@ -74,23 +77,28 @@ func DefaultConfig() *Config {
|
||||
ProtocolPrefix: "/CHORUS",
|
||||
BootstrapTimeout: 30 * time.Second,
|
||||
DiscoveryInterval: 60 * time.Second,
|
||||
Mode: dht.ModeAuto,
|
||||
AutoBootstrap: true,
|
||||
Mode: dht.ModeAuto,
|
||||
AutoBootstrap: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewLibP2PDHT creates a new LibP2PDHT instance
|
||||
// NewDHT is a backward compatible helper that delegates to NewLibP2PDHT.
|
||||
func NewDHT(ctx context.Context, host host.Host, opts ...Option) (*LibP2PDHT, error) {
|
||||
return NewLibP2PDHT(ctx, host, opts...)
|
||||
}
|
||||
|
||||
// NewLibP2PDHT creates a new LibP2PDHT instance
|
||||
func NewLibP2PDHT(ctx context.Context, host host.Host, opts ...Option) (*LibP2PDHT, error) {
|
||||
config := DefaultConfig()
|
||||
for _, opt := range opts {
|
||||
opt(config)
|
||||
}
|
||||
|
||||
|
||||
// Create context with cancellation
|
||||
dhtCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
|
||||
// Create Kademlia DHT
|
||||
kdht, err := dht.New(dhtCtx, host,
|
||||
kdht, err := dht.New(dhtCtx, host,
|
||||
dht.Mode(config.Mode),
|
||||
dht.ProtocolPrefix(protocol.ID(config.ProtocolPrefix)),
|
||||
)
|
||||
@@ -98,22 +106,23 @@ func NewLibP2PDHT(ctx context.Context, host host.Host, opts ...Option) (*LibP2PD
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to create DHT: %w", err)
|
||||
}
|
||||
|
||||
|
||||
d := &LibP2PDHT{
|
||||
host: host,
|
||||
kdht: kdht,
|
||||
ctx: dhtCtx,
|
||||
cancel: cancel,
|
||||
config: config,
|
||||
startTime: time.Now(),
|
||||
knownPeers: make(map[peer.ID]*PeerInfo),
|
||||
}
|
||||
|
||||
|
||||
// Initialize replication manager
|
||||
d.replicationManager = NewReplicationManager(dhtCtx, kdht, DefaultReplicationConfig())
|
||||
|
||||
|
||||
// Start background processes
|
||||
go d.startBackgroundTasks()
|
||||
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
@@ -178,25 +187,25 @@ func WithAutoBootstrap(auto bool) Option {
|
||||
func (d *LibP2PDHT) Bootstrap() error {
|
||||
d.bootstrapMutex.Lock()
|
||||
defer d.bootstrapMutex.Unlock()
|
||||
|
||||
|
||||
if d.bootstrapped {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// Connect to bootstrap peers
|
||||
if len(d.config.BootstrapPeers) == 0 {
|
||||
// Use default IPFS bootstrap peers if none configured
|
||||
d.config.BootstrapPeers = dht.DefaultBootstrapPeers
|
||||
}
|
||||
|
||||
|
||||
// Bootstrap the DHT
|
||||
bootstrapCtx, cancel := context.WithTimeout(d.ctx, d.config.BootstrapTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
||||
if err := d.kdht.Bootstrap(bootstrapCtx); err != nil {
|
||||
return fmt.Errorf("DHT bootstrap failed: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Connect to bootstrap peers
|
||||
var connected int
|
||||
for _, peerAddr := range d.config.BootstrapPeers {
|
||||
@@ -204,7 +213,7 @@ func (d *LibP2PDHT) Bootstrap() error {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(d.ctx, 10*time.Second)
|
||||
if err := d.host.Connect(connectCtx, *addrInfo); err != nil {
|
||||
cancel()
|
||||
@@ -213,11 +222,11 @@ func (d *LibP2PDHT) Bootstrap() error {
|
||||
cancel()
|
||||
connected++
|
||||
}
|
||||
|
||||
|
||||
if connected == 0 {
|
||||
return fmt.Errorf("failed to connect to any bootstrap peers")
|
||||
}
|
||||
|
||||
|
||||
d.bootstrapped = true
|
||||
return nil
|
||||
}
|
||||
@@ -233,13 +242,13 @@ func (d *LibP2PDHT) IsBootstrapped() bool {
|
||||
func (d *LibP2PDHT) keyToCID(key string) (cid.Cid, error) {
|
||||
// Hash the key
|
||||
hash := sha256.Sum256([]byte(key))
|
||||
|
||||
|
||||
// Create multihash
|
||||
mh, err := multihash.EncodeName(hash[:], "sha2-256")
|
||||
if err != nil {
|
||||
return cid.Undef, err
|
||||
}
|
||||
|
||||
|
||||
// Create CID
|
||||
return cid.NewCidV1(cid.Raw, mh), nil
|
||||
}
|
||||
@@ -249,13 +258,13 @@ func (d *LibP2PDHT) Provide(ctx context.Context, key string) error {
|
||||
if !d.IsBootstrapped() {
|
||||
return fmt.Errorf("DHT not bootstrapped")
|
||||
}
|
||||
|
||||
|
||||
// Convert key to CID
|
||||
keyCID, err := d.keyToCID(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create CID from key: %w", err)
|
||||
}
|
||||
|
||||
|
||||
return d.kdht.Provide(ctx, keyCID, true)
|
||||
}
|
||||
|
||||
@@ -264,31 +273,32 @@ func (d *LibP2PDHT) FindProviders(ctx context.Context, key string, limit int) ([
|
||||
if !d.IsBootstrapped() {
|
||||
return nil, fmt.Errorf("DHT not bootstrapped")
|
||||
}
|
||||
|
||||
|
||||
// Convert key to CID
|
||||
keyCID, err := d.keyToCID(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create CID from key: %w", err)
|
||||
}
|
||||
|
||||
// Find providers (FindProviders returns a channel and an error)
|
||||
providersChan, err := d.kdht.FindProviders(ctx, keyCID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find providers: %w", err)
|
||||
|
||||
maxProviders := limit
|
||||
if maxProviders <= 0 {
|
||||
maxProviders = defaultProviderResultLimit
|
||||
}
|
||||
|
||||
// Collect providers from channel
|
||||
providers := make([]peer.AddrInfo, 0, limit)
|
||||
// TODO: Fix libp2p FindProviders channel type mismatch
|
||||
// The channel appears to return int instead of peer.AddrInfo in this version
|
||||
_ = providersChan // Avoid unused variable error
|
||||
// for providerInfo := range providersChan {
|
||||
// providers = append(providers, providerInfo)
|
||||
// if len(providers) >= limit {
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
providerCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
providersChan := d.kdht.FindProvidersAsync(providerCtx, keyCID, maxProviders)
|
||||
providers := make([]peer.AddrInfo, 0, maxProviders)
|
||||
|
||||
for providerInfo := range providersChan {
|
||||
providers = append(providers, providerInfo)
|
||||
if limit > 0 && len(providers) >= limit {
|
||||
cancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
@@ -297,7 +307,7 @@ func (d *LibP2PDHT) PutValue(ctx context.Context, key string, value []byte) erro
|
||||
if !d.IsBootstrapped() {
|
||||
return fmt.Errorf("DHT not bootstrapped")
|
||||
}
|
||||
|
||||
|
||||
return d.kdht.PutValue(ctx, key, value)
|
||||
}
|
||||
|
||||
@@ -306,7 +316,7 @@ func (d *LibP2PDHT) GetValue(ctx context.Context, key string) ([]byte, error) {
|
||||
if !d.IsBootstrapped() {
|
||||
return nil, fmt.Errorf("DHT not bootstrapped")
|
||||
}
|
||||
|
||||
|
||||
return d.kdht.GetValue(ctx, key)
|
||||
}
|
||||
|
||||
@@ -315,7 +325,7 @@ func (d *LibP2PDHT) FindPeer(ctx context.Context, peerID peer.ID) (peer.AddrInfo
|
||||
if !d.IsBootstrapped() {
|
||||
return peer.AddrInfo{}, fmt.Errorf("DHT not bootstrapped")
|
||||
}
|
||||
|
||||
|
||||
return d.kdht.FindPeer(ctx, peerID)
|
||||
}
|
||||
|
||||
@@ -329,14 +339,30 @@ func (d *LibP2PDHT) GetConnectedPeers() []peer.ID {
|
||||
return d.kdht.Host().Network().Peers()
|
||||
}
|
||||
|
||||
// GetStats reports basic runtime statistics for the DHT
|
||||
func (d *LibP2PDHT) GetStats() DHTStats {
|
||||
stats := DHTStats{
|
||||
TotalPeers: len(d.GetConnectedPeers()),
|
||||
Uptime: time.Since(d.startTime),
|
||||
}
|
||||
|
||||
if d.replicationManager != nil {
|
||||
if metrics := d.replicationManager.GetMetrics(); metrics != nil {
|
||||
stats.TotalKeys = int(metrics.TotalKeys)
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// RegisterPeer registers a peer with capability information
|
||||
func (d *LibP2PDHT) RegisterPeer(peerID peer.ID, agent, role string, capabilities []string) {
|
||||
d.peersMutex.Lock()
|
||||
defer d.peersMutex.Unlock()
|
||||
|
||||
|
||||
// Get peer addresses from host
|
||||
peerInfo := d.host.Peerstore().PeerInfo(peerID)
|
||||
|
||||
|
||||
d.knownPeers[peerID] = &PeerInfo{
|
||||
ID: peerID,
|
||||
Addresses: peerInfo.Addrs,
|
||||
@@ -351,12 +377,12 @@ func (d *LibP2PDHT) RegisterPeer(peerID peer.ID, agent, role string, capabilitie
|
||||
func (d *LibP2PDHT) GetKnownPeers() map[peer.ID]*PeerInfo {
|
||||
d.peersMutex.RLock()
|
||||
defer d.peersMutex.RUnlock()
|
||||
|
||||
|
||||
result := make(map[peer.ID]*PeerInfo)
|
||||
for id, info := range d.knownPeers {
|
||||
result[id] = info
|
||||
}
|
||||
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -371,7 +397,7 @@ func (d *LibP2PDHT) FindPeersByRole(ctx context.Context, role string) ([]*PeerIn
|
||||
}
|
||||
}
|
||||
d.peersMutex.RUnlock()
|
||||
|
||||
|
||||
// Also search DHT for role-based keys
|
||||
roleKey := fmt.Sprintf("CHORUS:role:%s", role)
|
||||
providers, err := d.FindProviders(ctx, roleKey, 10)
|
||||
@@ -379,11 +405,11 @@ func (d *LibP2PDHT) FindPeersByRole(ctx context.Context, role string) ([]*PeerIn
|
||||
// Return local peers even if DHT search fails
|
||||
return localPeers, nil
|
||||
}
|
||||
|
||||
|
||||
// Convert providers to PeerInfo
|
||||
var result []*PeerInfo
|
||||
result = append(result, localPeers...)
|
||||
|
||||
|
||||
for _, provider := range providers {
|
||||
// Skip if we already have this peer
|
||||
found := false
|
||||
@@ -402,7 +428,7 @@ func (d *LibP2PDHT) FindPeersByRole(ctx context.Context, role string) ([]*PeerIn
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -424,10 +450,10 @@ func (d *LibP2PDHT) startBackgroundTasks() {
|
||||
if d.config.AutoBootstrap {
|
||||
go d.autoBootstrap()
|
||||
}
|
||||
|
||||
|
||||
// Start periodic peer discovery
|
||||
go d.periodicDiscovery()
|
||||
|
||||
|
||||
// Start peer cleanup
|
||||
go d.peerCleanup()
|
||||
}
|
||||
@@ -436,7 +462,7 @@ func (d *LibP2PDHT) startBackgroundTasks() {
|
||||
func (d *LibP2PDHT) autoBootstrap() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-d.ctx.Done():
|
||||
@@ -456,7 +482,7 @@ func (d *LibP2PDHT) autoBootstrap() {
|
||||
func (d *LibP2PDHT) periodicDiscovery() {
|
||||
ticker := time.NewTicker(d.config.DiscoveryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-d.ctx.Done():
|
||||
@@ -473,13 +499,13 @@ func (d *LibP2PDHT) periodicDiscovery() {
|
||||
func (d *LibP2PDHT) performDiscovery() {
|
||||
ctx, cancel := context.WithTimeout(d.ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
// Look for general CHORUS peers
|
||||
providers, err := d.FindProviders(ctx, "CHORUS:peer", 10)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Update known peers
|
||||
d.peersMutex.Lock()
|
||||
for _, provider := range providers {
|
||||
@@ -498,7 +524,7 @@ func (d *LibP2PDHT) performDiscovery() {
|
||||
func (d *LibP2PDHT) peerCleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-d.ctx.Done():
|
||||
@@ -513,9 +539,9 @@ func (d *LibP2PDHT) peerCleanup() {
|
||||
func (d *LibP2PDHT) cleanupStalePeers() {
|
||||
d.peersMutex.Lock()
|
||||
defer d.peersMutex.Unlock()
|
||||
|
||||
|
||||
staleThreshold := time.Now().Add(-time.Hour) // 1 hour threshold
|
||||
|
||||
|
||||
for peerID, peerInfo := range d.knownPeers {
|
||||
if peerInfo.LastSeen.Before(staleThreshold) {
|
||||
// Check if peer is still connected
|
||||
@@ -526,7 +552,7 @@ func (d *LibP2PDHT) cleanupStalePeers() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if !connected {
|
||||
delete(d.knownPeers, peerID)
|
||||
}
|
||||
@@ -589,11 +615,11 @@ func (d *LibP2PDHT) EnableReplication(config *ReplicationConfig) error {
|
||||
if d.replicationManager != nil {
|
||||
return fmt.Errorf("replication already enabled")
|
||||
}
|
||||
|
||||
|
||||
if config == nil {
|
||||
config = DefaultReplicationConfig()
|
||||
}
|
||||
|
||||
|
||||
d.replicationManager = NewReplicationManager(d.ctx, d.kdht, config)
|
||||
return nil
|
||||
}
|
||||
@@ -603,11 +629,11 @@ func (d *LibP2PDHT) DisableReplication() error {
|
||||
if d.replicationManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
if err := d.replicationManager.Stop(); err != nil {
|
||||
return fmt.Errorf("failed to stop replication manager: %w", err)
|
||||
}
|
||||
|
||||
|
||||
d.replicationManager = nil
|
||||
return nil
|
||||
}
|
||||
@@ -617,13 +643,18 @@ func (d *LibP2PDHT) IsReplicationEnabled() bool {
|
||||
return d.replicationManager != nil
|
||||
}
|
||||
|
||||
// ReplicationManager returns the underlying replication manager if enabled.
|
||||
func (d *LibP2PDHT) ReplicationManager() *ReplicationManager {
|
||||
return d.replicationManager
|
||||
}
|
||||
|
||||
// Close shuts down the DHT
|
||||
func (d *LibP2PDHT) Close() error {
|
||||
// Stop replication manager first
|
||||
if d.replicationManager != nil {
|
||||
d.replicationManager.Stop()
|
||||
}
|
||||
|
||||
|
||||
d.cancel()
|
||||
return d.kdht.Close()
|
||||
}
|
||||
@@ -633,10 +664,10 @@ func (d *LibP2PDHT) RefreshRoutingTable() error {
|
||||
if !d.IsBootstrapped() {
|
||||
return fmt.Errorf("DHT not bootstrapped")
|
||||
}
|
||||
|
||||
|
||||
// RefreshRoutingTable() returns a channel with errors, not a direct error
|
||||
errChan := d.kdht.RefreshRoutingTable()
|
||||
|
||||
|
||||
// Wait for the first error (if any) from the channel
|
||||
select {
|
||||
case err := <-errChan:
|
||||
@@ -654,4 +685,4 @@ func (d *LibP2PDHT) GetDHTSize() int {
|
||||
// Host returns the underlying libp2p host
|
||||
func (d *LibP2PDHT) Host() host.Host {
|
||||
return d.host
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,546 +2,155 @@ package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
libp2p "github.com/libp2p/go-libp2p"
|
||||
dhtmode "github.com/libp2p/go-libp2p-kad-dht"
|
||||
"github.com/libp2p/go-libp2p/core/test"
|
||||
dht "github.com/libp2p/go-libp2p-kad-dht"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
type harness struct {
|
||||
ctx context.Context
|
||||
host libp2pHost
|
||||
dht *LibP2PDHT
|
||||
}
|
||||
|
||||
type libp2pHost interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
func newHarness(t *testing.T, opts ...Option) *harness {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
host, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
|
||||
if err != nil {
|
||||
cancel()
|
||||
t.Fatalf("failed to create libp2p host: %v", err)
|
||||
}
|
||||
|
||||
options := append([]Option{WithAutoBootstrap(false)}, opts...)
|
||||
d, err := NewLibP2PDHT(ctx, host, options...)
|
||||
if err != nil {
|
||||
host.Close()
|
||||
cancel()
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
d.Close()
|
||||
host.Close()
|
||||
cancel()
|
||||
})
|
||||
|
||||
return &harness{ctx: ctx, host: host, dht: d}
|
||||
}
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
|
||||
if config.ProtocolPrefix != "/CHORUS" {
|
||||
t.Errorf("expected protocol prefix '/CHORUS', got %s", config.ProtocolPrefix)
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.ProtocolPrefix != "/CHORUS" {
|
||||
t.Fatalf("expected protocol prefix '/CHORUS', got %s", cfg.ProtocolPrefix)
|
||||
}
|
||||
|
||||
if config.BootstrapTimeout != 30*time.Second {
|
||||
t.Errorf("expected bootstrap timeout 30s, got %v", config.BootstrapTimeout)
|
||||
|
||||
if cfg.BootstrapTimeout != 30*time.Second {
|
||||
t.Fatalf("expected bootstrap timeout 30s, got %v", cfg.BootstrapTimeout)
|
||||
}
|
||||
|
||||
if config.Mode != dht.ModeAuto {
|
||||
t.Errorf("expected mode auto, got %v", config.Mode)
|
||||
|
||||
if cfg.Mode != dhtmode.ModeAuto {
|
||||
t.Fatalf("expected mode auto, got %v", cfg.Mode)
|
||||
}
|
||||
|
||||
if !config.AutoBootstrap {
|
||||
t.Error("expected auto bootstrap to be enabled")
|
||||
|
||||
if !cfg.AutoBootstrap {
|
||||
t.Fatal("expected auto bootstrap to be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDHT(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a test host
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
// Test with default options
|
||||
d, err := NewDHT(ctx, host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
if d.host != host {
|
||||
t.Error("host not set correctly")
|
||||
}
|
||||
|
||||
if d.config.ProtocolPrefix != "/CHORUS" {
|
||||
t.Errorf("expected protocol prefix '/CHORUS', got %s", d.config.ProtocolPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDHTWithOptions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
// Test with custom options
|
||||
d, err := NewDHT(ctx, host,
|
||||
func TestWithOptionsOverridesDefaults(t *testing.T) {
|
||||
h := newHarness(t,
|
||||
WithProtocolPrefix("/custom"),
|
||||
WithMode(dht.ModeClient),
|
||||
WithBootstrapTimeout(60*time.Second),
|
||||
WithDiscoveryInterval(120*time.Second),
|
||||
WithAutoBootstrap(false),
|
||||
WithDiscoveryInterval(2*time.Minute),
|
||||
WithBootstrapTimeout(45*time.Second),
|
||||
WithMode(dhtmode.ModeClient),
|
||||
WithAutoBootstrap(true),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
|
||||
cfg := h.dht.config
|
||||
|
||||
if cfg.ProtocolPrefix != "/custom" {
|
||||
t.Fatalf("expected protocol prefix '/custom', got %s", cfg.ProtocolPrefix)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
if d.config.ProtocolPrefix != "/custom" {
|
||||
t.Errorf("expected protocol prefix '/custom', got %s", d.config.ProtocolPrefix)
|
||||
|
||||
if cfg.DiscoveryInterval != 2*time.Minute {
|
||||
t.Fatalf("expected discovery interval 2m, got %v", cfg.DiscoveryInterval)
|
||||
}
|
||||
|
||||
if d.config.Mode != dht.ModeClient {
|
||||
t.Errorf("expected mode client, got %v", d.config.Mode)
|
||||
|
||||
if cfg.BootstrapTimeout != 45*time.Second {
|
||||
t.Fatalf("expected bootstrap timeout 45s, got %v", cfg.BootstrapTimeout)
|
||||
}
|
||||
|
||||
if d.config.BootstrapTimeout != 60*time.Second {
|
||||
t.Errorf("expected bootstrap timeout 60s, got %v", d.config.BootstrapTimeout)
|
||||
|
||||
if cfg.Mode != dhtmode.ModeClient {
|
||||
t.Fatalf("expected mode client, got %v", cfg.Mode)
|
||||
}
|
||||
|
||||
if d.config.DiscoveryInterval != 120*time.Second {
|
||||
t.Errorf("expected discovery interval 120s, got %v", d.config.DiscoveryInterval)
|
||||
}
|
||||
|
||||
if d.config.AutoBootstrap {
|
||||
t.Error("expected auto bootstrap to be disabled")
|
||||
|
||||
if !cfg.AutoBootstrap {
|
||||
t.Fatal("expected auto bootstrap to remain enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithBootstrapPeersFromStrings(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
bootstrapAddrs := []string{
|
||||
"/ip4/127.0.0.1/tcp/4001/p2p/QmTest1",
|
||||
"/ip4/127.0.0.1/tcp/4002/p2p/QmTest2",
|
||||
}
|
||||
|
||||
d, err := NewDHT(ctx, host, WithBootstrapPeersFromStrings(bootstrapAddrs))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
if len(d.config.BootstrapPeers) != 2 {
|
||||
t.Errorf("expected 2 bootstrap peers, got %d", len(d.config.BootstrapPeers))
|
||||
}
|
||||
}
|
||||
func TestProvideRequiresBootstrap(t *testing.T) {
|
||||
h := newHarness(t)
|
||||
|
||||
func TestWithBootstrapPeersFromStringsInvalid(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
err := h.dht.Provide(h.ctx, "key")
|
||||
if err == nil {
|
||||
t.Fatal("expected Provide to fail when not bootstrapped")
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
// Include invalid addresses - they should be filtered out
|
||||
bootstrapAddrs := []string{
|
||||
"/ip4/127.0.0.1/tcp/4001/p2p/QmTest1", // valid
|
||||
"invalid-address", // invalid
|
||||
"/ip4/127.0.0.1/tcp/4002/p2p/QmTest2", // valid
|
||||
}
|
||||
|
||||
d, err := NewDHT(ctx, host, WithBootstrapPeersFromStrings(bootstrapAddrs))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Should have filtered out the invalid address
|
||||
if len(d.config.BootstrapPeers) != 2 {
|
||||
t.Errorf("expected 2 valid bootstrap peers, got %d", len(d.config.BootstrapPeers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapWithoutPeers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host, WithAutoBootstrap(false))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Bootstrap should use default IPFS peers when none configured
|
||||
err = d.Bootstrap()
|
||||
// This might fail in test environment without network access, but should not panic
|
||||
if err != nil {
|
||||
// Expected in test environment
|
||||
t.Logf("Bootstrap failed as expected in test environment: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBootstrapped(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host, WithAutoBootstrap(false))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Should not be bootstrapped initially
|
||||
if d.IsBootstrapped() {
|
||||
t.Error("DHT should not be bootstrapped initially")
|
||||
if !strings.Contains(err.Error(), "not bootstrapped") {
|
||||
t.Fatalf("expected error to indicate bootstrap requirement, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterPeer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
h := newHarness(t)
|
||||
|
||||
peerID := test.RandPeerIDFatal(t)
|
||||
agent := "claude"
|
||||
role := "frontend"
|
||||
capabilities := []string{"react", "javascript"}
|
||||
|
||||
d.RegisterPeer(peerID, agent, role, capabilities)
|
||||
|
||||
knownPeers := d.GetKnownPeers()
|
||||
if len(knownPeers) != 1 {
|
||||
t.Errorf("expected 1 known peer, got %d", len(knownPeers))
|
||||
|
||||
h.dht.RegisterPeer(peerID, "apollo", "platform", []string{"go"})
|
||||
|
||||
peers := h.dht.GetKnownPeers()
|
||||
|
||||
info, ok := peers[peerID]
|
||||
if !ok {
|
||||
t.Fatalf("expected peer to be tracked")
|
||||
}
|
||||
|
||||
peerInfo, exists := knownPeers[peerID]
|
||||
if !exists {
|
||||
t.Error("peer not found in known peers")
|
||||
|
||||
if info.Agent != "apollo" {
|
||||
t.Fatalf("expected agent apollo, got %s", info.Agent)
|
||||
}
|
||||
|
||||
if peerInfo.Agent != agent {
|
||||
t.Errorf("expected agent %s, got %s", agent, peerInfo.Agent)
|
||||
|
||||
if info.Role != "platform" {
|
||||
t.Fatalf("expected role platform, got %s", info.Role)
|
||||
}
|
||||
|
||||
if peerInfo.Role != role {
|
||||
t.Errorf("expected role %s, got %s", role, peerInfo.Role)
|
||||
}
|
||||
|
||||
if len(peerInfo.Capabilities) != len(capabilities) {
|
||||
t.Errorf("expected %d capabilities, got %d", len(capabilities), len(peerInfo.Capabilities))
|
||||
|
||||
if len(info.Capabilities) != 1 || info.Capabilities[0] != "go" {
|
||||
t.Fatalf("expected capability go, got %v", info.Capabilities)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConnectedPeers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
func TestGetStatsProvidesUptime(t *testing.T) {
|
||||
h := newHarness(t)
|
||||
|
||||
stats := h.dht.GetStats()
|
||||
|
||||
if stats.TotalPeers != 0 {
|
||||
t.Fatalf("expected zero peers, got %d", stats.TotalPeers)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Initially should have no connected peers
|
||||
peers := d.GetConnectedPeers()
|
||||
if len(peers) != 0 {
|
||||
t.Errorf("expected 0 connected peers, got %d", len(peers))
|
||||
|
||||
if stats.Uptime < 0 {
|
||||
t.Fatalf("expected non-negative uptime, got %v", stats.Uptime)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutAndGetValue(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host, WithAutoBootstrap(false))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Test without bootstrap (should fail)
|
||||
key := "test-key"
|
||||
value := []byte("test-value")
|
||||
|
||||
err = d.PutValue(ctx, key, value)
|
||||
if err == nil {
|
||||
t.Error("PutValue should fail when DHT not bootstrapped")
|
||||
}
|
||||
|
||||
_, err = d.GetValue(ctx, key)
|
||||
if err == nil {
|
||||
t.Error("GetValue should fail when DHT not bootstrapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvideAndFindProviders(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host, WithAutoBootstrap(false))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Test without bootstrap (should fail)
|
||||
key := "test-service"
|
||||
|
||||
err = d.Provide(ctx, key)
|
||||
if err == nil {
|
||||
t.Error("Provide should fail when DHT not bootstrapped")
|
||||
}
|
||||
|
||||
_, err = d.FindProviders(ctx, key, 10)
|
||||
if err == nil {
|
||||
t.Error("FindProviders should fail when DHT not bootstrapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPeer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host, WithAutoBootstrap(false))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Test without bootstrap (should fail)
|
||||
peerID := test.RandPeerIDFatal(t)
|
||||
|
||||
_, err = d.FindPeer(ctx, peerID)
|
||||
if err == nil {
|
||||
t.Error("FindPeer should fail when DHT not bootstrapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPeersByRole(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host, WithAutoBootstrap(false))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Register some local peers
|
||||
peerID1 := test.RandPeerIDFatal(t)
|
||||
peerID2 := test.RandPeerIDFatal(t)
|
||||
|
||||
d.RegisterPeer(peerID1, "claude", "frontend", []string{"react"})
|
||||
d.RegisterPeer(peerID2, "claude", "backend", []string{"go"})
|
||||
|
||||
// Find frontend peers
|
||||
frontendPeers, err := d.FindPeersByRole(ctx, "frontend")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to find peers by role: %v", err)
|
||||
}
|
||||
|
||||
if len(frontendPeers) != 1 {
|
||||
t.Errorf("expected 1 frontend peer, got %d", len(frontendPeers))
|
||||
}
|
||||
|
||||
if frontendPeers[0].ID != peerID1 {
|
||||
t.Error("wrong peer returned for frontend role")
|
||||
}
|
||||
|
||||
// Find all peers with wildcard
|
||||
allPeers, err := d.FindPeersByRole(ctx, "*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to find all peers: %v", err)
|
||||
}
|
||||
|
||||
if len(allPeers) != 2 {
|
||||
t.Errorf("expected 2 peers with wildcard, got %d", len(allPeers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceRole(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host, WithAutoBootstrap(false))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Should fail when not bootstrapped
|
||||
err = d.AnnounceRole(ctx, "frontend")
|
||||
if err == nil {
|
||||
t.Error("AnnounceRole should fail when DHT not bootstrapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceCapability(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host, WithAutoBootstrap(false))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Should fail when not bootstrapped
|
||||
err = d.AnnounceCapability(ctx, "react")
|
||||
if err == nil {
|
||||
t.Error("AnnounceCapability should fail when DHT not bootstrapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRoutingTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
rt := d.GetRoutingTable()
|
||||
if rt == nil {
|
||||
t.Error("routing table should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDHTSize(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
size := d.GetDHTSize()
|
||||
// Should be 0 or small initially
|
||||
if size < 0 {
|
||||
t.Errorf("DHT size should be non-negative, got %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshRoutingTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host, WithAutoBootstrap(false))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// Should fail when not bootstrapped
|
||||
err = d.RefreshRoutingTable()
|
||||
if err == nil {
|
||||
t.Error("RefreshRoutingTable should fail when DHT not bootstrapped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHost(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
if d.Host() != host {
|
||||
t.Error("Host() should return the same host instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
host, err := libp2p.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test host: %v", err)
|
||||
}
|
||||
defer host.Close()
|
||||
|
||||
d, err := NewDHT(ctx, host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create DHT: %v", err)
|
||||
}
|
||||
|
||||
// Should close without error
|
||||
err = d.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -2,559 +2,155 @@ package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/config"
|
||||
)
|
||||
|
||||
// TestDHTSecurityPolicyEnforcement tests security policy enforcement in DHT operations
|
||||
func TestDHTSecurityPolicyEnforcement(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
currentRole string
|
||||
operation string
|
||||
ucxlAddress string
|
||||
contentType string
|
||||
expectSuccess bool
|
||||
expectedError string
|
||||
}{
|
||||
// Store operation tests
|
||||
type securityTestCase struct {
|
||||
name string
|
||||
role string
|
||||
address string
|
||||
contentType string
|
||||
expectSuccess bool
|
||||
expectErrHint string
|
||||
}
|
||||
|
||||
func newTestEncryptedStorage(cfg *config.Config) *EncryptedDHTStorage {
|
||||
return &EncryptedDHTStorage{
|
||||
ctx: context.Background(),
|
||||
config: cfg,
|
||||
nodeID: "test-node",
|
||||
cache: make(map[string]*CachedEntry),
|
||||
metrics: &StorageMetrics{LastUpdate: time.Now()},
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckStoreAccessPolicy(t *testing.T) {
|
||||
cases := []securityTestCase{
|
||||
{
|
||||
name: "admin_can_store_all_content",
|
||||
currentRole: "admin",
|
||||
operation: "store",
|
||||
ucxlAddress: "agent1:admin:system:security_audit",
|
||||
name: "backend developer can store",
|
||||
role: "backend_developer",
|
||||
address: "agent1:backend_developer:api:endpoint",
|
||||
contentType: "decision",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "backend_developer_can_store_backend_content",
|
||||
currentRole: "backend_developer",
|
||||
operation: "store",
|
||||
ucxlAddress: "agent1:backend_developer:api:endpoint_design",
|
||||
contentType: "suggestion",
|
||||
name: "project manager can store",
|
||||
role: "project_manager",
|
||||
address: "agent1:project_manager:plan:milestone",
|
||||
contentType: "decision",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "readonly_role_cannot_store",
|
||||
currentRole: "readonly_user",
|
||||
operation: "store",
|
||||
ucxlAddress: "agent1:readonly_user:project:observation",
|
||||
contentType: "suggestion",
|
||||
expectSuccess: false,
|
||||
expectedError: "read-only authority",
|
||||
name: "read only user cannot store",
|
||||
role: "readonly_user",
|
||||
address: "agent1:readonly_user:note:observation",
|
||||
contentType: "note",
|
||||
expectSuccess: false,
|
||||
expectErrHint: "read-only authority",
|
||||
},
|
||||
{
|
||||
name: "unknown_role_cannot_store",
|
||||
currentRole: "invalid_role",
|
||||
operation: "store",
|
||||
ucxlAddress: "agent1:invalid_role:project:task",
|
||||
contentType: "decision",
|
||||
expectSuccess: false,
|
||||
expectedError: "unknown creator role",
|
||||
},
|
||||
|
||||
// Retrieve operation tests
|
||||
{
|
||||
name: "any_valid_role_can_retrieve",
|
||||
currentRole: "qa_engineer",
|
||||
operation: "retrieve",
|
||||
ucxlAddress: "agent1:backend_developer:api:test_data",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "unknown_role_cannot_retrieve",
|
||||
currentRole: "nonexistent_role",
|
||||
operation: "retrieve",
|
||||
ucxlAddress: "agent1:backend_developer:api:test_data",
|
||||
expectSuccess: false,
|
||||
expectedError: "unknown current role",
|
||||
},
|
||||
|
||||
// Announce operation tests
|
||||
{
|
||||
name: "coordination_role_can_announce",
|
||||
currentRole: "senior_software_architect",
|
||||
operation: "announce",
|
||||
ucxlAddress: "agent1:senior_software_architect:architecture:blueprint",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "decision_role_can_announce",
|
||||
currentRole: "security_expert",
|
||||
operation: "announce",
|
||||
ucxlAddress: "agent1:security_expert:security:policy",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "suggestion_role_cannot_announce",
|
||||
currentRole: "suggestion_only_role",
|
||||
operation: "announce",
|
||||
ucxlAddress: "agent1:suggestion_only_role:project:idea",
|
||||
expectSuccess: false,
|
||||
expectedError: "lacks authority",
|
||||
},
|
||||
{
|
||||
name: "readonly_role_cannot_announce",
|
||||
currentRole: "readonly_user",
|
||||
operation: "announce",
|
||||
ucxlAddress: "agent1:readonly_user:project:observation",
|
||||
expectSuccess: false,
|
||||
expectedError: "lacks authority",
|
||||
name: "unknown role rejected",
|
||||
role: "ghost_role",
|
||||
address: "agent1:ghost_role:context",
|
||||
contentType: "decision",
|
||||
expectSuccess: false,
|
||||
expectErrHint: "unknown creator role",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
cfg := &config.Config{Agent: config.AgentConfig{}}
|
||||
eds := newTestEncryptedStorage(cfg)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test configuration
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-agent",
|
||||
Role: tc.currentRole,
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
KeyRotationDays: 90,
|
||||
AuditLogging: true,
|
||||
AuditPath: "/tmp/test-security-audit.log",
|
||||
},
|
||||
}
|
||||
|
||||
// Create mock encrypted storage
|
||||
eds := createMockEncryptedStorage(ctx, cfg)
|
||||
|
||||
var err error
|
||||
switch tc.operation {
|
||||
case "store":
|
||||
err = eds.checkStoreAccessPolicy(tc.currentRole, tc.ucxlAddress, tc.contentType)
|
||||
case "retrieve":
|
||||
err = eds.checkRetrieveAccessPolicy(tc.currentRole, tc.ucxlAddress)
|
||||
case "announce":
|
||||
err = eds.checkAnnounceAccessPolicy(tc.currentRole, tc.ucxlAddress)
|
||||
}
|
||||
|
||||
if tc.expectSuccess {
|
||||
if err != nil {
|
||||
t.Errorf("Expected %s operation to succeed for role %s, but got error: %v",
|
||||
tc.operation, tc.currentRole, err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("Expected %s operation to fail for role %s, but it succeeded",
|
||||
tc.operation, tc.currentRole)
|
||||
}
|
||||
if tc.expectedError != "" && !containsSubstring(err.Error(), tc.expectedError) {
|
||||
t.Errorf("Expected error to contain '%s', got '%s'", tc.expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
err := eds.checkStoreAccessPolicy(tc.role, tc.address, tc.contentType)
|
||||
verifySecurityExpectation(t, tc.expectSuccess, tc.expectErrHint, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDHTAuditLogging tests comprehensive audit logging for DHT operations
|
||||
func TestDHTAuditLogging(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
operation string
|
||||
role string
|
||||
ucxlAddress string
|
||||
success bool
|
||||
errorMsg string
|
||||
expectAudit bool
|
||||
}{
|
||||
func TestCheckRetrieveAccessPolicy(t *testing.T) {
|
||||
cases := []securityTestCase{
|
||||
{
|
||||
name: "successful_store_operation",
|
||||
operation: "store",
|
||||
role: "backend_developer",
|
||||
ucxlAddress: "agent1:backend_developer:api:user_service",
|
||||
success: true,
|
||||
expectAudit: true,
|
||||
name: "qa engineer allowed",
|
||||
role: "qa_engineer",
|
||||
address: "agent1:backend_developer:api:tests",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "failed_store_operation",
|
||||
operation: "store",
|
||||
role: "readonly_user",
|
||||
ucxlAddress: "agent1:readonly_user:project:readonly_attempt",
|
||||
success: false,
|
||||
errorMsg: "read-only authority",
|
||||
expectAudit: true,
|
||||
},
|
||||
{
|
||||
name: "successful_retrieve_operation",
|
||||
operation: "retrieve",
|
||||
role: "frontend_developer",
|
||||
ucxlAddress: "agent1:backend_developer:api:user_data",
|
||||
success: true,
|
||||
expectAudit: true,
|
||||
},
|
||||
{
|
||||
name: "successful_announce_operation",
|
||||
operation: "announce",
|
||||
role: "senior_software_architect",
|
||||
ucxlAddress: "agent1:senior_software_architect:architecture:system_design",
|
||||
success: true,
|
||||
expectAudit: true,
|
||||
},
|
||||
{
|
||||
name: "audit_disabled_no_logging",
|
||||
operation: "store",
|
||||
role: "backend_developer",
|
||||
ucxlAddress: "agent1:backend_developer:api:no_audit",
|
||||
success: true,
|
||||
expectAudit: false,
|
||||
name: "unknown role rejected",
|
||||
role: "unknown",
|
||||
address: "agent1:backend_developer:api:tests",
|
||||
expectSuccess: false,
|
||||
expectErrHint: "unknown current role",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
cfg := &config.Config{Agent: config.AgentConfig{}}
|
||||
eds := newTestEncryptedStorage(cfg)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create configuration with audit logging
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-agent",
|
||||
Role: tc.role,
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
KeyRotationDays: 90,
|
||||
AuditLogging: tc.expectAudit,
|
||||
AuditPath: "/tmp/test-dht-audit.log",
|
||||
},
|
||||
}
|
||||
|
||||
// Create mock encrypted storage
|
||||
eds := createMockEncryptedStorage(ctx, cfg)
|
||||
|
||||
// Capture audit output
|
||||
auditCaptured := false
|
||||
|
||||
// Simulate audit operation
|
||||
switch tc.operation {
|
||||
case "store":
|
||||
// Mock the audit function call
|
||||
if tc.expectAudit && cfg.Security.AuditLogging {
|
||||
eds.auditStoreOperation(tc.ucxlAddress, tc.role, "test-content", 1024, tc.success, tc.errorMsg)
|
||||
auditCaptured = true
|
||||
}
|
||||
case "retrieve":
|
||||
if tc.expectAudit && cfg.Security.AuditLogging {
|
||||
eds.auditRetrieveOperation(tc.ucxlAddress, tc.role, tc.success, tc.errorMsg)
|
||||
auditCaptured = true
|
||||
}
|
||||
case "announce":
|
||||
if tc.expectAudit && cfg.Security.AuditLogging {
|
||||
eds.auditAnnounceOperation(tc.ucxlAddress, tc.role, tc.success, tc.errorMsg)
|
||||
auditCaptured = true
|
||||
}
|
||||
}
|
||||
|
||||
// Verify audit logging behavior
|
||||
if tc.expectAudit && !auditCaptured {
|
||||
t.Errorf("Expected audit logging for %s operation but none was captured", tc.operation)
|
||||
}
|
||||
if !tc.expectAudit && auditCaptured {
|
||||
t.Errorf("Expected no audit logging for %s operation but audit was captured", tc.operation)
|
||||
}
|
||||
err := eds.checkRetrieveAccessPolicy(tc.role, tc.address)
|
||||
verifySecurityExpectation(t, tc.expectSuccess, tc.expectErrHint, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSecurityConfigIntegration tests integration with SecurityConfig
|
||||
func TestSecurityConfigIntegration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testConfigs := []struct {
|
||||
name string
|
||||
auditLogging bool
|
||||
auditPath string
|
||||
expectAuditWork bool
|
||||
}{
|
||||
func TestCheckAnnounceAccessPolicy(t *testing.T) {
|
||||
cases := []securityTestCase{
|
||||
{
|
||||
name: "audit_enabled_with_path",
|
||||
auditLogging: true,
|
||||
auditPath: "/tmp/test-audit-enabled.log",
|
||||
expectAuditWork: true,
|
||||
name: "architect can announce",
|
||||
role: "senior_software_architect",
|
||||
address: "agent1:senior_software_architect:architecture:proposal",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "audit_disabled",
|
||||
auditLogging: false,
|
||||
auditPath: "/tmp/test-audit-disabled.log",
|
||||
expectAuditWork: false,
|
||||
name: "suggestion role cannot announce",
|
||||
role: "suggestion_only_role",
|
||||
address: "agent1:suggestion_only_role:idea",
|
||||
expectSuccess: false,
|
||||
expectErrHint: "lacks authority",
|
||||
},
|
||||
{
|
||||
name: "audit_enabled_no_path",
|
||||
auditLogging: true,
|
||||
auditPath: "",
|
||||
expectAuditWork: false,
|
||||
name: "unknown role rejected",
|
||||
role: "mystery",
|
||||
address: "agent1:mystery:topic",
|
||||
expectSuccess: false,
|
||||
expectErrHint: "unknown current role",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testConfigs {
|
||||
cfg := &config.Config{Agent: config.AgentConfig{}}
|
||||
eds := newTestEncryptedStorage(cfg)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-agent",
|
||||
Role: "backend_developer",
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
KeyRotationDays: 90,
|
||||
AuditLogging: tc.auditLogging,
|
||||
AuditPath: tc.auditPath,
|
||||
},
|
||||
}
|
||||
|
||||
eds := createMockEncryptedStorage(ctx, cfg)
|
||||
|
||||
// Test audit function behavior with different configurations
|
||||
auditWorked := func() bool {
|
||||
if !cfg.Security.AuditLogging || cfg.Security.AuditPath == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}()
|
||||
|
||||
if auditWorked != tc.expectAuditWork {
|
||||
t.Errorf("Expected audit to work: %v, but got: %v", tc.expectAuditWork, auditWorked)
|
||||
}
|
||||
err := eds.checkAnnounceAccessPolicy(tc.role, tc.address)
|
||||
verifySecurityExpectation(t, tc.expectSuccess, tc.expectErrHint, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoleAuthorityHierarchy tests role authority hierarchy enforcement
|
||||
func TestRoleAuthorityHierarchy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test role authority levels for different operations
|
||||
authorityTests := []struct {
|
||||
role string
|
||||
authorityLevel config.AuthorityLevel
|
||||
canStore bool
|
||||
canRetrieve bool
|
||||
canAnnounce bool
|
||||
}{
|
||||
{
|
||||
role: "admin",
|
||||
authorityLevel: config.AuthorityMaster,
|
||||
canStore: true,
|
||||
canRetrieve: true,
|
||||
canAnnounce: true,
|
||||
},
|
||||
{
|
||||
role: "senior_software_architect",
|
||||
authorityLevel: config.AuthorityDecision,
|
||||
canStore: true,
|
||||
canRetrieve: true,
|
||||
canAnnounce: true,
|
||||
},
|
||||
{
|
||||
role: "security_expert",
|
||||
authorityLevel: config.AuthorityCoordination,
|
||||
canStore: true,
|
||||
canRetrieve: true,
|
||||
canAnnounce: true,
|
||||
},
|
||||
{
|
||||
role: "backend_developer",
|
||||
authorityLevel: config.AuthoritySuggestion,
|
||||
canStore: true,
|
||||
canRetrieve: true,
|
||||
canAnnounce: false,
|
||||
},
|
||||
func verifySecurityExpectation(t *testing.T, expectSuccess bool, hint string, err error) {
|
||||
t.Helper()
|
||||
|
||||
if expectSuccess {
|
||||
if err != nil {
|
||||
t.Fatalf("expected success, got error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, tt := range authorityTests {
|
||||
t.Run(tt.role+"_authority_test", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-agent",
|
||||
Role: tt.role,
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
KeyRotationDays: 90,
|
||||
AuditLogging: true,
|
||||
AuditPath: "/tmp/test-authority.log",
|
||||
},
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got success")
|
||||
}
|
||||
|
||||
eds := createMockEncryptedStorage(ctx, cfg)
|
||||
|
||||
// Test store permission
|
||||
storeErr := eds.checkStoreAccessPolicy(tt.role, "test:address", "content")
|
||||
if tt.canStore && storeErr != nil {
|
||||
t.Errorf("Role %s should be able to store but got error: %v", tt.role, storeErr)
|
||||
}
|
||||
if !tt.canStore && storeErr == nil {
|
||||
t.Errorf("Role %s should not be able to store but operation succeeded", tt.role)
|
||||
}
|
||||
|
||||
// Test retrieve permission
|
||||
retrieveErr := eds.checkRetrieveAccessPolicy(tt.role, "test:address")
|
||||
if tt.canRetrieve && retrieveErr != nil {
|
||||
t.Errorf("Role %s should be able to retrieve but got error: %v", tt.role, retrieveErr)
|
||||
}
|
||||
if !tt.canRetrieve && retrieveErr == nil {
|
||||
t.Errorf("Role %s should not be able to retrieve but operation succeeded", tt.role)
|
||||
}
|
||||
|
||||
// Test announce permission
|
||||
announceErr := eds.checkAnnounceAccessPolicy(tt.role, "test:address")
|
||||
if tt.canAnnounce && announceErr != nil {
|
||||
t.Errorf("Role %s should be able to announce but got error: %v", tt.role, announceErr)
|
||||
}
|
||||
if !tt.canAnnounce && announceErr == nil {
|
||||
t.Errorf("Role %s should not be able to announce but operation succeeded", tt.role)
|
||||
}
|
||||
})
|
||||
if hint != "" && !strings.Contains(err.Error(), hint) {
|
||||
t.Fatalf("expected error to contain %q, got %q", hint, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSecurityMetrics tests security-related metrics
|
||||
func TestSecurityMetrics(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-agent",
|
||||
Role: "backend_developer",
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
KeyRotationDays: 90,
|
||||
AuditLogging: true,
|
||||
AuditPath: "/tmp/test-metrics.log",
|
||||
},
|
||||
}
|
||||
|
||||
eds := createMockEncryptedStorage(ctx, cfg)
|
||||
|
||||
// Simulate some operations to generate metrics
|
||||
for i := 0; i < 5; i++ {
|
||||
eds.metrics.StoredItems++
|
||||
eds.metrics.RetrievedItems++
|
||||
eds.metrics.EncryptionOps++
|
||||
eds.metrics.DecryptionOps++
|
||||
}
|
||||
|
||||
metrics := eds.GetMetrics()
|
||||
|
||||
expectedMetrics := map[string]int64{
|
||||
"stored_items": 5,
|
||||
"retrieved_items": 5,
|
||||
"encryption_ops": 5,
|
||||
"decryption_ops": 5,
|
||||
}
|
||||
|
||||
for metricName, expectedValue := range expectedMetrics {
|
||||
if actualValue, ok := metrics[metricName]; !ok {
|
||||
t.Errorf("Expected metric %s to be present in metrics", metricName)
|
||||
} else if actualValue != expectedValue {
|
||||
t.Errorf("Expected %s to be %d, got %v", metricName, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createMockEncryptedStorage(ctx context.Context, cfg *config.Config) *EncryptedDHTStorage {
|
||||
return &EncryptedDHTStorage{
|
||||
ctx: ctx,
|
||||
config: cfg,
|
||||
nodeID: "test-node-id",
|
||||
cache: make(map[string]*CachedEntry),
|
||||
metrics: &StorageMetrics{
|
||||
LastUpdate: time.Now(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func containsSubstring(str, substr string) bool {
|
||||
if len(substr) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(str) < len(substr) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(str)-len(substr); i++ {
|
||||
if str[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Benchmarks for security performance
|
||||
|
||||
func BenchmarkSecurityPolicyChecks(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "bench-agent",
|
||||
Role: "backend_developer",
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
KeyRotationDays: 90,
|
||||
AuditLogging: true,
|
||||
AuditPath: "/tmp/bench-security.log",
|
||||
},
|
||||
}
|
||||
|
||||
eds := createMockEncryptedStorage(ctx, cfg)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
b.Run("store_policy_check", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
eds.checkStoreAccessPolicy("backend_developer", "test:address", "content")
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("retrieve_policy_check", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
eds.checkRetrieveAccessPolicy("backend_developer", "test:address")
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("announce_policy_check", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
eds.checkAnnounceAccessPolicy("senior_software_architect", "test:address")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkAuditOperations(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "bench-agent",
|
||||
Role: "backend_developer",
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
KeyRotationDays: 90,
|
||||
AuditLogging: true,
|
||||
AuditPath: "/tmp/bench-audit.log",
|
||||
},
|
||||
}
|
||||
|
||||
eds := createMockEncryptedStorage(ctx, cfg)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
b.Run("store_audit", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
eds.auditStoreOperation("test:address", "backend_developer", "content", 1024, true, "")
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("retrieve_audit", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
eds.auditRetrieveOperation("test:address", "backend_developer", true, "")
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("announce_audit", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
eds.auditAnnounceOperation("test:address", "backend_developer", true, "")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,14 +1,117 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"chorus/pkg/config"
|
||||
libp2p "github.com/libp2p/go-libp2p"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
// NewRealDHT creates a new real DHT implementation
|
||||
func NewRealDHT(config *config.HybridConfig) (DHT, error) {
|
||||
// TODO: Implement real DHT initialization
|
||||
// For now, return an error to indicate it's not yet implemented
|
||||
return nil, fmt.Errorf("real DHT implementation not yet available")
|
||||
}
|
||||
// RealDHT wraps a libp2p-based DHT to satisfy the generic DHT interface.
|
||||
type RealDHT struct {
|
||||
cancel context.CancelFunc
|
||||
host host.Host
|
||||
dht *LibP2PDHT
|
||||
}
|
||||
|
||||
// NewRealDHT creates a new real DHT implementation backed by libp2p.
|
||||
func NewRealDHT(cfg *config.HybridConfig) (DHT, error) {
|
||||
if cfg == nil {
|
||||
cfg = &config.HybridConfig{}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
listenAddr, err := multiaddr.NewMultiaddr("/ip4/0.0.0.0/tcp/0")
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to create listen address: %w", err)
|
||||
}
|
||||
|
||||
host, err := libp2p.New(
|
||||
libp2p.ListenAddrs(listenAddr),
|
||||
libp2p.Security(noise.ID, noise.New),
|
||||
libp2p.Transport(tcp.NewTCPTransport),
|
||||
libp2p.DefaultMuxers,
|
||||
libp2p.EnableRelay(),
|
||||
)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to create libp2p host: %w", err)
|
||||
}
|
||||
|
||||
opts := []Option{
|
||||
WithProtocolPrefix("/CHORUS"),
|
||||
}
|
||||
|
||||
if nodes := cfg.GetDHTBootstrapNodes(); len(nodes) > 0 {
|
||||
opts = append(opts, WithBootstrapPeersFromStrings(nodes))
|
||||
}
|
||||
|
||||
libp2pDHT, err := NewLibP2PDHT(ctx, host, opts...)
|
||||
if err != nil {
|
||||
host.Close()
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to initialize libp2p DHT: %w", err)
|
||||
}
|
||||
|
||||
if err := libp2pDHT.Bootstrap(); err != nil {
|
||||
libp2pDHT.Close()
|
||||
host.Close()
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to bootstrap DHT: %w", err)
|
||||
}
|
||||
|
||||
return &RealDHT{
|
||||
cancel: cancel,
|
||||
host: host,
|
||||
dht: libp2pDHT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PutValue stores a value in the DHT.
|
||||
func (r *RealDHT) PutValue(ctx context.Context, key string, value []byte) error {
|
||||
return r.dht.PutValue(ctx, key, value)
|
||||
}
|
||||
|
||||
// GetValue retrieves a value from the DHT.
|
||||
func (r *RealDHT) GetValue(ctx context.Context, key string) ([]byte, error) {
|
||||
return r.dht.GetValue(ctx, key)
|
||||
}
|
||||
|
||||
// Provide announces that this node can provide the given key.
|
||||
func (r *RealDHT) Provide(ctx context.Context, key string) error {
|
||||
return r.dht.Provide(ctx, key)
|
||||
}
|
||||
|
||||
// FindProviders locates peers that can provide the specified key.
|
||||
func (r *RealDHT) FindProviders(ctx context.Context, key string, limit int) ([]peer.AddrInfo, error) {
|
||||
return r.dht.FindProviders(ctx, key, limit)
|
||||
}
|
||||
|
||||
// GetStats exposes runtime metrics for the real DHT.
|
||||
func (r *RealDHT) GetStats() DHTStats {
|
||||
return r.dht.GetStats()
|
||||
}
|
||||
|
||||
// Close releases resources associated with the DHT.
|
||||
func (r *RealDHT) Close() error {
|
||||
r.cancel()
|
||||
|
||||
var errs []error
|
||||
if err := r.dht.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if err := r.host.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
@@ -2,159 +2,106 @@ package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestReplicationManager tests basic replication manager functionality
|
||||
func TestReplicationManager(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a mock DHT for testing
|
||||
mockDHT := NewMockDHTInterface()
|
||||
|
||||
// Create replication manager
|
||||
config := DefaultReplicationConfig()
|
||||
config.ReprovideInterval = 1 * time.Second // Short interval for testing
|
||||
config.CleanupInterval = 1 * time.Second
|
||||
|
||||
rm := NewReplicationManager(ctx, mockDHT.Mock(), config)
|
||||
defer rm.Stop()
|
||||
|
||||
// Test adding content
|
||||
testKey := "test-content-key"
|
||||
testSize := int64(1024)
|
||||
testPriority := 5
|
||||
|
||||
err := rm.AddContent(testKey, testSize, testPriority)
|
||||
func newReplicationManagerForTest(t *testing.T) *ReplicationManager {
|
||||
t.Helper()
|
||||
|
||||
cfg := &ReplicationConfig{
|
||||
ReplicationFactor: 3,
|
||||
ReprovideInterval: time.Hour,
|
||||
CleanupInterval: time.Hour,
|
||||
ProviderTTL: 30 * time.Minute,
|
||||
MaxProvidersPerKey: 5,
|
||||
EnableAutoReplication: false,
|
||||
EnableReprovide: false,
|
||||
MaxConcurrentReplications: 1,
|
||||
}
|
||||
|
||||
rm := NewReplicationManager(context.Background(), nil, cfg)
|
||||
t.Cleanup(func() {
|
||||
if rm.reprovideTimer != nil {
|
||||
rm.reprovideTimer.Stop()
|
||||
}
|
||||
if rm.cleanupTimer != nil {
|
||||
rm.cleanupTimer.Stop()
|
||||
}
|
||||
rm.cancel()
|
||||
})
|
||||
return rm
|
||||
}
|
||||
|
||||
func TestAddContentRegistersKey(t *testing.T) {
|
||||
rm := newReplicationManagerForTest(t)
|
||||
|
||||
if err := rm.AddContent("ucxl://example/path", 512, 1); err != nil {
|
||||
t.Fatalf("expected AddContent to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
rm.keysMutex.RLock()
|
||||
record, ok := rm.contentKeys["ucxl://example/path"]
|
||||
rm.keysMutex.RUnlock()
|
||||
|
||||
if !ok {
|
||||
t.Fatal("expected content key to be registered")
|
||||
}
|
||||
|
||||
if record.Size != 512 {
|
||||
t.Fatalf("expected size 512, got %d", record.Size)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveContentClearsTracking(t *testing.T) {
|
||||
rm := newReplicationManagerForTest(t)
|
||||
|
||||
if err := rm.AddContent("ucxl://example/path", 512, 1); err != nil {
|
||||
t.Fatalf("AddContent returned error: %v", err)
|
||||
}
|
||||
|
||||
if err := rm.RemoveContent("ucxl://example/path"); err != nil {
|
||||
t.Fatalf("RemoveContent returned error: %v", err)
|
||||
}
|
||||
|
||||
rm.keysMutex.RLock()
|
||||
_, exists := rm.contentKeys["ucxl://example/path"]
|
||||
rm.keysMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
t.Fatal("expected content key to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetReplicationStatusReturnsCopy(t *testing.T) {
|
||||
rm := newReplicationManagerForTest(t)
|
||||
|
||||
if err := rm.AddContent("ucxl://example/path", 512, 1); err != nil {
|
||||
t.Fatalf("AddContent returned error: %v", err)
|
||||
}
|
||||
|
||||
status, err := rm.GetReplicationStatus("ucxl://example/path")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add content: %v", err)
|
||||
t.Fatalf("GetReplicationStatus returned error: %v", err)
|
||||
}
|
||||
|
||||
// Test getting replication status
|
||||
status, err := rm.GetReplicationStatus(testKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get replication status: %v", err)
|
||||
|
||||
if status.Key != "ucxl://example/path" {
|
||||
t.Fatalf("expected status key to match, got %s", status.Key)
|
||||
}
|
||||
|
||||
if status.Key != testKey {
|
||||
t.Errorf("Expected key %s, got %s", testKey, status.Key)
|
||||
|
||||
// Mutating status should not affect internal state
|
||||
status.HealthyProviders = 99
|
||||
internal, _ := rm.GetReplicationStatus("ucxl://example/path")
|
||||
if internal.HealthyProviders == 99 {
|
||||
t.Fatal("expected GetReplicationStatus to return a copy")
|
||||
}
|
||||
|
||||
if status.Size != testSize {
|
||||
t.Errorf("Expected size %d, got %d", testSize, status.Size)
|
||||
}
|
||||
|
||||
if status.Priority != testPriority {
|
||||
t.Errorf("Expected priority %d, got %d", testPriority, status.Priority)
|
||||
}
|
||||
|
||||
// Test providing content
|
||||
err = rm.ProvideContent(testKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to provide content: %v", err)
|
||||
}
|
||||
|
||||
// Test metrics
|
||||
}
|
||||
|
||||
func TestGetMetricsReturnsSnapshot(t *testing.T) {
|
||||
rm := newReplicationManagerForTest(t)
|
||||
|
||||
metrics := rm.GetMetrics()
|
||||
if metrics.TotalKeys != 1 {
|
||||
t.Errorf("Expected 1 total key, got %d", metrics.TotalKeys)
|
||||
}
|
||||
|
||||
// Test finding providers
|
||||
providers, err := rm.FindProviders(ctx, testKey, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find providers: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Found %d providers for key %s", len(providers), testKey)
|
||||
|
||||
// Test removing content
|
||||
err = rm.RemoveContent(testKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to remove content: %v", err)
|
||||
}
|
||||
|
||||
// Verify content was removed
|
||||
metrics = rm.GetMetrics()
|
||||
if metrics.TotalKeys != 0 {
|
||||
t.Errorf("Expected 0 total keys after removal, got %d", metrics.TotalKeys)
|
||||
if metrics == rm.metrics {
|
||||
t.Fatal("expected GetMetrics to return a copy of metrics")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLibP2PDHTReplication tests DHT replication functionality
|
||||
func TestLibP2PDHTReplication(t *testing.T) {
|
||||
// This would normally require a real libp2p setup
|
||||
// For now, just test the interface methods exist
|
||||
|
||||
// Mock test - in a real implementation, you'd set up actual libp2p hosts
|
||||
t.Log("DHT replication interface methods are implemented")
|
||||
|
||||
// Example of how the replication would be used:
|
||||
// 1. Add content for replication
|
||||
// 2. Content gets automatically provided to the DHT
|
||||
// 3. Other nodes can discover this node as a provider
|
||||
// 4. Periodic reproviding ensures content availability
|
||||
// 5. Replication metrics track system health
|
||||
}
|
||||
|
||||
// TestReplicationConfig tests replication configuration
|
||||
func TestReplicationConfig(t *testing.T) {
|
||||
config := DefaultReplicationConfig()
|
||||
|
||||
// Test default values
|
||||
if config.ReplicationFactor != 3 {
|
||||
t.Errorf("Expected default replication factor 3, got %d", config.ReplicationFactor)
|
||||
}
|
||||
|
||||
if config.ReprovideInterval != 12*time.Hour {
|
||||
t.Errorf("Expected default reprovide interval 12h, got %v", config.ReprovideInterval)
|
||||
}
|
||||
|
||||
if !config.EnableAutoReplication {
|
||||
t.Error("Expected auto replication to be enabled by default")
|
||||
}
|
||||
|
||||
if !config.EnableReprovide {
|
||||
t.Error("Expected reprovide to be enabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderInfo tests provider information tracking
|
||||
func TestProviderInfo(t *testing.T) {
|
||||
// Test distance calculation
|
||||
key := []byte("test-key")
|
||||
peerID := "test-peer-id"
|
||||
|
||||
distance := calculateDistance(key, []byte(peerID))
|
||||
|
||||
// Distance should be non-zero for different inputs
|
||||
if distance == 0 {
|
||||
t.Error("Expected non-zero distance for different inputs")
|
||||
}
|
||||
|
||||
t.Logf("Distance between key and peer: %d", distance)
|
||||
}
|
||||
|
||||
// TestReplicationMetrics tests metrics collection
|
||||
func TestReplicationMetrics(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockDHT := NewMockDHTInterface()
|
||||
rm := NewReplicationManager(ctx, mockDHT.Mock(), DefaultReplicationConfig())
|
||||
defer rm.Stop()
|
||||
|
||||
// Add some content
|
||||
for i := 0; i < 3; i++ {
|
||||
key := fmt.Sprintf("test-key-%d", i)
|
||||
rm.AddContent(key, int64(1000+i*100), i+1)
|
||||
}
|
||||
|
||||
metrics := rm.GetMetrics()
|
||||
|
||||
if metrics.TotalKeys != 3 {
|
||||
t.Errorf("Expected 3 total keys, got %d", metrics.TotalKeys)
|
||||
}
|
||||
|
||||
t.Logf("Replication metrics: %+v", metrics)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,451 +2,185 @@ package election
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/config"
|
||||
pubsubpkg "chorus/pubsub"
|
||||
libp2p "github.com/libp2p/go-libp2p"
|
||||
)
|
||||
|
||||
func TestElectionManager_NewElectionManager(t *testing.T) {
|
||||
// newTestElectionManager wires a real libp2p host and PubSub instance so the
|
||||
// election manager exercises the same code paths used in production.
|
||||
func newTestElectionManager(t *testing.T) *ElectionManager {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
host, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
|
||||
if err != nil {
|
||||
cancel()
|
||||
t.Fatalf("failed to create libp2p host: %v", err)
|
||||
}
|
||||
|
||||
ps, err := pubsubpkg.NewPubSub(ctx, host, "", "")
|
||||
if err != nil {
|
||||
host.Close()
|
||||
cancel()
|
||||
t.Fatalf("failed to create pubsub: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
ID: host.ID().String(),
|
||||
Role: "context_admin",
|
||||
Capabilities: []string{"admin_election", "context_curation"},
|
||||
Models: []string{"meta/llama-3.1-8b-instruct"},
|
||||
Specialization: "coordination",
|
||||
},
|
||||
Security: config.SecurityConfig{},
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
if em == nil {
|
||||
t.Fatal("Expected NewElectionManager to return non-nil manager")
|
||||
}
|
||||
em := NewElectionManager(ctx, cfg, host, ps, host.ID().String())
|
||||
|
||||
if em.nodeID != "test-node" {
|
||||
t.Errorf("Expected nodeID to be 'test-node', got %s", em.nodeID)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
em.Stop()
|
||||
ps.Close()
|
||||
host.Close()
|
||||
cancel()
|
||||
})
|
||||
|
||||
return em
|
||||
}
|
||||
|
||||
func TestNewElectionManagerInitialState(t *testing.T) {
|
||||
em := newTestElectionManager(t)
|
||||
|
||||
if em.state != StateIdle {
|
||||
t.Errorf("Expected initial state to be StateIdle, got %v", em.state)
|
||||
t.Fatalf("expected initial state %q, got %q", StateIdle, em.state)
|
||||
}
|
||||
|
||||
if em.currentTerm != 0 {
|
||||
t.Fatalf("expected initial term 0, got %d", em.currentTerm)
|
||||
}
|
||||
|
||||
if em.nodeID == "" {
|
||||
t.Fatal("expected nodeID to be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_StartElection(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
func TestElectionManagerCanBeAdmin(t *testing.T) {
|
||||
em := newTestElectionManager(t)
|
||||
|
||||
if !em.canBeAdmin() {
|
||||
t.Fatal("expected node to qualify for admin election")
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
// Start election
|
||||
err := em.StartElection()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start election: %v", err)
|
||||
}
|
||||
|
||||
// Verify state changed
|
||||
if em.state != StateCandidate {
|
||||
t.Errorf("Expected state to be StateCandidate after starting election, got %v", em.state)
|
||||
}
|
||||
|
||||
// Verify we added ourselves as a candidate
|
||||
em.mu.RLock()
|
||||
candidate, exists := em.candidates[em.nodeID]
|
||||
em.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Error("Expected to find ourselves as a candidate after starting election")
|
||||
}
|
||||
|
||||
if candidate.NodeID != em.nodeID {
|
||||
t.Errorf("Expected candidate NodeID to be %s, got %s", em.nodeID, candidate.NodeID)
|
||||
em.config.Agent.Capabilities = []string{"runtime_support"}
|
||||
if em.canBeAdmin() {
|
||||
t.Fatal("expected node without admin capabilities to be ineligible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_Vote(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
// Add a candidate first
|
||||
candidate := &AdminCandidate{
|
||||
NodeID: "candidate-1",
|
||||
Term: 1,
|
||||
Score: 0.8,
|
||||
Capabilities: []string{"admin"},
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
|
||||
em.mu.Lock()
|
||||
em.candidates["candidate-1"] = candidate
|
||||
em.mu.Unlock()
|
||||
|
||||
// Vote for the candidate
|
||||
err := em.Vote("candidate-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to vote: %v", err)
|
||||
}
|
||||
|
||||
// Verify vote was recorded
|
||||
em.mu.RLock()
|
||||
vote, exists := em.votes[em.nodeID]
|
||||
em.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Error("Expected to find our vote after voting")
|
||||
}
|
||||
|
||||
if vote != "candidate-1" {
|
||||
t.Errorf("Expected vote to be for 'candidate-1', got %s", vote)
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_VoteInvalidCandidate(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
// Try to vote for non-existent candidate
|
||||
err := em.Vote("non-existent")
|
||||
if err == nil {
|
||||
t.Error("Expected error when voting for non-existent candidate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_AddCandidate(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
candidate := &AdminCandidate{
|
||||
NodeID: "new-candidate",
|
||||
Term: 1,
|
||||
Score: 0.7,
|
||||
Capabilities: []string{"admin", "leader"},
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
|
||||
err := em.AddCandidate(candidate)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add candidate: %v", err)
|
||||
}
|
||||
|
||||
// Verify candidate was added
|
||||
em.mu.RLock()
|
||||
stored, exists := em.candidates["new-candidate"]
|
||||
em.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Error("Expected to find added candidate")
|
||||
}
|
||||
|
||||
if stored.NodeID != "new-candidate" {
|
||||
t.Errorf("Expected stored candidate NodeID to be 'new-candidate', got %s", stored.NodeID)
|
||||
}
|
||||
|
||||
if stored.Score != 0.7 {
|
||||
t.Errorf("Expected stored candidate score to be 0.7, got %f", stored.Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_FindElectionWinner(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
// Add candidates with different scores
|
||||
candidates := []*AdminCandidate{
|
||||
{
|
||||
NodeID: "candidate-1",
|
||||
Term: 1,
|
||||
Score: 0.6,
|
||||
Capabilities: []string{"admin"},
|
||||
LastSeen: time.Now(),
|
||||
},
|
||||
{
|
||||
NodeID: "candidate-2",
|
||||
Term: 1,
|
||||
Score: 0.8,
|
||||
Capabilities: []string{"admin", "leader"},
|
||||
LastSeen: time.Now(),
|
||||
},
|
||||
{
|
||||
NodeID: "candidate-3",
|
||||
Term: 1,
|
||||
Score: 0.7,
|
||||
Capabilities: []string{"admin"},
|
||||
LastSeen: time.Now(),
|
||||
},
|
||||
}
|
||||
func TestFindElectionWinnerPrefersVotesThenScore(t *testing.T) {
|
||||
em := newTestElectionManager(t)
|
||||
|
||||
em.mu.Lock()
|
||||
for _, candidate := range candidates {
|
||||
em.candidates[candidate.NodeID] = candidate
|
||||
em.candidates = map[string]*AdminCandidate{
|
||||
"candidate-1": {
|
||||
NodeID: "candidate-1",
|
||||
PeerID: em.host.ID(),
|
||||
Score: 0.65,
|
||||
},
|
||||
"candidate-2": {
|
||||
NodeID: "candidate-2",
|
||||
PeerID: em.host.ID(),
|
||||
Score: 0.80,
|
||||
},
|
||||
}
|
||||
em.votes = map[string]string{
|
||||
"voter-a": "candidate-1",
|
||||
"voter-b": "candidate-2",
|
||||
"voter-c": "candidate-2",
|
||||
}
|
||||
|
||||
// Add some votes
|
||||
em.votes["voter-1"] = "candidate-2"
|
||||
em.votes["voter-2"] = "candidate-2"
|
||||
em.votes["voter-3"] = "candidate-1"
|
||||
em.mu.Unlock()
|
||||
|
||||
// Find winner
|
||||
winner := em.findElectionWinner()
|
||||
|
||||
if winner == nil {
|
||||
t.Fatal("Expected findElectionWinner to return a winner")
|
||||
t.Fatal("expected a winner to be selected")
|
||||
}
|
||||
|
||||
// candidate-2 should win with most votes (2 votes)
|
||||
if winner.NodeID != "candidate-2" {
|
||||
t.Errorf("Expected winner to be 'candidate-2', got %s", winner.NodeID)
|
||||
t.Fatalf("expected candidate-2 to win, got %s", winner.NodeID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_FindElectionWinnerNoVotes(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
// Add candidates but no votes - should fall back to highest score
|
||||
candidates := []*AdminCandidate{
|
||||
{
|
||||
NodeID: "candidate-1",
|
||||
Term: 1,
|
||||
Score: 0.6,
|
||||
Capabilities: []string{"admin"},
|
||||
LastSeen: time.Now(),
|
||||
},
|
||||
{
|
||||
NodeID: "candidate-2",
|
||||
Term: 1,
|
||||
Score: 0.9, // Highest score
|
||||
Capabilities: []string{"admin", "leader"},
|
||||
LastSeen: time.Now(),
|
||||
},
|
||||
}
|
||||
func TestHandleElectionMessageAddsCandidate(t *testing.T) {
|
||||
em := newTestElectionManager(t)
|
||||
|
||||
em.mu.Lock()
|
||||
for _, candidate := range candidates {
|
||||
em.candidates[candidate.NodeID] = candidate
|
||||
}
|
||||
em.currentTerm = 3
|
||||
em.state = StateElecting
|
||||
em.mu.Unlock()
|
||||
|
||||
// Find winner without any votes
|
||||
winner := em.findElectionWinner()
|
||||
|
||||
if winner == nil {
|
||||
t.Fatal("Expected findElectionWinner to return a winner")
|
||||
}
|
||||
|
||||
// candidate-2 should win with highest score
|
||||
if winner.NodeID != "candidate-2" {
|
||||
t.Errorf("Expected winner to be 'candidate-2' (highest score), got %s", winner.NodeID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_HandleElectionVote(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
// Add a candidate first
|
||||
candidate := &AdminCandidate{
|
||||
NodeID: "candidate-1",
|
||||
Term: 1,
|
||||
Score: 0.8,
|
||||
Capabilities: []string{"admin"},
|
||||
LastSeen: time.Now(),
|
||||
NodeID: "peer-2",
|
||||
PeerID: em.host.ID(),
|
||||
Capabilities: []string{"admin_election"},
|
||||
Uptime: time.Second,
|
||||
Score: 0.75,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(candidate)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal candidate: %v", err)
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &data); err != nil {
|
||||
t.Fatalf("failed to unmarshal candidate payload: %v", err)
|
||||
}
|
||||
|
||||
em.mu.Lock()
|
||||
em.candidates["candidate-1"] = candidate
|
||||
em.mu.Unlock()
|
||||
|
||||
// Create vote message
|
||||
msg := ElectionMessage{
|
||||
Type: MessageTypeVote,
|
||||
NodeID: "voter-1",
|
||||
Data: map[string]interface{}{
|
||||
"candidate": "candidate-1",
|
||||
},
|
||||
Type: "candidacy_announcement",
|
||||
NodeID: "peer-2",
|
||||
Timestamp: time.Now(),
|
||||
Term: 3,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
// Handle the vote
|
||||
em.handleElectionVote(msg)
|
||||
serialized, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal election message: %v", err)
|
||||
}
|
||||
|
||||
em.handleElectionMessage(serialized)
|
||||
|
||||
// Verify vote was recorded
|
||||
em.mu.RLock()
|
||||
vote, exists := em.votes["voter-1"]
|
||||
_, exists := em.candidates["peer-2"]
|
||||
em.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Error("Expected vote to be recorded after handling vote message")
|
||||
}
|
||||
|
||||
if vote != "candidate-1" {
|
||||
t.Errorf("Expected recorded vote to be for 'candidate-1', got %s", vote)
|
||||
t.Fatal("expected candidacy announcement to register candidate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_HandleElectionVoteInvalidData(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
func TestSendAdminHeartbeatRequiresLeadership(t *testing.T) {
|
||||
em := newTestElectionManager(t)
|
||||
|
||||
if err := em.SendAdminHeartbeat(); err == nil {
|
||||
t.Fatal("expected error when non-admin sends heartbeat")
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
// Create vote message with invalid data
|
||||
msg := ElectionMessage{
|
||||
Type: MessageTypeVote,
|
||||
NodeID: "voter-1",
|
||||
Data: "invalid-data", // Should be map[string]interface{}
|
||||
if err := em.Start(); err != nil {
|
||||
t.Fatalf("failed to start election manager: %v", err)
|
||||
}
|
||||
|
||||
// Handle the vote - should not crash
|
||||
em.handleElectionVote(msg)
|
||||
|
||||
// Verify no vote was recorded
|
||||
em.mu.RLock()
|
||||
_, exists := em.votes["voter-1"]
|
||||
em.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
t.Error("Expected no vote to be recorded with invalid data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_CompleteElection(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
// Set up election state
|
||||
em.mu.Lock()
|
||||
em.state = StateCandidate
|
||||
em.currentTerm = 1
|
||||
em.currentAdmin = em.nodeID
|
||||
em.mu.Unlock()
|
||||
|
||||
// Add a candidate
|
||||
candidate := &AdminCandidate{
|
||||
NodeID: "winner",
|
||||
Term: 1,
|
||||
Score: 0.9,
|
||||
Capabilities: []string{"admin", "leader"},
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
|
||||
em.mu.Lock()
|
||||
em.candidates["winner"] = candidate
|
||||
em.mu.Unlock()
|
||||
|
||||
// Complete election
|
||||
em.CompleteElection()
|
||||
|
||||
// Verify state reset
|
||||
em.mu.RLock()
|
||||
state := em.state
|
||||
em.mu.RUnlock()
|
||||
|
||||
if state != StateIdle {
|
||||
t.Errorf("Expected state to be StateIdle after completing election, got %v", state)
|
||||
if err := em.SendAdminHeartbeat(); err != nil {
|
||||
t.Fatalf("expected heartbeat to succeed for current admin, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestElectionManager_Concurrency(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agent: config.AgentConfig{
|
||||
ID: "test-node",
|
||||
},
|
||||
}
|
||||
|
||||
em := NewElectionManager(cfg)
|
||||
|
||||
// Test concurrent access to vote and candidate operations
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Add a candidate
|
||||
candidate := &AdminCandidate{
|
||||
NodeID: "candidate-1",
|
||||
Term: 1,
|
||||
Score: 0.8,
|
||||
Capabilities: []string{"admin"},
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
|
||||
err := em.AddCandidate(candidate)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add candidate: %v", err)
|
||||
}
|
||||
|
||||
// Run concurrent operations
|
||||
done := make(chan bool, 2)
|
||||
|
||||
// Concurrent voting
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
for i := 0; i < 10; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
em.Vote("candidate-1") // Ignore errors in concurrent test
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Concurrent state checking
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
for i := 0; i < 10; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
em.findElectionWinner() // Just check for races
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for completion
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Concurrent test timed out")
|
||||
}
|
||||
}
|
||||
}
|
||||
1020
pkg/execution/docker.go
Normal file
1020
pkg/execution/docker.go
Normal file
File diff suppressed because it is too large
Load Diff
482
pkg/execution/docker_test.go
Normal file
482
pkg/execution/docker_test.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package execution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewDockerSandbox(t *testing.T) {
|
||||
sandbox := NewDockerSandbox()
|
||||
|
||||
assert.NotNil(t, sandbox)
|
||||
assert.NotNil(t, sandbox.environment)
|
||||
assert.Empty(t, sandbox.containerID)
|
||||
}
|
||||
|
||||
func TestDockerSandbox_Initialize(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := NewDockerSandbox()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a minimal configuration
|
||||
config := &SandboxConfig{
|
||||
Type: "docker",
|
||||
Image: "alpine:latest",
|
||||
Architecture: "amd64",
|
||||
Resources: ResourceLimits{
|
||||
MemoryLimit: 512 * 1024 * 1024, // 512MB
|
||||
CPULimit: 1.0,
|
||||
ProcessLimit: 50,
|
||||
FileLimit: 1024,
|
||||
},
|
||||
Security: SecurityPolicy{
|
||||
ReadOnlyRoot: false,
|
||||
NoNewPrivileges: true,
|
||||
AllowNetworking: false,
|
||||
IsolateNetwork: true,
|
||||
IsolateProcess: true,
|
||||
DropCapabilities: []string{"ALL"},
|
||||
},
|
||||
Environment: map[string]string{
|
||||
"TEST_VAR": "test_value",
|
||||
},
|
||||
WorkingDir: "/workspace",
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
err := sandbox.Initialize(ctx, config)
|
||||
if err != nil {
|
||||
t.Skipf("Docker not available or image pull failed: %v", err)
|
||||
}
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
// Verify sandbox is initialized
|
||||
assert.NotEmpty(t, sandbox.containerID)
|
||||
assert.Equal(t, config, sandbox.config)
|
||||
assert.Equal(t, StatusRunning, sandbox.info.Status)
|
||||
assert.Equal(t, "docker", sandbox.info.Type)
|
||||
}
|
||||
|
||||
func TestDockerSandbox_ExecuteCommand(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := setupTestSandbox(t)
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd *Command
|
||||
expectedExit int
|
||||
expectedOutput string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "simple echo command",
|
||||
cmd: &Command{
|
||||
Executable: "echo",
|
||||
Args: []string{"hello world"},
|
||||
},
|
||||
expectedExit: 0,
|
||||
expectedOutput: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "command with environment",
|
||||
cmd: &Command{
|
||||
Executable: "sh",
|
||||
Args: []string{"-c", "echo $TEST_VAR"},
|
||||
Environment: map[string]string{"TEST_VAR": "custom_value"},
|
||||
},
|
||||
expectedExit: 0,
|
||||
expectedOutput: "custom_value\n",
|
||||
},
|
||||
{
|
||||
name: "failing command",
|
||||
cmd: &Command{
|
||||
Executable: "sh",
|
||||
Args: []string{"-c", "exit 1"},
|
||||
},
|
||||
expectedExit: 1,
|
||||
},
|
||||
{
|
||||
name: "command with timeout",
|
||||
cmd: &Command{
|
||||
Executable: "sleep",
|
||||
Args: []string{"2"},
|
||||
Timeout: 1 * time.Second,
|
||||
},
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := sandbox.ExecuteCommand(ctx, tt.cmd)
|
||||
|
||||
if tt.shouldError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedExit, result.ExitCode)
|
||||
assert.Equal(t, tt.expectedExit == 0, result.Success)
|
||||
|
||||
if tt.expectedOutput != "" {
|
||||
assert.Equal(t, tt.expectedOutput, result.Stdout)
|
||||
}
|
||||
|
||||
assert.NotZero(t, result.Duration)
|
||||
assert.False(t, result.StartTime.IsZero())
|
||||
assert.False(t, result.EndTime.IsZero())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDockerSandbox_FileOperations(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := setupTestSandbox(t)
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test WriteFile
|
||||
testContent := []byte("Hello, Docker sandbox!")
|
||||
testPath := "/tmp/test_file.txt"
|
||||
|
||||
err := sandbox.WriteFile(ctx, testPath, testContent, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test ReadFile
|
||||
readContent, err := sandbox.ReadFile(ctx, testPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testContent, readContent)
|
||||
|
||||
// Test ListFiles
|
||||
files, err := sandbox.ListFiles(ctx, "/tmp")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, files)
|
||||
|
||||
// Find our test file
|
||||
var testFile *FileInfo
|
||||
for _, file := range files {
|
||||
if file.Name == "test_file.txt" {
|
||||
testFile = &file
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, testFile)
|
||||
assert.Equal(t, "test_file.txt", testFile.Name)
|
||||
assert.Equal(t, int64(len(testContent)), testFile.Size)
|
||||
assert.False(t, testFile.IsDir)
|
||||
}
|
||||
|
||||
func TestDockerSandbox_CopyFiles(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := setupTestSandbox(t)
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a temporary file on host
|
||||
tempDir := t.TempDir()
|
||||
hostFile := filepath.Join(tempDir, "host_file.txt")
|
||||
hostContent := []byte("Content from host")
|
||||
|
||||
err := os.WriteFile(hostFile, hostContent, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Copy from host to container
|
||||
containerPath := "container:/tmp/copied_file.txt"
|
||||
err = sandbox.CopyFiles(ctx, hostFile, containerPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify file exists in container
|
||||
readContent, err := sandbox.ReadFile(ctx, "/tmp/copied_file.txt")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, hostContent, readContent)
|
||||
|
||||
// Copy from container back to host
|
||||
hostDestFile := filepath.Join(tempDir, "copied_back.txt")
|
||||
err = sandbox.CopyFiles(ctx, "container:/tmp/copied_file.txt", hostDestFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify file exists on host
|
||||
backContent, err := os.ReadFile(hostDestFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, hostContent, backContent)
|
||||
}
|
||||
|
||||
func TestDockerSandbox_Environment(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := setupTestSandbox(t)
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
// Test getting initial environment
|
||||
env := sandbox.GetEnvironment()
|
||||
assert.Equal(t, "test_value", env["TEST_VAR"])
|
||||
|
||||
// Test setting additional environment
|
||||
newEnv := map[string]string{
|
||||
"NEW_VAR": "new_value",
|
||||
"PATH": "/custom/path",
|
||||
}
|
||||
|
||||
err := sandbox.SetEnvironment(newEnv)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify environment is updated
|
||||
env = sandbox.GetEnvironment()
|
||||
assert.Equal(t, "new_value", env["NEW_VAR"])
|
||||
assert.Equal(t, "/custom/path", env["PATH"])
|
||||
assert.Equal(t, "test_value", env["TEST_VAR"]) // Original should still be there
|
||||
}
|
||||
|
||||
func TestDockerSandbox_WorkingDirectory(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := setupTestSandbox(t)
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
// Test getting initial working directory
|
||||
workDir := sandbox.GetWorkingDirectory()
|
||||
assert.Equal(t, "/workspace", workDir)
|
||||
|
||||
// Test setting working directory
|
||||
newWorkDir := "/tmp"
|
||||
err := sandbox.SetWorkingDirectory(newWorkDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify working directory is updated
|
||||
workDir = sandbox.GetWorkingDirectory()
|
||||
assert.Equal(t, newWorkDir, workDir)
|
||||
}
|
||||
|
||||
func TestDockerSandbox_ResourceUsage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := setupTestSandbox(t)
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get resource usage
|
||||
usage, err := sandbox.GetResourceUsage(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify usage structure
|
||||
assert.NotNil(t, usage)
|
||||
assert.False(t, usage.Timestamp.IsZero())
|
||||
assert.GreaterOrEqual(t, usage.CPUUsage, 0.0)
|
||||
assert.GreaterOrEqual(t, usage.MemoryUsage, int64(0))
|
||||
assert.GreaterOrEqual(t, usage.MemoryPercent, 0.0)
|
||||
}
|
||||
|
||||
func TestDockerSandbox_GetInfo(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := setupTestSandbox(t)
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
info := sandbox.GetInfo()
|
||||
|
||||
assert.NotEmpty(t, info.ID)
|
||||
assert.Contains(t, info.Name, "chorus-sandbox")
|
||||
assert.Equal(t, "docker", info.Type)
|
||||
assert.Equal(t, StatusRunning, info.Status)
|
||||
assert.Equal(t, "docker", info.Runtime)
|
||||
assert.Equal(t, "alpine:latest", info.Image)
|
||||
assert.False(t, info.CreatedAt.IsZero())
|
||||
assert.False(t, info.StartedAt.IsZero())
|
||||
}
|
||||
|
||||
func TestDockerSandbox_Cleanup(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := setupTestSandbox(t)
|
||||
|
||||
// Verify sandbox is running
|
||||
assert.Equal(t, StatusRunning, sandbox.info.Status)
|
||||
assert.NotEmpty(t, sandbox.containerID)
|
||||
|
||||
// Cleanup
|
||||
err := sandbox.Cleanup()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify sandbox is destroyed
|
||||
assert.Equal(t, StatusDestroyed, sandbox.info.Status)
|
||||
}
|
||||
|
||||
func TestDockerSandbox_SecurityPolicies(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
sandbox := NewDockerSandbox()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create configuration with strict security policies
|
||||
config := &SandboxConfig{
|
||||
Type: "docker",
|
||||
Image: "alpine:latest",
|
||||
Architecture: "amd64",
|
||||
Resources: ResourceLimits{
|
||||
MemoryLimit: 256 * 1024 * 1024, // 256MB
|
||||
CPULimit: 0.5,
|
||||
ProcessLimit: 10,
|
||||
FileLimit: 256,
|
||||
},
|
||||
Security: SecurityPolicy{
|
||||
ReadOnlyRoot: true,
|
||||
NoNewPrivileges: true,
|
||||
AllowNetworking: false,
|
||||
IsolateNetwork: true,
|
||||
IsolateProcess: true,
|
||||
DropCapabilities: []string{"ALL"},
|
||||
RunAsUser: "1000",
|
||||
RunAsGroup: "1000",
|
||||
TmpfsPaths: []string{"/tmp", "/var/tmp"},
|
||||
MaskedPaths: []string{"/proc/kcore", "/proc/keys"},
|
||||
ReadOnlyPaths: []string{"/etc"},
|
||||
},
|
||||
WorkingDir: "/workspace",
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
err := sandbox.Initialize(ctx, config)
|
||||
if err != nil {
|
||||
t.Skipf("Docker not available or security policies not supported: %v", err)
|
||||
}
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
// Test that we can't write to read-only filesystem
|
||||
result, err := sandbox.ExecuteCommand(ctx, &Command{
|
||||
Executable: "touch",
|
||||
Args: []string{"/test_readonly"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, 0, result.ExitCode) // Should fail due to read-only root
|
||||
|
||||
// Test that tmpfs is writable
|
||||
result, err = sandbox.ExecuteCommand(ctx, &Command{
|
||||
Executable: "touch",
|
||||
Args: []string{"/tmp/test_tmpfs"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.ExitCode) // Should succeed on tmpfs
|
||||
}
|
||||
|
||||
// setupTestSandbox creates a basic Docker sandbox for testing
|
||||
func setupTestSandbox(t *testing.T) *DockerSandbox {
|
||||
sandbox := NewDockerSandbox()
|
||||
ctx := context.Background()
|
||||
|
||||
config := &SandboxConfig{
|
||||
Type: "docker",
|
||||
Image: "alpine:latest",
|
||||
Architecture: "amd64",
|
||||
Resources: ResourceLimits{
|
||||
MemoryLimit: 512 * 1024 * 1024, // 512MB
|
||||
CPULimit: 1.0,
|
||||
ProcessLimit: 50,
|
||||
FileLimit: 1024,
|
||||
},
|
||||
Security: SecurityPolicy{
|
||||
ReadOnlyRoot: false,
|
||||
NoNewPrivileges: true,
|
||||
AllowNetworking: true, // Allow networking for easier testing
|
||||
IsolateNetwork: false,
|
||||
IsolateProcess: true,
|
||||
DropCapabilities: []string{"NET_ADMIN", "SYS_ADMIN"},
|
||||
},
|
||||
Environment: map[string]string{
|
||||
"TEST_VAR": "test_value",
|
||||
},
|
||||
WorkingDir: "/workspace",
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
err := sandbox.Initialize(ctx, config)
|
||||
if err != nil {
|
||||
t.Skipf("Docker not available: %v", err)
|
||||
}
|
||||
|
||||
return sandbox
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkDockerSandbox_ExecuteCommand(b *testing.B) {
|
||||
if testing.Short() {
|
||||
b.Skip("Skipping Docker benchmark in short mode")
|
||||
}
|
||||
|
||||
sandbox := &DockerSandbox{}
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup minimal config for benchmarking
|
||||
config := &SandboxConfig{
|
||||
Type: "docker",
|
||||
Image: "alpine:latest",
|
||||
Architecture: "amd64",
|
||||
Resources: ResourceLimits{
|
||||
MemoryLimit: 256 * 1024 * 1024,
|
||||
CPULimit: 1.0,
|
||||
ProcessLimit: 50,
|
||||
},
|
||||
Security: SecurityPolicy{
|
||||
NoNewPrivileges: true,
|
||||
AllowNetworking: true,
|
||||
},
|
||||
WorkingDir: "/workspace",
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
err := sandbox.Initialize(ctx, config)
|
||||
if err != nil {
|
||||
b.Skipf("Docker not available: %v", err)
|
||||
}
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
cmd := &Command{
|
||||
Executable: "echo",
|
||||
Args: []string{"benchmark test"},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := sandbox.ExecuteCommand(ctx, cmd)
|
||||
if err != nil {
|
||||
b.Fatalf("Command execution failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
494
pkg/execution/engine.go
Normal file
494
pkg/execution/engine.go
Normal file
@@ -0,0 +1,494 @@
|
||||
package execution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/ai"
|
||||
)
|
||||
|
||||
// TaskExecutionEngine provides AI-powered task execution with isolated sandboxes
|
||||
type TaskExecutionEngine interface {
|
||||
ExecuteTask(ctx context.Context, request *TaskExecutionRequest) (*TaskExecutionResult, error)
|
||||
Initialize(ctx context.Context, config *EngineConfig) error
|
||||
Shutdown() error
|
||||
GetMetrics() *EngineMetrics
|
||||
}
|
||||
|
||||
// TaskExecutionRequest represents a task to be executed
|
||||
type TaskExecutionRequest struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Context map[string]interface{} `json:"context,omitempty"`
|
||||
Requirements *TaskRequirements `json:"requirements,omitempty"`
|
||||
Timeout time.Duration `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
// TaskRequirements specifies execution environment needs
|
||||
type TaskRequirements struct {
|
||||
AIModel string `json:"ai_model,omitempty"`
|
||||
SandboxType string `json:"sandbox_type,omitempty"`
|
||||
RequiredTools []string `json:"required_tools,omitempty"`
|
||||
EnvironmentVars map[string]string `json:"environment_vars,omitempty"`
|
||||
ResourceLimits *ResourceLimits `json:"resource_limits,omitempty"`
|
||||
SecurityPolicy *SecurityPolicy `json:"security_policy,omitempty"`
|
||||
}
|
||||
|
||||
// TaskExecutionResult contains the results of task execution
|
||||
type TaskExecutionResult struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Success bool `json:"success"`
|
||||
Output string `json:"output"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
Artifacts []TaskArtifact `json:"artifacts,omitempty"`
|
||||
Metrics *ExecutionMetrics `json:"metrics"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// TaskArtifact represents a file or data produced during execution
|
||||
type TaskArtifact struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Content []byte `json:"content,omitempty"`
|
||||
Size int64 `json:"size"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ExecutionMetrics tracks resource usage and performance
|
||||
type ExecutionMetrics struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
AIProviderTime time.Duration `json:"ai_provider_time"`
|
||||
SandboxTime time.Duration `json:"sandbox_time"`
|
||||
ResourceUsage *ResourceUsage `json:"resource_usage,omitempty"`
|
||||
CommandsExecuted int `json:"commands_executed"`
|
||||
FilesGenerated int `json:"files_generated"`
|
||||
}
|
||||
|
||||
// EngineConfig configures the task execution engine
|
||||
type EngineConfig struct {
|
||||
AIProviderFactory *ai.ProviderFactory `json:"-"`
|
||||
SandboxDefaults *SandboxConfig `json:"sandbox_defaults"`
|
||||
DefaultTimeout time.Duration `json:"default_timeout"`
|
||||
MaxConcurrentTasks int `json:"max_concurrent_tasks"`
|
||||
EnableMetrics bool `json:"enable_metrics"`
|
||||
LogLevel string `json:"log_level"`
|
||||
}
|
||||
|
||||
// EngineMetrics tracks overall engine performance
|
||||
type EngineMetrics struct {
|
||||
TasksExecuted int64 `json:"tasks_executed"`
|
||||
TasksSuccessful int64 `json:"tasks_successful"`
|
||||
TasksFailed int64 `json:"tasks_failed"`
|
||||
AverageTime time.Duration `json:"average_time"`
|
||||
TotalExecutionTime time.Duration `json:"total_execution_time"`
|
||||
ActiveTasks int `json:"active_tasks"`
|
||||
}
|
||||
|
||||
// DefaultTaskExecutionEngine implements the TaskExecutionEngine interface
|
||||
type DefaultTaskExecutionEngine struct {
|
||||
config *EngineConfig
|
||||
aiFactory *ai.ProviderFactory
|
||||
metrics *EngineMetrics
|
||||
activeTasks map[string]context.CancelFunc
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewTaskExecutionEngine creates a new task execution engine
|
||||
func NewTaskExecutionEngine() *DefaultTaskExecutionEngine {
|
||||
return &DefaultTaskExecutionEngine{
|
||||
metrics: &EngineMetrics{},
|
||||
activeTasks: make(map[string]context.CancelFunc),
|
||||
logger: log.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize configures and prepares the execution engine
|
||||
func (e *DefaultTaskExecutionEngine) Initialize(ctx context.Context, config *EngineConfig) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("engine config cannot be nil")
|
||||
}
|
||||
|
||||
if config.AIProviderFactory == nil {
|
||||
return fmt.Errorf("AI provider factory is required")
|
||||
}
|
||||
|
||||
e.config = config
|
||||
e.aiFactory = config.AIProviderFactory
|
||||
|
||||
// Set default values
|
||||
if e.config.DefaultTimeout == 0 {
|
||||
e.config.DefaultTimeout = 5 * time.Minute
|
||||
}
|
||||
if e.config.MaxConcurrentTasks == 0 {
|
||||
e.config.MaxConcurrentTasks = 10
|
||||
}
|
||||
|
||||
e.logger.Printf("TaskExecutionEngine initialized with %d max concurrent tasks", e.config.MaxConcurrentTasks)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteTask executes a task using AI providers and isolated sandboxes
|
||||
func (e *DefaultTaskExecutionEngine) ExecuteTask(ctx context.Context, request *TaskExecutionRequest) (*TaskExecutionResult, error) {
|
||||
if e.config == nil {
|
||||
return nil, fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Create task context with timeout
|
||||
timeout := request.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = e.config.DefaultTimeout
|
||||
}
|
||||
|
||||
taskCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Track active task
|
||||
e.activeTasks[request.ID] = cancel
|
||||
defer delete(e.activeTasks, request.ID)
|
||||
|
||||
e.metrics.ActiveTasks++
|
||||
defer func() { e.metrics.ActiveTasks-- }()
|
||||
|
||||
result := &TaskExecutionResult{
|
||||
TaskID: request.ID,
|
||||
Metrics: &ExecutionMetrics{StartTime: startTime},
|
||||
}
|
||||
|
||||
// Execute the task
|
||||
err := e.executeTaskInternal(taskCtx, request, result)
|
||||
|
||||
// Update metrics
|
||||
result.Metrics.EndTime = time.Now()
|
||||
result.Metrics.Duration = result.Metrics.EndTime.Sub(result.Metrics.StartTime)
|
||||
|
||||
e.metrics.TasksExecuted++
|
||||
e.metrics.TotalExecutionTime += result.Metrics.Duration
|
||||
|
||||
if err != nil {
|
||||
result.Success = false
|
||||
result.ErrorMessage = err.Error()
|
||||
e.metrics.TasksFailed++
|
||||
e.logger.Printf("Task %s failed: %v", request.ID, err)
|
||||
} else {
|
||||
result.Success = true
|
||||
e.metrics.TasksSuccessful++
|
||||
e.logger.Printf("Task %s completed successfully in %v", request.ID, result.Metrics.Duration)
|
||||
}
|
||||
|
||||
e.metrics.AverageTime = e.metrics.TotalExecutionTime / time.Duration(e.metrics.TasksExecuted)
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// executeTaskInternal performs the actual task execution
|
||||
func (e *DefaultTaskExecutionEngine) executeTaskInternal(ctx context.Context, request *TaskExecutionRequest, result *TaskExecutionResult) error {
|
||||
// Step 1: Determine AI model and get provider
|
||||
aiStartTime := time.Now()
|
||||
|
||||
role := e.determineRoleFromTask(request)
|
||||
provider, providerConfig, err := e.aiFactory.GetProviderForRole(role)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get AI provider for role %s: %w", role, err)
|
||||
}
|
||||
|
||||
// Step 2: Create AI request
|
||||
aiRequest := &ai.TaskRequest{
|
||||
TaskID: request.ID,
|
||||
TaskTitle: request.Type,
|
||||
TaskDescription: request.Description,
|
||||
Context: request.Context,
|
||||
ModelName: providerConfig.DefaultModel,
|
||||
AgentRole: role,
|
||||
}
|
||||
|
||||
// Step 3: Get AI response
|
||||
aiResponse, err := provider.ExecuteTask(ctx, aiRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AI provider execution failed: %w", err)
|
||||
}
|
||||
|
||||
result.Metrics.AIProviderTime = time.Since(aiStartTime)
|
||||
|
||||
// Step 4: Parse AI response for executable commands
|
||||
commands, artifacts, err := e.parseAIResponse(aiResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse AI response: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Execute commands in sandbox if needed
|
||||
if len(commands) > 0 {
|
||||
sandboxStartTime := time.Now()
|
||||
|
||||
sandboxResult, err := e.executeSandboxCommands(ctx, request, commands)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sandbox execution failed: %w", err)
|
||||
}
|
||||
|
||||
result.Metrics.SandboxTime = time.Since(sandboxStartTime)
|
||||
result.Metrics.CommandsExecuted = len(commands)
|
||||
result.Metrics.ResourceUsage = sandboxResult.ResourceUsage
|
||||
|
||||
// Merge sandbox artifacts
|
||||
artifacts = append(artifacts, sandboxResult.Artifacts...)
|
||||
}
|
||||
|
||||
// Step 6: Process results and artifacts
|
||||
result.Output = e.formatOutput(aiResponse, artifacts)
|
||||
result.Artifacts = artifacts
|
||||
result.Metrics.FilesGenerated = len(artifacts)
|
||||
|
||||
// Add metadata
|
||||
result.Metadata = map[string]interface{}{
|
||||
"ai_provider": providerConfig.Type,
|
||||
"ai_model": providerConfig.DefaultModel,
|
||||
"role": role,
|
||||
"commands": len(commands),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// determineRoleFromTask analyzes the task to determine appropriate AI role
|
||||
func (e *DefaultTaskExecutionEngine) determineRoleFromTask(request *TaskExecutionRequest) string {
|
||||
taskType := strings.ToLower(request.Type)
|
||||
description := strings.ToLower(request.Description)
|
||||
|
||||
// Determine role based on task type and description keywords
|
||||
if strings.Contains(taskType, "code") || strings.Contains(description, "program") ||
|
||||
strings.Contains(description, "script") || strings.Contains(description, "function") {
|
||||
return "developer"
|
||||
}
|
||||
|
||||
if strings.Contains(taskType, "analysis") || strings.Contains(description, "analyze") ||
|
||||
strings.Contains(description, "review") {
|
||||
return "analyst"
|
||||
}
|
||||
|
||||
if strings.Contains(taskType, "test") || strings.Contains(description, "test") {
|
||||
return "tester"
|
||||
}
|
||||
|
||||
// Default to general purpose
|
||||
return "general"
|
||||
}
|
||||
|
||||
// parseAIResponse extracts executable commands and artifacts from AI response
|
||||
func (e *DefaultTaskExecutionEngine) parseAIResponse(response *ai.TaskResponse) ([]string, []TaskArtifact, error) {
|
||||
var commands []string
|
||||
var artifacts []TaskArtifact
|
||||
|
||||
// Parse response content for commands and files
|
||||
// This is a simplified parser - in reality would need more sophisticated parsing
|
||||
|
||||
if len(response.Actions) > 0 {
|
||||
for _, action := range response.Actions {
|
||||
switch action.Type {
|
||||
case "command", "command_run":
|
||||
// Extract command from content or target
|
||||
if action.Content != "" {
|
||||
commands = append(commands, action.Content)
|
||||
} else if action.Target != "" {
|
||||
commands = append(commands, action.Target)
|
||||
}
|
||||
case "file", "file_create", "file_edit":
|
||||
// Create artifact from file action
|
||||
if action.Target != "" && action.Content != "" {
|
||||
artifact := TaskArtifact{
|
||||
Name: action.Target,
|
||||
Type: "file",
|
||||
Content: []byte(action.Content),
|
||||
Size: int64(len(action.Content)),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
artifacts = append(artifacts, artifact)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return commands, artifacts, nil
|
||||
}
|
||||
|
||||
// SandboxExecutionResult contains results from sandbox command execution
|
||||
type SandboxExecutionResult struct {
|
||||
Output string
|
||||
Artifacts []TaskArtifact
|
||||
ResourceUsage *ResourceUsage
|
||||
}
|
||||
|
||||
// executeSandboxCommands runs commands in an isolated sandbox
|
||||
func (e *DefaultTaskExecutionEngine) executeSandboxCommands(ctx context.Context, request *TaskExecutionRequest, commands []string) (*SandboxExecutionResult, error) {
|
||||
// Create sandbox configuration
|
||||
sandboxConfig := e.createSandboxConfig(request)
|
||||
|
||||
// Initialize sandbox
|
||||
sandbox := NewDockerSandbox()
|
||||
err := sandbox.Initialize(ctx, sandboxConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize sandbox: %w", err)
|
||||
}
|
||||
defer sandbox.Cleanup()
|
||||
|
||||
var outputs []string
|
||||
var artifacts []TaskArtifact
|
||||
|
||||
// Execute each command
|
||||
for _, cmdStr := range commands {
|
||||
cmd := &Command{
|
||||
Executable: "/bin/sh",
|
||||
Args: []string{"-c", cmdStr},
|
||||
WorkingDir: "/workspace",
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
cmdResult, err := sandbox.ExecuteCommand(ctx, cmd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("command execution failed: %w", err)
|
||||
}
|
||||
|
||||
outputs = append(outputs, fmt.Sprintf("$ %s\n%s", cmdStr, cmdResult.Stdout))
|
||||
|
||||
if cmdResult.ExitCode != 0 {
|
||||
outputs = append(outputs, fmt.Sprintf("Error (exit %d): %s", cmdResult.ExitCode, cmdResult.Stderr))
|
||||
}
|
||||
}
|
||||
|
||||
// Get resource usage
|
||||
resourceUsage, _ := sandbox.GetResourceUsage(ctx)
|
||||
|
||||
// Collect any generated files as artifacts
|
||||
files, err := sandbox.ListFiles(ctx, "/workspace")
|
||||
if err == nil {
|
||||
for _, file := range files {
|
||||
if !file.IsDir && file.Size > 0 {
|
||||
content, err := sandbox.ReadFile(ctx, "/workspace/"+file.Name)
|
||||
if err == nil {
|
||||
artifact := TaskArtifact{
|
||||
Name: file.Name,
|
||||
Type: "generated_file",
|
||||
Content: content,
|
||||
Size: file.Size,
|
||||
CreatedAt: file.ModTime,
|
||||
}
|
||||
artifacts = append(artifacts, artifact)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &SandboxExecutionResult{
|
||||
Output: strings.Join(outputs, "\n"),
|
||||
Artifacts: artifacts,
|
||||
ResourceUsage: resourceUsage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createSandboxConfig creates a sandbox configuration from task requirements
|
||||
func (e *DefaultTaskExecutionEngine) createSandboxConfig(request *TaskExecutionRequest) *SandboxConfig {
|
||||
config := &SandboxConfig{
|
||||
Type: "docker",
|
||||
Image: "alpine:latest",
|
||||
Architecture: "amd64",
|
||||
WorkingDir: "/workspace",
|
||||
Timeout: 5 * time.Minute,
|
||||
Environment: make(map[string]string),
|
||||
}
|
||||
|
||||
// Apply defaults from engine config
|
||||
if e.config.SandboxDefaults != nil {
|
||||
if e.config.SandboxDefaults.Image != "" {
|
||||
config.Image = e.config.SandboxDefaults.Image
|
||||
}
|
||||
if e.config.SandboxDefaults.Resources.MemoryLimit > 0 {
|
||||
config.Resources = e.config.SandboxDefaults.Resources
|
||||
}
|
||||
if e.config.SandboxDefaults.Security.NoNewPrivileges {
|
||||
config.Security = e.config.SandboxDefaults.Security
|
||||
}
|
||||
}
|
||||
|
||||
// Apply task-specific requirements
|
||||
if request.Requirements != nil {
|
||||
if request.Requirements.SandboxType != "" {
|
||||
config.Type = request.Requirements.SandboxType
|
||||
}
|
||||
|
||||
if request.Requirements.EnvironmentVars != nil {
|
||||
for k, v := range request.Requirements.EnvironmentVars {
|
||||
config.Environment[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if request.Requirements.ResourceLimits != nil {
|
||||
config.Resources = *request.Requirements.ResourceLimits
|
||||
}
|
||||
|
||||
if request.Requirements.SecurityPolicy != nil {
|
||||
config.Security = *request.Requirements.SecurityPolicy
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// formatOutput creates a formatted output string from AI response and artifacts
|
||||
func (e *DefaultTaskExecutionEngine) formatOutput(aiResponse *ai.TaskResponse, artifacts []TaskArtifact) string {
|
||||
var output strings.Builder
|
||||
|
||||
output.WriteString("AI Response:\n")
|
||||
output.WriteString(aiResponse.Response)
|
||||
output.WriteString("\n\n")
|
||||
|
||||
if len(artifacts) > 0 {
|
||||
output.WriteString("Generated Artifacts:\n")
|
||||
for _, artifact := range artifacts {
|
||||
output.WriteString(fmt.Sprintf("- %s (%s, %d bytes)\n",
|
||||
artifact.Name, artifact.Type, artifact.Size))
|
||||
}
|
||||
}
|
||||
|
||||
return output.String()
|
||||
}
|
||||
|
||||
// GetMetrics returns current engine metrics
|
||||
func (e *DefaultTaskExecutionEngine) GetMetrics() *EngineMetrics {
|
||||
return e.metrics
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the execution engine
|
||||
func (e *DefaultTaskExecutionEngine) Shutdown() error {
|
||||
e.logger.Printf("Shutting down TaskExecutionEngine...")
|
||||
|
||||
// Cancel all active tasks
|
||||
for taskID, cancel := range e.activeTasks {
|
||||
e.logger.Printf("Canceling active task: %s", taskID)
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Wait for tasks to finish (with timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for len(e.activeTasks) > 0 {
|
||||
select {
|
||||
case <-shutdownCtx.Done():
|
||||
e.logger.Printf("Shutdown timeout reached, %d tasks may still be active", len(e.activeTasks))
|
||||
return nil
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Continue waiting
|
||||
}
|
||||
}
|
||||
|
||||
e.logger.Printf("TaskExecutionEngine shutdown complete")
|
||||
return nil
|
||||
}
|
||||
599
pkg/execution/engine_test.go
Normal file
599
pkg/execution/engine_test.go
Normal file
@@ -0,0 +1,599 @@
|
||||
package execution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/ai"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockProvider implements ai.ModelProvider for testing
|
||||
type MockProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockProvider) ExecuteTask(ctx context.Context, request *ai.TaskRequest) (*ai.TaskResponse, error) {
|
||||
args := m.Called(ctx, request)
|
||||
return args.Get(0).(*ai.TaskResponse), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetCapabilities() ai.ProviderCapabilities {
|
||||
args := m.Called()
|
||||
return args.Get(0).(ai.ProviderCapabilities)
|
||||
}
|
||||
|
||||
func (m *MockProvider) ValidateConfig() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetProviderInfo() ai.ProviderInfo {
|
||||
args := m.Called()
|
||||
return args.Get(0).(ai.ProviderInfo)
|
||||
}
|
||||
|
||||
// MockProviderFactory for testing
|
||||
type MockProviderFactory struct {
|
||||
mock.Mock
|
||||
provider ai.ModelProvider
|
||||
config ai.ProviderConfig
|
||||
}
|
||||
|
||||
func (m *MockProviderFactory) GetProviderForRole(role string) (ai.ModelProvider, ai.ProviderConfig, error) {
|
||||
args := m.Called(role)
|
||||
return args.Get(0).(ai.ModelProvider), args.Get(1).(ai.ProviderConfig), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockProviderFactory) GetProvider(name string) (ai.ModelProvider, error) {
|
||||
args := m.Called(name)
|
||||
return args.Get(0).(ai.ModelProvider), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockProviderFactory) ListProviders() []string {
|
||||
args := m.Called()
|
||||
return args.Get(0).([]string)
|
||||
}
|
||||
|
||||
func (m *MockProviderFactory) GetHealthStatus() map[string]bool {
|
||||
args := m.Called()
|
||||
return args.Get(0).(map[string]bool)
|
||||
}
|
||||
|
||||
func TestNewTaskExecutionEngine(t *testing.T) {
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
assert.NotNil(t, engine)
|
||||
assert.NotNil(t, engine.metrics)
|
||||
assert.NotNil(t, engine.activeTasks)
|
||||
assert.NotNil(t, engine.logger)
|
||||
}
|
||||
|
||||
func TestTaskExecutionEngine_Initialize(t *testing.T) {
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *EngineConfig
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing AI factory",
|
||||
config: &EngineConfig{
|
||||
DefaultTimeout: 1 * time.Minute,
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid config",
|
||||
config: &EngineConfig{
|
||||
AIProviderFactory: &MockProviderFactory{},
|
||||
DefaultTimeout: 1 * time.Minute,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "config with defaults",
|
||||
config: &EngineConfig{
|
||||
AIProviderFactory: &MockProviderFactory{},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := engine.Initialize(context.Background(), tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.config, engine.config)
|
||||
|
||||
// Check defaults are set
|
||||
if tt.config.DefaultTimeout == 0 {
|
||||
assert.Equal(t, 5*time.Minute, engine.config.DefaultTimeout)
|
||||
}
|
||||
if tt.config.MaxConcurrentTasks == 0 {
|
||||
assert.Equal(t, 10, engine.config.MaxConcurrentTasks)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskExecutionEngine_ExecuteTask_SimpleResponse(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
// Setup mock AI provider
|
||||
mockProvider := &MockProvider{}
|
||||
mockFactory := &MockProviderFactory{}
|
||||
|
||||
// Configure mock responses
|
||||
mockProvider.On("ExecuteTask", mock.Anything, mock.Anything).Return(
|
||||
&ai.TaskResponse{
|
||||
TaskID: "test-123",
|
||||
Content: "Task completed successfully",
|
||||
Success: true,
|
||||
Actions: []ai.ActionResult{},
|
||||
Metadata: map[string]interface{}{},
|
||||
}, nil)
|
||||
|
||||
mockFactory.On("GetProviderForRole", "general").Return(
|
||||
mockProvider,
|
||||
ai.ProviderConfig{
|
||||
Provider: "mock",
|
||||
Model: "test-model",
|
||||
},
|
||||
nil)
|
||||
|
||||
config := &EngineConfig{
|
||||
AIProviderFactory: mockFactory,
|
||||
DefaultTimeout: 30 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
|
||||
err := engine.Initialize(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Execute simple task (no sandbox commands)
|
||||
request := &TaskExecutionRequest{
|
||||
ID: "test-123",
|
||||
Type: "analysis",
|
||||
Description: "Analyze the given data",
|
||||
Context: map[string]interface{}{"data": "sample data"},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := engine.ExecuteTask(ctx, request)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.Success)
|
||||
assert.Equal(t, "test-123", result.TaskID)
|
||||
assert.Contains(t, result.Output, "Task completed successfully")
|
||||
assert.NotNil(t, result.Metrics)
|
||||
assert.False(t, result.Metrics.StartTime.IsZero())
|
||||
assert.False(t, result.Metrics.EndTime.IsZero())
|
||||
assert.Greater(t, result.Metrics.Duration, time.Duration(0))
|
||||
|
||||
// Verify mocks were called
|
||||
mockProvider.AssertCalled(t, "ExecuteTask", mock.Anything, mock.Anything)
|
||||
mockFactory.AssertCalled(t, "GetProviderForRole", "general")
|
||||
}
|
||||
|
||||
func TestTaskExecutionEngine_ExecuteTask_WithCommands(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Docker integration test in short mode")
|
||||
}
|
||||
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
// Setup mock AI provider with commands
|
||||
mockProvider := &MockProvider{}
|
||||
mockFactory := &MockProviderFactory{}
|
||||
|
||||
// Configure mock to return commands
|
||||
mockProvider.On("ExecuteTask", mock.Anything, mock.Anything).Return(
|
||||
&ai.TaskResponse{
|
||||
TaskID: "test-456",
|
||||
Content: "Executing commands",
|
||||
Success: true,
|
||||
Actions: []ai.ActionResult{
|
||||
{
|
||||
Type: "command",
|
||||
Content: map[string]interface{}{
|
||||
"command": "echo 'Hello World'",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "file",
|
||||
Content: map[string]interface{}{
|
||||
"name": "test.txt",
|
||||
"content": "Test file content",
|
||||
},
|
||||
},
|
||||
},
|
||||
Metadata: map[string]interface{}{},
|
||||
}, nil)
|
||||
|
||||
mockFactory.On("GetProviderForRole", "developer").Return(
|
||||
mockProvider,
|
||||
ai.ProviderConfig{
|
||||
Provider: "mock",
|
||||
Model: "test-model",
|
||||
},
|
||||
nil)
|
||||
|
||||
config := &EngineConfig{
|
||||
AIProviderFactory: mockFactory,
|
||||
DefaultTimeout: 1 * time.Minute,
|
||||
SandboxDefaults: &SandboxConfig{
|
||||
Type: "docker",
|
||||
Image: "alpine:latest",
|
||||
Resources: ResourceLimits{
|
||||
MemoryLimit: 256 * 1024 * 1024,
|
||||
CPULimit: 0.5,
|
||||
},
|
||||
Security: SecurityPolicy{
|
||||
NoNewPrivileges: true,
|
||||
AllowNetworking: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.Initialize(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Execute task with commands
|
||||
request := &TaskExecutionRequest{
|
||||
ID: "test-456",
|
||||
Type: "code_generation",
|
||||
Description: "Generate a simple script",
|
||||
Timeout: 2 * time.Minute,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := engine.ExecuteTask(ctx, request)
|
||||
|
||||
if err != nil {
|
||||
// If Docker is not available, skip this test
|
||||
t.Skipf("Docker not available for sandbox testing: %v", err)
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.Success)
|
||||
assert.Equal(t, "test-456", result.TaskID)
|
||||
assert.NotEmpty(t, result.Output)
|
||||
assert.GreaterOrEqual(t, len(result.Artifacts), 1) // At least the file artifact
|
||||
assert.Equal(t, 1, result.Metrics.CommandsExecuted)
|
||||
assert.Greater(t, result.Metrics.SandboxTime, time.Duration(0))
|
||||
|
||||
// Check artifacts
|
||||
var foundTestFile bool
|
||||
for _, artifact := range result.Artifacts {
|
||||
if artifact.Name == "test.txt" {
|
||||
foundTestFile = true
|
||||
assert.Equal(t, "file", artifact.Type)
|
||||
assert.Equal(t, "Test file content", string(artifact.Content))
|
||||
}
|
||||
}
|
||||
assert.True(t, foundTestFile, "Expected test.txt artifact not found")
|
||||
}
|
||||
|
||||
func TestTaskExecutionEngine_DetermineRoleFromTask(t *testing.T) {
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request *TaskExecutionRequest
|
||||
expectedRole string
|
||||
}{
|
||||
{
|
||||
name: "code task",
|
||||
request: &TaskExecutionRequest{
|
||||
Type: "code_generation",
|
||||
Description: "Write a function to sort array",
|
||||
},
|
||||
expectedRole: "developer",
|
||||
},
|
||||
{
|
||||
name: "analysis task",
|
||||
request: &TaskExecutionRequest{
|
||||
Type: "analysis",
|
||||
Description: "Analyze the performance metrics",
|
||||
},
|
||||
expectedRole: "analyst",
|
||||
},
|
||||
{
|
||||
name: "test task",
|
||||
request: &TaskExecutionRequest{
|
||||
Type: "testing",
|
||||
Description: "Write tests for the function",
|
||||
},
|
||||
expectedRole: "tester",
|
||||
},
|
||||
{
|
||||
name: "program task by description",
|
||||
request: &TaskExecutionRequest{
|
||||
Type: "general",
|
||||
Description: "Create a program that processes data",
|
||||
},
|
||||
expectedRole: "developer",
|
||||
},
|
||||
{
|
||||
name: "review task by description",
|
||||
request: &TaskExecutionRequest{
|
||||
Type: "general",
|
||||
Description: "Review the code quality",
|
||||
},
|
||||
expectedRole: "analyst",
|
||||
},
|
||||
{
|
||||
name: "general task",
|
||||
request: &TaskExecutionRequest{
|
||||
Type: "documentation",
|
||||
Description: "Write user documentation",
|
||||
},
|
||||
expectedRole: "general",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
role := engine.determineRoleFromTask(tt.request)
|
||||
assert.Equal(t, tt.expectedRole, role)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskExecutionEngine_ParseAIResponse(t *testing.T) {
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response *ai.TaskResponse
|
||||
expectedCommands int
|
||||
expectedArtifacts int
|
||||
}{
|
||||
{
|
||||
name: "response with commands and files",
|
||||
response: &ai.TaskResponse{
|
||||
Actions: []ai.ActionResult{
|
||||
{
|
||||
Type: "command",
|
||||
Content: map[string]interface{}{
|
||||
"command": "ls -la",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "command",
|
||||
Content: map[string]interface{}{
|
||||
"command": "echo 'test'",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "file",
|
||||
Content: map[string]interface{}{
|
||||
"name": "script.sh",
|
||||
"content": "#!/bin/bash\necho 'Hello'",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCommands: 2,
|
||||
expectedArtifacts: 1,
|
||||
},
|
||||
{
|
||||
name: "response with no actions",
|
||||
response: &ai.TaskResponse{
|
||||
Actions: []ai.ActionResult{},
|
||||
},
|
||||
expectedCommands: 0,
|
||||
expectedArtifacts: 0,
|
||||
},
|
||||
{
|
||||
name: "response with unknown action types",
|
||||
response: &ai.TaskResponse{
|
||||
Actions: []ai.ActionResult{
|
||||
{
|
||||
Type: "unknown",
|
||||
Content: map[string]interface{}{
|
||||
"data": "some data",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCommands: 0,
|
||||
expectedArtifacts: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
commands, artifacts, err := engine.parseAIResponse(tt.response)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, commands, tt.expectedCommands)
|
||||
assert.Len(t, artifacts, tt.expectedArtifacts)
|
||||
|
||||
// Validate artifact content if present
|
||||
for _, artifact := range artifacts {
|
||||
assert.NotEmpty(t, artifact.Name)
|
||||
assert.NotEmpty(t, artifact.Type)
|
||||
assert.Greater(t, artifact.Size, int64(0))
|
||||
assert.False(t, artifact.CreatedAt.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskExecutionEngine_CreateSandboxConfig(t *testing.T) {
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
// Initialize with default config
|
||||
config := &EngineConfig{
|
||||
AIProviderFactory: &MockProviderFactory{},
|
||||
SandboxDefaults: &SandboxConfig{
|
||||
Image: "ubuntu:20.04",
|
||||
Resources: ResourceLimits{
|
||||
MemoryLimit: 1024 * 1024 * 1024,
|
||||
CPULimit: 2.0,
|
||||
},
|
||||
Security: SecurityPolicy{
|
||||
NoNewPrivileges: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
engine.Initialize(context.Background(), config)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request *TaskExecutionRequest
|
||||
validate func(t *testing.T, config *SandboxConfig)
|
||||
}{
|
||||
{
|
||||
name: "basic request uses defaults",
|
||||
request: &TaskExecutionRequest{
|
||||
ID: "test",
|
||||
Type: "general",
|
||||
Description: "test task",
|
||||
},
|
||||
validate: func(t *testing.T, config *SandboxConfig) {
|
||||
assert.Equal(t, "ubuntu:20.04", config.Image)
|
||||
assert.Equal(t, int64(1024*1024*1024), config.Resources.MemoryLimit)
|
||||
assert.Equal(t, 2.0, config.Resources.CPULimit)
|
||||
assert.True(t, config.Security.NoNewPrivileges)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request with custom requirements",
|
||||
request: &TaskExecutionRequest{
|
||||
ID: "test",
|
||||
Type: "custom",
|
||||
Description: "custom task",
|
||||
Requirements: &TaskRequirements{
|
||||
SandboxType: "container",
|
||||
EnvironmentVars: map[string]string{
|
||||
"ENV_VAR": "test_value",
|
||||
},
|
||||
ResourceLimits: &ResourceLimits{
|
||||
MemoryLimit: 512 * 1024 * 1024,
|
||||
CPULimit: 1.0,
|
||||
},
|
||||
SecurityPolicy: &SecurityPolicy{
|
||||
ReadOnlyRoot: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, config *SandboxConfig) {
|
||||
assert.Equal(t, "container", config.Type)
|
||||
assert.Equal(t, "test_value", config.Environment["ENV_VAR"])
|
||||
assert.Equal(t, int64(512*1024*1024), config.Resources.MemoryLimit)
|
||||
assert.Equal(t, 1.0, config.Resources.CPULimit)
|
||||
assert.True(t, config.Security.ReadOnlyRoot)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sandboxConfig := engine.createSandboxConfig(tt.request)
|
||||
tt.validate(t, sandboxConfig)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskExecutionEngine_GetMetrics(t *testing.T) {
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
metrics := engine.GetMetrics()
|
||||
|
||||
assert.NotNil(t, metrics)
|
||||
assert.Equal(t, int64(0), metrics.TasksExecuted)
|
||||
assert.Equal(t, int64(0), metrics.TasksSuccessful)
|
||||
assert.Equal(t, int64(0), metrics.TasksFailed)
|
||||
}
|
||||
|
||||
func TestTaskExecutionEngine_Shutdown(t *testing.T) {
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
// Initialize engine
|
||||
config := &EngineConfig{
|
||||
AIProviderFactory: &MockProviderFactory{},
|
||||
}
|
||||
err := engine.Initialize(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a mock active task
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
engine.activeTasks["test-task"] = cancel
|
||||
|
||||
// Shutdown should cancel active tasks
|
||||
err = engine.Shutdown()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify task was cleaned up
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Expected - task was canceled
|
||||
default:
|
||||
t.Error("Expected task context to be canceled")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkTaskExecutionEngine_ExecuteSimpleTask(b *testing.B) {
|
||||
engine := NewTaskExecutionEngine()
|
||||
|
||||
// Setup mock AI provider
|
||||
mockProvider := &MockProvider{}
|
||||
mockFactory := &MockProviderFactory{}
|
||||
|
||||
mockProvider.On("ExecuteTask", mock.Anything, mock.Anything).Return(
|
||||
&ai.TaskResponse{
|
||||
TaskID: "bench",
|
||||
Content: "Benchmark task completed",
|
||||
Success: true,
|
||||
Actions: []ai.ActionResult{},
|
||||
}, nil)
|
||||
|
||||
mockFactory.On("GetProviderForRole", mock.Anything).Return(
|
||||
mockProvider,
|
||||
ai.ProviderConfig{Provider: "mock", Model: "test"},
|
||||
nil)
|
||||
|
||||
config := &EngineConfig{
|
||||
AIProviderFactory: mockFactory,
|
||||
DefaultTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
engine.Initialize(context.Background(), config)
|
||||
|
||||
request := &TaskExecutionRequest{
|
||||
ID: "bench",
|
||||
Type: "benchmark",
|
||||
Description: "Benchmark task",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := engine.ExecuteTask(context.Background(), request)
|
||||
if err != nil {
|
||||
b.Fatalf("Task execution failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
415
pkg/execution/sandbox.go
Normal file
415
pkg/execution/sandbox.go
Normal file
@@ -0,0 +1,415 @@
|
||||
package execution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExecutionSandbox defines the interface for isolated task execution environments
|
||||
type ExecutionSandbox interface {
|
||||
// Initialize sets up the sandbox environment
|
||||
Initialize(ctx context.Context, config *SandboxConfig) error
|
||||
|
||||
// ExecuteCommand runs a command within the sandbox
|
||||
ExecuteCommand(ctx context.Context, cmd *Command) (*CommandResult, error)
|
||||
|
||||
// CopyFiles copies files between host and sandbox
|
||||
CopyFiles(ctx context.Context, source, dest string) error
|
||||
|
||||
// WriteFile writes content to a file in the sandbox
|
||||
WriteFile(ctx context.Context, path string, content []byte, mode uint32) error
|
||||
|
||||
// ReadFile reads content from a file in the sandbox
|
||||
ReadFile(ctx context.Context, path string) ([]byte, error)
|
||||
|
||||
// ListFiles lists files in a directory within the sandbox
|
||||
ListFiles(ctx context.Context, path string) ([]FileInfo, error)
|
||||
|
||||
// GetWorkingDirectory returns the current working directory in the sandbox
|
||||
GetWorkingDirectory() string
|
||||
|
||||
// SetWorkingDirectory changes the working directory in the sandbox
|
||||
SetWorkingDirectory(path string) error
|
||||
|
||||
// GetEnvironment returns environment variables in the sandbox
|
||||
GetEnvironment() map[string]string
|
||||
|
||||
// SetEnvironment sets environment variables in the sandbox
|
||||
SetEnvironment(env map[string]string) error
|
||||
|
||||
// GetResourceUsage returns current resource usage statistics
|
||||
GetResourceUsage(ctx context.Context) (*ResourceUsage, error)
|
||||
|
||||
// Cleanup destroys the sandbox and cleans up resources
|
||||
Cleanup() error
|
||||
|
||||
// GetInfo returns information about the sandbox
|
||||
GetInfo() SandboxInfo
|
||||
}
|
||||
|
||||
// SandboxConfig represents configuration for a sandbox environment
|
||||
type SandboxConfig struct {
|
||||
// Sandbox type and runtime
|
||||
Type string `json:"type"` // docker, vm, process
|
||||
Image string `json:"image"` // Container/VM image
|
||||
Runtime string `json:"runtime"` // docker, containerd, etc.
|
||||
Architecture string `json:"architecture"` // amd64, arm64
|
||||
|
||||
// Resource limits
|
||||
Resources ResourceLimits `json:"resources"`
|
||||
|
||||
// Security settings
|
||||
Security SecurityPolicy `json:"security"`
|
||||
|
||||
// Repository configuration
|
||||
Repository RepositoryConfig `json:"repository"`
|
||||
|
||||
// Network settings
|
||||
Network NetworkConfig `json:"network"`
|
||||
|
||||
// Environment settings
|
||||
Environment map[string]string `json:"environment"`
|
||||
WorkingDir string `json:"working_dir"`
|
||||
|
||||
// Tool and service access
|
||||
Tools []string `json:"tools"` // Available tools in sandbox
|
||||
MCPServers []string `json:"mcp_servers"` // MCP servers to connect to
|
||||
|
||||
// Execution settings
|
||||
Timeout time.Duration `json:"timeout"` // Maximum execution time
|
||||
CleanupDelay time.Duration `json:"cleanup_delay"` // Delay before cleanup
|
||||
|
||||
// Metadata
|
||||
Labels map[string]string `json:"labels"`
|
||||
Annotations map[string]string `json:"annotations"`
|
||||
}
|
||||
|
||||
// Command represents a command to execute in the sandbox
|
||||
type Command struct {
|
||||
// Command specification
|
||||
Executable string `json:"executable"`
|
||||
Args []string `json:"args"`
|
||||
WorkingDir string `json:"working_dir"`
|
||||
Environment map[string]string `json:"environment"`
|
||||
|
||||
// Input/Output
|
||||
Stdin io.Reader `json:"-"`
|
||||
StdinContent string `json:"stdin_content"`
|
||||
|
||||
// Execution settings
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
User string `json:"user"`
|
||||
|
||||
// Security settings
|
||||
AllowNetwork bool `json:"allow_network"`
|
||||
AllowWrite bool `json:"allow_write"`
|
||||
RestrictPaths []string `json:"restrict_paths"`
|
||||
}
|
||||
|
||||
// CommandResult represents the result of command execution
|
||||
type CommandResult struct {
|
||||
// Exit information
|
||||
ExitCode int `json:"exit_code"`
|
||||
Success bool `json:"success"`
|
||||
|
||||
// Output
|
||||
Stdout string `json:"stdout"`
|
||||
Stderr string `json:"stderr"`
|
||||
Combined string `json:"combined"`
|
||||
|
||||
// Timing
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
|
||||
// Resource usage during execution
|
||||
ResourceUsage ResourceUsage `json:"resource_usage"`
|
||||
|
||||
// Error information
|
||||
Error string `json:"error,omitempty"`
|
||||
Signal string `json:"signal,omitempty"`
|
||||
|
||||
// Metadata
|
||||
ProcessID int `json:"process_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// FileInfo represents information about a file in the sandbox
|
||||
type FileInfo struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Size int64 `json:"size"`
|
||||
Mode uint32 `json:"mode"`
|
||||
ModTime time.Time `json:"mod_time"`
|
||||
IsDir bool `json:"is_dir"`
|
||||
Owner string `json:"owner"`
|
||||
Group string `json:"group"`
|
||||
Permissions string `json:"permissions"`
|
||||
}
|
||||
|
||||
// ResourceLimits defines resource constraints for the sandbox
|
||||
type ResourceLimits struct {
|
||||
// CPU limits
|
||||
CPULimit float64 `json:"cpu_limit"` // CPU cores (e.g., 1.5)
|
||||
CPURequest float64 `json:"cpu_request"` // CPU cores requested
|
||||
|
||||
// Memory limits
|
||||
MemoryLimit int64 `json:"memory_limit"` // Bytes
|
||||
MemoryRequest int64 `json:"memory_request"` // Bytes
|
||||
|
||||
// Storage limits
|
||||
DiskLimit int64 `json:"disk_limit"` // Bytes
|
||||
DiskRequest int64 `json:"disk_request"` // Bytes
|
||||
|
||||
// Network limits
|
||||
NetworkInLimit int64 `json:"network_in_limit"` // Bytes/sec
|
||||
NetworkOutLimit int64 `json:"network_out_limit"` // Bytes/sec
|
||||
|
||||
// Process limits
|
||||
ProcessLimit int `json:"process_limit"` // Max processes
|
||||
FileLimit int `json:"file_limit"` // Max open files
|
||||
|
||||
// Time limits
|
||||
WallTimeLimit time.Duration `json:"wall_time_limit"` // Max wall clock time
|
||||
CPUTimeLimit time.Duration `json:"cpu_time_limit"` // Max CPU time
|
||||
}
|
||||
|
||||
// SecurityPolicy defines security constraints and policies
|
||||
type SecurityPolicy struct {
|
||||
// Container security
|
||||
RunAsUser string `json:"run_as_user"`
|
||||
RunAsGroup string `json:"run_as_group"`
|
||||
ReadOnlyRoot bool `json:"read_only_root"`
|
||||
NoNewPrivileges bool `json:"no_new_privileges"`
|
||||
|
||||
// Capabilities
|
||||
AddCapabilities []string `json:"add_capabilities"`
|
||||
DropCapabilities []string `json:"drop_capabilities"`
|
||||
|
||||
// SELinux/AppArmor
|
||||
SELinuxContext string `json:"selinux_context"`
|
||||
AppArmorProfile string `json:"apparmor_profile"`
|
||||
SeccompProfile string `json:"seccomp_profile"`
|
||||
|
||||
// Network security
|
||||
AllowNetworking bool `json:"allow_networking"`
|
||||
AllowedHosts []string `json:"allowed_hosts"`
|
||||
BlockedHosts []string `json:"blocked_hosts"`
|
||||
AllowedPorts []int `json:"allowed_ports"`
|
||||
|
||||
// File system security
|
||||
ReadOnlyPaths []string `json:"read_only_paths"`
|
||||
MaskedPaths []string `json:"masked_paths"`
|
||||
TmpfsPaths []string `json:"tmpfs_paths"`
|
||||
|
||||
// Resource protection
|
||||
PreventEscalation bool `json:"prevent_escalation"`
|
||||
IsolateNetwork bool `json:"isolate_network"`
|
||||
IsolateProcess bool `json:"isolate_process"`
|
||||
|
||||
// Monitoring
|
||||
EnableAuditLog bool `json:"enable_audit_log"`
|
||||
LogSecurityEvents bool `json:"log_security_events"`
|
||||
}
|
||||
|
||||
// RepositoryConfig defines how the repository is mounted in the sandbox
|
||||
type RepositoryConfig struct {
|
||||
// Repository source
|
||||
URL string `json:"url"`
|
||||
Branch string `json:"branch"`
|
||||
CommitHash string `json:"commit_hash"`
|
||||
LocalPath string `json:"local_path"`
|
||||
|
||||
// Mount configuration
|
||||
MountPoint string `json:"mount_point"` // Path in sandbox
|
||||
ReadOnly bool `json:"read_only"`
|
||||
|
||||
// Git configuration
|
||||
GitConfig GitConfig `json:"git_config"`
|
||||
|
||||
// File filters
|
||||
IncludeFiles []string `json:"include_files"` // Glob patterns
|
||||
ExcludeFiles []string `json:"exclude_files"` // Glob patterns
|
||||
|
||||
// Access permissions
|
||||
Permissions string `json:"permissions"` // rwx format
|
||||
Owner string `json:"owner"`
|
||||
Group string `json:"group"`
|
||||
}
|
||||
|
||||
// GitConfig defines Git configuration within the sandbox
|
||||
type GitConfig struct {
|
||||
UserName string `json:"user_name"`
|
||||
UserEmail string `json:"user_email"`
|
||||
SigningKey string `json:"signing_key"`
|
||||
ConfigValues map[string]string `json:"config_values"`
|
||||
}
|
||||
|
||||
// NetworkConfig defines network settings for the sandbox
|
||||
type NetworkConfig struct {
|
||||
// Network isolation
|
||||
Isolated bool `json:"isolated"` // No network access
|
||||
Bridge string `json:"bridge"` // Network bridge
|
||||
|
||||
// DNS settings
|
||||
DNSServers []string `json:"dns_servers"`
|
||||
DNSSearch []string `json:"dns_search"`
|
||||
|
||||
// Proxy settings
|
||||
HTTPProxy string `json:"http_proxy"`
|
||||
HTTPSProxy string `json:"https_proxy"`
|
||||
NoProxy string `json:"no_proxy"`
|
||||
|
||||
// Port mappings
|
||||
PortMappings []PortMapping `json:"port_mappings"`
|
||||
|
||||
// Bandwidth limits
|
||||
IngressLimit int64 `json:"ingress_limit"` // Bytes/sec
|
||||
EgressLimit int64 `json:"egress_limit"` // Bytes/sec
|
||||
}
|
||||
|
||||
// PortMapping defines port forwarding configuration
|
||||
type PortMapping struct {
|
||||
HostPort int `json:"host_port"`
|
||||
ContainerPort int `json:"container_port"`
|
||||
Protocol string `json:"protocol"` // tcp, udp
|
||||
}
|
||||
|
||||
// ResourceUsage represents current resource consumption
|
||||
type ResourceUsage struct {
|
||||
// Timestamp of measurement
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
|
||||
// CPU usage
|
||||
CPUUsage float64 `json:"cpu_usage"` // Percentage
|
||||
CPUTime time.Duration `json:"cpu_time"` // Total CPU time
|
||||
|
||||
// Memory usage
|
||||
MemoryUsage int64 `json:"memory_usage"` // Bytes
|
||||
MemoryPercent float64 `json:"memory_percent"` // Percentage of limit
|
||||
MemoryPeak int64 `json:"memory_peak"` // Peak usage
|
||||
|
||||
// Disk usage
|
||||
DiskUsage int64 `json:"disk_usage"` // Bytes
|
||||
DiskReads int64 `json:"disk_reads"` // Read operations
|
||||
DiskWrites int64 `json:"disk_writes"` // Write operations
|
||||
|
||||
// Network usage
|
||||
NetworkIn int64 `json:"network_in"` // Bytes received
|
||||
NetworkOut int64 `json:"network_out"` // Bytes sent
|
||||
|
||||
// Process information
|
||||
ProcessCount int `json:"process_count"` // Active processes
|
||||
ThreadCount int `json:"thread_count"` // Active threads
|
||||
FileHandles int `json:"file_handles"` // Open file handles
|
||||
|
||||
// Runtime information
|
||||
Uptime time.Duration `json:"uptime"` // Sandbox uptime
|
||||
}
|
||||
|
||||
// SandboxInfo provides information about a sandbox instance
|
||||
type SandboxInfo struct {
|
||||
// Identification
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
|
||||
// Status
|
||||
Status SandboxStatus `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
|
||||
// Runtime information
|
||||
Runtime string `json:"runtime"`
|
||||
Image string `json:"image"`
|
||||
Platform string `json:"platform"`
|
||||
|
||||
// Network information
|
||||
IPAddress string `json:"ip_address"`
|
||||
MACAddress string `json:"mac_address"`
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// Resource information
|
||||
AllocatedResources ResourceLimits `json:"allocated_resources"`
|
||||
|
||||
// Configuration
|
||||
Config SandboxConfig `json:"config"`
|
||||
|
||||
// Metadata
|
||||
Labels map[string]string `json:"labels"`
|
||||
Annotations map[string]string `json:"annotations"`
|
||||
}
|
||||
|
||||
// SandboxStatus represents the current status of a sandbox
|
||||
type SandboxStatus string
|
||||
|
||||
const (
|
||||
StatusCreating SandboxStatus = "creating"
|
||||
StatusStarting SandboxStatus = "starting"
|
||||
StatusRunning SandboxStatus = "running"
|
||||
StatusPaused SandboxStatus = "paused"
|
||||
StatusStopping SandboxStatus = "stopping"
|
||||
StatusStopped SandboxStatus = "stopped"
|
||||
StatusFailed SandboxStatus = "failed"
|
||||
StatusDestroyed SandboxStatus = "destroyed"
|
||||
)
|
||||
|
||||
// Common sandbox errors
|
||||
var (
|
||||
ErrSandboxNotFound = &SandboxError{Code: "SANDBOX_NOT_FOUND", Message: "Sandbox not found"}
|
||||
ErrSandboxAlreadyExists = &SandboxError{Code: "SANDBOX_ALREADY_EXISTS", Message: "Sandbox already exists"}
|
||||
ErrSandboxNotRunning = &SandboxError{Code: "SANDBOX_NOT_RUNNING", Message: "Sandbox is not running"}
|
||||
ErrSandboxInitFailed = &SandboxError{Code: "SANDBOX_INIT_FAILED", Message: "Sandbox initialization failed"}
|
||||
ErrCommandExecutionFailed = &SandboxError{Code: "COMMAND_EXECUTION_FAILED", Message: "Command execution failed"}
|
||||
ErrResourceLimitExceeded = &SandboxError{Code: "RESOURCE_LIMIT_EXCEEDED", Message: "Resource limit exceeded"}
|
||||
ErrSecurityViolation = &SandboxError{Code: "SECURITY_VIOLATION", Message: "Security policy violation"}
|
||||
ErrFileOperationFailed = &SandboxError{Code: "FILE_OPERATION_FAILED", Message: "File operation failed"}
|
||||
ErrNetworkAccessDenied = &SandboxError{Code: "NETWORK_ACCESS_DENIED", Message: "Network access denied"}
|
||||
ErrTimeoutExceeded = &SandboxError{Code: "TIMEOUT_EXCEEDED", Message: "Execution timeout exceeded"}
|
||||
)
|
||||
|
||||
// SandboxError represents sandbox-specific errors
|
||||
type SandboxError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
Retryable bool `json:"retryable"`
|
||||
Cause error `json:"-"`
|
||||
}
|
||||
|
||||
func (e *SandboxError) Error() string {
|
||||
if e.Details != "" {
|
||||
return e.Message + ": " + e.Details
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func (e *SandboxError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
func (e *SandboxError) IsRetryable() bool {
|
||||
return e.Retryable
|
||||
}
|
||||
|
||||
// NewSandboxError creates a new sandbox error with details
|
||||
func NewSandboxError(base *SandboxError, details string) *SandboxError {
|
||||
return &SandboxError{
|
||||
Code: base.Code,
|
||||
Message: base.Message,
|
||||
Details: details,
|
||||
Retryable: base.Retryable,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSandboxErrorWithCause creates a new sandbox error with an underlying cause
|
||||
func NewSandboxErrorWithCause(base *SandboxError, details string, cause error) *SandboxError {
|
||||
return &SandboxError{
|
||||
Code: base.Code,
|
||||
Message: base.Message,
|
||||
Details: details,
|
||||
Retryable: base.Retryable,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
||||
639
pkg/execution/sandbox_test.go
Normal file
639
pkg/execution/sandbox_test.go
Normal file
@@ -0,0 +1,639 @@
|
||||
package execution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSandboxError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *SandboxError
|
||||
expected string
|
||||
retryable bool
|
||||
}{
|
||||
{
|
||||
name: "simple error",
|
||||
err: ErrSandboxNotFound,
|
||||
expected: "Sandbox not found",
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "error with details",
|
||||
err: NewSandboxError(ErrResourceLimitExceeded, "Memory limit of 1GB exceeded"),
|
||||
expected: "Resource limit exceeded: Memory limit of 1GB exceeded",
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "retryable error",
|
||||
err: &SandboxError{
|
||||
Code: "TEMPORARY_FAILURE",
|
||||
Message: "Temporary network failure",
|
||||
Retryable: true,
|
||||
},
|
||||
expected: "Temporary network failure",
|
||||
retryable: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.err.Error())
|
||||
assert.Equal(t, tt.retryable, tt.err.IsRetryable())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSandboxErrorUnwrap(t *testing.T) {
|
||||
baseErr := errors.New("underlying error")
|
||||
sandboxErr := NewSandboxErrorWithCause(ErrCommandExecutionFailed, "command failed", baseErr)
|
||||
|
||||
unwrapped := sandboxErr.Unwrap()
|
||||
assert.Equal(t, baseErr, unwrapped)
|
||||
}
|
||||
|
||||
func TestSandboxConfig(t *testing.T) {
|
||||
config := &SandboxConfig{
|
||||
Type: "docker",
|
||||
Image: "alpine:latest",
|
||||
Runtime: "docker",
|
||||
Architecture: "amd64",
|
||||
Resources: ResourceLimits{
|
||||
MemoryLimit: 1024 * 1024 * 1024, // 1GB
|
||||
MemoryRequest: 512 * 1024 * 1024, // 512MB
|
||||
CPULimit: 2.0,
|
||||
CPURequest: 1.0,
|
||||
DiskLimit: 10 * 1024 * 1024 * 1024, // 10GB
|
||||
ProcessLimit: 100,
|
||||
FileLimit: 1024,
|
||||
WallTimeLimit: 30 * time.Minute,
|
||||
CPUTimeLimit: 10 * time.Minute,
|
||||
},
|
||||
Security: SecurityPolicy{
|
||||
RunAsUser: "1000",
|
||||
RunAsGroup: "1000",
|
||||
ReadOnlyRoot: true,
|
||||
NoNewPrivileges: true,
|
||||
AddCapabilities: []string{"NET_BIND_SERVICE"},
|
||||
DropCapabilities: []string{"ALL"},
|
||||
SELinuxContext: "unconfined_u:unconfined_r:container_t:s0",
|
||||
AppArmorProfile: "docker-default",
|
||||
SeccompProfile: "runtime/default",
|
||||
AllowNetworking: false,
|
||||
AllowedHosts: []string{"api.example.com"},
|
||||
BlockedHosts: []string{"malicious.com"},
|
||||
AllowedPorts: []int{80, 443},
|
||||
ReadOnlyPaths: []string{"/etc", "/usr"},
|
||||
MaskedPaths: []string{"/proc/kcore", "/proc/keys"},
|
||||
TmpfsPaths: []string{"/tmp", "/var/tmp"},
|
||||
PreventEscalation: true,
|
||||
IsolateNetwork: true,
|
||||
IsolateProcess: true,
|
||||
EnableAuditLog: true,
|
||||
LogSecurityEvents: true,
|
||||
},
|
||||
Repository: RepositoryConfig{
|
||||
URL: "https://github.com/example/repo.git",
|
||||
Branch: "main",
|
||||
LocalPath: "/home/user/repo",
|
||||
MountPoint: "/workspace",
|
||||
ReadOnly: false,
|
||||
GitConfig: GitConfig{
|
||||
UserName: "Test User",
|
||||
UserEmail: "test@example.com",
|
||||
ConfigValues: map[string]string{
|
||||
"core.autocrlf": "input",
|
||||
},
|
||||
},
|
||||
IncludeFiles: []string{"*.go", "*.md"},
|
||||
ExcludeFiles: []string{"*.tmp", "*.log"},
|
||||
Permissions: "755",
|
||||
Owner: "user",
|
||||
Group: "user",
|
||||
},
|
||||
Network: NetworkConfig{
|
||||
Isolated: false,
|
||||
Bridge: "docker0",
|
||||
DNSServers: []string{"8.8.8.8", "1.1.1.1"},
|
||||
DNSSearch: []string{"example.com"},
|
||||
HTTPProxy: "http://proxy:8080",
|
||||
HTTPSProxy: "http://proxy:8080",
|
||||
NoProxy: "localhost,127.0.0.1",
|
||||
PortMappings: []PortMapping{
|
||||
{HostPort: 8080, ContainerPort: 80, Protocol: "tcp"},
|
||||
},
|
||||
IngressLimit: 1024 * 1024, // 1MB/s
|
||||
EgressLimit: 2048 * 1024, // 2MB/s
|
||||
},
|
||||
Environment: map[string]string{
|
||||
"NODE_ENV": "test",
|
||||
"DEBUG": "true",
|
||||
},
|
||||
WorkingDir: "/workspace",
|
||||
Tools: []string{"git", "node", "npm"},
|
||||
MCPServers: []string{"file-server", "web-server"},
|
||||
Timeout: 5 * time.Minute,
|
||||
CleanupDelay: 30 * time.Second,
|
||||
Labels: map[string]string{
|
||||
"app": "chorus",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
Annotations: map[string]string{
|
||||
"description": "Test sandbox configuration",
|
||||
},
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
assert.NotEmpty(t, config.Type)
|
||||
assert.NotEmpty(t, config.Image)
|
||||
assert.NotEmpty(t, config.Architecture)
|
||||
|
||||
// Validate resource limits
|
||||
assert.Greater(t, config.Resources.MemoryLimit, int64(0))
|
||||
assert.Greater(t, config.Resources.CPULimit, 0.0)
|
||||
|
||||
// Validate security policy
|
||||
assert.NotEmpty(t, config.Security.RunAsUser)
|
||||
assert.True(t, config.Security.NoNewPrivileges)
|
||||
assert.NotEmpty(t, config.Security.DropCapabilities)
|
||||
|
||||
// Validate repository config
|
||||
assert.NotEmpty(t, config.Repository.MountPoint)
|
||||
assert.NotEmpty(t, config.Repository.GitConfig.UserName)
|
||||
|
||||
// Validate network config
|
||||
assert.NotEmpty(t, config.Network.DNSServers)
|
||||
assert.Len(t, config.Network.PortMappings, 1)
|
||||
|
||||
// Validate timeouts
|
||||
assert.Greater(t, config.Timeout, time.Duration(0))
|
||||
assert.Greater(t, config.CleanupDelay, time.Duration(0))
|
||||
}
|
||||
|
||||
func TestCommand(t *testing.T) {
|
||||
cmd := &Command{
|
||||
Executable: "python3",
|
||||
Args: []string{"-c", "print('hello world')"},
|
||||
WorkingDir: "/workspace",
|
||||
Environment: map[string]string{"PYTHONPATH": "/custom/path"},
|
||||
StdinContent: "input data",
|
||||
Timeout: 30 * time.Second,
|
||||
User: "1000",
|
||||
AllowNetwork: true,
|
||||
AllowWrite: true,
|
||||
RestrictPaths: []string{"/etc", "/usr"},
|
||||
}
|
||||
|
||||
// Validate command structure
|
||||
assert.Equal(t, "python3", cmd.Executable)
|
||||
assert.Len(t, cmd.Args, 2)
|
||||
assert.Equal(t, "/workspace", cmd.WorkingDir)
|
||||
assert.Equal(t, "/custom/path", cmd.Environment["PYTHONPATH"])
|
||||
assert.Equal(t, "input data", cmd.StdinContent)
|
||||
assert.Equal(t, 30*time.Second, cmd.Timeout)
|
||||
assert.True(t, cmd.AllowNetwork)
|
||||
assert.True(t, cmd.AllowWrite)
|
||||
assert.Len(t, cmd.RestrictPaths, 2)
|
||||
}
|
||||
|
||||
func TestCommandResult(t *testing.T) {
|
||||
startTime := time.Now()
|
||||
endTime := startTime.Add(2 * time.Second)
|
||||
|
||||
result := &CommandResult{
|
||||
ExitCode: 0,
|
||||
Success: true,
|
||||
Stdout: "Standard output",
|
||||
Stderr: "Standard error",
|
||||
Combined: "Combined output",
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
Duration: endTime.Sub(startTime),
|
||||
ResourceUsage: ResourceUsage{
|
||||
CPUUsage: 25.5,
|
||||
MemoryUsage: 1024 * 1024, // 1MB
|
||||
},
|
||||
ProcessID: 12345,
|
||||
Metadata: map[string]interface{}{
|
||||
"container_id": "abc123",
|
||||
"image": "alpine:latest",
|
||||
},
|
||||
}
|
||||
|
||||
// Validate result structure
|
||||
assert.Equal(t, 0, result.ExitCode)
|
||||
assert.True(t, result.Success)
|
||||
assert.Equal(t, "Standard output", result.Stdout)
|
||||
assert.Equal(t, "Standard error", result.Stderr)
|
||||
assert.Equal(t, 2*time.Second, result.Duration)
|
||||
assert.Equal(t, 25.5, result.ResourceUsage.CPUUsage)
|
||||
assert.Equal(t, int64(1024*1024), result.ResourceUsage.MemoryUsage)
|
||||
assert.Equal(t, 12345, result.ProcessID)
|
||||
assert.Equal(t, "abc123", result.Metadata["container_id"])
|
||||
}
|
||||
|
||||
func TestFileInfo(t *testing.T) {
|
||||
modTime := time.Now()
|
||||
|
||||
fileInfo := FileInfo{
|
||||
Name: "test.txt",
|
||||
Path: "/workspace/test.txt",
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
ModTime: modTime,
|
||||
IsDir: false,
|
||||
Owner: "user",
|
||||
Group: "user",
|
||||
Permissions: "-rw-r--r--",
|
||||
}
|
||||
|
||||
// Validate file info structure
|
||||
assert.Equal(t, "test.txt", fileInfo.Name)
|
||||
assert.Equal(t, "/workspace/test.txt", fileInfo.Path)
|
||||
assert.Equal(t, int64(1024), fileInfo.Size)
|
||||
assert.Equal(t, uint32(0644), fileInfo.Mode)
|
||||
assert.Equal(t, modTime, fileInfo.ModTime)
|
||||
assert.False(t, fileInfo.IsDir)
|
||||
assert.Equal(t, "user", fileInfo.Owner)
|
||||
assert.Equal(t, "user", fileInfo.Group)
|
||||
assert.Equal(t, "-rw-r--r--", fileInfo.Permissions)
|
||||
}
|
||||
|
||||
func TestResourceLimits(t *testing.T) {
|
||||
limits := ResourceLimits{
|
||||
CPULimit: 2.5,
|
||||
CPURequest: 1.0,
|
||||
MemoryLimit: 2 * 1024 * 1024 * 1024, // 2GB
|
||||
MemoryRequest: 1 * 1024 * 1024 * 1024, // 1GB
|
||||
DiskLimit: 50 * 1024 * 1024 * 1024, // 50GB
|
||||
DiskRequest: 10 * 1024 * 1024 * 1024, // 10GB
|
||||
NetworkInLimit: 10 * 1024 * 1024, // 10MB/s
|
||||
NetworkOutLimit: 5 * 1024 * 1024, // 5MB/s
|
||||
ProcessLimit: 200,
|
||||
FileLimit: 2048,
|
||||
WallTimeLimit: 1 * time.Hour,
|
||||
CPUTimeLimit: 30 * time.Minute,
|
||||
}
|
||||
|
||||
// Validate resource limits
|
||||
assert.Equal(t, 2.5, limits.CPULimit)
|
||||
assert.Equal(t, 1.0, limits.CPURequest)
|
||||
assert.Equal(t, int64(2*1024*1024*1024), limits.MemoryLimit)
|
||||
assert.Equal(t, int64(1*1024*1024*1024), limits.MemoryRequest)
|
||||
assert.Equal(t, int64(50*1024*1024*1024), limits.DiskLimit)
|
||||
assert.Equal(t, 200, limits.ProcessLimit)
|
||||
assert.Equal(t, 2048, limits.FileLimit)
|
||||
assert.Equal(t, 1*time.Hour, limits.WallTimeLimit)
|
||||
assert.Equal(t, 30*time.Minute, limits.CPUTimeLimit)
|
||||
}
|
||||
|
||||
func TestResourceUsage(t *testing.T) {
|
||||
timestamp := time.Now()
|
||||
|
||||
usage := ResourceUsage{
|
||||
Timestamp: timestamp,
|
||||
CPUUsage: 75.5,
|
||||
CPUTime: 15 * time.Minute,
|
||||
MemoryUsage: 512 * 1024 * 1024, // 512MB
|
||||
MemoryPercent: 25.0,
|
||||
MemoryPeak: 768 * 1024 * 1024, // 768MB
|
||||
DiskUsage: 1 * 1024 * 1024 * 1024, // 1GB
|
||||
DiskReads: 1000,
|
||||
DiskWrites: 500,
|
||||
NetworkIn: 10 * 1024 * 1024, // 10MB
|
||||
NetworkOut: 5 * 1024 * 1024, // 5MB
|
||||
ProcessCount: 25,
|
||||
ThreadCount: 100,
|
||||
FileHandles: 50,
|
||||
Uptime: 2 * time.Hour,
|
||||
}
|
||||
|
||||
// Validate resource usage
|
||||
assert.Equal(t, timestamp, usage.Timestamp)
|
||||
assert.Equal(t, 75.5, usage.CPUUsage)
|
||||
assert.Equal(t, 15*time.Minute, usage.CPUTime)
|
||||
assert.Equal(t, int64(512*1024*1024), usage.MemoryUsage)
|
||||
assert.Equal(t, 25.0, usage.MemoryPercent)
|
||||
assert.Equal(t, int64(768*1024*1024), usage.MemoryPeak)
|
||||
assert.Equal(t, 25, usage.ProcessCount)
|
||||
assert.Equal(t, 100, usage.ThreadCount)
|
||||
assert.Equal(t, 50, usage.FileHandles)
|
||||
assert.Equal(t, 2*time.Hour, usage.Uptime)
|
||||
}
|
||||
|
||||
func TestSandboxInfo(t *testing.T) {
|
||||
createdAt := time.Now()
|
||||
startedAt := createdAt.Add(5 * time.Second)
|
||||
|
||||
info := SandboxInfo{
|
||||
ID: "sandbox-123",
|
||||
Name: "test-sandbox",
|
||||
Type: "docker",
|
||||
Status: StatusRunning,
|
||||
CreatedAt: createdAt,
|
||||
StartedAt: startedAt,
|
||||
Runtime: "docker",
|
||||
Image: "alpine:latest",
|
||||
Platform: "linux/amd64",
|
||||
IPAddress: "172.17.0.2",
|
||||
MACAddress: "02:42:ac:11:00:02",
|
||||
Hostname: "sandbox-123",
|
||||
AllocatedResources: ResourceLimits{
|
||||
MemoryLimit: 1024 * 1024 * 1024, // 1GB
|
||||
CPULimit: 2.0,
|
||||
},
|
||||
Labels: map[string]string{
|
||||
"app": "chorus",
|
||||
},
|
||||
Annotations: map[string]string{
|
||||
"creator": "test",
|
||||
},
|
||||
}
|
||||
|
||||
// Validate sandbox info
|
||||
assert.Equal(t, "sandbox-123", info.ID)
|
||||
assert.Equal(t, "test-sandbox", info.Name)
|
||||
assert.Equal(t, "docker", info.Type)
|
||||
assert.Equal(t, StatusRunning, info.Status)
|
||||
assert.Equal(t, createdAt, info.CreatedAt)
|
||||
assert.Equal(t, startedAt, info.StartedAt)
|
||||
assert.Equal(t, "docker", info.Runtime)
|
||||
assert.Equal(t, "alpine:latest", info.Image)
|
||||
assert.Equal(t, "172.17.0.2", info.IPAddress)
|
||||
assert.Equal(t, "chorus", info.Labels["app"])
|
||||
assert.Equal(t, "test", info.Annotations["creator"])
|
||||
}
|
||||
|
||||
func TestSandboxStatus(t *testing.T) {
|
||||
statuses := []SandboxStatus{
|
||||
StatusCreating,
|
||||
StatusStarting,
|
||||
StatusRunning,
|
||||
StatusPaused,
|
||||
StatusStopping,
|
||||
StatusStopped,
|
||||
StatusFailed,
|
||||
StatusDestroyed,
|
||||
}
|
||||
|
||||
expectedStatuses := []string{
|
||||
"creating",
|
||||
"starting",
|
||||
"running",
|
||||
"paused",
|
||||
"stopping",
|
||||
"stopped",
|
||||
"failed",
|
||||
"destroyed",
|
||||
}
|
||||
|
||||
for i, status := range statuses {
|
||||
assert.Equal(t, expectedStatuses[i], string(status))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortMapping(t *testing.T) {
|
||||
mapping := PortMapping{
|
||||
HostPort: 8080,
|
||||
ContainerPort: 80,
|
||||
Protocol: "tcp",
|
||||
}
|
||||
|
||||
assert.Equal(t, 8080, mapping.HostPort)
|
||||
assert.Equal(t, 80, mapping.ContainerPort)
|
||||
assert.Equal(t, "tcp", mapping.Protocol)
|
||||
}
|
||||
|
||||
func TestGitConfig(t *testing.T) {
|
||||
config := GitConfig{
|
||||
UserName: "Test User",
|
||||
UserEmail: "test@example.com",
|
||||
SigningKey: "ABC123",
|
||||
ConfigValues: map[string]string{
|
||||
"core.autocrlf": "input",
|
||||
"pull.rebase": "true",
|
||||
"init.defaultBranch": "main",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "Test User", config.UserName)
|
||||
assert.Equal(t, "test@example.com", config.UserEmail)
|
||||
assert.Equal(t, "ABC123", config.SigningKey)
|
||||
assert.Equal(t, "input", config.ConfigValues["core.autocrlf"])
|
||||
assert.Equal(t, "true", config.ConfigValues["pull.rebase"])
|
||||
assert.Equal(t, "main", config.ConfigValues["init.defaultBranch"])
|
||||
}
|
||||
|
||||
// MockSandbox implements ExecutionSandbox for testing
|
||||
type MockSandbox struct {
|
||||
id string
|
||||
status SandboxStatus
|
||||
workingDir string
|
||||
environment map[string]string
|
||||
shouldFail bool
|
||||
commandResult *CommandResult
|
||||
files []FileInfo
|
||||
resourceUsage *ResourceUsage
|
||||
}
|
||||
|
||||
func NewMockSandbox() *MockSandbox {
|
||||
return &MockSandbox{
|
||||
id: "mock-sandbox-123",
|
||||
status: StatusStopped,
|
||||
workingDir: "/workspace",
|
||||
environment: make(map[string]string),
|
||||
files: []FileInfo{},
|
||||
commandResult: &CommandResult{
|
||||
Success: true,
|
||||
ExitCode: 0,
|
||||
Stdout: "mock output",
|
||||
},
|
||||
resourceUsage: &ResourceUsage{
|
||||
CPUUsage: 10.0,
|
||||
MemoryUsage: 100 * 1024 * 1024, // 100MB
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSandbox) Initialize(ctx context.Context, config *SandboxConfig) error {
|
||||
if m.shouldFail {
|
||||
return NewSandboxError(ErrSandboxInitFailed, "mock initialization failed")
|
||||
}
|
||||
m.status = StatusRunning
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) ExecuteCommand(ctx context.Context, cmd *Command) (*CommandResult, error) {
|
||||
if m.shouldFail {
|
||||
return nil, NewSandboxError(ErrCommandExecutionFailed, "mock command execution failed")
|
||||
}
|
||||
return m.commandResult, nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) CopyFiles(ctx context.Context, source, dest string) error {
|
||||
if m.shouldFail {
|
||||
return NewSandboxError(ErrFileOperationFailed, "mock file copy failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) WriteFile(ctx context.Context, path string, content []byte, mode uint32) error {
|
||||
if m.shouldFail {
|
||||
return NewSandboxError(ErrFileOperationFailed, "mock file write failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) ReadFile(ctx context.Context, path string) ([]byte, error) {
|
||||
if m.shouldFail {
|
||||
return nil, NewSandboxError(ErrFileOperationFailed, "mock file read failed")
|
||||
}
|
||||
return []byte("mock file content"), nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) ListFiles(ctx context.Context, path string) ([]FileInfo, error) {
|
||||
if m.shouldFail {
|
||||
return nil, NewSandboxError(ErrFileOperationFailed, "mock file list failed")
|
||||
}
|
||||
return m.files, nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) GetWorkingDirectory() string {
|
||||
return m.workingDir
|
||||
}
|
||||
|
||||
func (m *MockSandbox) SetWorkingDirectory(path string) error {
|
||||
if m.shouldFail {
|
||||
return NewSandboxError(ErrFileOperationFailed, "mock set working directory failed")
|
||||
}
|
||||
m.workingDir = path
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) GetEnvironment() map[string]string {
|
||||
env := make(map[string]string)
|
||||
for k, v := range m.environment {
|
||||
env[k] = v
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func (m *MockSandbox) SetEnvironment(env map[string]string) error {
|
||||
if m.shouldFail {
|
||||
return NewSandboxError(ErrFileOperationFailed, "mock set environment failed")
|
||||
}
|
||||
for k, v := range env {
|
||||
m.environment[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) GetResourceUsage(ctx context.Context) (*ResourceUsage, error) {
|
||||
if m.shouldFail {
|
||||
return nil, NewSandboxError(ErrSandboxInitFailed, "mock resource usage failed")
|
||||
}
|
||||
return m.resourceUsage, nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) Cleanup() error {
|
||||
if m.shouldFail {
|
||||
return NewSandboxError(ErrSandboxInitFailed, "mock cleanup failed")
|
||||
}
|
||||
m.status = StatusDestroyed
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSandbox) GetInfo() SandboxInfo {
|
||||
return SandboxInfo{
|
||||
ID: m.id,
|
||||
Status: m.status,
|
||||
Type: "mock",
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockSandbox(t *testing.T) {
|
||||
sandbox := NewMockSandbox()
|
||||
ctx := context.Background()
|
||||
|
||||
// Test initialization
|
||||
err := sandbox.Initialize(ctx, &SandboxConfig{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, StatusRunning, sandbox.status)
|
||||
|
||||
// Test command execution
|
||||
result, err := sandbox.ExecuteCommand(ctx, &Command{})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.Success)
|
||||
assert.Equal(t, "mock output", result.Stdout)
|
||||
|
||||
// Test file operations
|
||||
err = sandbox.WriteFile(ctx, "/test.txt", []byte("test"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err := sandbox.ReadFile(ctx, "/test.txt")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("mock file content"), content)
|
||||
|
||||
files, err := sandbox.ListFiles(ctx, "/")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, files) // Mock returns empty list by default
|
||||
|
||||
// Test environment
|
||||
env := sandbox.GetEnvironment()
|
||||
assert.Empty(t, env)
|
||||
|
||||
err = sandbox.SetEnvironment(map[string]string{"TEST": "value"})
|
||||
require.NoError(t, err)
|
||||
|
||||
env = sandbox.GetEnvironment()
|
||||
assert.Equal(t, "value", env["TEST"])
|
||||
|
||||
// Test resource usage
|
||||
usage, err := sandbox.GetResourceUsage(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 10.0, usage.CPUUsage)
|
||||
|
||||
// Test cleanup
|
||||
err = sandbox.Cleanup()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, StatusDestroyed, sandbox.status)
|
||||
}
|
||||
|
||||
func TestMockSandboxFailure(t *testing.T) {
|
||||
sandbox := NewMockSandbox()
|
||||
sandbox.shouldFail = true
|
||||
ctx := context.Background()
|
||||
|
||||
// All operations should fail when shouldFail is true
|
||||
err := sandbox.Initialize(ctx, &SandboxConfig{})
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = sandbox.ExecuteCommand(ctx, &Command{})
|
||||
assert.Error(t, err)
|
||||
|
||||
err = sandbox.WriteFile(ctx, "/test.txt", []byte("test"), 0644)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = sandbox.ReadFile(ctx, "/test.txt")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = sandbox.ListFiles(ctx, "/")
|
||||
assert.Error(t, err)
|
||||
|
||||
err = sandbox.SetWorkingDirectory("/tmp")
|
||||
assert.Error(t, err)
|
||||
|
||||
err = sandbox.SetEnvironment(map[string]string{"TEST": "value"})
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = sandbox.GetResourceUsage(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = sandbox.Cleanup()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -179,9 +179,11 @@ func (ehc *EnhancedHealthChecks) registerHealthChecks() {
|
||||
ehc.manager.RegisterCheck(ehc.createEnhancedPubSubCheck())
|
||||
}
|
||||
|
||||
if ehc.config.EnableDHTProbes {
|
||||
ehc.manager.RegisterCheck(ehc.createEnhancedDHTCheck())
|
||||
}
|
||||
// Temporarily disable DHT health check to prevent shutdown issues
|
||||
// TODO: Fix DHT configuration and re-enable this check
|
||||
// if ehc.config.EnableDHTProbes {
|
||||
// ehc.manager.RegisterCheck(ehc.createEnhancedDHTCheck())
|
||||
// }
|
||||
|
||||
if ehc.config.EnableElectionProbes {
|
||||
ehc.manager.RegisterCheck(ehc.createElectionHealthCheck())
|
||||
@@ -290,7 +292,7 @@ func (ehc *EnhancedHealthChecks) createElectionHealthCheck() *HealthCheck {
|
||||
return &HealthCheck{
|
||||
Name: "election-health",
|
||||
Description: "Election system health and leadership stability check",
|
||||
Enabled: true,
|
||||
Enabled: false, // Temporarily disabled to prevent shutdown loops
|
||||
Critical: false,
|
||||
Interval: ehc.config.ElectionProbeInterval,
|
||||
Timeout: ehc.config.ElectionProbeTimeout,
|
||||
|
||||
@@ -2,27 +2,26 @@ package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// CHORUSMetrics provides comprehensive Prometheus metrics for the CHORUS system
|
||||
type CHORUSMetrics struct {
|
||||
registry *prometheus.Registry
|
||||
httpServer *http.Server
|
||||
|
||||
registry *prometheus.Registry
|
||||
httpServer *http.Server
|
||||
|
||||
// System metrics
|
||||
systemInfo *prometheus.GaugeVec
|
||||
uptime prometheus.Gauge
|
||||
buildInfo *prometheus.GaugeVec
|
||||
|
||||
systemInfo *prometheus.GaugeVec
|
||||
uptime prometheus.Gauge
|
||||
buildInfo *prometheus.GaugeVec
|
||||
|
||||
// P2P metrics
|
||||
p2pConnectedPeers prometheus.Gauge
|
||||
p2pMessagesSent *prometheus.CounterVec
|
||||
@@ -30,95 +29,98 @@ type CHORUSMetrics struct {
|
||||
p2pMessageLatency *prometheus.HistogramVec
|
||||
p2pConnectionDuration *prometheus.HistogramVec
|
||||
p2pPeerScore *prometheus.GaugeVec
|
||||
|
||||
|
||||
// DHT metrics
|
||||
dhtPutOperations *prometheus.CounterVec
|
||||
dhtGetOperations *prometheus.CounterVec
|
||||
dhtOperationLatency *prometheus.HistogramVec
|
||||
dhtProviderRecords prometheus.Gauge
|
||||
dhtReplicationFactor *prometheus.GaugeVec
|
||||
dhtContentKeys prometheus.Gauge
|
||||
dhtCacheHits *prometheus.CounterVec
|
||||
dhtCacheMisses *prometheus.CounterVec
|
||||
|
||||
dhtPutOperations *prometheus.CounterVec
|
||||
dhtGetOperations *prometheus.CounterVec
|
||||
dhtOperationLatency *prometheus.HistogramVec
|
||||
dhtProviderRecords prometheus.Gauge
|
||||
dhtReplicationFactor *prometheus.GaugeVec
|
||||
dhtContentKeys prometheus.Gauge
|
||||
dhtCacheHits *prometheus.CounterVec
|
||||
dhtCacheMisses *prometheus.CounterVec
|
||||
|
||||
// PubSub metrics
|
||||
pubsubTopics prometheus.Gauge
|
||||
pubsubSubscribers *prometheus.GaugeVec
|
||||
pubsubMessages *prometheus.CounterVec
|
||||
pubsubMessageLatency *prometheus.HistogramVec
|
||||
pubsubMessageSize *prometheus.HistogramVec
|
||||
|
||||
pubsubTopics prometheus.Gauge
|
||||
pubsubSubscribers *prometheus.GaugeVec
|
||||
pubsubMessages *prometheus.CounterVec
|
||||
pubsubMessageLatency *prometheus.HistogramVec
|
||||
pubsubMessageSize *prometheus.HistogramVec
|
||||
|
||||
// Election metrics
|
||||
electionTerm prometheus.Gauge
|
||||
electionState *prometheus.GaugeVec
|
||||
heartbeatsSent prometheus.Counter
|
||||
heartbeatsReceived prometheus.Counter
|
||||
leadershipChanges prometheus.Counter
|
||||
leaderUptime prometheus.Gauge
|
||||
electionLatency prometheus.Histogram
|
||||
|
||||
electionTerm prometheus.Gauge
|
||||
electionState *prometheus.GaugeVec
|
||||
heartbeatsSent prometheus.Counter
|
||||
heartbeatsReceived prometheus.Counter
|
||||
leadershipChanges prometheus.Counter
|
||||
leaderUptime prometheus.Gauge
|
||||
electionLatency prometheus.Histogram
|
||||
|
||||
// Health metrics
|
||||
healthChecksPassed *prometheus.CounterVec
|
||||
healthChecksFailed *prometheus.CounterVec
|
||||
healthCheckDuration *prometheus.HistogramVec
|
||||
systemHealthScore prometheus.Gauge
|
||||
componentHealthScore *prometheus.GaugeVec
|
||||
|
||||
healthChecksPassed *prometheus.CounterVec
|
||||
healthChecksFailed *prometheus.CounterVec
|
||||
healthCheckDuration *prometheus.HistogramVec
|
||||
systemHealthScore prometheus.Gauge
|
||||
componentHealthScore *prometheus.GaugeVec
|
||||
|
||||
// Task metrics
|
||||
tasksActive prometheus.Gauge
|
||||
tasksQueued prometheus.Gauge
|
||||
tasksCompleted *prometheus.CounterVec
|
||||
taskDuration *prometheus.HistogramVec
|
||||
taskQueueWaitTime prometheus.Histogram
|
||||
|
||||
tasksActive prometheus.Gauge
|
||||
tasksQueued prometheus.Gauge
|
||||
tasksCompleted *prometheus.CounterVec
|
||||
taskDuration *prometheus.HistogramVec
|
||||
taskQueueWaitTime prometheus.Histogram
|
||||
|
||||
// SLURP metrics (context generation)
|
||||
slurpGenerated *prometheus.CounterVec
|
||||
slurpGenerationTime prometheus.Histogram
|
||||
slurpQueueLength prometheus.Gauge
|
||||
slurpActiveJobs prometheus.Gauge
|
||||
slurpLeadershipEvents prometheus.Counter
|
||||
|
||||
|
||||
// SHHH sentinel metrics
|
||||
shhhFindings *prometheus.CounterVec
|
||||
|
||||
// UCXI metrics (protocol resolution)
|
||||
ucxiRequests *prometheus.CounterVec
|
||||
ucxiResolutionLatency prometheus.Histogram
|
||||
ucxiCacheHits prometheus.Counter
|
||||
ucxiCacheMisses prometheus.Counter
|
||||
ucxiContentSize prometheus.Histogram
|
||||
|
||||
|
||||
// Resource metrics
|
||||
cpuUsage prometheus.Gauge
|
||||
memoryUsage prometheus.Gauge
|
||||
diskUsage *prometheus.GaugeVec
|
||||
networkBytesIn prometheus.Counter
|
||||
networkBytesOut prometheus.Counter
|
||||
goroutines prometheus.Gauge
|
||||
|
||||
cpuUsage prometheus.Gauge
|
||||
memoryUsage prometheus.Gauge
|
||||
diskUsage *prometheus.GaugeVec
|
||||
networkBytesIn prometheus.Counter
|
||||
networkBytesOut prometheus.Counter
|
||||
goroutines prometheus.Gauge
|
||||
|
||||
// Error metrics
|
||||
errors *prometheus.CounterVec
|
||||
panics prometheus.Counter
|
||||
|
||||
startTime time.Time
|
||||
mu sync.RWMutex
|
||||
errors *prometheus.CounterVec
|
||||
panics prometheus.Counter
|
||||
|
||||
startTime time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// MetricsConfig configures the metrics system
|
||||
type MetricsConfig struct {
|
||||
// HTTP server config
|
||||
ListenAddr string
|
||||
MetricsPath string
|
||||
|
||||
ListenAddr string
|
||||
MetricsPath string
|
||||
|
||||
// Histogram buckets
|
||||
LatencyBuckets []float64
|
||||
SizeBuckets []float64
|
||||
|
||||
|
||||
// Labels
|
||||
NodeID string
|
||||
Version string
|
||||
Environment string
|
||||
Cluster string
|
||||
|
||||
NodeID string
|
||||
Version string
|
||||
Environment string
|
||||
Cluster string
|
||||
|
||||
// Collection intervals
|
||||
SystemMetricsInterval time.Duration
|
||||
SystemMetricsInterval time.Duration
|
||||
ResourceMetricsInterval time.Duration
|
||||
}
|
||||
|
||||
@@ -143,20 +145,20 @@ func NewCHORUSMetrics(config *MetricsConfig) *CHORUSMetrics {
|
||||
if config == nil {
|
||||
config = DefaultMetricsConfig()
|
||||
}
|
||||
|
||||
|
||||
registry := prometheus.NewRegistry()
|
||||
|
||||
|
||||
metrics := &CHORUSMetrics{
|
||||
registry: registry,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
// Initialize all metrics
|
||||
metrics.initializeMetrics(config)
|
||||
|
||||
|
||||
// Register with custom registry
|
||||
metrics.registerMetrics()
|
||||
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
@@ -170,14 +172,14 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"node_id", "version", "go_version", "cluster", "environment"},
|
||||
)
|
||||
|
||||
|
||||
m.uptime = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_uptime_seconds",
|
||||
Help: "System uptime in seconds",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
// P2P metrics
|
||||
m.p2pConnectedPeers = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
@@ -185,7 +187,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
Help: "Number of connected P2P peers",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.p2pMessagesSent = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_p2p_messages_sent_total",
|
||||
@@ -193,7 +195,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"message_type", "peer_id"},
|
||||
)
|
||||
|
||||
|
||||
m.p2pMessagesReceived = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_p2p_messages_received_total",
|
||||
@@ -201,7 +203,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"message_type", "peer_id"},
|
||||
)
|
||||
|
||||
|
||||
m.p2pMessageLatency = promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "chorus_p2p_message_latency_seconds",
|
||||
@@ -210,7 +212,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"message_type"},
|
||||
)
|
||||
|
||||
|
||||
// DHT metrics
|
||||
m.dhtPutOperations = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
@@ -219,7 +221,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"status"},
|
||||
)
|
||||
|
||||
|
||||
m.dhtGetOperations = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_dht_get_operations_total",
|
||||
@@ -227,7 +229,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"status"},
|
||||
)
|
||||
|
||||
|
||||
m.dhtOperationLatency = promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "chorus_dht_operation_latency_seconds",
|
||||
@@ -236,21 +238,21 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"operation", "status"},
|
||||
)
|
||||
|
||||
|
||||
m.dhtProviderRecords = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_dht_provider_records",
|
||||
Help: "Number of DHT provider records",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.dhtContentKeys = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_dht_content_keys",
|
||||
Help: "Number of DHT content keys",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.dhtReplicationFactor = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_dht_replication_factor",
|
||||
@@ -258,7 +260,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"key_hash"},
|
||||
)
|
||||
|
||||
|
||||
// PubSub metrics
|
||||
m.pubsubTopics = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
@@ -266,7 +268,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
Help: "Number of active PubSub topics",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.pubsubMessages = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_pubsub_messages_total",
|
||||
@@ -274,7 +276,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"topic", "direction", "message_type"},
|
||||
)
|
||||
|
||||
|
||||
m.pubsubMessageLatency = promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "chorus_pubsub_message_latency_seconds",
|
||||
@@ -283,7 +285,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"topic"},
|
||||
)
|
||||
|
||||
|
||||
// Election metrics
|
||||
m.electionTerm = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
@@ -291,7 +293,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
Help: "Current election term",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.electionState = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_election_state",
|
||||
@@ -299,28 +301,28 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"state"},
|
||||
)
|
||||
|
||||
|
||||
m.heartbeatsSent = promauto.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_heartbeats_sent_total",
|
||||
Help: "Total number of heartbeats sent",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.heartbeatsReceived = promauto.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_heartbeats_received_total",
|
||||
Help: "Total number of heartbeats received",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.leadershipChanges = promauto.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_leadership_changes_total",
|
||||
Help: "Total number of leadership changes",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
// Health metrics
|
||||
m.healthChecksPassed = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
@@ -329,7 +331,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"check_name"},
|
||||
)
|
||||
|
||||
|
||||
m.healthChecksFailed = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_health_checks_failed_total",
|
||||
@@ -337,14 +339,14 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"check_name", "reason"},
|
||||
)
|
||||
|
||||
|
||||
m.systemHealthScore = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_system_health_score",
|
||||
Help: "Overall system health score (0-1)",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.componentHealthScore = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_component_health_score",
|
||||
@@ -352,7 +354,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"component"},
|
||||
)
|
||||
|
||||
|
||||
// Task metrics
|
||||
m.tasksActive = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
@@ -360,14 +362,14 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
Help: "Number of active tasks",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.tasksQueued = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_tasks_queued",
|
||||
Help: "Number of queued tasks",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.tasksCompleted = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_tasks_completed_total",
|
||||
@@ -375,7 +377,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"status", "task_type"},
|
||||
)
|
||||
|
||||
|
||||
m.taskDuration = promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "chorus_task_duration_seconds",
|
||||
@@ -384,7 +386,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"task_type", "status"},
|
||||
)
|
||||
|
||||
|
||||
// SLURP metrics
|
||||
m.slurpGenerated = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
@@ -393,7 +395,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"role", "status"},
|
||||
)
|
||||
|
||||
|
||||
m.slurpGenerationTime = promauto.NewHistogram(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "chorus_slurp_generation_time_seconds",
|
||||
@@ -401,14 +403,23 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
Buckets: []float64{0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.slurpQueueLength = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_slurp_queue_length",
|
||||
Help: "Length of SLURP generation queue",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
// SHHH metrics
|
||||
m.shhhFindings = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_shhh_findings_total",
|
||||
Help: "Total number of SHHH redaction findings",
|
||||
},
|
||||
[]string{"rule", "severity"},
|
||||
)
|
||||
|
||||
// UCXI metrics
|
||||
m.ucxiRequests = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
@@ -417,7 +428,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"method", "status"},
|
||||
)
|
||||
|
||||
|
||||
m.ucxiResolutionLatency = promauto.NewHistogram(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "chorus_ucxi_resolution_latency_seconds",
|
||||
@@ -425,7 +436,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
Buckets: config.LatencyBuckets,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
// Resource metrics
|
||||
m.cpuUsage = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
@@ -433,14 +444,14 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
Help: "CPU usage ratio (0-1)",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.memoryUsage = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_memory_usage_bytes",
|
||||
Help: "Memory usage in bytes",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
m.diskUsage = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_disk_usage_ratio",
|
||||
@@ -448,14 +459,14 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"mount_point"},
|
||||
)
|
||||
|
||||
|
||||
m.goroutines = promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "chorus_goroutines",
|
||||
Help: "Number of goroutines",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
// Error metrics
|
||||
m.errors = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
@@ -464,7 +475,7 @@ func (m *CHORUSMetrics) initializeMetrics(config *MetricsConfig) {
|
||||
},
|
||||
[]string{"component", "error_type"},
|
||||
)
|
||||
|
||||
|
||||
m.panics = promauto.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "chorus_panics_total",
|
||||
@@ -482,31 +493,31 @@ func (m *CHORUSMetrics) registerMetrics() {
|
||||
// StartServer starts the Prometheus metrics HTTP server
|
||||
func (m *CHORUSMetrics) StartServer(config *MetricsConfig) error {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
|
||||
// Use custom registry
|
||||
handler := promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{
|
||||
EnableOpenMetrics: true,
|
||||
})
|
||||
mux.Handle(config.MetricsPath, handler)
|
||||
|
||||
|
||||
// Health endpoint
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
|
||||
m.httpServer = &http.Server{
|
||||
Addr: config.ListenAddr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
|
||||
go func() {
|
||||
log.Printf("Starting metrics server on %s%s", config.ListenAddr, config.MetricsPath)
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Printf("Metrics server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -656,6 +667,15 @@ func (m *CHORUSMetrics) SetSLURPQueueLength(length int) {
|
||||
m.slurpQueueLength.Set(float64(length))
|
||||
}
|
||||
|
||||
// SHHH Metrics Methods
|
||||
|
||||
func (m *CHORUSMetrics) IncrementSHHHFindings(rule, severity string, count int) {
|
||||
if m == nil || m.shhhFindings == nil || count <= 0 {
|
||||
return
|
||||
}
|
||||
m.shhhFindings.WithLabelValues(rule, severity).Add(float64(count))
|
||||
}
|
||||
|
||||
// UCXI Metrics Methods
|
||||
|
||||
func (m *CHORUSMetrics) IncrementUCXIRequests(method, status string) {
|
||||
@@ -708,21 +728,21 @@ func (m *CHORUSMetrics) UpdateUptime() {
|
||||
func (m *CHORUSMetrics) CollectMetrics(config *MetricsConfig) {
|
||||
systemTicker := time.NewTicker(config.SystemMetricsInterval)
|
||||
resourceTicker := time.NewTicker(config.ResourceMetricsInterval)
|
||||
|
||||
|
||||
go func() {
|
||||
defer systemTicker.Stop()
|
||||
defer resourceTicker.Stop()
|
||||
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-systemTicker.C:
|
||||
m.UpdateUptime()
|
||||
// Collect other system metrics
|
||||
|
||||
|
||||
case <-resourceTicker.C:
|
||||
// Collect resource metrics (would integrate with actual system monitoring)
|
||||
// m.collectResourceMetrics()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
261
pkg/providers/factory.go
Normal file
261
pkg/providers/factory.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"chorus/pkg/repository"
|
||||
)
|
||||
|
||||
// ProviderFactory creates task providers for different repository types
|
||||
type ProviderFactory struct {
|
||||
supportedProviders map[string]ProviderCreator
|
||||
}
|
||||
|
||||
// ProviderCreator is a function that creates a provider from config
|
||||
type ProviderCreator func(config *repository.Config) (repository.TaskProvider, error)
|
||||
|
||||
// NewProviderFactory creates a new provider factory with all supported providers
|
||||
func NewProviderFactory() *ProviderFactory {
|
||||
factory := &ProviderFactory{
|
||||
supportedProviders: make(map[string]ProviderCreator),
|
||||
}
|
||||
|
||||
// Register all supported providers
|
||||
factory.RegisterProvider("gitea", func(config *repository.Config) (repository.TaskProvider, error) {
|
||||
return NewGiteaProvider(config)
|
||||
})
|
||||
|
||||
factory.RegisterProvider("github", func(config *repository.Config) (repository.TaskProvider, error) {
|
||||
return NewGitHubProvider(config)
|
||||
})
|
||||
|
||||
factory.RegisterProvider("gitlab", func(config *repository.Config) (repository.TaskProvider, error) {
|
||||
return NewGitLabProvider(config)
|
||||
})
|
||||
|
||||
factory.RegisterProvider("mock", func(config *repository.Config) (repository.TaskProvider, error) {
|
||||
return &repository.MockTaskProvider{}, nil
|
||||
})
|
||||
|
||||
return factory
|
||||
}
|
||||
|
||||
// RegisterProvider registers a new provider creator
|
||||
func (f *ProviderFactory) RegisterProvider(providerType string, creator ProviderCreator) {
|
||||
f.supportedProviders[strings.ToLower(providerType)] = creator
|
||||
}
|
||||
|
||||
// CreateProvider creates a task provider based on the configuration
|
||||
func (f *ProviderFactory) CreateProvider(ctx interface{}, config *repository.Config) (repository.TaskProvider, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("configuration cannot be nil")
|
||||
}
|
||||
|
||||
providerType := strings.ToLower(config.Provider)
|
||||
if providerType == "" {
|
||||
// Fall back to Type field if Provider is not set
|
||||
providerType = strings.ToLower(config.Type)
|
||||
}
|
||||
|
||||
if providerType == "" {
|
||||
return nil, fmt.Errorf("provider type must be specified in config.Provider or config.Type")
|
||||
}
|
||||
|
||||
creator, exists := f.supportedProviders[providerType]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider type: %s. Supported types: %v",
|
||||
providerType, f.GetSupportedTypes())
|
||||
}
|
||||
|
||||
provider, err := creator(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s provider: %w", providerType, err)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// GetSupportedTypes returns a list of all supported provider types
|
||||
func (f *ProviderFactory) GetSupportedTypes() []string {
|
||||
types := make([]string, 0, len(f.supportedProviders))
|
||||
for providerType := range f.supportedProviders {
|
||||
types = append(types, providerType)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
// SupportedProviders returns list of supported providers (alias for GetSupportedTypes)
|
||||
func (f *ProviderFactory) SupportedProviders() []string {
|
||||
return f.GetSupportedTypes()
|
||||
}
|
||||
|
||||
// ValidateConfig validates a provider configuration
|
||||
func (f *ProviderFactory) ValidateConfig(config *repository.Config) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("configuration cannot be nil")
|
||||
}
|
||||
|
||||
providerType := strings.ToLower(config.Provider)
|
||||
if providerType == "" {
|
||||
providerType = strings.ToLower(config.Type)
|
||||
}
|
||||
|
||||
if providerType == "" {
|
||||
return fmt.Errorf("provider type must be specified")
|
||||
}
|
||||
|
||||
// Check if provider type is supported
|
||||
if _, exists := f.supportedProviders[providerType]; !exists {
|
||||
return fmt.Errorf("unsupported provider type: %s", providerType)
|
||||
}
|
||||
|
||||
// Provider-specific validation
|
||||
switch providerType {
|
||||
case "gitea":
|
||||
return f.validateGiteaConfig(config)
|
||||
case "github":
|
||||
return f.validateGitHubConfig(config)
|
||||
case "gitlab":
|
||||
return f.validateGitLabConfig(config)
|
||||
case "mock":
|
||||
return nil // Mock provider doesn't need validation
|
||||
default:
|
||||
return fmt.Errorf("validation not implemented for provider type: %s", providerType)
|
||||
}
|
||||
}
|
||||
|
||||
// validateGiteaConfig validates Gitea-specific configuration
|
||||
func (f *ProviderFactory) validateGiteaConfig(config *repository.Config) error {
|
||||
if config.BaseURL == "" {
|
||||
return fmt.Errorf("baseURL is required for Gitea provider")
|
||||
}
|
||||
if config.AccessToken == "" {
|
||||
return fmt.Errorf("accessToken is required for Gitea provider")
|
||||
}
|
||||
if config.Owner == "" {
|
||||
return fmt.Errorf("owner is required for Gitea provider")
|
||||
}
|
||||
if config.Repository == "" {
|
||||
return fmt.Errorf("repository is required for Gitea provider")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateGitHubConfig validates GitHub-specific configuration
|
||||
func (f *ProviderFactory) validateGitHubConfig(config *repository.Config) error {
|
||||
if config.AccessToken == "" {
|
||||
return fmt.Errorf("accessToken is required for GitHub provider")
|
||||
}
|
||||
if config.Owner == "" {
|
||||
return fmt.Errorf("owner is required for GitHub provider")
|
||||
}
|
||||
if config.Repository == "" {
|
||||
return fmt.Errorf("repository is required for GitHub provider")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateGitLabConfig validates GitLab-specific configuration
|
||||
func (f *ProviderFactory) validateGitLabConfig(config *repository.Config) error {
|
||||
if config.AccessToken == "" {
|
||||
return fmt.Errorf("accessToken is required for GitLab provider")
|
||||
}
|
||||
|
||||
// GitLab requires either owner/repository or project_id in settings
|
||||
if config.Owner != "" && config.Repository != "" {
|
||||
return nil // owner/repo provided
|
||||
}
|
||||
|
||||
if config.Settings != nil {
|
||||
if projectID, ok := config.Settings["project_id"].(string); ok && projectID != "" {
|
||||
return nil // project_id provided
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("either owner/repository or project_id in settings is required for GitLab provider")
|
||||
}
|
||||
|
||||
// GetProviderInfo returns information about a specific provider
|
||||
func (f *ProviderFactory) GetProviderInfo(providerType string) (*ProviderInfo, error) {
|
||||
providerType = strings.ToLower(providerType)
|
||||
|
||||
if _, exists := f.supportedProviders[providerType]; !exists {
|
||||
return nil, fmt.Errorf("unsupported provider type: %s", providerType)
|
||||
}
|
||||
|
||||
switch providerType {
|
||||
case "gitea":
|
||||
return &ProviderInfo{
|
||||
Name: "Gitea",
|
||||
Type: "gitea",
|
||||
Description: "Gitea self-hosted Git service provider",
|
||||
RequiredFields: []string{"baseURL", "accessToken", "owner", "repository"},
|
||||
OptionalFields: []string{"taskLabel", "inProgressLabel", "completedLabel", "baseBranch", "branchPrefix"},
|
||||
SupportedFeatures: []string{"issues", "labels", "comments", "assignments"},
|
||||
APIDocumentation: "https://docs.gitea.io/en-us/api-usage/",
|
||||
}, nil
|
||||
|
||||
case "github":
|
||||
return &ProviderInfo{
|
||||
Name: "GitHub",
|
||||
Type: "github",
|
||||
Description: "GitHub cloud and enterprise Git service provider",
|
||||
RequiredFields: []string{"accessToken", "owner", "repository"},
|
||||
OptionalFields: []string{"taskLabel", "inProgressLabel", "completedLabel", "baseBranch", "branchPrefix"},
|
||||
SupportedFeatures: []string{"issues", "labels", "comments", "assignments", "projects"},
|
||||
APIDocumentation: "https://docs.github.com/en/rest",
|
||||
}, nil
|
||||
|
||||
case "gitlab":
|
||||
return &ProviderInfo{
|
||||
Name: "GitLab",
|
||||
Type: "gitlab",
|
||||
Description: "GitLab cloud and self-hosted Git service provider",
|
||||
RequiredFields: []string{"accessToken", "owner/repository OR project_id"},
|
||||
OptionalFields: []string{"baseURL", "taskLabel", "inProgressLabel", "completedLabel", "baseBranch", "branchPrefix"},
|
||||
SupportedFeatures: []string{"issues", "labels", "notes", "assignments", "time_tracking", "milestones"},
|
||||
APIDocumentation: "https://docs.gitlab.com/ee/api/",
|
||||
}, nil
|
||||
|
||||
case "mock":
|
||||
return &ProviderInfo{
|
||||
Name: "Mock Provider",
|
||||
Type: "mock",
|
||||
Description: "Mock provider for testing and development",
|
||||
RequiredFields: []string{},
|
||||
OptionalFields: []string{},
|
||||
SupportedFeatures: []string{"basic_operations"},
|
||||
APIDocumentation: "Built-in mock for testing purposes",
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("provider info not available for: %s", providerType)
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderInfo contains metadata about a provider
|
||||
type ProviderInfo struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
RequiredFields []string `json:"required_fields"`
|
||||
OptionalFields []string `json:"optional_fields"`
|
||||
SupportedFeatures []string `json:"supported_features"`
|
||||
APIDocumentation string `json:"api_documentation"`
|
||||
}
|
||||
|
||||
// ListProviders returns detailed information about all supported providers
|
||||
func (f *ProviderFactory) ListProviders() ([]*ProviderInfo, error) {
|
||||
providers := make([]*ProviderInfo, 0, len(f.supportedProviders))
|
||||
|
||||
for providerType := range f.supportedProviders {
|
||||
info, err := f.GetProviderInfo(providerType)
|
||||
if err != nil {
|
||||
continue // Skip providers without info
|
||||
}
|
||||
providers = append(providers, info)
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
617
pkg/providers/gitea.go
Normal file
617
pkg/providers/gitea.go
Normal file
@@ -0,0 +1,617 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/repository"
|
||||
)
|
||||
|
||||
// GiteaProvider implements TaskProvider for Gitea API
|
||||
type GiteaProvider struct {
|
||||
config *repository.Config
|
||||
httpClient *http.Client
|
||||
baseURL string
|
||||
token string
|
||||
owner string
|
||||
repo string
|
||||
}
|
||||
|
||||
// NewGiteaProvider creates a new Gitea provider
|
||||
func NewGiteaProvider(config *repository.Config) (*GiteaProvider, error) {
|
||||
if config.BaseURL == "" {
|
||||
return nil, fmt.Errorf("base URL is required for Gitea provider")
|
||||
}
|
||||
if config.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is required for Gitea provider")
|
||||
}
|
||||
if config.Owner == "" {
|
||||
return nil, fmt.Errorf("owner is required for Gitea provider")
|
||||
}
|
||||
if config.Repository == "" {
|
||||
return nil, fmt.Errorf("repository name is required for Gitea provider")
|
||||
}
|
||||
|
||||
// Ensure base URL has proper format
|
||||
baseURL := strings.TrimSuffix(config.BaseURL, "/")
|
||||
if !strings.HasPrefix(baseURL, "http") {
|
||||
baseURL = "https://" + baseURL
|
||||
}
|
||||
|
||||
return &GiteaProvider{
|
||||
config: config,
|
||||
baseURL: baseURL,
|
||||
token: config.AccessToken,
|
||||
owner: config.Owner,
|
||||
repo: config.Repository,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GiteaIssue represents a Gitea issue
|
||||
type GiteaIssue struct {
|
||||
ID int64 `json:"id"`
|
||||
Number int `json:"number"`
|
||||
Title string `json:"title"`
|
||||
Body string `json:"body"`
|
||||
State string `json:"state"`
|
||||
Labels []GiteaLabel `json:"labels"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Repository *GiteaRepository `json:"repository"`
|
||||
Assignee *GiteaUser `json:"assignee"`
|
||||
Assignees []GiteaUser `json:"assignees"`
|
||||
}
|
||||
|
||||
// GiteaLabel represents a Gitea label
|
||||
type GiteaLabel struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Color string `json:"color"`
|
||||
}
|
||||
|
||||
// GiteaRepository represents a Gitea repository
|
||||
type GiteaRepository struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
FullName string `json:"full_name"`
|
||||
Owner *GiteaUser `json:"owner"`
|
||||
}
|
||||
|
||||
// GiteaUser represents a Gitea user
|
||||
type GiteaUser struct {
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
FullName string `json:"full_name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// GiteaComment represents a Gitea issue comment
|
||||
type GiteaComment struct {
|
||||
ID int64 `json:"id"`
|
||||
Body string `json:"body"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
User *GiteaUser `json:"user"`
|
||||
}
|
||||
|
||||
// makeRequest makes an HTTP request to the Gitea API
|
||||
func (g *GiteaProvider) makeRequest(method, endpoint string, body interface{}) (*http.Response, error) {
|
||||
var reqBody io.Reader
|
||||
|
||||
if body != nil {
|
||||
jsonData, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
reqBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1%s", g.baseURL, endpoint)
|
||||
req, err := http.NewRequest(method, url, reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "token "+g.token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := g.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetTasks retrieves tasks (issues) from the Gitea repository
|
||||
func (g *GiteaProvider) GetTasks(projectID int) ([]*repository.Task, error) {
|
||||
// Build query parameters
|
||||
params := url.Values{}
|
||||
params.Add("state", "open")
|
||||
params.Add("type", "issues")
|
||||
params.Add("sort", "created")
|
||||
params.Add("order", "desc")
|
||||
|
||||
// Add task label filter if specified
|
||||
if g.config.TaskLabel != "" {
|
||||
params.Add("labels", g.config.TaskLabel)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues?%s", g.owner, g.repo, params.Encode())
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get issues: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var issues []GiteaIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||
}
|
||||
|
||||
// Convert Gitea issues to repository tasks
|
||||
tasks := make([]*repository.Task, 0, len(issues))
|
||||
for _, issue := range issues {
|
||||
task := g.issueToTask(&issue)
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// ClaimTask claims a task by assigning it to the agent and adding in-progress label
|
||||
func (g *GiteaProvider) ClaimTask(taskNumber int, agentID string) (bool, error) {
|
||||
// First, get the current issue to check its state
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, taskNumber)
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get issue: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return false, fmt.Errorf("issue not found or not accessible")
|
||||
}
|
||||
|
||||
var issue GiteaIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||
return false, fmt.Errorf("failed to decode issue: %w", err)
|
||||
}
|
||||
|
||||
// Check if issue is already assigned
|
||||
if issue.Assignee != nil {
|
||||
return false, fmt.Errorf("issue is already assigned to %s", issue.Assignee.Username)
|
||||
}
|
||||
|
||||
// Add in-progress label if specified
|
||||
if g.config.InProgressLabel != "" {
|
||||
err := g.addLabelToIssue(taskNumber, g.config.InProgressLabel)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to add in-progress label: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add a comment indicating the task has been claimed
|
||||
comment := fmt.Sprintf("🤖 Task claimed by CHORUS agent `%s`\n\nThis task is now being processed automatically.", agentID)
|
||||
err = g.addCommentToIssue(taskNumber, comment)
|
||||
if err != nil {
|
||||
// Don't fail the claim if comment fails
|
||||
fmt.Printf("Warning: failed to add claim comment: %v\n", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// UpdateTaskStatus updates the status of a task
|
||||
func (g *GiteaProvider) UpdateTaskStatus(task *repository.Task, status string, comment string) error {
|
||||
// Add a comment with the status update
|
||||
statusComment := fmt.Sprintf("**Status Update:** %s\n\n%s", status, comment)
|
||||
|
||||
err := g.addCommentToIssue(task.Number, statusComment)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add status comment: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteTask completes a task by updating status and adding completion comment
|
||||
func (g *GiteaProvider) CompleteTask(task *repository.Task, result *repository.TaskResult) error {
|
||||
// Create completion comment with results
|
||||
var commentBuffer strings.Builder
|
||||
commentBuffer.WriteString(fmt.Sprintf("✅ **Task Completed Successfully**\n\n"))
|
||||
commentBuffer.WriteString(fmt.Sprintf("**Result:** %s\n\n", result.Message))
|
||||
|
||||
// Add metadata if available
|
||||
if result.Metadata != nil {
|
||||
commentBuffer.WriteString("**Execution Details:**\n")
|
||||
for key, value := range result.Metadata {
|
||||
commentBuffer.WriteString(fmt.Sprintf("- **%s:** %v\n", key, value))
|
||||
}
|
||||
commentBuffer.WriteString("\n")
|
||||
}
|
||||
|
||||
commentBuffer.WriteString("🤖 Completed by CHORUS autonomous agent")
|
||||
|
||||
// Add completion comment
|
||||
err := g.addCommentToIssue(task.Number, commentBuffer.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add completion comment: %w", err)
|
||||
}
|
||||
|
||||
// Remove in-progress label and add completed label
|
||||
if g.config.InProgressLabel != "" {
|
||||
err := g.removeLabelFromIssue(task.Number, g.config.InProgressLabel)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to remove in-progress label: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
if g.config.CompletedLabel != "" {
|
||||
err := g.addLabelToIssue(task.Number, g.config.CompletedLabel)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to add completed label: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close the issue if the task was successful
|
||||
if result.Success {
|
||||
err := g.closeIssue(task.Number)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close issue: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTaskDetails retrieves detailed information about a specific task
|
||||
func (g *GiteaProvider) GetTaskDetails(projectID int, taskNumber int) (*repository.Task, error) {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, taskNumber)
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get issue: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("issue not found")
|
||||
}
|
||||
|
||||
var issue GiteaIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issue: %w", err)
|
||||
}
|
||||
|
||||
return g.issueToTask(&issue), nil
|
||||
}
|
||||
|
||||
// ListAvailableTasks lists all available (unassigned) tasks
|
||||
func (g *GiteaProvider) ListAvailableTasks(projectID int) ([]*repository.Task, error) {
|
||||
// Get all open issues without assignees
|
||||
params := url.Values{}
|
||||
params.Add("state", "open")
|
||||
params.Add("type", "issues")
|
||||
params.Add("assigned", "false") // Only unassigned issues
|
||||
|
||||
if g.config.TaskLabel != "" {
|
||||
params.Add("labels", g.config.TaskLabel)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues?%s", g.owner, g.repo, params.Encode())
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available issues: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var issues []GiteaIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||
}
|
||||
|
||||
// Convert to tasks and filter out assigned ones
|
||||
tasks := make([]*repository.Task, 0, len(issues))
|
||||
for _, issue := range issues {
|
||||
// Skip assigned issues
|
||||
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
task := g.issueToTask(&issue)
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
// issueToTask converts a Gitea issue to a repository Task
|
||||
func (g *GiteaProvider) issueToTask(issue *GiteaIssue) *repository.Task {
|
||||
// Extract labels
|
||||
labels := make([]string, len(issue.Labels))
|
||||
for i, label := range issue.Labels {
|
||||
labels[i] = label.Name
|
||||
}
|
||||
|
||||
// Calculate priority and complexity based on labels and content
|
||||
priority := g.calculatePriority(labels, issue.Title, issue.Body)
|
||||
complexity := g.calculateComplexity(labels, issue.Title, issue.Body)
|
||||
|
||||
// Determine required role and expertise from labels
|
||||
requiredRole := g.determineRequiredRole(labels)
|
||||
requiredExpertise := g.determineRequiredExpertise(labels)
|
||||
|
||||
return &repository.Task{
|
||||
Number: issue.Number,
|
||||
Title: issue.Title,
|
||||
Body: issue.Body,
|
||||
Repository: fmt.Sprintf("%s/%s", g.owner, g.repo),
|
||||
Labels: labels,
|
||||
Priority: priority,
|
||||
Complexity: complexity,
|
||||
Status: issue.State,
|
||||
CreatedAt: issue.CreatedAt,
|
||||
UpdatedAt: issue.UpdatedAt,
|
||||
RequiredRole: requiredRole,
|
||||
RequiredExpertise: requiredExpertise,
|
||||
Metadata: map[string]interface{}{
|
||||
"gitea_id": issue.ID,
|
||||
"provider": "gitea",
|
||||
"repository": issue.Repository,
|
||||
"assignee": issue.Assignee,
|
||||
"assignees": issue.Assignees,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// calculatePriority determines task priority from labels and content
|
||||
func (g *GiteaProvider) calculatePriority(labels []string, title, body string) int {
|
||||
priority := 5 // default
|
||||
|
||||
for _, label := range labels {
|
||||
switch strings.ToLower(label) {
|
||||
case "priority:critical", "critical", "urgent":
|
||||
priority = 10
|
||||
case "priority:high", "high":
|
||||
priority = 8
|
||||
case "priority:medium", "medium":
|
||||
priority = 5
|
||||
case "priority:low", "low":
|
||||
priority = 2
|
||||
case "bug", "security", "hotfix":
|
||||
priority = max(priority, 7)
|
||||
}
|
||||
}
|
||||
|
||||
// Boost priority for urgent keywords in title
|
||||
titleLower := strings.ToLower(title)
|
||||
if strings.Contains(titleLower, "urgent") || strings.Contains(titleLower, "critical") ||
|
||||
strings.Contains(titleLower, "hotfix") || strings.Contains(titleLower, "security") {
|
||||
priority = max(priority, 8)
|
||||
}
|
||||
|
||||
return priority
|
||||
}
|
||||
|
||||
// calculateComplexity estimates task complexity from labels and content
|
||||
func (g *GiteaProvider) calculateComplexity(labels []string, title, body string) int {
|
||||
complexity := 3 // default
|
||||
|
||||
for _, label := range labels {
|
||||
switch strings.ToLower(label) {
|
||||
case "complexity:high", "epic", "major":
|
||||
complexity = 8
|
||||
case "complexity:medium":
|
||||
complexity = 5
|
||||
case "complexity:low", "simple", "trivial":
|
||||
complexity = 2
|
||||
case "refactor", "architecture":
|
||||
complexity = max(complexity, 7)
|
||||
case "bug", "hotfix":
|
||||
complexity = max(complexity, 4)
|
||||
case "enhancement", "feature":
|
||||
complexity = max(complexity, 5)
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate complexity from body length
|
||||
bodyLength := len(strings.Fields(body))
|
||||
if bodyLength > 200 {
|
||||
complexity = max(complexity, 6)
|
||||
} else if bodyLength > 50 {
|
||||
complexity = max(complexity, 4)
|
||||
}
|
||||
|
||||
return complexity
|
||||
}
|
||||
|
||||
// determineRequiredRole determines what agent role is needed for this task
|
||||
func (g *GiteaProvider) determineRequiredRole(labels []string) string {
|
||||
for _, label := range labels {
|
||||
switch strings.ToLower(label) {
|
||||
case "frontend", "ui", "ux", "css", "html", "javascript", "react", "vue":
|
||||
return "frontend-developer"
|
||||
case "backend", "api", "server", "database", "sql":
|
||||
return "backend-developer"
|
||||
case "devops", "infrastructure", "deployment", "docker", "kubernetes":
|
||||
return "devops-engineer"
|
||||
case "security", "authentication", "authorization":
|
||||
return "security-engineer"
|
||||
case "testing", "qa", "quality":
|
||||
return "tester"
|
||||
case "documentation", "docs":
|
||||
return "technical-writer"
|
||||
case "design", "mockup", "wireframe":
|
||||
return "designer"
|
||||
}
|
||||
}
|
||||
|
||||
return "developer" // default role
|
||||
}
|
||||
|
||||
// determineRequiredExpertise determines what expertise is needed
|
||||
func (g *GiteaProvider) determineRequiredExpertise(labels []string) []string {
|
||||
expertise := make([]string, 0)
|
||||
expertiseMap := make(map[string]bool) // prevent duplicates
|
||||
|
||||
for _, label := range labels {
|
||||
labelLower := strings.ToLower(label)
|
||||
|
||||
// Programming languages
|
||||
languages := []string{"go", "python", "javascript", "typescript", "java", "rust", "c++", "php"}
|
||||
for _, lang := range languages {
|
||||
if strings.Contains(labelLower, lang) {
|
||||
if !expertiseMap[lang] {
|
||||
expertise = append(expertise, lang)
|
||||
expertiseMap[lang] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Technologies and frameworks
|
||||
technologies := []string{"docker", "kubernetes", "react", "vue", "angular", "nodejs", "django", "flask", "spring"}
|
||||
for _, tech := range technologies {
|
||||
if strings.Contains(labelLower, tech) {
|
||||
if !expertiseMap[tech] {
|
||||
expertise = append(expertise, tech)
|
||||
expertiseMap[tech] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Domain areas
|
||||
domains := []string{"frontend", "backend", "database", "security", "testing", "devops", "api"}
|
||||
for _, domain := range domains {
|
||||
if strings.Contains(labelLower, domain) {
|
||||
if !expertiseMap[domain] {
|
||||
expertise = append(expertise, domain)
|
||||
expertiseMap[domain] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default expertise if none detected
|
||||
if len(expertise) == 0 {
|
||||
expertise = []string{"development", "programming"}
|
||||
}
|
||||
|
||||
return expertise
|
||||
}
|
||||
|
||||
// addLabelToIssue adds a label to an issue
|
||||
func (g *GiteaProvider) addLabelToIssue(issueNumber int, labelName string) error {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/labels", g.owner, g.repo, issueNumber)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"labels": []string{labelName},
|
||||
}
|
||||
|
||||
resp, err := g.makeRequest("POST", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to add label (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeLabelFromIssue removes a label from an issue
|
||||
func (g *GiteaProvider) removeLabelFromIssue(issueNumber int, labelName string) error {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/labels/%s", g.owner, g.repo, issueNumber, url.QueryEscape(labelName))
|
||||
|
||||
resp, err := g.makeRequest("DELETE", endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to remove label (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addCommentToIssue adds a comment to an issue
|
||||
func (g *GiteaProvider) addCommentToIssue(issueNumber int, comment string) error {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/comments", g.owner, g.repo, issueNumber)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"body": comment,
|
||||
}
|
||||
|
||||
resp, err := g.makeRequest("POST", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to add comment (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeIssue closes an issue
|
||||
func (g *GiteaProvider) closeIssue(issueNumber int) error {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, issueNumber)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"state": "closed",
|
||||
}
|
||||
|
||||
resp, err := g.makeRequest("PATCH", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to close issue (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// max returns the maximum of two integers
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
732
pkg/providers/github.go
Normal file
732
pkg/providers/github.go
Normal file
@@ -0,0 +1,732 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/repository"
|
||||
)
|
||||
|
||||
// GitHubProvider implements TaskProvider for GitHub API
|
||||
type GitHubProvider struct {
|
||||
config *repository.Config
|
||||
httpClient *http.Client
|
||||
token string
|
||||
owner string
|
||||
repo string
|
||||
}
|
||||
|
||||
// NewGitHubProvider creates a new GitHub provider
|
||||
func NewGitHubProvider(config *repository.Config) (*GitHubProvider, error) {
|
||||
if config.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is required for GitHub provider")
|
||||
}
|
||||
if config.Owner == "" {
|
||||
return nil, fmt.Errorf("owner is required for GitHub provider")
|
||||
}
|
||||
if config.Repository == "" {
|
||||
return nil, fmt.Errorf("repository name is required for GitHub provider")
|
||||
}
|
||||
|
||||
return &GitHubProvider{
|
||||
config: config,
|
||||
token: config.AccessToken,
|
||||
owner: config.Owner,
|
||||
repo: config.Repository,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GitHubIssue represents a GitHub issue
|
||||
type GitHubIssue struct {
|
||||
ID int64 `json:"id"`
|
||||
Number int `json:"number"`
|
||||
Title string `json:"title"`
|
||||
Body string `json:"body"`
|
||||
State string `json:"state"`
|
||||
Labels []GitHubLabel `json:"labels"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Repository *GitHubRepository `json:"repository,omitempty"`
|
||||
Assignee *GitHubUser `json:"assignee"`
|
||||
Assignees []GitHubUser `json:"assignees"`
|
||||
User *GitHubUser `json:"user"`
|
||||
PullRequest *GitHubPullRequestRef `json:"pull_request,omitempty"`
|
||||
}
|
||||
|
||||
// GitHubLabel represents a GitHub label
|
||||
type GitHubLabel struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Color string `json:"color"`
|
||||
}
|
||||
|
||||
// GitHubRepository represents a GitHub repository
|
||||
type GitHubRepository struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
FullName string `json:"full_name"`
|
||||
Owner *GitHubUser `json:"owner"`
|
||||
}
|
||||
|
||||
// GitHubUser represents a GitHub user
|
||||
type GitHubUser struct {
|
||||
ID int64 `json:"id"`
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}
|
||||
|
||||
// GitHubPullRequestRef indicates if issue is a PR
|
||||
type GitHubPullRequestRef struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// GitHubComment represents a GitHub issue comment
|
||||
type GitHubComment struct {
|
||||
ID int64 `json:"id"`
|
||||
Body string `json:"body"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
User *GitHubUser `json:"user"`
|
||||
}
|
||||
|
||||
// makeRequest makes an HTTP request to the GitHub API
|
||||
func (g *GitHubProvider) makeRequest(method, endpoint string, body interface{}) (*http.Response, error) {
|
||||
var reqBody io.Reader
|
||||
|
||||
if body != nil {
|
||||
jsonData, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
reqBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://api.github.com%s", endpoint)
|
||||
req, err := http.NewRequest(method, url, reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "token "+g.token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "CHORUS-Agent/1.0")
|
||||
|
||||
resp, err := g.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetTasks retrieves tasks (issues) from the GitHub repository
|
||||
func (g *GitHubProvider) GetTasks(projectID int) ([]*repository.Task, error) {
|
||||
// Build query parameters
|
||||
params := url.Values{}
|
||||
params.Add("state", "open")
|
||||
params.Add("sort", "created")
|
||||
params.Add("direction", "desc")
|
||||
|
||||
// Add task label filter if specified
|
||||
if g.config.TaskLabel != "" {
|
||||
params.Add("labels", g.config.TaskLabel)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues?%s", g.owner, g.repo, params.Encode())
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get issues: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var issues []GitHubIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||
}
|
||||
|
||||
// Filter out pull requests (GitHub API includes PRs in issues endpoint)
|
||||
tasks := make([]*repository.Task, 0, len(issues))
|
||||
for _, issue := range issues {
|
||||
// Skip pull requests
|
||||
if issue.PullRequest != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
task := g.issueToTask(&issue)
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// ClaimTask claims a task by assigning it to the agent and adding in-progress label
|
||||
func (g *GitHubProvider) ClaimTask(taskNumber int, agentID string) (bool, error) {
|
||||
// First, get the current issue to check its state
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, taskNumber)
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get issue: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return false, fmt.Errorf("issue not found or not accessible")
|
||||
}
|
||||
|
||||
var issue GitHubIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||
return false, fmt.Errorf("failed to decode issue: %w", err)
|
||||
}
|
||||
|
||||
// Check if issue is already assigned
|
||||
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||
assigneeName := ""
|
||||
if issue.Assignee != nil {
|
||||
assigneeName = issue.Assignee.Login
|
||||
} else if len(issue.Assignees) > 0 {
|
||||
assigneeName = issue.Assignees[0].Login
|
||||
}
|
||||
return false, fmt.Errorf("issue is already assigned to %s", assigneeName)
|
||||
}
|
||||
|
||||
// Add in-progress label if specified
|
||||
if g.config.InProgressLabel != "" {
|
||||
err := g.addLabelToIssue(taskNumber, g.config.InProgressLabel)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to add in-progress label: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add a comment indicating the task has been claimed
|
||||
comment := fmt.Sprintf("🤖 **Task Claimed by CHORUS Agent**\n\nAgent ID: `%s`\nStatus: Processing\n\nThis task is now being handled automatically by the CHORUS autonomous agent system.", agentID)
|
||||
err = g.addCommentToIssue(taskNumber, comment)
|
||||
if err != nil {
|
||||
// Don't fail the claim if comment fails
|
||||
fmt.Printf("Warning: failed to add claim comment: %v\n", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// UpdateTaskStatus updates the status of a task
|
||||
func (g *GitHubProvider) UpdateTaskStatus(task *repository.Task, status string, comment string) error {
|
||||
// Add a comment with the status update
|
||||
statusComment := fmt.Sprintf("📊 **Status Update: %s**\n\n%s\n\n---\n*Updated by CHORUS Agent*", status, comment)
|
||||
|
||||
err := g.addCommentToIssue(task.Number, statusComment)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add status comment: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteTask completes a task by updating status and adding completion comment
|
||||
func (g *GitHubProvider) CompleteTask(task *repository.Task, result *repository.TaskResult) error {
|
||||
// Create completion comment with results
|
||||
var commentBuffer strings.Builder
|
||||
commentBuffer.WriteString("✅ **Task Completed Successfully**\n\n")
|
||||
commentBuffer.WriteString(fmt.Sprintf("**Result:** %s\n\n", result.Message))
|
||||
|
||||
// Add metadata if available
|
||||
if result.Metadata != nil {
|
||||
commentBuffer.WriteString("## Execution Details\n\n")
|
||||
for key, value := range result.Metadata {
|
||||
// Format the metadata nicely
|
||||
switch key {
|
||||
case "duration":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- ⏱️ **Duration:** %v\n", value))
|
||||
case "execution_type":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 🔧 **Execution Type:** %v\n", value))
|
||||
case "commands_executed":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 🖥️ **Commands Executed:** %v\n", value))
|
||||
case "files_generated":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 📄 **Files Generated:** %v\n", value))
|
||||
case "ai_provider":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 🤖 **AI Provider:** %v\n", value))
|
||||
case "ai_model":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 🧠 **AI Model:** %v\n", value))
|
||||
default:
|
||||
commentBuffer.WriteString(fmt.Sprintf("- **%s:** %v\n", key, value))
|
||||
}
|
||||
}
|
||||
commentBuffer.WriteString("\n")
|
||||
}
|
||||
|
||||
commentBuffer.WriteString("---\n🤖 *Completed by CHORUS Autonomous Agent System*")
|
||||
|
||||
// Add completion comment
|
||||
err := g.addCommentToIssue(task.Number, commentBuffer.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add completion comment: %w", err)
|
||||
}
|
||||
|
||||
// Remove in-progress label and add completed label
|
||||
if g.config.InProgressLabel != "" {
|
||||
err := g.removeLabelFromIssue(task.Number, g.config.InProgressLabel)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to remove in-progress label: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
if g.config.CompletedLabel != "" {
|
||||
err := g.addLabelToIssue(task.Number, g.config.CompletedLabel)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to add completed label: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close the issue if the task was successful
|
||||
if result.Success {
|
||||
err := g.closeIssue(task.Number)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close issue: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTaskDetails retrieves detailed information about a specific task
|
||||
func (g *GitHubProvider) GetTaskDetails(projectID int, taskNumber int) (*repository.Task, error) {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, taskNumber)
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get issue: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("issue not found")
|
||||
}
|
||||
|
||||
var issue GitHubIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issue: %w", err)
|
||||
}
|
||||
|
||||
// Skip pull requests
|
||||
if issue.PullRequest != nil {
|
||||
return nil, fmt.Errorf("pull requests are not supported as tasks")
|
||||
}
|
||||
|
||||
return g.issueToTask(&issue), nil
|
||||
}
|
||||
|
||||
// ListAvailableTasks lists all available (unassigned) tasks
|
||||
func (g *GitHubProvider) ListAvailableTasks(projectID int) ([]*repository.Task, error) {
|
||||
// GitHub doesn't have a direct "unassigned" filter, so we get open issues and filter
|
||||
params := url.Values{}
|
||||
params.Add("state", "open")
|
||||
params.Add("sort", "created")
|
||||
params.Add("direction", "desc")
|
||||
|
||||
if g.config.TaskLabel != "" {
|
||||
params.Add("labels", g.config.TaskLabel)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues?%s", g.owner, g.repo, params.Encode())
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available issues: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var issues []GitHubIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||
}
|
||||
|
||||
// Filter out assigned issues and PRs
|
||||
tasks := make([]*repository.Task, 0, len(issues))
|
||||
for _, issue := range issues {
|
||||
// Skip pull requests
|
||||
if issue.PullRequest != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip assigned issues
|
||||
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
task := g.issueToTask(&issue)
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
// issueToTask converts a GitHub issue to a repository Task
|
||||
func (g *GitHubProvider) issueToTask(issue *GitHubIssue) *repository.Task {
|
||||
// Extract labels
|
||||
labels := make([]string, len(issue.Labels))
|
||||
for i, label := range issue.Labels {
|
||||
labels[i] = label.Name
|
||||
}
|
||||
|
||||
// Calculate priority and complexity based on labels and content
|
||||
priority := g.calculatePriority(labels, issue.Title, issue.Body)
|
||||
complexity := g.calculateComplexity(labels, issue.Title, issue.Body)
|
||||
|
||||
// Determine required role and expertise from labels
|
||||
requiredRole := g.determineRequiredRole(labels)
|
||||
requiredExpertise := g.determineRequiredExpertise(labels)
|
||||
|
||||
return &repository.Task{
|
||||
Number: issue.Number,
|
||||
Title: issue.Title,
|
||||
Body: issue.Body,
|
||||
Repository: fmt.Sprintf("%s/%s", g.owner, g.repo),
|
||||
Labels: labels,
|
||||
Priority: priority,
|
||||
Complexity: complexity,
|
||||
Status: issue.State,
|
||||
CreatedAt: issue.CreatedAt,
|
||||
UpdatedAt: issue.UpdatedAt,
|
||||
RequiredRole: requiredRole,
|
||||
RequiredExpertise: requiredExpertise,
|
||||
Metadata: map[string]interface{}{
|
||||
"github_id": issue.ID,
|
||||
"provider": "github",
|
||||
"repository": issue.Repository,
|
||||
"assignee": issue.Assignee,
|
||||
"assignees": issue.Assignees,
|
||||
"user": issue.User,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// calculatePriority determines task priority from labels and content
|
||||
func (g *GitHubProvider) calculatePriority(labels []string, title, body string) int {
|
||||
priority := 5 // default
|
||||
|
||||
for _, label := range labels {
|
||||
labelLower := strings.ToLower(label)
|
||||
switch {
|
||||
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "critical"):
|
||||
priority = 10
|
||||
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "high"):
|
||||
priority = 8
|
||||
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "medium"):
|
||||
priority = 5
|
||||
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "low"):
|
||||
priority = 2
|
||||
case labelLower == "critical" || labelLower == "urgent":
|
||||
priority = 10
|
||||
case labelLower == "high":
|
||||
priority = 8
|
||||
case labelLower == "bug" || labelLower == "security" || labelLower == "hotfix":
|
||||
priority = max(priority, 7)
|
||||
case labelLower == "enhancement" || labelLower == "feature":
|
||||
priority = max(priority, 5)
|
||||
case labelLower == "good first issue":
|
||||
priority = max(priority, 3)
|
||||
}
|
||||
}
|
||||
|
||||
// Boost priority for urgent keywords in title
|
||||
titleLower := strings.ToLower(title)
|
||||
urgentKeywords := []string{"urgent", "critical", "hotfix", "security", "broken", "crash"}
|
||||
for _, keyword := range urgentKeywords {
|
||||
if strings.Contains(titleLower, keyword) {
|
||||
priority = max(priority, 8)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return priority
|
||||
}
|
||||
|
||||
// calculateComplexity estimates task complexity from labels and content
|
||||
func (g *GitHubProvider) calculateComplexity(labels []string, title, body string) int {
|
||||
complexity := 3 // default
|
||||
|
||||
for _, label := range labels {
|
||||
labelLower := strings.ToLower(label)
|
||||
switch {
|
||||
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "high"):
|
||||
complexity = 8
|
||||
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "medium"):
|
||||
complexity = 5
|
||||
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "low"):
|
||||
complexity = 2
|
||||
case labelLower == "epic" || labelLower == "major":
|
||||
complexity = 8
|
||||
case labelLower == "refactor" || labelLower == "architecture":
|
||||
complexity = max(complexity, 7)
|
||||
case labelLower == "bug" || labelLower == "hotfix":
|
||||
complexity = max(complexity, 4)
|
||||
case labelLower == "enhancement" || labelLower == "feature":
|
||||
complexity = max(complexity, 5)
|
||||
case labelLower == "good first issue" || labelLower == "beginner":
|
||||
complexity = 2
|
||||
case labelLower == "documentation" || labelLower == "docs":
|
||||
complexity = max(complexity, 3)
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate complexity from body length and content
|
||||
bodyLength := len(strings.Fields(body))
|
||||
if bodyLength > 500 {
|
||||
complexity = max(complexity, 7)
|
||||
} else if bodyLength > 200 {
|
||||
complexity = max(complexity, 5)
|
||||
} else if bodyLength > 50 {
|
||||
complexity = max(complexity, 4)
|
||||
}
|
||||
|
||||
// Look for complexity indicators in content
|
||||
bodyLower := strings.ToLower(body)
|
||||
complexityIndicators := []string{"refactor", "architecture", "breaking change", "migration", "redesign"}
|
||||
for _, indicator := range complexityIndicators {
|
||||
if strings.Contains(bodyLower, indicator) {
|
||||
complexity = max(complexity, 7)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return complexity
|
||||
}
|
||||
|
||||
// determineRequiredRole determines what agent role is needed for this task
|
||||
func (g *GitHubProvider) determineRequiredRole(labels []string) string {
|
||||
roleKeywords := map[string]string{
|
||||
// Frontend
|
||||
"frontend": "frontend-developer",
|
||||
"ui": "frontend-developer",
|
||||
"ux": "ui-ux-designer",
|
||||
"css": "frontend-developer",
|
||||
"html": "frontend-developer",
|
||||
"javascript": "frontend-developer",
|
||||
"react": "frontend-developer",
|
||||
"vue": "frontend-developer",
|
||||
"angular": "frontend-developer",
|
||||
|
||||
// Backend
|
||||
"backend": "backend-developer",
|
||||
"api": "backend-developer",
|
||||
"server": "backend-developer",
|
||||
"database": "backend-developer",
|
||||
"sql": "backend-developer",
|
||||
|
||||
// DevOps
|
||||
"devops": "devops-engineer",
|
||||
"infrastructure": "devops-engineer",
|
||||
"deployment": "devops-engineer",
|
||||
"docker": "devops-engineer",
|
||||
"kubernetes": "devops-engineer",
|
||||
"ci/cd": "devops-engineer",
|
||||
|
||||
// Security
|
||||
"security": "security-engineer",
|
||||
"authentication": "security-engineer",
|
||||
"authorization": "security-engineer",
|
||||
"vulnerability": "security-engineer",
|
||||
|
||||
// Testing
|
||||
"testing": "tester",
|
||||
"qa": "tester",
|
||||
"test": "tester",
|
||||
|
||||
// Documentation
|
||||
"documentation": "technical-writer",
|
||||
"docs": "technical-writer",
|
||||
|
||||
// Design
|
||||
"design": "ui-ux-designer",
|
||||
"mockup": "ui-ux-designer",
|
||||
"wireframe": "ui-ux-designer",
|
||||
}
|
||||
|
||||
for _, label := range labels {
|
||||
labelLower := strings.ToLower(label)
|
||||
for keyword, role := range roleKeywords {
|
||||
if strings.Contains(labelLower, keyword) {
|
||||
return role
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "developer" // default role
|
||||
}
|
||||
|
||||
// determineRequiredExpertise determines what expertise is needed
|
||||
func (g *GitHubProvider) determineRequiredExpertise(labels []string) []string {
|
||||
expertise := make([]string, 0)
|
||||
expertiseMap := make(map[string]bool) // prevent duplicates
|
||||
|
||||
expertiseKeywords := map[string][]string{
|
||||
// Programming languages
|
||||
"go": {"go", "golang"},
|
||||
"python": {"python"},
|
||||
"javascript": {"javascript", "js"},
|
||||
"typescript": {"typescript", "ts"},
|
||||
"java": {"java"},
|
||||
"rust": {"rust"},
|
||||
"c++": {"c++", "cpp"},
|
||||
"c#": {"c#", "csharp"},
|
||||
"php": {"php"},
|
||||
"ruby": {"ruby"},
|
||||
|
||||
// Frontend technologies
|
||||
"react": {"react"},
|
||||
"vue": {"vue", "vuejs"},
|
||||
"angular": {"angular"},
|
||||
"svelte": {"svelte"},
|
||||
|
||||
// Backend frameworks
|
||||
"nodejs": {"nodejs", "node.js", "node"},
|
||||
"django": {"django"},
|
||||
"flask": {"flask"},
|
||||
"spring": {"spring"},
|
||||
"express": {"express"},
|
||||
|
||||
// Databases
|
||||
"postgresql": {"postgresql", "postgres"},
|
||||
"mysql": {"mysql"},
|
||||
"mongodb": {"mongodb", "mongo"},
|
||||
"redis": {"redis"},
|
||||
|
||||
// DevOps tools
|
||||
"docker": {"docker"},
|
||||
"kubernetes": {"kubernetes", "k8s"},
|
||||
"aws": {"aws"},
|
||||
"azure": {"azure"},
|
||||
"gcp": {"gcp", "google cloud"},
|
||||
|
||||
// Other technologies
|
||||
"graphql": {"graphql"},
|
||||
"rest": {"rest", "restful"},
|
||||
"grpc": {"grpc"},
|
||||
}
|
||||
|
||||
for _, label := range labels {
|
||||
labelLower := strings.ToLower(label)
|
||||
for expertiseArea, keywords := range expertiseKeywords {
|
||||
for _, keyword := range keywords {
|
||||
if strings.Contains(labelLower, keyword) && !expertiseMap[expertiseArea] {
|
||||
expertise = append(expertise, expertiseArea)
|
||||
expertiseMap[expertiseArea] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default expertise if none detected
|
||||
if len(expertise) == 0 {
|
||||
expertise = []string{"development", "programming"}
|
||||
}
|
||||
|
||||
return expertise
|
||||
}
|
||||
|
||||
// addLabelToIssue adds a label to an issue
|
||||
func (g *GitHubProvider) addLabelToIssue(issueNumber int, labelName string) error {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/labels", g.owner, g.repo, issueNumber)
|
||||
|
||||
body := []string{labelName}
|
||||
|
||||
resp, err := g.makeRequest("POST", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to add label (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeLabelFromIssue removes a label from an issue
|
||||
func (g *GitHubProvider) removeLabelFromIssue(issueNumber int, labelName string) error {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/labels/%s", g.owner, g.repo, issueNumber, url.QueryEscape(labelName))
|
||||
|
||||
resp, err := g.makeRequest("DELETE", endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to remove label (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addCommentToIssue adds a comment to an issue
|
||||
func (g *GitHubProvider) addCommentToIssue(issueNumber int, comment string) error {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/comments", g.owner, g.repo, issueNumber)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"body": comment,
|
||||
}
|
||||
|
||||
resp, err := g.makeRequest("POST", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to add comment (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeIssue closes an issue
|
||||
func (g *GitHubProvider) closeIssue(issueNumber int) error {
|
||||
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, issueNumber)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"state": "closed",
|
||||
}
|
||||
|
||||
resp, err := g.makeRequest("PATCH", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to close issue (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
781
pkg/providers/gitlab.go
Normal file
781
pkg/providers/gitlab.go
Normal file
@@ -0,0 +1,781 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/repository"
|
||||
)
|
||||
|
||||
// GitLabProvider implements TaskProvider for GitLab API
|
||||
type GitLabProvider struct {
|
||||
config *repository.Config
|
||||
httpClient *http.Client
|
||||
baseURL string
|
||||
token string
|
||||
projectID string // GitLab uses project ID or namespace/project-name
|
||||
}
|
||||
|
||||
// NewGitLabProvider creates a new GitLab provider
|
||||
func NewGitLabProvider(config *repository.Config) (*GitLabProvider, error) {
|
||||
if config.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is required for GitLab provider")
|
||||
}
|
||||
|
||||
// Default to gitlab.com if no base URL provided
|
||||
baseURL := config.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://gitlab.com"
|
||||
}
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
|
||||
// Build project ID from owner/repo if provided, otherwise use settings
|
||||
var projectID string
|
||||
if config.Owner != "" && config.Repository != "" {
|
||||
projectID = url.QueryEscape(fmt.Sprintf("%s/%s", config.Owner, config.Repository))
|
||||
} else if projectIDSetting, ok := config.Settings["project_id"].(string); ok {
|
||||
projectID = projectIDSetting
|
||||
} else {
|
||||
return nil, fmt.Errorf("either owner/repository or project_id in settings is required for GitLab provider")
|
||||
}
|
||||
|
||||
return &GitLabProvider{
|
||||
config: config,
|
||||
baseURL: baseURL,
|
||||
token: config.AccessToken,
|
||||
projectID: projectID,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GitLabIssue represents a GitLab issue
|
||||
type GitLabIssue struct {
|
||||
ID int `json:"id"`
|
||||
IID int `json:"iid"` // Project-specific ID (what users see)
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
Labels []string `json:"labels"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ProjectID int `json:"project_id"`
|
||||
Author *GitLabUser `json:"author"`
|
||||
Assignee *GitLabUser `json:"assignee"`
|
||||
Assignees []GitLabUser `json:"assignees"`
|
||||
WebURL string `json:"web_url"`
|
||||
TimeStats *GitLabTimeStats `json:"time_stats,omitempty"`
|
||||
}
|
||||
|
||||
// GitLabUser represents a GitLab user
|
||||
type GitLabUser struct {
|
||||
ID int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}
|
||||
|
||||
// GitLabTimeStats represents time tracking statistics
|
||||
type GitLabTimeStats struct {
|
||||
TimeEstimate int `json:"time_estimate"`
|
||||
TotalTimeSpent int `json:"total_time_spent"`
|
||||
HumanTimeEstimate string `json:"human_time_estimate"`
|
||||
HumanTotalTimeSpent string `json:"human_total_time_spent"`
|
||||
}
|
||||
|
||||
// GitLabNote represents a GitLab issue note (comment)
|
||||
type GitLabNote struct {
|
||||
ID int `json:"id"`
|
||||
Body string `json:"body"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Author *GitLabUser `json:"author"`
|
||||
System bool `json:"system"`
|
||||
}
|
||||
|
||||
// GitLabProject represents a GitLab project
|
||||
type GitLabProject struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
NameWithNamespace string `json:"name_with_namespace"`
|
||||
PathWithNamespace string `json:"path_with_namespace"`
|
||||
WebURL string `json:"web_url"`
|
||||
}
|
||||
|
||||
// makeRequest makes an HTTP request to the GitLab API
|
||||
func (g *GitLabProvider) makeRequest(method, endpoint string, body interface{}) (*http.Response, error) {
|
||||
var reqBody io.Reader
|
||||
|
||||
if body != nil {
|
||||
jsonData, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
reqBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v4%s", g.baseURL, endpoint)
|
||||
req, err := http.NewRequest(method, url, reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Private-Token", g.token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := g.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetTasks retrieves tasks (issues) from the GitLab project
|
||||
func (g *GitLabProvider) GetTasks(projectID int) ([]*repository.Task, error) {
|
||||
// Build query parameters
|
||||
params := url.Values{}
|
||||
params.Add("state", "opened")
|
||||
params.Add("sort", "created_desc")
|
||||
params.Add("per_page", "100") // GitLab default is 20
|
||||
|
||||
// Add task label filter if specified
|
||||
if g.config.TaskLabel != "" {
|
||||
params.Add("labels", g.config.TaskLabel)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/projects/%s/issues?%s", g.projectID, params.Encode())
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get issues: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var issues []GitLabIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||
}
|
||||
|
||||
// Convert GitLab issues to repository tasks
|
||||
tasks := make([]*repository.Task, 0, len(issues))
|
||||
for _, issue := range issues {
|
||||
task := g.issueToTask(&issue)
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// ClaimTask claims a task by assigning it to the agent and adding in-progress label
|
||||
func (g *GitLabProvider) ClaimTask(taskNumber int, agentID string) (bool, error) {
|
||||
// First, get the current issue to check its state
|
||||
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, taskNumber)
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get issue: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return false, fmt.Errorf("issue not found or not accessible")
|
||||
}
|
||||
|
||||
var issue GitLabIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||
return false, fmt.Errorf("failed to decode issue: %w", err)
|
||||
}
|
||||
|
||||
// Check if issue is already assigned
|
||||
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||
assigneeName := ""
|
||||
if issue.Assignee != nil {
|
||||
assigneeName = issue.Assignee.Username
|
||||
} else if len(issue.Assignees) > 0 {
|
||||
assigneeName = issue.Assignees[0].Username
|
||||
}
|
||||
return false, fmt.Errorf("issue is already assigned to %s", assigneeName)
|
||||
}
|
||||
|
||||
// Add in-progress label if specified
|
||||
if g.config.InProgressLabel != "" {
|
||||
err := g.addLabelToIssue(taskNumber, g.config.InProgressLabel)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to add in-progress label: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add a note indicating the task has been claimed
|
||||
comment := fmt.Sprintf("🤖 **Task Claimed by CHORUS Agent**\n\nAgent ID: `%s` \nStatus: Processing \n\nThis task is now being handled automatically by the CHORUS autonomous agent system.", agentID)
|
||||
err = g.addNoteToIssue(taskNumber, comment)
|
||||
if err != nil {
|
||||
// Don't fail the claim if note fails
|
||||
fmt.Printf("Warning: failed to add claim note: %v\n", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// UpdateTaskStatus updates the status of a task
|
||||
func (g *GitLabProvider) UpdateTaskStatus(task *repository.Task, status string, comment string) error {
|
||||
// Add a note with the status update
|
||||
statusComment := fmt.Sprintf("📊 **Status Update: %s**\n\n%s\n\n---\n*Updated by CHORUS Agent*", status, comment)
|
||||
|
||||
err := g.addNoteToIssue(task.Number, statusComment)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add status note: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteTask completes a task by updating status and adding completion comment
|
||||
func (g *GitLabProvider) CompleteTask(task *repository.Task, result *repository.TaskResult) error {
|
||||
// Create completion comment with results
|
||||
var commentBuffer strings.Builder
|
||||
commentBuffer.WriteString("✅ **Task Completed Successfully**\n\n")
|
||||
commentBuffer.WriteString(fmt.Sprintf("**Result:** %s\n\n", result.Message))
|
||||
|
||||
// Add metadata if available
|
||||
if result.Metadata != nil {
|
||||
commentBuffer.WriteString("## Execution Details\n\n")
|
||||
for key, value := range result.Metadata {
|
||||
// Format the metadata nicely
|
||||
switch key {
|
||||
case "duration":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- ⏱️ **Duration:** %v\n", value))
|
||||
case "execution_type":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 🔧 **Execution Type:** %v\n", value))
|
||||
case "commands_executed":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 🖥️ **Commands Executed:** %v\n", value))
|
||||
case "files_generated":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 📄 **Files Generated:** %v\n", value))
|
||||
case "ai_provider":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 🤖 **AI Provider:** %v\n", value))
|
||||
case "ai_model":
|
||||
commentBuffer.WriteString(fmt.Sprintf("- 🧠 **AI Model:** %v\n", value))
|
||||
default:
|
||||
commentBuffer.WriteString(fmt.Sprintf("- **%s:** %v\n", key, value))
|
||||
}
|
||||
}
|
||||
commentBuffer.WriteString("\n")
|
||||
}
|
||||
|
||||
commentBuffer.WriteString("---\n🤖 *Completed by CHORUS Autonomous Agent System*")
|
||||
|
||||
// Add completion note
|
||||
err := g.addNoteToIssue(task.Number, commentBuffer.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add completion note: %w", err)
|
||||
}
|
||||
|
||||
// Remove in-progress label and add completed label
|
||||
if g.config.InProgressLabel != "" {
|
||||
err := g.removeLabelFromIssue(task.Number, g.config.InProgressLabel)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to remove in-progress label: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
if g.config.CompletedLabel != "" {
|
||||
err := g.addLabelToIssue(task.Number, g.config.CompletedLabel)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to add completed label: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close the issue if the task was successful
|
||||
if result.Success {
|
||||
err := g.closeIssue(task.Number)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close issue: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTaskDetails retrieves detailed information about a specific task
|
||||
func (g *GitLabProvider) GetTaskDetails(projectID int, taskNumber int) (*repository.Task, error) {
|
||||
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, taskNumber)
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get issue: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("issue not found")
|
||||
}
|
||||
|
||||
var issue GitLabIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issue: %w", err)
|
||||
}
|
||||
|
||||
return g.issueToTask(&issue), nil
|
||||
}
|
||||
|
||||
// ListAvailableTasks lists all available (unassigned) tasks
|
||||
func (g *GitLabProvider) ListAvailableTasks(projectID int) ([]*repository.Task, error) {
|
||||
// Get open issues without assignees
|
||||
params := url.Values{}
|
||||
params.Add("state", "opened")
|
||||
params.Add("assignee_id", "None") // GitLab filter for unassigned issues
|
||||
params.Add("sort", "created_desc")
|
||||
params.Add("per_page", "100")
|
||||
|
||||
if g.config.TaskLabel != "" {
|
||||
params.Add("labels", g.config.TaskLabel)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/projects/%s/issues?%s", g.projectID, params.Encode())
|
||||
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available issues: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var issues []GitLabIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||
}
|
||||
|
||||
// Convert to tasks
|
||||
tasks := make([]*repository.Task, 0, len(issues))
|
||||
for _, issue := range issues {
|
||||
// Double-check that issue is truly unassigned
|
||||
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
task := g.issueToTask(&issue)
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
// issueToTask converts a GitLab issue to a repository Task
|
||||
func (g *GitLabProvider) issueToTask(issue *GitLabIssue) *repository.Task {
|
||||
// Calculate priority and complexity based on labels and content
|
||||
priority := g.calculatePriority(issue.Labels, issue.Title, issue.Description)
|
||||
complexity := g.calculateComplexity(issue.Labels, issue.Title, issue.Description)
|
||||
|
||||
// Determine required role and expertise from labels
|
||||
requiredRole := g.determineRequiredRole(issue.Labels)
|
||||
requiredExpertise := g.determineRequiredExpertise(issue.Labels)
|
||||
|
||||
// Extract project name from projectID
|
||||
repositoryName := strings.Replace(g.projectID, "%2F", "/", -1) // URL decode
|
||||
|
||||
return &repository.Task{
|
||||
Number: issue.IID, // Use IID (project-specific ID) not global ID
|
||||
Title: issue.Title,
|
||||
Body: issue.Description,
|
||||
Repository: repositoryName,
|
||||
Labels: issue.Labels,
|
||||
Priority: priority,
|
||||
Complexity: complexity,
|
||||
Status: issue.State,
|
||||
CreatedAt: issue.CreatedAt,
|
||||
UpdatedAt: issue.UpdatedAt,
|
||||
RequiredRole: requiredRole,
|
||||
RequiredExpertise: requiredExpertise,
|
||||
Metadata: map[string]interface{}{
|
||||
"gitlab_id": issue.ID,
|
||||
"gitlab_iid": issue.IID,
|
||||
"provider": "gitlab",
|
||||
"project_id": issue.ProjectID,
|
||||
"web_url": issue.WebURL,
|
||||
"assignee": issue.Assignee,
|
||||
"assignees": issue.Assignees,
|
||||
"author": issue.Author,
|
||||
"time_stats": issue.TimeStats,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// calculatePriority determines task priority from labels and content
|
||||
func (g *GitLabProvider) calculatePriority(labels []string, title, body string) int {
|
||||
priority := 5 // default
|
||||
|
||||
for _, label := range labels {
|
||||
labelLower := strings.ToLower(label)
|
||||
switch {
|
||||
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "critical"):
|
||||
priority = 10
|
||||
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "high"):
|
||||
priority = 8
|
||||
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "medium"):
|
||||
priority = 5
|
||||
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "low"):
|
||||
priority = 2
|
||||
case labelLower == "critical" || labelLower == "urgent":
|
||||
priority = 10
|
||||
case labelLower == "high":
|
||||
priority = 8
|
||||
case labelLower == "bug" || labelLower == "security" || labelLower == "hotfix":
|
||||
priority = max(priority, 7)
|
||||
case labelLower == "enhancement" || labelLower == "feature":
|
||||
priority = max(priority, 5)
|
||||
case strings.Contains(labelLower, "milestone"):
|
||||
priority = max(priority, 6)
|
||||
}
|
||||
}
|
||||
|
||||
// Boost priority for urgent keywords in title
|
||||
titleLower := strings.ToLower(title)
|
||||
urgentKeywords := []string{"urgent", "critical", "hotfix", "security", "broken", "crash", "blocker"}
|
||||
for _, keyword := range urgentKeywords {
|
||||
if strings.Contains(titleLower, keyword) {
|
||||
priority = max(priority, 8)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return priority
|
||||
}
|
||||
|
||||
// calculateComplexity estimates task complexity from labels and content
|
||||
func (g *GitLabProvider) calculateComplexity(labels []string, title, body string) int {
|
||||
complexity := 3 // default
|
||||
|
||||
for _, label := range labels {
|
||||
labelLower := strings.ToLower(label)
|
||||
switch {
|
||||
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "high"):
|
||||
complexity = 8
|
||||
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "medium"):
|
||||
complexity = 5
|
||||
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "low"):
|
||||
complexity = 2
|
||||
case labelLower == "epic" || labelLower == "major":
|
||||
complexity = 8
|
||||
case labelLower == "refactor" || labelLower == "architecture":
|
||||
complexity = max(complexity, 7)
|
||||
case labelLower == "bug" || labelLower == "hotfix":
|
||||
complexity = max(complexity, 4)
|
||||
case labelLower == "enhancement" || labelLower == "feature":
|
||||
complexity = max(complexity, 5)
|
||||
case strings.Contains(labelLower, "beginner") || strings.Contains(labelLower, "newcomer"):
|
||||
complexity = 2
|
||||
case labelLower == "documentation" || labelLower == "docs":
|
||||
complexity = max(complexity, 3)
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate complexity from body length and content
|
||||
bodyLength := len(strings.Fields(body))
|
||||
if bodyLength > 500 {
|
||||
complexity = max(complexity, 7)
|
||||
} else if bodyLength > 200 {
|
||||
complexity = max(complexity, 5)
|
||||
} else if bodyLength > 50 {
|
||||
complexity = max(complexity, 4)
|
||||
}
|
||||
|
||||
// Look for complexity indicators in content
|
||||
bodyLower := strings.ToLower(body)
|
||||
complexityIndicators := []string{
|
||||
"refactor", "architecture", "breaking change", "migration",
|
||||
"redesign", "database schema", "api changes", "infrastructure",
|
||||
}
|
||||
for _, indicator := range complexityIndicators {
|
||||
if strings.Contains(bodyLower, indicator) {
|
||||
complexity = max(complexity, 7)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return complexity
|
||||
}
|
||||
|
||||
// determineRequiredRole determines what agent role is needed for this task
|
||||
func (g *GitLabProvider) determineRequiredRole(labels []string) string {
|
||||
roleKeywords := map[string]string{
|
||||
// Frontend
|
||||
"frontend": "frontend-developer",
|
||||
"ui": "frontend-developer",
|
||||
"ux": "ui-ux-designer",
|
||||
"css": "frontend-developer",
|
||||
"html": "frontend-developer",
|
||||
"javascript": "frontend-developer",
|
||||
"react": "frontend-developer",
|
||||
"vue": "frontend-developer",
|
||||
"angular": "frontend-developer",
|
||||
|
||||
// Backend
|
||||
"backend": "backend-developer",
|
||||
"api": "backend-developer",
|
||||
"server": "backend-developer",
|
||||
"database": "backend-developer",
|
||||
"sql": "backend-developer",
|
||||
|
||||
// DevOps
|
||||
"devops": "devops-engineer",
|
||||
"infrastructure": "devops-engineer",
|
||||
"deployment": "devops-engineer",
|
||||
"docker": "devops-engineer",
|
||||
"kubernetes": "devops-engineer",
|
||||
"ci/cd": "devops-engineer",
|
||||
"pipeline": "devops-engineer",
|
||||
|
||||
// Security
|
||||
"security": "security-engineer",
|
||||
"authentication": "security-engineer",
|
||||
"authorization": "security-engineer",
|
||||
"vulnerability": "security-engineer",
|
||||
|
||||
// Testing
|
||||
"testing": "tester",
|
||||
"qa": "tester",
|
||||
"test": "tester",
|
||||
|
||||
// Documentation
|
||||
"documentation": "technical-writer",
|
||||
"docs": "technical-writer",
|
||||
|
||||
// Design
|
||||
"design": "ui-ux-designer",
|
||||
"mockup": "ui-ux-designer",
|
||||
"wireframe": "ui-ux-designer",
|
||||
}
|
||||
|
||||
for _, label := range labels {
|
||||
labelLower := strings.ToLower(label)
|
||||
for keyword, role := range roleKeywords {
|
||||
if strings.Contains(labelLower, keyword) {
|
||||
return role
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "developer" // default role
|
||||
}
|
||||
|
||||
// determineRequiredExpertise determines what expertise is needed
|
||||
func (g *GitLabProvider) determineRequiredExpertise(labels []string) []string {
|
||||
expertise := make([]string, 0)
|
||||
expertiseMap := make(map[string]bool) // prevent duplicates
|
||||
|
||||
expertiseKeywords := map[string][]string{
|
||||
// Programming languages
|
||||
"go": {"go", "golang"},
|
||||
"python": {"python"},
|
||||
"javascript": {"javascript", "js"},
|
||||
"typescript": {"typescript", "ts"},
|
||||
"java": {"java"},
|
||||
"rust": {"rust"},
|
||||
"c++": {"c++", "cpp"},
|
||||
"c#": {"c#", "csharp"},
|
||||
"php": {"php"},
|
||||
"ruby": {"ruby"},
|
||||
|
||||
// Frontend technologies
|
||||
"react": {"react"},
|
||||
"vue": {"vue", "vuejs"},
|
||||
"angular": {"angular"},
|
||||
"svelte": {"svelte"},
|
||||
|
||||
// Backend frameworks
|
||||
"nodejs": {"nodejs", "node.js", "node"},
|
||||
"django": {"django"},
|
||||
"flask": {"flask"},
|
||||
"spring": {"spring"},
|
||||
"express": {"express"},
|
||||
|
||||
// Databases
|
||||
"postgresql": {"postgresql", "postgres"},
|
||||
"mysql": {"mysql"},
|
||||
"mongodb": {"mongodb", "mongo"},
|
||||
"redis": {"redis"},
|
||||
|
||||
// DevOps tools
|
||||
"docker": {"docker"},
|
||||
"kubernetes": {"kubernetes", "k8s"},
|
||||
"aws": {"aws"},
|
||||
"azure": {"azure"},
|
||||
"gcp": {"gcp", "google cloud"},
|
||||
"gitlab-ci": {"gitlab-ci", "ci/cd"},
|
||||
|
||||
// Other technologies
|
||||
"graphql": {"graphql"},
|
||||
"rest": {"rest", "restful"},
|
||||
"grpc": {"grpc"},
|
||||
}
|
||||
|
||||
for _, label := range labels {
|
||||
labelLower := strings.ToLower(label)
|
||||
for expertiseArea, keywords := range expertiseKeywords {
|
||||
for _, keyword := range keywords {
|
||||
if strings.Contains(labelLower, keyword) && !expertiseMap[expertiseArea] {
|
||||
expertise = append(expertise, expertiseArea)
|
||||
expertiseMap[expertiseArea] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default expertise if none detected
|
||||
if len(expertise) == 0 {
|
||||
expertise = []string{"development", "programming"}
|
||||
}
|
||||
|
||||
return expertise
|
||||
}
|
||||
|
||||
// addLabelToIssue adds a label to an issue
|
||||
func (g *GitLabProvider) addLabelToIssue(issueNumber int, labelName string) error {
|
||||
// First get the current labels
|
||||
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, issueNumber)
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to get current issue labels")
|
||||
}
|
||||
|
||||
var issue GitLabIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||
return fmt.Errorf("failed to decode issue: %w", err)
|
||||
}
|
||||
|
||||
// Add new label to existing labels
|
||||
labels := append(issue.Labels, labelName)
|
||||
|
||||
// Update the issue with new labels
|
||||
body := map[string]interface{}{
|
||||
"labels": strings.Join(labels, ","),
|
||||
}
|
||||
|
||||
resp, err = g.makeRequest("PUT", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to add label (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeLabelFromIssue removes a label from an issue
|
||||
func (g *GitLabProvider) removeLabelFromIssue(issueNumber int, labelName string) error {
|
||||
// First get the current labels
|
||||
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, issueNumber)
|
||||
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to get current issue labels")
|
||||
}
|
||||
|
||||
var issue GitLabIssue
|
||||
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||
return fmt.Errorf("failed to decode issue: %w", err)
|
||||
}
|
||||
|
||||
// Remove the specified label
|
||||
var newLabels []string
|
||||
for _, label := range issue.Labels {
|
||||
if label != labelName {
|
||||
newLabels = append(newLabels, label)
|
||||
}
|
||||
}
|
||||
|
||||
// Update the issue with new labels
|
||||
body := map[string]interface{}{
|
||||
"labels": strings.Join(newLabels, ","),
|
||||
}
|
||||
|
||||
resp, err = g.makeRequest("PUT", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to remove label (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNoteToIssue adds a note (comment) to an issue
|
||||
func (g *GitLabProvider) addNoteToIssue(issueNumber int, note string) error {
|
||||
endpoint := fmt.Sprintf("/projects/%s/issues/%d/notes", g.projectID, issueNumber)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"body": note,
|
||||
}
|
||||
|
||||
resp, err := g.makeRequest("POST", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to add note (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeIssue closes an issue
|
||||
func (g *GitLabProvider) closeIssue(issueNumber int) error {
|
||||
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, issueNumber)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"state_event": "close",
|
||||
}
|
||||
|
||||
resp, err := g.makeRequest("PUT", endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to close issue (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
698
pkg/providers/provider_test.go
Normal file
698
pkg/providers/provider_test.go
Normal file
@@ -0,0 +1,698 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/repository"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test Gitea Provider
|
||||
func TestGiteaProvider_NewGiteaProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *repository.Config
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &repository.Config{
|
||||
BaseURL: "https://gitea.example.com",
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing base URL",
|
||||
config: &repository.Config{
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "base URL is required",
|
||||
},
|
||||
{
|
||||
name: "missing access token",
|
||||
config: &repository.Config{
|
||||
BaseURL: "https://gitea.example.com",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "access token is required",
|
||||
},
|
||||
{
|
||||
name: "missing owner",
|
||||
config: &repository.Config{
|
||||
BaseURL: "https://gitea.example.com",
|
||||
AccessToken: "test-token",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "owner is required",
|
||||
},
|
||||
{
|
||||
name: "missing repository",
|
||||
config: &repository.Config{
|
||||
BaseURL: "https://gitea.example.com",
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "repository name is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := NewGiteaProvider(tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
assert.Equal(t, tt.config.AccessToken, provider.token)
|
||||
assert.Equal(t, tt.config.Owner, provider.owner)
|
||||
assert.Equal(t, tt.config.Repository, provider.repo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGiteaProvider_GetTasks(t *testing.T) {
|
||||
// Create a mock Gitea server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
assert.Contains(t, r.URL.Path, "/api/v1/repos/testowner/testrepo/issues")
|
||||
assert.Equal(t, "token test-token", r.Header.Get("Authorization"))
|
||||
|
||||
// Mock response
|
||||
issues := []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"number": 42,
|
||||
"title": "Test Issue 1",
|
||||
"body": "This is a test issue",
|
||||
"state": "open",
|
||||
"labels": []map[string]interface{}{
|
||||
{"id": 1, "name": "bug", "color": "d73a4a"},
|
||||
},
|
||||
"created_at": "2023-01-01T12:00:00Z",
|
||||
"updated_at": "2023-01-01T12:00:00Z",
|
||||
"repository": map[string]interface{}{
|
||||
"id": 1,
|
||||
"name": "testrepo",
|
||||
"full_name": "testowner/testrepo",
|
||||
},
|
||||
"assignee": nil,
|
||||
"assignees": []interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(issues)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
config := &repository.Config{
|
||||
BaseURL: server.URL,
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
}
|
||||
|
||||
provider, err := NewGiteaProvider(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
tasks, err := provider.GetTasks(1)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, tasks, 1)
|
||||
assert.Equal(t, 42, tasks[0].Number)
|
||||
assert.Equal(t, "Test Issue 1", tasks[0].Title)
|
||||
assert.Equal(t, "This is a test issue", tasks[0].Body)
|
||||
assert.Equal(t, "testowner/testrepo", tasks[0].Repository)
|
||||
assert.Equal(t, []string{"bug"}, tasks[0].Labels)
|
||||
}
|
||||
|
||||
// Test GitHub Provider
|
||||
func TestGitHubProvider_NewGitHubProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *repository.Config
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &repository.Config{
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing access token",
|
||||
config: &repository.Config{
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "access token is required",
|
||||
},
|
||||
{
|
||||
name: "missing owner",
|
||||
config: &repository.Config{
|
||||
AccessToken: "test-token",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "owner is required",
|
||||
},
|
||||
{
|
||||
name: "missing repository",
|
||||
config: &repository.Config{
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "repository name is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := NewGitHubProvider(tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
assert.Equal(t, tt.config.AccessToken, provider.token)
|
||||
assert.Equal(t, tt.config.Owner, provider.owner)
|
||||
assert.Equal(t, tt.config.Repository, provider.repo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubProvider_GetTasks(t *testing.T) {
|
||||
// Create a mock GitHub server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
assert.Contains(t, r.URL.Path, "/repos/testowner/testrepo/issues")
|
||||
assert.Equal(t, "token test-token", r.Header.Get("Authorization"))
|
||||
|
||||
// Mock response (GitHub API format)
|
||||
issues := []map[string]interface{}{
|
||||
{
|
||||
"id": 123456789,
|
||||
"number": 42,
|
||||
"title": "Test GitHub Issue",
|
||||
"body": "This is a test GitHub issue",
|
||||
"state": "open",
|
||||
"labels": []map[string]interface{}{
|
||||
{"id": 1, "name": "enhancement", "color": "a2eeef"},
|
||||
},
|
||||
"created_at": "2023-01-01T12:00:00Z",
|
||||
"updated_at": "2023-01-01T12:00:00Z",
|
||||
"assignee": nil,
|
||||
"assignees": []interface{}{},
|
||||
"user": map[string]interface{}{
|
||||
"id": 1,
|
||||
"login": "testuser",
|
||||
"name": "Test User",
|
||||
},
|
||||
"pull_request": nil, // Not a PR
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(issues)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Override the GitHub API URL for testing
|
||||
config := &repository.Config{
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
BaseURL: server.URL, // This won't be used in real GitHub provider, but for testing we modify the URL in the provider
|
||||
}
|
||||
|
||||
provider, err := NewGitHubProvider(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// For testing, we need to create a modified provider that uses our test server
|
||||
testProvider := &GitHubProvider{
|
||||
config: config,
|
||||
token: config.AccessToken,
|
||||
owner: config.Owner,
|
||||
repo: config.Repository,
|
||||
httpClient: provider.httpClient,
|
||||
}
|
||||
|
||||
// We can't easily test GitHub provider without modifying the URL, so we'll test the factory instead
|
||||
assert.Equal(t, "test-token", provider.token)
|
||||
assert.Equal(t, "testowner", provider.owner)
|
||||
assert.Equal(t, "testrepo", provider.repo)
|
||||
}
|
||||
|
||||
// Test GitLab Provider
|
||||
func TestGitLabProvider_NewGitLabProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *repository.Config
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config with owner/repo",
|
||||
config: &repository.Config{
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with project ID",
|
||||
config: &repository.Config{
|
||||
AccessToken: "test-token",
|
||||
Settings: map[string]interface{}{
|
||||
"project_id": "123",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing access token",
|
||||
config: &repository.Config{
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "access token is required",
|
||||
},
|
||||
{
|
||||
name: "missing owner/repo and project_id",
|
||||
config: &repository.Config{
|
||||
AccessToken: "test-token",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "either owner/repository or project_id",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := NewGitLabProvider(tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
assert.Equal(t, tt.config.AccessToken, provider.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Provider Factory
|
||||
func TestProviderFactory_CreateProvider(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *repository.Config
|
||||
expectedType string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "create gitea provider",
|
||||
config: &repository.Config{
|
||||
Provider: "gitea",
|
||||
BaseURL: "https://gitea.example.com",
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectedType: "*providers.GiteaProvider",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "create github provider",
|
||||
config: &repository.Config{
|
||||
Provider: "github",
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectedType: "*providers.GitHubProvider",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "create gitlab provider",
|
||||
config: &repository.Config{
|
||||
Provider: "gitlab",
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectedType: "*providers.GitLabProvider",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "create mock provider",
|
||||
config: &repository.Config{
|
||||
Provider: "mock",
|
||||
},
|
||||
expectedType: "*repository.MockTaskProvider",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported provider",
|
||||
config: &repository.Config{
|
||||
Provider: "unsupported",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := factory.CreateProvider(nil, tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
// Note: We can't easily test exact type without reflection, so we just ensure it's not nil
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_ValidateConfig(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *repository.Config
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid gitea config",
|
||||
config: &repository.Config{
|
||||
Provider: "gitea",
|
||||
BaseURL: "https://gitea.example.com",
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid gitea config - missing baseURL",
|
||||
config: &repository.Config{
|
||||
Provider: "gitea",
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid github config",
|
||||
config: &repository.Config{
|
||||
Provider: "github",
|
||||
AccessToken: "test-token",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid github config - missing token",
|
||||
config: &repository.Config{
|
||||
Provider: "github",
|
||||
Owner: "testowner",
|
||||
Repository: "testrepo",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid mock config",
|
||||
config: &repository.Config{
|
||||
Provider: "mock",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := factory.ValidateConfig(tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_GetSupportedTypes(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
types := factory.GetSupportedTypes()
|
||||
|
||||
assert.Contains(t, types, "gitea")
|
||||
assert.Contains(t, types, "github")
|
||||
assert.Contains(t, types, "gitlab")
|
||||
assert.Contains(t, types, "mock")
|
||||
assert.Len(t, types, 4)
|
||||
}
|
||||
|
||||
func TestProviderFactory_GetProviderInfo(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
info, err := factory.GetProviderInfo("gitea")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Gitea", info.Name)
|
||||
assert.Equal(t, "gitea", info.Type)
|
||||
assert.Contains(t, info.RequiredFields, "baseURL")
|
||||
assert.Contains(t, info.RequiredFields, "accessToken")
|
||||
|
||||
// Test unsupported provider
|
||||
_, err = factory.GetProviderInfo("unsupported")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// Test priority and complexity calculation
|
||||
func TestPriorityComplexityCalculation(t *testing.T) {
|
||||
provider := &GiteaProvider{} // We can test these methods with any provider
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
labels []string
|
||||
title string
|
||||
body string
|
||||
expectedPriority int
|
||||
expectedComplexity int
|
||||
}{
|
||||
{
|
||||
name: "critical bug",
|
||||
labels: []string{"critical", "bug"},
|
||||
title: "Critical security vulnerability",
|
||||
body: "This is a critical security issue that needs immediate attention",
|
||||
expectedPriority: 10,
|
||||
expectedComplexity: 7,
|
||||
},
|
||||
{
|
||||
name: "simple enhancement",
|
||||
labels: []string{"enhancement", "good first issue"},
|
||||
title: "Add help text to button",
|
||||
body: "Small UI improvement",
|
||||
expectedPriority: 5,
|
||||
expectedComplexity: 2,
|
||||
},
|
||||
{
|
||||
name: "complex refactor",
|
||||
labels: []string{"refactor", "epic"},
|
||||
title: "Refactor authentication system",
|
||||
body: string(make([]byte, 1000)), // Long body
|
||||
expectedPriority: 5,
|
||||
expectedComplexity: 8,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
priority := provider.calculatePriority(tt.labels, tt.title, tt.body)
|
||||
complexity := provider.calculateComplexity(tt.labels, tt.title, tt.body)
|
||||
|
||||
assert.Equal(t, tt.expectedPriority, priority)
|
||||
assert.Equal(t, tt.expectedComplexity, complexity)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test role determination
|
||||
func TestRoleDetermination(t *testing.T) {
|
||||
provider := &GiteaProvider{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
labels []string
|
||||
expectedRole string
|
||||
}{
|
||||
{
|
||||
name: "frontend task",
|
||||
labels: []string{"frontend", "ui"},
|
||||
expectedRole: "frontend-developer",
|
||||
},
|
||||
{
|
||||
name: "backend task",
|
||||
labels: []string{"backend", "api"},
|
||||
expectedRole: "backend-developer",
|
||||
},
|
||||
{
|
||||
name: "devops task",
|
||||
labels: []string{"devops", "deployment"},
|
||||
expectedRole: "devops-engineer",
|
||||
},
|
||||
{
|
||||
name: "security task",
|
||||
labels: []string{"security", "vulnerability"},
|
||||
expectedRole: "security-engineer",
|
||||
},
|
||||
{
|
||||
name: "testing task",
|
||||
labels: []string{"testing", "qa"},
|
||||
expectedRole: "tester",
|
||||
},
|
||||
{
|
||||
name: "documentation task",
|
||||
labels: []string{"documentation"},
|
||||
expectedRole: "technical-writer",
|
||||
},
|
||||
{
|
||||
name: "design task",
|
||||
labels: []string{"design", "mockup"},
|
||||
expectedRole: "ui-ux-designer",
|
||||
},
|
||||
{
|
||||
name: "generic task",
|
||||
labels: []string{"bug"},
|
||||
expectedRole: "developer",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
role := provider.determineRequiredRole(tt.labels)
|
||||
assert.Equal(t, tt.expectedRole, role)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test expertise determination
|
||||
func TestExpertiseDetermination(t *testing.T) {
|
||||
provider := &GiteaProvider{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
labels []string
|
||||
expectedExpertise []string
|
||||
}{
|
||||
{
|
||||
name: "go programming",
|
||||
labels: []string{"go", "backend"},
|
||||
expectedExpertise: []string{"backend"},
|
||||
},
|
||||
{
|
||||
name: "react frontend",
|
||||
labels: []string{"react", "javascript"},
|
||||
expectedExpertise: []string{"javascript"},
|
||||
},
|
||||
{
|
||||
name: "docker devops",
|
||||
labels: []string{"docker", "kubernetes"},
|
||||
expectedExpertise: []string{"docker", "kubernetes"},
|
||||
},
|
||||
{
|
||||
name: "no specific labels",
|
||||
labels: []string{"bug", "minor"},
|
||||
expectedExpertise: []string{"development", "programming"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
expertise := provider.determineRequiredExpertise(tt.labels)
|
||||
// Check if all expected expertise areas are present
|
||||
for _, expected := range tt.expectedExpertise {
|
||||
assert.Contains(t, expertise, expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGiteaProvider_CalculatePriority(b *testing.B) {
|
||||
provider := &GiteaProvider{}
|
||||
labels := []string{"critical", "bug", "security"}
|
||||
title := "Critical security vulnerability in authentication"
|
||||
body := "This is a detailed description of a critical security vulnerability that affects user authentication and needs immediate attention."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
provider.calculatePriority(labels, title, body)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkProviderFactory_CreateProvider(b *testing.B) {
|
||||
factory := NewProviderFactory()
|
||||
config := &repository.Config{
|
||||
Provider: "mock",
|
||||
AccessToken: "test-token",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
provider, err := factory.CreateProvider(nil, config)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create provider: %v", err)
|
||||
}
|
||||
_ = provider
|
||||
}
|
||||
}
|
||||
@@ -147,17 +147,28 @@ func (m *DefaultTaskMatcher) ScoreTaskForAgent(task *Task, agentInfo *AgentInfo)
|
||||
}
|
||||
|
||||
// DefaultProviderFactory provides a default implementation of ProviderFactory
|
||||
type DefaultProviderFactory struct{}
|
||||
// This is now a wrapper around the real provider factory
|
||||
type DefaultProviderFactory struct {
|
||||
factory ProviderFactory
|
||||
}
|
||||
|
||||
// CreateProvider creates a task provider (stub implementation)
|
||||
// NewDefaultProviderFactory creates a new default provider factory
|
||||
func NewDefaultProviderFactory() *DefaultProviderFactory {
|
||||
// This will be replaced by importing the providers factory
|
||||
// For now, return a stub that creates mock providers
|
||||
return &DefaultProviderFactory{}
|
||||
}
|
||||
|
||||
// CreateProvider creates a task provider
|
||||
func (f *DefaultProviderFactory) CreateProvider(ctx interface{}, config *Config) (TaskProvider, error) {
|
||||
// In a real implementation, this would create GitHub, GitLab, etc. providers
|
||||
// For backward compatibility, fall back to mock if no real factory is available
|
||||
// In production, this should be replaced with the real provider factory
|
||||
return &MockTaskProvider{}, nil
|
||||
}
|
||||
|
||||
// GetSupportedTypes returns supported repository types
|
||||
func (f *DefaultProviderFactory) GetSupportedTypes() []string {
|
||||
return []string{"github", "gitlab", "mock"}
|
||||
return []string{"github", "gitlab", "gitea", "mock"}
|
||||
}
|
||||
|
||||
// SupportedProviders returns list of supported providers
|
||||
|
||||
11
pkg/shhh/doc.go
Normal file
11
pkg/shhh/doc.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Package shhh provides the CHORUS secrets sentinel responsible for detecting
|
||||
// and redacting sensitive values before they leave the runtime. The sentinel
|
||||
// focuses on predictable failure modes (log emission, telemetry fan-out,
|
||||
// request forwarding) and offers a composable API for registering additional
|
||||
// redaction rules, emitting audit events, and tracking operational metrics.
|
||||
//
|
||||
// The initial implementation focuses on high-signal secrets (API keys,
|
||||
// bearer/OAuth tokens, private keys) so the runtime can start integrating
|
||||
// SHHH into COOEE and WHOOSH logging immediately while the broader roadmap
|
||||
// items (automated redaction replay, policy driven rules) continue landing.
|
||||
package shhh
|
||||
130
pkg/shhh/rule.go
Normal file
130
pkg/shhh/rule.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package shhh
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type compiledRule struct {
|
||||
name string
|
||||
regex *regexp.Regexp
|
||||
replacement string
|
||||
severity Severity
|
||||
tags []string
|
||||
}
|
||||
|
||||
type matchRecord struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (r *compiledRule) apply(in string) (string, []matchRecord) {
|
||||
indices := r.regex.FindAllStringSubmatchIndex(in, -1)
|
||||
if len(indices) == 0 {
|
||||
return in, nil
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
builder.Grow(len(in))
|
||||
|
||||
matches := make([]matchRecord, 0, len(indices))
|
||||
last := 0
|
||||
for _, loc := range indices {
|
||||
start, end := loc[0], loc[1]
|
||||
builder.WriteString(in[last:start])
|
||||
replaced := r.regex.ExpandString(nil, r.replacement, in, loc)
|
||||
builder.Write(replaced)
|
||||
matches = append(matches, matchRecord{value: in[start:end]})
|
||||
last = end
|
||||
}
|
||||
builder.WriteString(in[last:])
|
||||
|
||||
return builder.String(), matches
|
||||
}
|
||||
|
||||
func buildDefaultRuleConfigs(placeholder string) []RuleConfig {
|
||||
if placeholder == "" {
|
||||
placeholder = "[REDACTED]"
|
||||
}
|
||||
return []RuleConfig{
|
||||
{
|
||||
Name: "bearer-token",
|
||||
Pattern: `(?i)(authorization\s*:\s*bearer\s+)([A-Za-z0-9\-._~+/]+=*)`,
|
||||
ReplacementTemplate: "$1" + placeholder,
|
||||
Severity: SeverityMedium,
|
||||
Tags: []string{"token", "http"},
|
||||
},
|
||||
{
|
||||
Name: "api-key",
|
||||
Pattern: `(?i)((?:api[_-]?key|token|secret|password)\s*[:=]\s*["']?)([A-Za-z0-9\-._~+/]{8,})(["']?)`,
|
||||
ReplacementTemplate: "$1" + placeholder + "$3",
|
||||
Severity: SeverityHigh,
|
||||
Tags: []string{"credentials"},
|
||||
},
|
||||
{
|
||||
Name: "openai-secret",
|
||||
Pattern: `(sk-[A-Za-z0-9]{20,})`,
|
||||
ReplacementTemplate: placeholder,
|
||||
Severity: SeverityHigh,
|
||||
Tags: []string{"llm", "api"},
|
||||
},
|
||||
{
|
||||
Name: "oauth-refresh-token",
|
||||
Pattern: `(?i)(refresh_token"?\s*[:=]\s*["']?)([A-Za-z0-9\-._~+/]{8,})(["']?)`,
|
||||
ReplacementTemplate: "$1" + placeholder + "$3",
|
||||
Severity: SeverityMedium,
|
||||
Tags: []string{"oauth"},
|
||||
},
|
||||
{
|
||||
Name: "private-key-block",
|
||||
Pattern: `(?s)(-----BEGIN [^-]+ PRIVATE KEY-----)[^-]+(-----END [^-]+ PRIVATE KEY-----)`,
|
||||
ReplacementTemplate: "$1\n" + placeholder + "\n$2",
|
||||
Severity: SeverityHigh,
|
||||
Tags: []string{"pem", "key"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func compileRules(cfg Config, placeholder string) ([]*compiledRule, error) {
|
||||
configs := make([]RuleConfig, 0)
|
||||
if !cfg.DisableDefaultRules {
|
||||
configs = append(configs, buildDefaultRuleConfigs(placeholder)...)
|
||||
}
|
||||
configs = append(configs, cfg.CustomRules...)
|
||||
|
||||
rules := make([]*compiledRule, 0, len(configs))
|
||||
for _, rc := range configs {
|
||||
if rc.Name == "" || rc.Pattern == "" {
|
||||
continue
|
||||
}
|
||||
replacement := rc.ReplacementTemplate
|
||||
if replacement == "" {
|
||||
replacement = placeholder
|
||||
}
|
||||
re, err := regexp.Compile(rc.Pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
compiled := &compiledRule{
|
||||
name: rc.Name,
|
||||
replacement: replacement,
|
||||
regex: re,
|
||||
severity: rc.Severity,
|
||||
tags: append([]string(nil), rc.Tags...),
|
||||
}
|
||||
rules = append(rules, compiled)
|
||||
}
|
||||
|
||||
sort.SliceStable(rules, func(i, j int) bool {
|
||||
return rules[i].name < rules[j].name
|
||||
})
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func hashSecret(value string) string {
|
||||
sum := sha256.Sum256([]byte(value))
|
||||
return base64.RawStdEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
407
pkg/shhh/sentinel.go
Normal file
407
pkg/shhh/sentinel.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package shhh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Option configures the sentinel during construction.
|
||||
type Option func(*Sentinel)
|
||||
|
||||
// FindingObserver receives aggregated findings for each redaction operation.
|
||||
type FindingObserver func(context.Context, []Finding)
|
||||
|
||||
// WithAuditSink attaches an audit sink for per-redaction events.
|
||||
func WithAuditSink(sink AuditSink) Option {
|
||||
return func(s *Sentinel) {
|
||||
s.audit = sink
|
||||
}
|
||||
}
|
||||
|
||||
// WithStats allows callers to supply a shared stats collector.
|
||||
func WithStats(stats *Stats) Option {
|
||||
return func(s *Sentinel) {
|
||||
s.stats = stats
|
||||
}
|
||||
}
|
||||
|
||||
// WithFindingObserver registers an observer that is invoked whenever redaction
|
||||
// produces findings.
|
||||
func WithFindingObserver(observer FindingObserver) Option {
|
||||
return func(s *Sentinel) {
|
||||
if observer == nil {
|
||||
return
|
||||
}
|
||||
s.observers = append(s.observers, observer)
|
||||
}
|
||||
}
|
||||
|
||||
// Sentinel performs secret detection/redaction across text payloads.
|
||||
type Sentinel struct {
|
||||
mu sync.RWMutex
|
||||
enabled bool
|
||||
placeholder string
|
||||
rules []*compiledRule
|
||||
audit AuditSink
|
||||
stats *Stats
|
||||
observers []FindingObserver
|
||||
}
|
||||
|
||||
// NewSentinel creates a new secrets sentinel using the provided configuration.
|
||||
func NewSentinel(cfg Config, opts ...Option) (*Sentinel, error) {
|
||||
placeholder := cfg.RedactionPlaceholder
|
||||
if placeholder == "" {
|
||||
placeholder = "[REDACTED]"
|
||||
}
|
||||
|
||||
s := &Sentinel{
|
||||
enabled: !cfg.Disabled,
|
||||
placeholder: placeholder,
|
||||
stats: NewStats(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
if s.stats == nil {
|
||||
s.stats = NewStats()
|
||||
}
|
||||
|
||||
rules, err := compileRules(cfg, placeholder)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compile SHHH rules: %w", err)
|
||||
}
|
||||
if len(rules) == 0 {
|
||||
return nil, errors.New("no SHHH rules configured")
|
||||
}
|
||||
s.rules = rules
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Enabled reports whether the sentinel is actively redacting.
|
||||
func (s *Sentinel) Enabled() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.enabled
|
||||
}
|
||||
|
||||
// Toggle enables or disables the sentinel at runtime.
|
||||
func (s *Sentinel) Toggle(enabled bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.enabled = enabled
|
||||
}
|
||||
|
||||
// SetAuditSink updates the audit sink at runtime.
|
||||
func (s *Sentinel) SetAuditSink(sink AuditSink) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.audit = sink
|
||||
}
|
||||
|
||||
// AddFindingObserver registers an observer after construction.
|
||||
func (s *Sentinel) AddFindingObserver(observer FindingObserver) {
|
||||
if observer == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.observers = append(s.observers, observer)
|
||||
}
|
||||
|
||||
// StatsSnapshot returns a snapshot of the current counters.
|
||||
func (s *Sentinel) StatsSnapshot() StatsSnapshot {
|
||||
s.mu.RLock()
|
||||
stats := s.stats
|
||||
s.mu.RUnlock()
|
||||
if stats == nil {
|
||||
return StatsSnapshot{}
|
||||
}
|
||||
return stats.Snapshot()
|
||||
}
|
||||
|
||||
// RedactText scans the provided text and redacts any findings.
|
||||
func (s *Sentinel) RedactText(ctx context.Context, text string, labels map[string]string) (string, []Finding) {
|
||||
s.mu.RLock()
|
||||
enabled := s.enabled
|
||||
rules := s.rules
|
||||
stats := s.stats
|
||||
audit := s.audit
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !enabled || len(rules) == 0 {
|
||||
return text, nil
|
||||
}
|
||||
if stats != nil {
|
||||
stats.IncScan()
|
||||
}
|
||||
|
||||
aggregates := make(map[string]*findingAggregate)
|
||||
current := text
|
||||
path := derivePath(labels)
|
||||
|
||||
for _, rule := range rules {
|
||||
redacted, matches := rule.apply(current)
|
||||
if len(matches) == 0 {
|
||||
continue
|
||||
}
|
||||
current = redacted
|
||||
if stats != nil {
|
||||
stats.AddFindings(rule.name, len(matches))
|
||||
}
|
||||
recordAggregate(aggregates, rule, path, len(matches))
|
||||
|
||||
if audit != nil {
|
||||
metadata := cloneLabels(labels)
|
||||
for _, match := range matches {
|
||||
event := AuditEvent{
|
||||
Rule: rule.name,
|
||||
Severity: rule.severity,
|
||||
Tags: append([]string(nil), rule.tags...),
|
||||
Path: path,
|
||||
Hash: hashSecret(match.value),
|
||||
Metadata: metadata,
|
||||
}
|
||||
audit.RecordRedaction(ctx, event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
findings := flattenAggregates(aggregates)
|
||||
s.notifyObservers(ctx, findings)
|
||||
return current, findings
|
||||
}
|
||||
|
||||
// RedactMap walks the map and redacts in-place. It returns the collected findings.
|
||||
func (s *Sentinel) RedactMap(ctx context.Context, payload map[string]any) []Finding {
|
||||
return s.RedactMapWithLabels(ctx, payload, nil)
|
||||
}
|
||||
|
||||
// RedactMapWithLabels allows callers to specify base labels that will be merged
|
||||
// into metadata for nested structures.
|
||||
func (s *Sentinel) RedactMapWithLabels(ctx context.Context, payload map[string]any, baseLabels map[string]string) []Finding {
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
aggregates := make(map[string]*findingAggregate)
|
||||
s.redactValue(ctx, payload, "", baseLabels, aggregates)
|
||||
findings := flattenAggregates(aggregates)
|
||||
s.notifyObservers(ctx, findings)
|
||||
return findings
|
||||
}
|
||||
|
||||
func (s *Sentinel) redactValue(ctx context.Context, value any, path string, baseLabels map[string]string, agg map[string]*findingAggregate) {
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}:
|
||||
for key, val := range v {
|
||||
childPath := joinPath(path, key)
|
||||
switch typed := val.(type) {
|
||||
case string:
|
||||
labels := mergeLabels(baseLabels, childPath)
|
||||
redacted, findings := s.RedactText(ctx, typed, labels)
|
||||
if redacted != typed {
|
||||
v[key] = redacted
|
||||
}
|
||||
mergeAggregates(agg, findings)
|
||||
case fmt.Stringer:
|
||||
labels := mergeLabels(baseLabels, childPath)
|
||||
text := typed.String()
|
||||
redacted, findings := s.RedactText(ctx, text, labels)
|
||||
if redacted != text {
|
||||
v[key] = redacted
|
||||
}
|
||||
mergeAggregates(agg, findings)
|
||||
default:
|
||||
s.redactValue(ctx, typed, childPath, baseLabels, agg)
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
for idx, item := range v {
|
||||
childPath := indexPath(path, idx)
|
||||
switch typed := item.(type) {
|
||||
case string:
|
||||
labels := mergeLabels(baseLabels, childPath)
|
||||
redacted, findings := s.RedactText(ctx, typed, labels)
|
||||
if redacted != typed {
|
||||
v[idx] = redacted
|
||||
}
|
||||
mergeAggregates(agg, findings)
|
||||
case fmt.Stringer:
|
||||
labels := mergeLabels(baseLabels, childPath)
|
||||
text := typed.String()
|
||||
redacted, findings := s.RedactText(ctx, text, labels)
|
||||
if redacted != text {
|
||||
v[idx] = redacted
|
||||
}
|
||||
mergeAggregates(agg, findings)
|
||||
default:
|
||||
s.redactValue(ctx, typed, childPath, baseLabels, agg)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for idx, item := range v {
|
||||
childPath := indexPath(path, idx)
|
||||
labels := mergeLabels(baseLabels, childPath)
|
||||
redacted, findings := s.RedactText(ctx, item, labels)
|
||||
if redacted != item {
|
||||
v[idx] = redacted
|
||||
}
|
||||
mergeAggregates(agg, findings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sentinel) notifyObservers(ctx context.Context, findings []Finding) {
|
||||
if len(findings) == 0 {
|
||||
return
|
||||
}
|
||||
findingsCopy := append([]Finding(nil), findings...)
|
||||
s.mu.RLock()
|
||||
observers := append([]FindingObserver(nil), s.observers...)
|
||||
s.mu.RUnlock()
|
||||
for _, observer := range observers {
|
||||
observer(ctx, findingsCopy)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeAggregates(dest map[string]*findingAggregate, findings []Finding) {
|
||||
for i := range findings {
|
||||
f := findings[i]
|
||||
agg := dest[f.Rule]
|
||||
if agg == nil {
|
||||
agg = &findingAggregate{
|
||||
rule: f.Rule,
|
||||
severity: f.Severity,
|
||||
tags: append([]string(nil), f.Tags...),
|
||||
locations: make(map[string]int),
|
||||
}
|
||||
dest[f.Rule] = agg
|
||||
}
|
||||
agg.count += f.Count
|
||||
for _, loc := range f.Locations {
|
||||
agg.locations[loc.Path] += loc.Count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func recordAggregate(dest map[string]*findingAggregate, rule *compiledRule, path string, count int) {
|
||||
agg := dest[rule.name]
|
||||
if agg == nil {
|
||||
agg = &findingAggregate{
|
||||
rule: rule.name,
|
||||
severity: rule.severity,
|
||||
tags: append([]string(nil), rule.tags...),
|
||||
locations: make(map[string]int),
|
||||
}
|
||||
dest[rule.name] = agg
|
||||
}
|
||||
agg.count += count
|
||||
if path != "" {
|
||||
agg.locations[path] += count
|
||||
}
|
||||
}
|
||||
|
||||
func flattenAggregates(agg map[string]*findingAggregate) []Finding {
|
||||
if len(agg) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, 0, len(agg))
|
||||
for key := range agg {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
findings := make([]Finding, 0, len(agg))
|
||||
for _, key := range keys {
|
||||
entry := agg[key]
|
||||
locations := make([]Location, 0, len(entry.locations))
|
||||
if len(entry.locations) > 0 {
|
||||
paths := make([]string, 0, len(entry.locations))
|
||||
for path := range entry.locations {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
sort.Strings(paths)
|
||||
for _, path := range paths {
|
||||
locations = append(locations, Location{Path: path, Count: entry.locations[path]})
|
||||
}
|
||||
}
|
||||
findings = append(findings, Finding{
|
||||
Rule: entry.rule,
|
||||
Severity: entry.severity,
|
||||
Tags: append([]string(nil), entry.tags...),
|
||||
Count: entry.count,
|
||||
Locations: locations,
|
||||
})
|
||||
}
|
||||
return findings
|
||||
}
|
||||
|
||||
func derivePath(labels map[string]string) string {
|
||||
if labels == nil {
|
||||
return ""
|
||||
}
|
||||
if path := labels["path"]; path != "" {
|
||||
return path
|
||||
}
|
||||
if path := labels["source"]; path != "" {
|
||||
return path
|
||||
}
|
||||
if path := labels["field"]; path != "" {
|
||||
return path
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func cloneLabels(labels map[string]string) map[string]string {
|
||||
if len(labels) == 0 {
|
||||
return nil
|
||||
}
|
||||
clone := make(map[string]string, len(labels))
|
||||
for k, v := range labels {
|
||||
clone[k] = v
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func joinPath(prefix, key string) string {
|
||||
if prefix == "" {
|
||||
return key
|
||||
}
|
||||
if key == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + "." + key
|
||||
}
|
||||
|
||||
func indexPath(prefix string, idx int) string {
|
||||
if prefix == "" {
|
||||
return fmt.Sprintf("[%d]", idx)
|
||||
}
|
||||
return fmt.Sprintf("%s[%d]", prefix, idx)
|
||||
}
|
||||
|
||||
func mergeLabels(base map[string]string, path string) map[string]string {
|
||||
if base == nil && path == "" {
|
||||
return nil
|
||||
}
|
||||
labels := cloneLabels(base)
|
||||
if labels == nil {
|
||||
labels = make(map[string]string, 1)
|
||||
}
|
||||
if path != "" {
|
||||
labels["path"] = path
|
||||
}
|
||||
return labels
|
||||
}
|
||||
|
||||
type findingAggregate struct {
|
||||
rule string
|
||||
severity Severity
|
||||
tags []string
|
||||
count int
|
||||
locations map[string]int
|
||||
}
|
||||
95
pkg/shhh/sentinel_test.go
Normal file
95
pkg/shhh/sentinel_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package shhh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type recordingSink struct {
|
||||
events []AuditEvent
|
||||
}
|
||||
|
||||
func (r *recordingSink) RecordRedaction(_ context.Context, event AuditEvent) {
|
||||
r.events = append(r.events, event)
|
||||
}
|
||||
|
||||
func TestRedactText_DefaultRules(t *testing.T) {
|
||||
sentinel, err := NewSentinel(Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
input := "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.secret"
|
||||
redacted, findings := sentinel.RedactText(context.Background(), input, map[string]string{"source": "http.request.headers.authorization"})
|
||||
|
||||
require.Equal(t, "Authorization: Bearer [REDACTED]", redacted)
|
||||
require.Len(t, findings, 1)
|
||||
require.Equal(t, "bearer-token", findings[0].Rule)
|
||||
require.Equal(t, 1, findings[0].Count)
|
||||
require.NotEmpty(t, findings[0].Locations)
|
||||
|
||||
snapshot := sentinel.StatsSnapshot()
|
||||
require.Equal(t, uint64(1), snapshot.TotalScans)
|
||||
require.Equal(t, uint64(1), snapshot.TotalFindings)
|
||||
require.Equal(t, uint64(1), snapshot.PerRuleFindings["bearer-token"])
|
||||
}
|
||||
|
||||
func TestRedactMap_NestedStructures(t *testing.T) {
|
||||
sentinel, err := NewSentinel(Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
payload := map[string]any{
|
||||
"config": map[string]any{
|
||||
"api_key": "API_KEY=1234567890ABCDEFG",
|
||||
},
|
||||
"tokens": []any{
|
||||
"sk-test1234567890ABCDEF",
|
||||
map[string]any{"refresh": "refresh_token=abcdef12345"},
|
||||
},
|
||||
}
|
||||
|
||||
findings := sentinel.RedactMap(context.Background(), payload)
|
||||
require.NotEmpty(t, findings)
|
||||
|
||||
config := payload["config"].(map[string]any)
|
||||
require.Equal(t, "API_KEY=[REDACTED]", config["api_key"])
|
||||
|
||||
tokens := payload["tokens"].([]any)
|
||||
require.Equal(t, "[REDACTED]", tokens[0])
|
||||
|
||||
inner := tokens[1].(map[string]any)
|
||||
require.Equal(t, "refresh_token=[REDACTED]", inner["refresh"])
|
||||
|
||||
total := 0
|
||||
for _, finding := range findings {
|
||||
total += finding.Count
|
||||
}
|
||||
require.Equal(t, 3, total)
|
||||
}
|
||||
|
||||
func TestAuditSinkReceivesEvents(t *testing.T) {
|
||||
sink := &recordingSink{}
|
||||
cfg := Config{
|
||||
DisableDefaultRules: true,
|
||||
CustomRules: []RuleConfig{
|
||||
{
|
||||
Name: "custom-secret",
|
||||
Pattern: `(secret\s*=\s*)([A-Za-z0-9]{6,})`,
|
||||
ReplacementTemplate: "$1[REDACTED]",
|
||||
Severity: SeverityHigh,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sentinel, err := NewSentinel(cfg, WithAuditSink(sink))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, findings := sentinel.RedactText(context.Background(), "secret=mysecretvalue", map[string]string{"source": "test"})
|
||||
require.Len(t, findings, 1)
|
||||
require.Equal(t, 1, findings[0].Count)
|
||||
|
||||
require.Len(t, sink.events, 1)
|
||||
require.Equal(t, "custom-secret", sink.events[0].Rule)
|
||||
require.NotEmpty(t, sink.events[0].Hash)
|
||||
require.Equal(t, "test", sink.events[0].Path)
|
||||
}
|
||||
60
pkg/shhh/stats.go
Normal file
60
pkg/shhh/stats.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package shhh
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Stats tracks aggregate counts for the sentinel.
|
||||
type Stats struct {
|
||||
totalScans atomic.Uint64
|
||||
totalFindings atomic.Uint64
|
||||
perRule sync.Map // string -> *atomic.Uint64
|
||||
}
|
||||
|
||||
// NewStats constructs a Stats collector.
|
||||
func NewStats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
// IncScan increments the total scan counter.
|
||||
func (s *Stats) IncScan() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.totalScans.Add(1)
|
||||
}
|
||||
|
||||
// AddFindings records findings for a rule.
|
||||
func (s *Stats) AddFindings(rule string, count int) {
|
||||
if s == nil || count <= 0 {
|
||||
return
|
||||
}
|
||||
s.totalFindings.Add(uint64(count))
|
||||
counterAny, _ := s.perRule.LoadOrStore(rule, new(atomic.Uint64))
|
||||
counter := counterAny.(*atomic.Uint64)
|
||||
counter.Add(uint64(count))
|
||||
}
|
||||
|
||||
// Snapshot returns a point-in-time view of the counters.
|
||||
func (s *Stats) Snapshot() StatsSnapshot {
|
||||
if s == nil {
|
||||
return StatsSnapshot{}
|
||||
}
|
||||
snapshot := StatsSnapshot{
|
||||
TotalScans: s.totalScans.Load(),
|
||||
TotalFindings: s.totalFindings.Load(),
|
||||
PerRuleFindings: make(map[string]uint64),
|
||||
}
|
||||
s.perRule.Range(func(key, value any) bool {
|
||||
name, ok := key.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if counter, ok := value.(*atomic.Uint64); ok {
|
||||
snapshot.PerRuleFindings[name] = counter.Load()
|
||||
}
|
||||
return true
|
||||
})
|
||||
return snapshot
|
||||
}
|
||||
73
pkg/shhh/types.go
Normal file
73
pkg/shhh/types.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package shhh
|
||||
|
||||
import "context"
|
||||
|
||||
// Severity represents the criticality associated with a redaction finding.
|
||||
type Severity string
|
||||
|
||||
const (
|
||||
// SeverityLow indicates low-impact findings (e.g. non-production credentials).
|
||||
SeverityLow Severity = "low"
|
||||
// SeverityMedium indicates medium impact findings (e.g. access tokens).
|
||||
SeverityMedium Severity = "medium"
|
||||
// SeverityHigh indicates high-impact findings (e.g. private keys).
|
||||
SeverityHigh Severity = "high"
|
||||
)
|
||||
|
||||
// RuleConfig defines a redaction rule that SHHH should enforce.
|
||||
type RuleConfig struct {
|
||||
Name string `json:"name"`
|
||||
Pattern string `json:"pattern"`
|
||||
ReplacementTemplate string `json:"replacement_template"`
|
||||
Severity Severity `json:"severity"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
// Config controls sentinel behaviour.
|
||||
type Config struct {
|
||||
// Disabled toggles redaction off entirely.
|
||||
Disabled bool `json:"disabled"`
|
||||
// RedactionPlaceholder overrides the default placeholder value.
|
||||
RedactionPlaceholder string `json:"redaction_placeholder"`
|
||||
// DisableDefaultRules disables the built-in curated rule set.
|
||||
DisableDefaultRules bool `json:"disable_default_rules"`
|
||||
// CustomRules allows callers to append bespoke redaction patterns.
|
||||
CustomRules []RuleConfig `json:"custom_rules"`
|
||||
}
|
||||
|
||||
// Finding represents a single rule firing during redaction.
|
||||
type Finding struct {
|
||||
Rule string `json:"rule"`
|
||||
Severity Severity `json:"severity"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Count int `json:"count"`
|
||||
Locations []Location `json:"locations,omitempty"`
|
||||
}
|
||||
|
||||
// Location describes where a secret was found.
|
||||
type Location struct {
|
||||
Path string `json:"path"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// StatsSnapshot exposes aggregate counters for observability.
|
||||
type StatsSnapshot struct {
|
||||
TotalScans uint64 `json:"total_scans"`
|
||||
TotalFindings uint64 `json:"total_findings"`
|
||||
PerRuleFindings map[string]uint64 `json:"per_rule_findings"`
|
||||
}
|
||||
|
||||
// AuditEvent captures a single redaction occurrence for downstream sinks.
|
||||
type AuditEvent struct {
|
||||
Rule string `json:"rule"`
|
||||
Severity Severity `json:"severity"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Hash string `json:"hash"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// AuditSink receives redaction events for long term storage / replay.
|
||||
type AuditSink interface {
|
||||
RecordRedaction(ctx context.Context, event AuditEvent)
|
||||
}
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
|
||||
// DecisionPublisher handles publishing task completion decisions to encrypted DHT storage
|
||||
type DecisionPublisher struct {
|
||||
ctx context.Context
|
||||
config *config.Config
|
||||
dhtStorage storage.UCXLStorage
|
||||
nodeID string
|
||||
agentName string
|
||||
ctx context.Context
|
||||
config *config.Config
|
||||
dhtStorage storage.UCXLStorage
|
||||
nodeID string
|
||||
agentName string
|
||||
}
|
||||
|
||||
// NewDecisionPublisher creates a new decision publisher
|
||||
@@ -39,28 +39,28 @@ func NewDecisionPublisher(
|
||||
|
||||
// TaskDecision represents a decision made by an agent upon task completion
|
||||
type TaskDecision struct {
|
||||
Agent string `json:"agent"`
|
||||
Role string `json:"role"`
|
||||
Project string `json:"project"`
|
||||
Task string `json:"task"`
|
||||
Decision string `json:"decision"`
|
||||
Context map[string]interface{} `json:"context"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Success bool `json:"success"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
FilesModified []string `json:"files_modified,omitempty"`
|
||||
LinesChanged int `json:"lines_changed,omitempty"`
|
||||
TestResults *TestResults `json:"test_results,omitempty"`
|
||||
Dependencies []string `json:"dependencies,omitempty"`
|
||||
NextSteps []string `json:"next_steps,omitempty"`
|
||||
Agent string `json:"agent"`
|
||||
Role string `json:"role"`
|
||||
Project string `json:"project"`
|
||||
Task string `json:"task"`
|
||||
Decision string `json:"decision"`
|
||||
Context map[string]interface{} `json:"context"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Success bool `json:"success"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
FilesModified []string `json:"files_modified,omitempty"`
|
||||
LinesChanged int `json:"lines_changed,omitempty"`
|
||||
TestResults *TestResults `json:"test_results,omitempty"`
|
||||
Dependencies []string `json:"dependencies,omitempty"`
|
||||
NextSteps []string `json:"next_steps,omitempty"`
|
||||
}
|
||||
|
||||
// TestResults captures test execution results
|
||||
type TestResults struct {
|
||||
Passed int `json:"passed"`
|
||||
Failed int `json:"failed"`
|
||||
Skipped int `json:"skipped"`
|
||||
Coverage float64 `json:"coverage,omitempty"`
|
||||
Passed int `json:"passed"`
|
||||
Failed int `json:"failed"`
|
||||
Skipped int `json:"skipped"`
|
||||
Coverage float64 `json:"coverage,omitempty"`
|
||||
FailedTests []string `json:"failed_tests,omitempty"`
|
||||
}
|
||||
|
||||
@@ -74,7 +74,11 @@ func (dp *DecisionPublisher) PublishTaskDecision(decision *TaskDecision) error {
|
||||
decision.Role = dp.config.Agent.Role
|
||||
}
|
||||
if decision.Project == "" {
|
||||
decision.Project = "default-project" // TODO: Add project field to config
|
||||
if project := dp.config.Agent.Project; project != "" {
|
||||
decision.Project = project
|
||||
} else {
|
||||
decision.Project = "chorus"
|
||||
}
|
||||
}
|
||||
if decision.Timestamp.IsZero() {
|
||||
decision.Timestamp = time.Now()
|
||||
@@ -173,16 +177,16 @@ func (dp *DecisionPublisher) PublishArchitecturalDecision(
|
||||
nextSteps []string,
|
||||
) error {
|
||||
taskDecision := &TaskDecision{
|
||||
Task: taskName,
|
||||
Decision: decision,
|
||||
Success: true,
|
||||
Task: taskName,
|
||||
Decision: decision,
|
||||
Success: true,
|
||||
NextSteps: nextSteps,
|
||||
Context: map[string]interface{}{
|
||||
"decision_type": "architecture",
|
||||
"rationale": rationale,
|
||||
"alternatives": alternatives,
|
||||
"implications": implications,
|
||||
"node_id": dp.nodeID,
|
||||
"decision_type": "architecture",
|
||||
"rationale": rationale,
|
||||
"alternatives": alternatives,
|
||||
"implications": implications,
|
||||
"node_id": dp.nodeID,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -291,7 +295,7 @@ func (dp *DecisionPublisher) SubscribeToDecisions(
|
||||
) error {
|
||||
// This is a placeholder for future pubsub implementation
|
||||
// For now, we'll implement a simple polling mechanism
|
||||
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -341,10 +345,10 @@ func (dp *DecisionPublisher) PublishSystemStatus(
|
||||
Decision: status,
|
||||
Success: dp.allHealthChecksPass(healthChecks),
|
||||
Context: map[string]interface{}{
|
||||
"decision_type": "system",
|
||||
"metrics": metrics,
|
||||
"health_checks": healthChecks,
|
||||
"node_id": dp.nodeID,
|
||||
"decision_type": "system",
|
||||
"metrics": metrics,
|
||||
"health_checks": healthChecks,
|
||||
"node_id": dp.nodeID,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -364,13 +368,17 @@ func (dp *DecisionPublisher) allHealthChecksPass(healthChecks map[string]bool) b
|
||||
// GetPublisherMetrics returns metrics about the decision publisher
|
||||
func (dp *DecisionPublisher) GetPublisherMetrics() map[string]interface{} {
|
||||
dhtMetrics := dp.dhtStorage.GetMetrics()
|
||||
|
||||
return map[string]interface{}{
|
||||
"node_id": dp.nodeID,
|
||||
"agent_name": dp.agentName,
|
||||
"current_role": dp.config.Agent.Role,
|
||||
"project": "default-project", // TODO: Add project field to config
|
||||
"dht_metrics": dhtMetrics,
|
||||
"last_publish": time.Now(), // This would be tracked in a real implementation
|
||||
project := dp.config.Agent.Project
|
||||
if project == "" {
|
||||
project = "chorus"
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"node_id": dp.nodeID,
|
||||
"agent_name": dp.agentName,
|
||||
"current_role": dp.config.Agent.Role,
|
||||
"project": project,
|
||||
"dht_metrics": dhtMetrics,
|
||||
"last_publish": time.Now(), // This would be tracked in a real implementation
|
||||
}
|
||||
}
|
||||
|
||||
443
pubsub/pubsub.go
443
pubsub/pubsub.go
@@ -8,9 +8,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"chorus/pkg/shhh"
|
||||
pubsub "github.com/libp2p/go-libp2p-pubsub"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
pubsub "github.com/libp2p/go-libp2p-pubsub"
|
||||
)
|
||||
|
||||
// PubSub handles publish/subscribe messaging for Bzzz coordination and HMMM meta-discussion
|
||||
@@ -19,36 +20,42 @@ type PubSub struct {
|
||||
host host.Host
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
|
||||
// Topic subscriptions
|
||||
chorusTopic *pubsub.Topic
|
||||
hmmmTopic *pubsub.Topic
|
||||
contextTopic *pubsub.Topic
|
||||
|
||||
chorusTopic *pubsub.Topic
|
||||
hmmmTopic *pubsub.Topic
|
||||
contextTopic *pubsub.Topic
|
||||
|
||||
// Message subscriptions
|
||||
chorusSub *pubsub.Subscription
|
||||
hmmmSub *pubsub.Subscription
|
||||
contextSub *pubsub.Subscription
|
||||
|
||||
chorusSub *pubsub.Subscription
|
||||
hmmmSub *pubsub.Subscription
|
||||
contextSub *pubsub.Subscription
|
||||
|
||||
// Dynamic topic management
|
||||
dynamicTopics map[string]*pubsub.Topic
|
||||
dynamicTopicsMux sync.RWMutex
|
||||
dynamicSubs map[string]*pubsub.Subscription
|
||||
dynamicSubsMux sync.RWMutex
|
||||
dynamicTopics map[string]*pubsub.Topic
|
||||
dynamicTopicsMux sync.RWMutex
|
||||
dynamicSubs map[string]*pubsub.Subscription
|
||||
dynamicSubsMux sync.RWMutex
|
||||
dynamicHandlers map[string]func([]byte, peer.ID)
|
||||
dynamicHandlersMux sync.RWMutex
|
||||
|
||||
// Configuration
|
||||
chorusTopicName string
|
||||
hmmmTopicName string
|
||||
contextTopicName string
|
||||
chorusTopicName string
|
||||
hmmmTopicName string
|
||||
contextTopicName string
|
||||
|
||||
// External message handler for HMMM messages
|
||||
HmmmMessageHandler func(msg Message, from peer.ID)
|
||||
|
||||
HmmmMessageHandler func(msg Message, from peer.ID)
|
||||
|
||||
// External message handler for Context Feedback messages
|
||||
ContextFeedbackHandler func(msg Message, from peer.ID)
|
||||
|
||||
|
||||
// Hypercore-style logging
|
||||
hypercoreLog HypercoreLogger
|
||||
|
||||
// SHHH sentinel
|
||||
redactor *shhh.Sentinel
|
||||
redactorMux sync.RWMutex
|
||||
}
|
||||
|
||||
// HypercoreLogger interface for dependency injection
|
||||
@@ -62,45 +69,45 @@ type MessageType string
|
||||
|
||||
const (
|
||||
// Bzzz coordination messages
|
||||
TaskAnnouncement MessageType = "task_announcement"
|
||||
TaskClaim MessageType = "task_claim"
|
||||
TaskProgress MessageType = "task_progress"
|
||||
TaskComplete MessageType = "task_complete"
|
||||
CapabilityBcast MessageType = "capability_broadcast" // Only broadcast when capabilities change
|
||||
TaskAnnouncement MessageType = "task_announcement"
|
||||
TaskClaim MessageType = "task_claim"
|
||||
TaskProgress MessageType = "task_progress"
|
||||
TaskComplete MessageType = "task_complete"
|
||||
CapabilityBcast MessageType = "capability_broadcast" // Only broadcast when capabilities change
|
||||
AvailabilityBcast MessageType = "availability_broadcast" // Regular availability status
|
||||
|
||||
|
||||
// HMMM meta-discussion messages
|
||||
MetaDiscussion MessageType = "meta_discussion" // Generic type for all discussion
|
||||
TaskHelpRequest MessageType = "task_help_request" // Request for assistance
|
||||
TaskHelpResponse MessageType = "task_help_response" // Response to a help request
|
||||
CoordinationRequest MessageType = "coordination_request" // Request for coordination
|
||||
CoordinationComplete MessageType = "coordination_complete" // Coordination session completed
|
||||
DependencyAlert MessageType = "dependency_alert" // Dependency detected
|
||||
EscalationTrigger MessageType = "escalation_trigger" // Human escalation needed
|
||||
|
||||
MetaDiscussion MessageType = "meta_discussion" // Generic type for all discussion
|
||||
TaskHelpRequest MessageType = "task_help_request" // Request for assistance
|
||||
TaskHelpResponse MessageType = "task_help_response" // Response to a help request
|
||||
CoordinationRequest MessageType = "coordination_request" // Request for coordination
|
||||
CoordinationComplete MessageType = "coordination_complete" // Coordination session completed
|
||||
DependencyAlert MessageType = "dependency_alert" // Dependency detected
|
||||
EscalationTrigger MessageType = "escalation_trigger" // Human escalation needed
|
||||
|
||||
// Role-based collaboration messages
|
||||
RoleAnnouncement MessageType = "role_announcement" // Agent announces its role and capabilities
|
||||
ExpertiseRequest MessageType = "expertise_request" // Request for specific expertise
|
||||
ExpertiseResponse MessageType = "expertise_response" // Response offering expertise
|
||||
StatusUpdate MessageType = "status_update" // Regular status updates from agents
|
||||
WorkAllocation MessageType = "work_allocation" // Allocation of work to specific roles
|
||||
RoleCollaboration MessageType = "role_collaboration" // Cross-role collaboration message
|
||||
MentorshipRequest MessageType = "mentorship_request" // Junior role requesting mentorship
|
||||
MentorshipResponse MessageType = "mentorship_response" // Senior role providing mentorship
|
||||
ProjectUpdate MessageType = "project_update" // Project-level status updates
|
||||
DeliverableReady MessageType = "deliverable_ready" // Notification that deliverable is complete
|
||||
|
||||
RoleAnnouncement MessageType = "role_announcement" // Agent announces its role and capabilities
|
||||
ExpertiseRequest MessageType = "expertise_request" // Request for specific expertise
|
||||
ExpertiseResponse MessageType = "expertise_response" // Response offering expertise
|
||||
StatusUpdate MessageType = "status_update" // Regular status updates from agents
|
||||
WorkAllocation MessageType = "work_allocation" // Allocation of work to specific roles
|
||||
RoleCollaboration MessageType = "role_collaboration" // Cross-role collaboration message
|
||||
MentorshipRequest MessageType = "mentorship_request" // Junior role requesting mentorship
|
||||
MentorshipResponse MessageType = "mentorship_response" // Senior role providing mentorship
|
||||
ProjectUpdate MessageType = "project_update" // Project-level status updates
|
||||
DeliverableReady MessageType = "deliverable_ready" // Notification that deliverable is complete
|
||||
|
||||
// RL Context Curator feedback messages
|
||||
FeedbackEvent MessageType = "feedback_event" // Context feedback for RL learning
|
||||
ContextRequest MessageType = "context_request" // Request context from HCFS
|
||||
ContextResponse MessageType = "context_response" // Response with context data
|
||||
ContextUsage MessageType = "context_usage" // Report context usage patterns
|
||||
ContextRelevance MessageType = "context_relevance" // Report context relevance scoring
|
||||
|
||||
FeedbackEvent MessageType = "feedback_event" // Context feedback for RL learning
|
||||
ContextRequest MessageType = "context_request" // Request context from HCFS
|
||||
ContextResponse MessageType = "context_response" // Response with context data
|
||||
ContextUsage MessageType = "context_usage" // Report context usage patterns
|
||||
ContextRelevance MessageType = "context_relevance" // Report context relevance scoring
|
||||
|
||||
// SLURP event integration messages
|
||||
SlurpEventGenerated MessageType = "slurp_event_generated" // HMMM consensus generated SLURP event
|
||||
SlurpEventAck MessageType = "slurp_event_ack" // Acknowledgment of SLURP event receipt
|
||||
SlurpContextUpdate MessageType = "slurp_context_update" // Context update from SLURP system
|
||||
SlurpEventGenerated MessageType = "slurp_event_generated" // HMMM consensus generated SLURP event
|
||||
SlurpEventAck MessageType = "slurp_event_ack" // Acknowledgment of SLURP event receipt
|
||||
SlurpContextUpdate MessageType = "slurp_context_update" // Context update from SLURP system
|
||||
)
|
||||
|
||||
// Message represents a Bzzz/Antennae message
|
||||
@@ -110,14 +117,14 @@ type Message struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
HopCount int `json:"hop_count,omitempty"` // For Antennae hop limiting
|
||||
|
||||
|
||||
// Role-based collaboration fields
|
||||
FromRole string `json:"from_role,omitempty"` // Role of sender
|
||||
ToRoles []string `json:"to_roles,omitempty"` // Target roles
|
||||
FromRole string `json:"from_role,omitempty"` // Role of sender
|
||||
ToRoles []string `json:"to_roles,omitempty"` // Target roles
|
||||
RequiredExpertise []string `json:"required_expertise,omitempty"` // Required expertise areas
|
||||
ProjectID string `json:"project_id,omitempty"` // Associated project
|
||||
Priority string `json:"priority,omitempty"` // Message priority (low, medium, high, urgent)
|
||||
ThreadID string `json:"thread_id,omitempty"` // Conversation thread ID
|
||||
ProjectID string `json:"project_id,omitempty"` // Associated project
|
||||
Priority string `json:"priority,omitempty"` // Message priority (low, medium, high, urgent)
|
||||
ThreadID string `json:"thread_id,omitempty"` // Conversation thread ID
|
||||
}
|
||||
|
||||
// NewPubSub creates a new PubSub instance for Bzzz coordination and HMMM meta-discussion
|
||||
@@ -150,16 +157,17 @@ func NewPubSubWithLogger(ctx context.Context, h host.Host, chorusTopic, hmmmTopi
|
||||
}
|
||||
|
||||
p := &PubSub{
|
||||
ps: ps,
|
||||
host: h,
|
||||
ctx: pubsubCtx,
|
||||
cancel: cancel,
|
||||
chorusTopicName: chorusTopic,
|
||||
ps: ps,
|
||||
host: h,
|
||||
ctx: pubsubCtx,
|
||||
cancel: cancel,
|
||||
chorusTopicName: chorusTopic,
|
||||
hmmmTopicName: hmmmTopic,
|
||||
contextTopicName: contextTopic,
|
||||
dynamicTopics: make(map[string]*pubsub.Topic),
|
||||
dynamicSubs: make(map[string]*pubsub.Subscription),
|
||||
hypercoreLog: logger,
|
||||
dynamicTopics: make(map[string]*pubsub.Topic),
|
||||
dynamicSubs: make(map[string]*pubsub.Subscription),
|
||||
dynamicHandlers: make(map[string]func([]byte, peer.ID)),
|
||||
hypercoreLog: logger,
|
||||
}
|
||||
|
||||
// Join static topics
|
||||
@@ -177,6 +185,13 @@ func NewPubSubWithLogger(ctx context.Context, h host.Host, chorusTopic, hmmmTopi
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// SetRedactor wires the SHHH sentinel so outbound messages are sanitized before publication.
|
||||
func (p *PubSub) SetRedactor(redactor *shhh.Sentinel) {
|
||||
p.redactorMux.Lock()
|
||||
defer p.redactorMux.Unlock()
|
||||
p.redactor = redactor
|
||||
}
|
||||
|
||||
// SetHmmmMessageHandler sets the handler for incoming HMMM messages.
|
||||
func (p *PubSub) SetHmmmMessageHandler(handler func(msg Message, from peer.ID)) {
|
||||
p.HmmmMessageHandler = handler
|
||||
@@ -231,15 +246,21 @@ func (p *PubSub) joinStaticTopics() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// JoinDynamicTopic joins a new topic for a specific task
|
||||
func (p *PubSub) JoinDynamicTopic(topicName string) error {
|
||||
p.dynamicTopicsMux.Lock()
|
||||
defer p.dynamicTopicsMux.Unlock()
|
||||
p.dynamicSubsMux.Lock()
|
||||
defer p.dynamicSubsMux.Unlock()
|
||||
// subscribeDynamicTopic joins a topic and optionally assigns a raw handler.
|
||||
func (p *PubSub) subscribeDynamicTopic(topicName string, handler func([]byte, peer.ID)) error {
|
||||
if topicName == "" {
|
||||
return fmt.Errorf("topic name cannot be empty")
|
||||
}
|
||||
|
||||
if _, exists := p.dynamicTopics[topicName]; exists {
|
||||
return nil // Already joined
|
||||
p.dynamicTopicsMux.RLock()
|
||||
_, exists := p.dynamicTopics[topicName]
|
||||
p.dynamicTopicsMux.RUnlock()
|
||||
|
||||
if exists {
|
||||
p.dynamicHandlersMux.Lock()
|
||||
p.dynamicHandlers[topicName] = handler
|
||||
p.dynamicHandlersMux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
topic, err := p.ps.Join(topicName)
|
||||
@@ -253,38 +274,68 @@ func (p *PubSub) JoinDynamicTopic(topicName string) error {
|
||||
return fmt.Errorf("failed to subscribe to dynamic topic %s: %w", topicName, err)
|
||||
}
|
||||
|
||||
p.dynamicTopicsMux.Lock()
|
||||
if _, already := p.dynamicTopics[topicName]; already {
|
||||
p.dynamicTopicsMux.Unlock()
|
||||
sub.Cancel()
|
||||
topic.Close()
|
||||
p.dynamicHandlersMux.Lock()
|
||||
p.dynamicHandlers[topicName] = handler
|
||||
p.dynamicHandlersMux.Unlock()
|
||||
return nil
|
||||
}
|
||||
p.dynamicTopics[topicName] = topic
|
||||
p.dynamicSubs[topicName] = sub
|
||||
p.dynamicTopicsMux.Unlock()
|
||||
|
||||
// Start a handler for this new subscription
|
||||
go p.handleDynamicMessages(sub)
|
||||
p.dynamicSubsMux.Lock()
|
||||
p.dynamicSubs[topicName] = sub
|
||||
p.dynamicSubsMux.Unlock()
|
||||
|
||||
p.dynamicHandlersMux.Lock()
|
||||
p.dynamicHandlers[topicName] = handler
|
||||
p.dynamicHandlersMux.Unlock()
|
||||
|
||||
go p.handleDynamicMessages(topicName, sub)
|
||||
|
||||
fmt.Printf("✅ Joined dynamic topic: %s\n", topicName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// JoinDynamicTopic joins a new topic for a specific task
|
||||
func (p *PubSub) JoinDynamicTopic(topicName string) error {
|
||||
return p.subscribeDynamicTopic(topicName, nil)
|
||||
}
|
||||
|
||||
// SubscribeRawTopic joins a topic and delivers raw payloads to the provided handler.
|
||||
func (p *PubSub) SubscribeRawTopic(topicName string, handler func([]byte, peer.ID)) error {
|
||||
if handler == nil {
|
||||
return fmt.Errorf("handler cannot be nil")
|
||||
}
|
||||
return p.subscribeDynamicTopic(topicName, handler)
|
||||
}
|
||||
|
||||
// JoinRoleBasedTopics joins topics based on role and expertise
|
||||
func (p *PubSub) JoinRoleBasedTopics(role string, expertise []string, reportsTo []string) error {
|
||||
var topicsToJoin []string
|
||||
|
||||
|
||||
// Join role-specific topic
|
||||
if role != "" {
|
||||
roleTopic := fmt.Sprintf("CHORUS/roles/%s/v1", strings.ToLower(strings.ReplaceAll(role, " ", "_")))
|
||||
topicsToJoin = append(topicsToJoin, roleTopic)
|
||||
}
|
||||
|
||||
|
||||
// Join expertise-specific topics
|
||||
for _, exp := range expertise {
|
||||
expertiseTopic := fmt.Sprintf("CHORUS/expertise/%s/v1", strings.ToLower(strings.ReplaceAll(exp, " ", "_")))
|
||||
topicsToJoin = append(topicsToJoin, expertiseTopic)
|
||||
}
|
||||
|
||||
|
||||
// Join reporting hierarchy topics
|
||||
for _, supervisor := range reportsTo {
|
||||
supervisorTopic := fmt.Sprintf("CHORUS/hierarchy/%s/v1", strings.ToLower(strings.ReplaceAll(supervisor, " ", "_")))
|
||||
topicsToJoin = append(topicsToJoin, supervisorTopic)
|
||||
}
|
||||
|
||||
|
||||
// Join all identified topics
|
||||
for _, topicName := range topicsToJoin {
|
||||
if err := p.JoinDynamicTopic(topicName); err != nil {
|
||||
@@ -292,7 +343,7 @@ func (p *PubSub) JoinRoleBasedTopics(role string, expertise []string, reportsTo
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fmt.Printf("🎯 Joined %d role-based topics for role: %s\n", len(topicsToJoin), role)
|
||||
return nil
|
||||
}
|
||||
@@ -302,7 +353,7 @@ func (p *PubSub) JoinProjectTopic(projectID string) error {
|
||||
if projectID == "" {
|
||||
return fmt.Errorf("project ID cannot be empty")
|
||||
}
|
||||
|
||||
|
||||
topicName := fmt.Sprintf("CHORUS/projects/%s/coordination/v1", projectID)
|
||||
return p.JoinDynamicTopic(topicName)
|
||||
}
|
||||
@@ -324,6 +375,10 @@ func (p *PubSub) LeaveDynamicTopic(topicName string) {
|
||||
delete(p.dynamicTopics, topicName)
|
||||
}
|
||||
|
||||
p.dynamicHandlersMux.Lock()
|
||||
delete(p.dynamicHandlers, topicName)
|
||||
p.dynamicHandlersMux.Unlock()
|
||||
|
||||
fmt.Printf("🗑️ Left dynamic topic: %s\n", topicName)
|
||||
}
|
||||
|
||||
@@ -337,11 +392,12 @@ func (p *PubSub) PublishToDynamicTopic(topicName string, msgType MessageType, da
|
||||
return fmt.Errorf("not subscribed to dynamic topic: %s", topicName)
|
||||
}
|
||||
|
||||
payload := p.sanitizePayload(topicName, msgType, data)
|
||||
msg := Message{
|
||||
Type: msgType,
|
||||
From: p.host.ID().String(),
|
||||
Timestamp: time.Now(),
|
||||
Data: data,
|
||||
Data: payload,
|
||||
}
|
||||
|
||||
msgBytes, err := json.Marshal(msg)
|
||||
@@ -356,34 +412,35 @@ func (p *PubSub) PublishToDynamicTopic(topicName string, msgType MessageType, da
|
||||
// wrapping it in the CHORUS Message envelope. Intended for HMMM per-issue rooms
|
||||
// or other modules that maintain their own schemas.
|
||||
func (p *PubSub) PublishRaw(topicName string, payload []byte) error {
|
||||
// Dynamic topic
|
||||
p.dynamicTopicsMux.RLock()
|
||||
if topic, exists := p.dynamicTopics[topicName]; exists {
|
||||
p.dynamicTopicsMux.RUnlock()
|
||||
return topic.Publish(p.ctx, payload)
|
||||
}
|
||||
p.dynamicTopicsMux.RUnlock()
|
||||
// Dynamic topic
|
||||
p.dynamicTopicsMux.RLock()
|
||||
if topic, exists := p.dynamicTopics[topicName]; exists {
|
||||
p.dynamicTopicsMux.RUnlock()
|
||||
return topic.Publish(p.ctx, payload)
|
||||
}
|
||||
p.dynamicTopicsMux.RUnlock()
|
||||
|
||||
// Static topics by name
|
||||
switch topicName {
|
||||
case p.chorusTopicName:
|
||||
return p.chorusTopic.Publish(p.ctx, payload)
|
||||
case p.hmmmTopicName:
|
||||
return p.hmmmTopic.Publish(p.ctx, payload)
|
||||
case p.contextTopicName:
|
||||
return p.contextTopic.Publish(p.ctx, payload)
|
||||
default:
|
||||
return fmt.Errorf("not subscribed to topic: %s", topicName)
|
||||
}
|
||||
// Static topics by name
|
||||
switch topicName {
|
||||
case p.chorusTopicName:
|
||||
return p.chorusTopic.Publish(p.ctx, payload)
|
||||
case p.hmmmTopicName:
|
||||
return p.hmmmTopic.Publish(p.ctx, payload)
|
||||
case p.contextTopicName:
|
||||
return p.contextTopic.Publish(p.ctx, payload)
|
||||
default:
|
||||
return fmt.Errorf("not subscribed to topic: %s", topicName)
|
||||
}
|
||||
}
|
||||
|
||||
// PublishBzzzMessage publishes a message to the Bzzz coordination topic
|
||||
func (p *PubSub) PublishBzzzMessage(msgType MessageType, data map[string]interface{}) error {
|
||||
payload := p.sanitizePayload(p.chorusTopicName, msgType, data)
|
||||
msg := Message{
|
||||
Type: msgType,
|
||||
From: p.host.ID().String(),
|
||||
Timestamp: time.Now(),
|
||||
Data: data,
|
||||
Data: payload,
|
||||
}
|
||||
|
||||
msgBytes, err := json.Marshal(msg)
|
||||
@@ -396,11 +453,12 @@ func (p *PubSub) PublishBzzzMessage(msgType MessageType, data map[string]interfa
|
||||
|
||||
// PublishHmmmMessage publishes a message to the HMMM meta-discussion topic
|
||||
func (p *PubSub) PublishHmmmMessage(msgType MessageType, data map[string]interface{}) error {
|
||||
payload := p.sanitizePayload(p.hmmmTopicName, msgType, data)
|
||||
msg := Message{
|
||||
Type: msgType,
|
||||
From: p.host.ID().String(),
|
||||
Timestamp: time.Now(),
|
||||
Data: data,
|
||||
Data: payload,
|
||||
}
|
||||
|
||||
msgBytes, err := json.Marshal(msg)
|
||||
@@ -425,11 +483,12 @@ func (p *PubSub) SetAntennaeMessageHandler(handler func(msg Message, from peer.I
|
||||
|
||||
// PublishContextFeedbackMessage publishes a message to the Context Feedback topic
|
||||
func (p *PubSub) PublishContextFeedbackMessage(msgType MessageType, data map[string]interface{}) error {
|
||||
payload := p.sanitizePayload(p.contextTopicName, msgType, data)
|
||||
msg := Message{
|
||||
Type: msgType,
|
||||
From: p.host.ID().String(),
|
||||
Timestamp: time.Now(),
|
||||
Data: data,
|
||||
Data: payload,
|
||||
}
|
||||
|
||||
msgBytes, err := json.Marshal(msg)
|
||||
@@ -442,11 +501,16 @@ func (p *PubSub) PublishContextFeedbackMessage(msgType MessageType, data map[str
|
||||
|
||||
// PublishRoleBasedMessage publishes a role-based collaboration message
|
||||
func (p *PubSub) PublishRoleBasedMessage(msgType MessageType, data map[string]interface{}, opts MessageOptions) error {
|
||||
topicName := p.chorusTopicName
|
||||
if isRoleMessage(msgType) {
|
||||
topicName = p.hmmmTopicName
|
||||
}
|
||||
payload := p.sanitizePayload(topicName, msgType, data)
|
||||
msg := Message{
|
||||
Type: msgType,
|
||||
From: p.host.ID().String(),
|
||||
Timestamp: time.Now(),
|
||||
Data: data,
|
||||
Data: payload,
|
||||
FromRole: opts.FromRole,
|
||||
ToRoles: opts.ToRoles,
|
||||
RequiredExpertise: opts.RequiredExpertise,
|
||||
@@ -462,10 +526,8 @@ func (p *PubSub) PublishRoleBasedMessage(msgType MessageType, data map[string]in
|
||||
|
||||
// Determine which topic to use based on message type
|
||||
var topic *pubsub.Topic
|
||||
switch msgType {
|
||||
case RoleAnnouncement, ExpertiseRequest, ExpertiseResponse, StatusUpdate,
|
||||
WorkAllocation, RoleCollaboration, MentorshipRequest, MentorshipResponse,
|
||||
ProjectUpdate, DeliverableReady:
|
||||
switch {
|
||||
case isRoleMessage(msgType):
|
||||
topic = p.hmmmTopic // Use HMMM topic for role-based messages
|
||||
default:
|
||||
topic = p.chorusTopic // Default to Bzzz topic
|
||||
@@ -492,14 +554,14 @@ func (p *PubSub) PublishSlurpContextUpdate(data map[string]interface{}) error {
|
||||
// PublishSlurpIntegrationEvent publishes a generic SLURP integration event
|
||||
func (p *PubSub) PublishSlurpIntegrationEvent(eventType string, discussionID string, slurpEvent map[string]interface{}) error {
|
||||
data := map[string]interface{}{
|
||||
"event_type": eventType,
|
||||
"discussion_id": discussionID,
|
||||
"slurp_event": slurpEvent,
|
||||
"timestamp": time.Now(),
|
||||
"source": "hmmm-slurp-integration",
|
||||
"peer_id": p.host.ID().String(),
|
||||
"event_type": eventType,
|
||||
"discussion_id": discussionID,
|
||||
"slurp_event": slurpEvent,
|
||||
"timestamp": time.Now(),
|
||||
"source": "hmmm-slurp-integration",
|
||||
"peer_id": p.host.ID().String(),
|
||||
}
|
||||
|
||||
|
||||
return p.PublishSlurpEventGenerated(data)
|
||||
}
|
||||
|
||||
@@ -604,15 +666,23 @@ func (p *PubSub) handleContextFeedbackMessages() {
|
||||
}
|
||||
}
|
||||
|
||||
// getDynamicHandler returns the raw handler for a topic if registered.
|
||||
func (p *PubSub) getDynamicHandler(topicName string) func([]byte, peer.ID) {
|
||||
p.dynamicHandlersMux.RLock()
|
||||
handler := p.dynamicHandlers[topicName]
|
||||
p.dynamicHandlersMux.RUnlock()
|
||||
return handler
|
||||
}
|
||||
|
||||
// handleDynamicMessages processes messages from a dynamic topic subscription
|
||||
func (p *PubSub) handleDynamicMessages(sub *pubsub.Subscription) {
|
||||
func (p *PubSub) handleDynamicMessages(topicName string, sub *pubsub.Subscription) {
|
||||
for {
|
||||
msg, err := sub.Next(p.ctx)
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil || err.Error() == "subscription cancelled" {
|
||||
return // Subscription was cancelled, exit handler
|
||||
}
|
||||
fmt.Printf("❌ Error receiving dynamic message: %v\n", err)
|
||||
fmt.Printf("❌ Error receiving dynamic message on %s: %v\n", topicName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -620,13 +690,18 @@ func (p *PubSub) handleDynamicMessages(sub *pubsub.Subscription) {
|
||||
continue
|
||||
}
|
||||
|
||||
var dynamicMsg Message
|
||||
if err := json.Unmarshal(msg.Data, &dynamicMsg); err != nil {
|
||||
fmt.Printf("❌ Failed to unmarshal dynamic message: %v\n", err)
|
||||
if handler := p.getDynamicHandler(topicName); handler != nil {
|
||||
handler(msg.Data, msg.ReceivedFrom)
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the main HMMM handler for all dynamic messages
|
||||
var dynamicMsg Message
|
||||
if err := json.Unmarshal(msg.Data, &dynamicMsg); err != nil {
|
||||
fmt.Printf("❌ Failed to unmarshal dynamic message on %s: %v\n", topicName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the main HMMM handler for all dynamic messages without custom handlers
|
||||
if p.HmmmMessageHandler != nil {
|
||||
p.HmmmMessageHandler(dynamicMsg, msg.ReceivedFrom)
|
||||
}
|
||||
@@ -636,7 +711,7 @@ func (p *PubSub) handleDynamicMessages(sub *pubsub.Subscription) {
|
||||
// processBzzzMessage handles different types of Bzzz coordination messages
|
||||
func (p *PubSub) processBzzzMessage(msg Message, from peer.ID) {
|
||||
fmt.Printf("🐝 Bzzz [%s] from %s: %v\n", msg.Type, from.ShortString(), msg.Data)
|
||||
|
||||
|
||||
// Log to hypercore if logger is available
|
||||
if p.hypercoreLog != nil {
|
||||
logData := map[string]interface{}{
|
||||
@@ -647,7 +722,7 @@ func (p *PubSub) processBzzzMessage(msg Message, from peer.ID) {
|
||||
"data": msg.Data,
|
||||
"topic": "CHORUS",
|
||||
}
|
||||
|
||||
|
||||
// Map pubsub message types to hypercore log types
|
||||
var logType string
|
||||
switch msg.Type {
|
||||
@@ -666,7 +741,7 @@ func (p *PubSub) processBzzzMessage(msg Message, from peer.ID) {
|
||||
default:
|
||||
logType = "network_event"
|
||||
}
|
||||
|
||||
|
||||
if err := p.hypercoreLog.AppendString(logType, logData); err != nil {
|
||||
fmt.Printf("❌ Failed to log Bzzz message to hypercore: %v\n", err)
|
||||
}
|
||||
@@ -675,9 +750,9 @@ func (p *PubSub) processBzzzMessage(msg Message, from peer.ID) {
|
||||
|
||||
// processHmmmMessage provides default handling for HMMM messages if no external handler is set
|
||||
func (p *PubSub) processHmmmMessage(msg Message, from peer.ID) {
|
||||
fmt.Printf("🎯 Default HMMM Handler [%s] from %s: %v\n",
|
||||
fmt.Printf("🎯 Default HMMM Handler [%s] from %s: %v\n",
|
||||
msg.Type, from.ShortString(), msg.Data)
|
||||
|
||||
|
||||
// Log to hypercore if logger is available
|
||||
if p.hypercoreLog != nil {
|
||||
logData := map[string]interface{}{
|
||||
@@ -694,7 +769,7 @@ func (p *PubSub) processHmmmMessage(msg Message, from peer.ID) {
|
||||
"priority": msg.Priority,
|
||||
"thread_id": msg.ThreadID,
|
||||
}
|
||||
|
||||
|
||||
// Map pubsub message types to hypercore log types
|
||||
var logType string
|
||||
switch msg.Type {
|
||||
@@ -717,7 +792,7 @@ func (p *PubSub) processHmmmMessage(msg Message, from peer.ID) {
|
||||
default:
|
||||
logType = "collaboration"
|
||||
}
|
||||
|
||||
|
||||
if err := p.hypercoreLog.AppendString(logType, logData); err != nil {
|
||||
fmt.Printf("❌ Failed to log HMMM message to hypercore: %v\n", err)
|
||||
}
|
||||
@@ -726,25 +801,25 @@ func (p *PubSub) processHmmmMessage(msg Message, from peer.ID) {
|
||||
|
||||
// processContextFeedbackMessage provides default handling for context feedback messages if no external handler is set
|
||||
func (p *PubSub) processContextFeedbackMessage(msg Message, from peer.ID) {
|
||||
fmt.Printf("🧠 Context Feedback [%s] from %s: %v\n",
|
||||
fmt.Printf("🧠 Context Feedback [%s] from %s: %v\n",
|
||||
msg.Type, from.ShortString(), msg.Data)
|
||||
|
||||
|
||||
// Log to hypercore if logger is available
|
||||
if p.hypercoreLog != nil {
|
||||
logData := map[string]interface{}{
|
||||
"message_type": string(msg.Type),
|
||||
"from_peer": from.String(),
|
||||
"from_short": from.ShortString(),
|
||||
"timestamp": msg.Timestamp,
|
||||
"data": msg.Data,
|
||||
"topic": "context_feedback",
|
||||
"from_role": msg.FromRole,
|
||||
"to_roles": msg.ToRoles,
|
||||
"project_id": msg.ProjectID,
|
||||
"priority": msg.Priority,
|
||||
"thread_id": msg.ThreadID,
|
||||
"message_type": string(msg.Type),
|
||||
"from_peer": from.String(),
|
||||
"from_short": from.ShortString(),
|
||||
"timestamp": msg.Timestamp,
|
||||
"data": msg.Data,
|
||||
"topic": "context_feedback",
|
||||
"from_role": msg.FromRole,
|
||||
"to_roles": msg.ToRoles,
|
||||
"project_id": msg.ProjectID,
|
||||
"priority": msg.Priority,
|
||||
"thread_id": msg.ThreadID,
|
||||
}
|
||||
|
||||
|
||||
// Map context feedback message types to hypercore log types
|
||||
var logType string
|
||||
switch msg.Type {
|
||||
@@ -757,17 +832,79 @@ func (p *PubSub) processContextFeedbackMessage(msg Message, from peer.ID) {
|
||||
default:
|
||||
logType = "context_feedback"
|
||||
}
|
||||
|
||||
|
||||
if err := p.hypercoreLog.AppendString(logType, logData); err != nil {
|
||||
fmt.Printf("❌ Failed to log Context Feedback message to hypercore: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSub) sanitizePayload(topic string, msgType MessageType, data map[string]interface{}) map[string]interface{} {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := clonePayloadMap(data)
|
||||
p.redactorMux.RLock()
|
||||
redactor := p.redactor
|
||||
p.redactorMux.RUnlock()
|
||||
if redactor != nil {
|
||||
labels := map[string]string{
|
||||
"source": "pubsub",
|
||||
"topic": topic,
|
||||
"message_type": string(msgType),
|
||||
}
|
||||
redactor.RedactMapWithLabels(context.Background(), cloned, labels)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func isRoleMessage(msgType MessageType) bool {
|
||||
switch msgType {
|
||||
case RoleAnnouncement, ExpertiseRequest, ExpertiseResponse, StatusUpdate,
|
||||
WorkAllocation, RoleCollaboration, MentorshipRequest, MentorshipResponse,
|
||||
ProjectUpdate, DeliverableReady:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func clonePayloadMap(in map[string]interface{}) map[string]interface{} {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]interface{}, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = clonePayloadValue(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func clonePayloadValue(v interface{}) interface{} {
|
||||
switch tv := v.(type) {
|
||||
case map[string]interface{}:
|
||||
return clonePayloadMap(tv)
|
||||
case []interface{}:
|
||||
return clonePayloadSlice(tv)
|
||||
case []string:
|
||||
return append([]string(nil), tv...)
|
||||
default:
|
||||
return tv
|
||||
}
|
||||
}
|
||||
|
||||
func clonePayloadSlice(in []interface{}) []interface{} {
|
||||
out := make([]interface{}, len(in))
|
||||
for i, val := range in {
|
||||
out[i] = clonePayloadValue(val)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Close shuts down the PubSub instance
|
||||
func (p *PubSub) Close() error {
|
||||
p.cancel()
|
||||
|
||||
|
||||
if p.chorusSub != nil {
|
||||
p.chorusSub.Cancel()
|
||||
}
|
||||
@@ -777,7 +914,7 @@ func (p *PubSub) Close() error {
|
||||
if p.contextSub != nil {
|
||||
p.contextSub.Cancel()
|
||||
}
|
||||
|
||||
|
||||
if p.chorusTopic != nil {
|
||||
p.chorusTopic.Close()
|
||||
}
|
||||
@@ -787,7 +924,13 @@ func (p *PubSub) Close() error {
|
||||
if p.contextTopic != nil {
|
||||
p.contextTopic.Close()
|
||||
}
|
||||
|
||||
|
||||
p.dynamicSubsMux.Lock()
|
||||
for _, sub := range p.dynamicSubs {
|
||||
sub.Cancel()
|
||||
}
|
||||
p.dynamicSubsMux.Unlock()
|
||||
|
||||
p.dynamicTopicsMux.Lock()
|
||||
for _, topic := range p.dynamicTopics {
|
||||
topic.Close()
|
||||
|
||||
1
vendor/github.com/Microsoft/go-winio/.gitattributes
generated
vendored
Normal file
1
vendor/github.com/Microsoft/go-winio/.gitattributes
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
* text=auto eol=lf
|
||||
10
vendor/github.com/Microsoft/go-winio/.gitignore
generated
vendored
Normal file
10
vendor/github.com/Microsoft/go-winio/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
.vscode/
|
||||
|
||||
*.exe
|
||||
|
||||
# testing
|
||||
testdata
|
||||
|
||||
# go workspaces
|
||||
go.work
|
||||
go.work.sum
|
||||
147
vendor/github.com/Microsoft/go-winio/.golangci.yml
generated
vendored
Normal file
147
vendor/github.com/Microsoft/go-winio/.golangci.yml
generated
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
linters:
|
||||
enable:
|
||||
# style
|
||||
- containedctx # struct contains a context
|
||||
- dupl # duplicate code
|
||||
- errname # erorrs are named correctly
|
||||
- nolintlint # "//nolint" directives are properly explained
|
||||
- revive # golint replacement
|
||||
- unconvert # unnecessary conversions
|
||||
- wastedassign
|
||||
|
||||
# bugs, performance, unused, etc ...
|
||||
- contextcheck # function uses a non-inherited context
|
||||
- errorlint # errors not wrapped for 1.13
|
||||
- exhaustive # check exhaustiveness of enum switch statements
|
||||
- gofmt # files are gofmt'ed
|
||||
- gosec # security
|
||||
- nilerr # returns nil even with non-nil error
|
||||
- thelper # test helpers without t.Helper()
|
||||
- unparam # unused function params
|
||||
|
||||
issues:
|
||||
exclude-dirs:
|
||||
- pkg/etw/sample
|
||||
|
||||
exclude-rules:
|
||||
# err is very often shadowed in nested scopes
|
||||
- linters:
|
||||
- govet
|
||||
text: '^shadow: declaration of "err" shadows declaration'
|
||||
|
||||
# ignore long lines for skip autogen directives
|
||||
- linters:
|
||||
- revive
|
||||
text: "^line-length-limit: "
|
||||
source: "^//(go:generate|sys) "
|
||||
|
||||
#TODO: remove after upgrading to go1.18
|
||||
# ignore comment spacing for nolint and sys directives
|
||||
- linters:
|
||||
- revive
|
||||
text: "^comment-spacings: no space between comment delimiter and comment text"
|
||||
source: "//(cspell:|nolint:|sys |todo)"
|
||||
|
||||
# not on go 1.18 yet, so no any
|
||||
- linters:
|
||||
- revive
|
||||
text: "^use-any: since GO 1.18 'interface{}' can be replaced by 'any'"
|
||||
|
||||
# allow unjustified ignores of error checks in defer statements
|
||||
- linters:
|
||||
- nolintlint
|
||||
text: "^directive `//nolint:errcheck` should provide explanation"
|
||||
source: '^\s*defer '
|
||||
|
||||
# allow unjustified ignores of error lints for io.EOF
|
||||
- linters:
|
||||
- nolintlint
|
||||
text: "^directive `//nolint:errorlint` should provide explanation"
|
||||
source: '[=|!]= io.EOF'
|
||||
|
||||
|
||||
linters-settings:
|
||||
exhaustive:
|
||||
default-signifies-exhaustive: true
|
||||
govet:
|
||||
enable-all: true
|
||||
disable:
|
||||
# struct order is often for Win32 compat
|
||||
# also, ignore pointer bytes/GC issues for now until performance becomes an issue
|
||||
- fieldalignment
|
||||
nolintlint:
|
||||
require-explanation: true
|
||||
require-specific: true
|
||||
revive:
|
||||
# revive is more configurable than static check, so likely the preferred alternative to static-check
|
||||
# (once the perf issue is solved: https://github.com/golangci/golangci-lint/issues/2997)
|
||||
enable-all-rules:
|
||||
true
|
||||
# https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md
|
||||
rules:
|
||||
# rules with required arguments
|
||||
- name: argument-limit
|
||||
disabled: true
|
||||
- name: banned-characters
|
||||
disabled: true
|
||||
- name: cognitive-complexity
|
||||
disabled: true
|
||||
- name: cyclomatic
|
||||
disabled: true
|
||||
- name: file-header
|
||||
disabled: true
|
||||
- name: function-length
|
||||
disabled: true
|
||||
- name: function-result-limit
|
||||
disabled: true
|
||||
- name: max-public-structs
|
||||
disabled: true
|
||||
# geneally annoying rules
|
||||
- name: add-constant # complains about any and all strings and integers
|
||||
disabled: true
|
||||
- name: confusing-naming # we frequently use "Foo()" and "foo()" together
|
||||
disabled: true
|
||||
- name: flag-parameter # excessive, and a common idiom we use
|
||||
disabled: true
|
||||
- name: unhandled-error # warns over common fmt.Print* and io.Close; rely on errcheck instead
|
||||
disabled: true
|
||||
# general config
|
||||
- name: line-length-limit
|
||||
arguments:
|
||||
- 140
|
||||
- name: var-naming
|
||||
arguments:
|
||||
- []
|
||||
- - CID
|
||||
- CRI
|
||||
- CTRD
|
||||
- DACL
|
||||
- DLL
|
||||
- DOS
|
||||
- ETW
|
||||
- FSCTL
|
||||
- GCS
|
||||
- GMSA
|
||||
- HCS
|
||||
- HV
|
||||
- IO
|
||||
- LCOW
|
||||
- LDAP
|
||||
- LPAC
|
||||
- LTSC
|
||||
- MMIO
|
||||
- NT
|
||||
- OCI
|
||||
- PMEM
|
||||
- PWSH
|
||||
- RX
|
||||
- SACl
|
||||
- SID
|
||||
- SMB
|
||||
- TX
|
||||
- VHD
|
||||
- VHDX
|
||||
- VMID
|
||||
- VPCI
|
||||
- WCOW
|
||||
- WIM
|
||||
1
vendor/github.com/Microsoft/go-winio/CODEOWNERS
generated
vendored
Normal file
1
vendor/github.com/Microsoft/go-winio/CODEOWNERS
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
* @microsoft/containerplat
|
||||
22
vendor/github.com/Microsoft/go-winio/LICENSE
generated
vendored
Normal file
22
vendor/github.com/Microsoft/go-winio/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Microsoft
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
89
vendor/github.com/Microsoft/go-winio/README.md
generated
vendored
Normal file
89
vendor/github.com/Microsoft/go-winio/README.md
generated
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
# go-winio [](https://github.com/microsoft/go-winio/actions/workflows/ci.yml)
|
||||
|
||||
This repository contains utilities for efficiently performing Win32 IO operations in
|
||||
Go. Currently, this is focused on accessing named pipes and other file handles, and
|
||||
for using named pipes as a net transport.
|
||||
|
||||
This code relies on IO completion ports to avoid blocking IO on system threads, allowing Go
|
||||
to reuse the thread to schedule another goroutine. This limits support to Windows Vista and
|
||||
newer operating systems. This is similar to the implementation of network sockets in Go's net
|
||||
package.
|
||||
|
||||
Please see the LICENSE file for licensing information.
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions.
|
||||
Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that
|
||||
you have the right to, and actually do, grant us the rights to use your contribution.
|
||||
For details, visit [Microsoft CLA](https://cla.microsoft.com).
|
||||
|
||||
When you submit a pull request, a CLA-bot will automatically determine whether you need to
|
||||
provide a CLA and decorate the PR appropriately (e.g., label, comment).
|
||||
Simply follow the instructions provided by the bot.
|
||||
You will only need to do this once across all repos using our CLA.
|
||||
|
||||
Additionally, the pull request pipeline requires the following steps to be performed before
|
||||
mergining.
|
||||
|
||||
### Code Sign-Off
|
||||
|
||||
We require that contributors sign their commits using [`git commit --signoff`][git-commit-s]
|
||||
to certify they either authored the work themselves or otherwise have permission to use it in this project.
|
||||
|
||||
A range of commits can be signed off using [`git rebase --signoff`][git-rebase-s].
|
||||
|
||||
Please see [the developer certificate](https://developercertificate.org) for more info,
|
||||
as well as to make sure that you can attest to the rules listed.
|
||||
Our CI uses the DCO Github app to ensure that all commits in a given PR are signed-off.
|
||||
|
||||
### Linting
|
||||
|
||||
Code must pass a linting stage, which uses [`golangci-lint`][lint].
|
||||
The linting settings are stored in [`.golangci.yaml`](./.golangci.yaml), and can be run
|
||||
automatically with VSCode by adding the following to your workspace or folder settings:
|
||||
|
||||
```json
|
||||
"go.lintTool": "golangci-lint",
|
||||
"go.lintOnSave": "package",
|
||||
```
|
||||
|
||||
Additional editor [integrations options are also available][lint-ide].
|
||||
|
||||
Alternatively, `golangci-lint` can be [installed locally][lint-install] and run from the repo root:
|
||||
|
||||
```shell
|
||||
# use . or specify a path to only lint a package
|
||||
# to show all lint errors, use flags "--max-issues-per-linter=0 --max-same-issues=0"
|
||||
> golangci-lint run ./...
|
||||
```
|
||||
|
||||
### Go Generate
|
||||
|
||||
The pipeline checks that auto-generated code, via `go generate`, are up to date.
|
||||
|
||||
This can be done for the entire repo:
|
||||
|
||||
```shell
|
||||
> go generate ./...
|
||||
```
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Special Thanks
|
||||
|
||||
Thanks to [natefinch][natefinch] for the inspiration for this library.
|
||||
See [npipe](https://github.com/natefinch/npipe) for another named pipe implementation.
|
||||
|
||||
[lint]: https://golangci-lint.run/
|
||||
[lint-ide]: https://golangci-lint.run/usage/integrations/#editor-integration
|
||||
[lint-install]: https://golangci-lint.run/usage/install/#local-installation
|
||||
|
||||
[git-commit-s]: https://git-scm.com/docs/git-commit#Documentation/git-commit.txt--s
|
||||
[git-rebase-s]: https://git-scm.com/docs/git-rebase#Documentation/git-rebase.txt---signoff
|
||||
|
||||
[natefinch]: https://github.com/natefinch
|
||||
41
vendor/github.com/Microsoft/go-winio/SECURITY.md
generated
vendored
Normal file
41
vendor/github.com/Microsoft/go-winio/SECURITY.md
generated
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
||||
287
vendor/github.com/Microsoft/go-winio/backup.go
generated
vendored
Normal file
287
vendor/github.com/Microsoft/go-winio/backup.go
generated
vendored
Normal file
@@ -0,0 +1,287 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"unicode/utf16"
|
||||
|
||||
"github.com/Microsoft/go-winio/internal/fs"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead
|
||||
//sys backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite
|
||||
|
||||
const (
|
||||
BackupData = uint32(iota + 1)
|
||||
BackupEaData
|
||||
BackupSecurity
|
||||
BackupAlternateData
|
||||
BackupLink
|
||||
BackupPropertyData
|
||||
BackupObjectId //revive:disable-line:var-naming ID, not Id
|
||||
BackupReparseData
|
||||
BackupSparseBlock
|
||||
BackupTxfsData
|
||||
)
|
||||
|
||||
const (
|
||||
StreamSparseAttributes = uint32(8)
|
||||
)
|
||||
|
||||
//nolint:revive // var-naming: ALL_CAPS
|
||||
const (
|
||||
WRITE_DAC = windows.WRITE_DAC
|
||||
WRITE_OWNER = windows.WRITE_OWNER
|
||||
ACCESS_SYSTEM_SECURITY = windows.ACCESS_SYSTEM_SECURITY
|
||||
)
|
||||
|
||||
// BackupHeader represents a backup stream of a file.
|
||||
type BackupHeader struct {
|
||||
//revive:disable-next-line:var-naming ID, not Id
|
||||
Id uint32 // The backup stream ID
|
||||
Attributes uint32 // Stream attributes
|
||||
Size int64 // The size of the stream in bytes
|
||||
Name string // The name of the stream (for BackupAlternateData only).
|
||||
Offset int64 // The offset of the stream in the file (for BackupSparseBlock only).
|
||||
}
|
||||
|
||||
type win32StreamID struct {
|
||||
StreamID uint32
|
||||
Attributes uint32
|
||||
Size uint64
|
||||
NameSize uint32
|
||||
}
|
||||
|
||||
// BackupStreamReader reads from a stream produced by the BackupRead Win32 API and produces a series
|
||||
// of BackupHeader values.
|
||||
type BackupStreamReader struct {
|
||||
r io.Reader
|
||||
bytesLeft int64
|
||||
}
|
||||
|
||||
// NewBackupStreamReader produces a BackupStreamReader from any io.Reader.
|
||||
func NewBackupStreamReader(r io.Reader) *BackupStreamReader {
|
||||
return &BackupStreamReader{r, 0}
|
||||
}
|
||||
|
||||
// Next returns the next backup stream and prepares for calls to Read(). It skips the remainder of the current stream if
|
||||
// it was not completely read.
|
||||
func (r *BackupStreamReader) Next() (*BackupHeader, error) {
|
||||
if r.bytesLeft > 0 { //nolint:nestif // todo: flatten this
|
||||
if s, ok := r.r.(io.Seeker); ok {
|
||||
// Make sure Seek on io.SeekCurrent sometimes succeeds
|
||||
// before trying the actual seek.
|
||||
if _, err := s.Seek(0, io.SeekCurrent); err == nil {
|
||||
if _, err = s.Seek(r.bytesLeft, io.SeekCurrent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.bytesLeft = 0
|
||||
}
|
||||
}
|
||||
if _, err := io.Copy(io.Discard, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var wsi win32StreamID
|
||||
if err := binary.Read(r.r, binary.LittleEndian, &wsi); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr := &BackupHeader{
|
||||
Id: wsi.StreamID,
|
||||
Attributes: wsi.Attributes,
|
||||
Size: int64(wsi.Size),
|
||||
}
|
||||
if wsi.NameSize != 0 {
|
||||
name := make([]uint16, int(wsi.NameSize/2))
|
||||
if err := binary.Read(r.r, binary.LittleEndian, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr.Name = windows.UTF16ToString(name)
|
||||
}
|
||||
if wsi.StreamID == BackupSparseBlock {
|
||||
if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr.Size -= 8
|
||||
}
|
||||
r.bytesLeft = hdr.Size
|
||||
return hdr, nil
|
||||
}
|
||||
|
||||
// Read reads from the current backup stream.
|
||||
func (r *BackupStreamReader) Read(b []byte) (int, error) {
|
||||
if r.bytesLeft == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if int64(len(b)) > r.bytesLeft {
|
||||
b = b[:r.bytesLeft]
|
||||
}
|
||||
n, err := r.r.Read(b)
|
||||
r.bytesLeft -= int64(n)
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
} else if r.bytesLeft == 0 && err == nil {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// BackupStreamWriter writes a stream compatible with the BackupWrite Win32 API.
|
||||
type BackupStreamWriter struct {
|
||||
w io.Writer
|
||||
bytesLeft int64
|
||||
}
|
||||
|
||||
// NewBackupStreamWriter produces a BackupStreamWriter on top of an io.Writer.
|
||||
func NewBackupStreamWriter(w io.Writer) *BackupStreamWriter {
|
||||
return &BackupStreamWriter{w, 0}
|
||||
}
|
||||
|
||||
// WriteHeader writes the next backup stream header and prepares for calls to Write().
|
||||
func (w *BackupStreamWriter) WriteHeader(hdr *BackupHeader) error {
|
||||
if w.bytesLeft != 0 {
|
||||
return fmt.Errorf("missing %d bytes", w.bytesLeft)
|
||||
}
|
||||
name := utf16.Encode([]rune(hdr.Name))
|
||||
wsi := win32StreamID{
|
||||
StreamID: hdr.Id,
|
||||
Attributes: hdr.Attributes,
|
||||
Size: uint64(hdr.Size),
|
||||
NameSize: uint32(len(name) * 2),
|
||||
}
|
||||
if hdr.Id == BackupSparseBlock {
|
||||
// Include space for the int64 block offset
|
||||
wsi.Size += 8
|
||||
}
|
||||
if err := binary.Write(w.w, binary.LittleEndian, &wsi); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(name) != 0 {
|
||||
if err := binary.Write(w.w, binary.LittleEndian, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if hdr.Id == BackupSparseBlock {
|
||||
if err := binary.Write(w.w, binary.LittleEndian, hdr.Offset); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.bytesLeft = hdr.Size
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes to the current backup stream.
|
||||
func (w *BackupStreamWriter) Write(b []byte) (int, error) {
|
||||
if w.bytesLeft < int64(len(b)) {
|
||||
return 0, fmt.Errorf("too many bytes by %d", int64(len(b))-w.bytesLeft)
|
||||
}
|
||||
n, err := w.w.Write(b)
|
||||
w.bytesLeft -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// BackupFileReader provides an io.ReadCloser interface on top of the BackupRead Win32 API.
|
||||
type BackupFileReader struct {
|
||||
f *os.File
|
||||
includeSecurity bool
|
||||
ctx uintptr
|
||||
}
|
||||
|
||||
// NewBackupFileReader returns a new BackupFileReader from a file handle. If includeSecurity is true,
|
||||
// Read will attempt to read the security descriptor of the file.
|
||||
func NewBackupFileReader(f *os.File, includeSecurity bool) *BackupFileReader {
|
||||
r := &BackupFileReader{f, includeSecurity, 0}
|
||||
return r
|
||||
}
|
||||
|
||||
// Read reads a backup stream from the file by calling the Win32 API BackupRead().
|
||||
func (r *BackupFileReader) Read(b []byte) (int, error) {
|
||||
var bytesRead uint32
|
||||
err := backupRead(windows.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(r.f)
|
||||
if bytesRead == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return int(bytesRead), nil
|
||||
}
|
||||
|
||||
// Close frees Win32 resources associated with the BackupFileReader. It does not close
|
||||
// the underlying file.
|
||||
func (r *BackupFileReader) Close() error {
|
||||
if r.ctx != 0 {
|
||||
_ = backupRead(windows.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx)
|
||||
runtime.KeepAlive(r.f)
|
||||
r.ctx = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackupFileWriter provides an io.WriteCloser interface on top of the BackupWrite Win32 API.
|
||||
type BackupFileWriter struct {
|
||||
f *os.File
|
||||
includeSecurity bool
|
||||
ctx uintptr
|
||||
}
|
||||
|
||||
// NewBackupFileWriter returns a new BackupFileWriter from a file handle. If includeSecurity is true,
|
||||
// Write() will attempt to restore the security descriptor from the stream.
|
||||
func NewBackupFileWriter(f *os.File, includeSecurity bool) *BackupFileWriter {
|
||||
w := &BackupFileWriter{f, includeSecurity, 0}
|
||||
return w
|
||||
}
|
||||
|
||||
// Write restores a portion of the file using the provided backup stream.
|
||||
func (w *BackupFileWriter) Write(b []byte) (int, error) {
|
||||
var bytesWritten uint32
|
||||
err := backupWrite(windows.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(w.f)
|
||||
if int(bytesWritten) != len(b) {
|
||||
return int(bytesWritten), errors.New("not all bytes could be written")
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Close frees Win32 resources associated with the BackupFileWriter. It does not
|
||||
// close the underlying file.
|
||||
func (w *BackupFileWriter) Close() error {
|
||||
if w.ctx != 0 {
|
||||
_ = backupWrite(windows.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx)
|
||||
runtime.KeepAlive(w.f)
|
||||
w.ctx = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenForBackup opens a file or directory, potentially skipping access checks if the backup
|
||||
// or restore privileges have been acquired.
|
||||
//
|
||||
// If the file opened was a directory, it cannot be used with Readdir().
|
||||
func OpenForBackup(path string, access uint32, share uint32, createmode uint32) (*os.File, error) {
|
||||
h, err := fs.CreateFile(path,
|
||||
fs.AccessMask(access),
|
||||
fs.FileShareMode(share),
|
||||
nil,
|
||||
fs.FileCreationDisposition(createmode),
|
||||
fs.FILE_FLAG_BACKUP_SEMANTICS|fs.FILE_FLAG_OPEN_REPARSE_POINT,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
err = &os.PathError{Op: "open", Path: path, Err: err}
|
||||
return nil, err
|
||||
}
|
||||
return os.NewFile(uintptr(h), path), nil
|
||||
}
|
||||
22
vendor/github.com/Microsoft/go-winio/doc.go
generated
vendored
Normal file
22
vendor/github.com/Microsoft/go-winio/doc.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
// This package provides utilities for efficiently performing Win32 IO operations in Go.
|
||||
// Currently, this package is provides support for genreal IO and management of
|
||||
// - named pipes
|
||||
// - files
|
||||
// - [Hyper-V sockets]
|
||||
//
|
||||
// This code is similar to Go's [net] package, and uses IO completion ports to avoid
|
||||
// blocking IO on system threads, allowing Go to reuse the thread to schedule other goroutines.
|
||||
//
|
||||
// This limits support to Windows Vista and newer operating systems.
|
||||
//
|
||||
// Additionally, this package provides support for:
|
||||
// - creating and managing GUIDs
|
||||
// - writing to [ETW]
|
||||
// - opening and manageing VHDs
|
||||
// - parsing [Windows Image files]
|
||||
// - auto-generating Win32 API code
|
||||
//
|
||||
// [Hyper-V sockets]: https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service
|
||||
// [ETW]: https://docs.microsoft.com/en-us/windows-hardware/drivers/devtest/event-tracing-for-windows--etw-
|
||||
// [Windows Image files]: https://docs.microsoft.com/en-us/windows-hardware/manufacture/desktop/work-with-windows-images
|
||||
package winio
|
||||
137
vendor/github.com/Microsoft/go-winio/ea.go
generated
vendored
Normal file
137
vendor/github.com/Microsoft/go-winio/ea.go
generated
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type fileFullEaInformation struct {
|
||||
NextEntryOffset uint32
|
||||
Flags uint8
|
||||
NameLength uint8
|
||||
ValueLength uint16
|
||||
}
|
||||
|
||||
var (
|
||||
fileFullEaInformationSize = binary.Size(&fileFullEaInformation{})
|
||||
|
||||
errInvalidEaBuffer = errors.New("invalid extended attribute buffer")
|
||||
errEaNameTooLarge = errors.New("extended attribute name too large")
|
||||
errEaValueTooLarge = errors.New("extended attribute value too large")
|
||||
)
|
||||
|
||||
// ExtendedAttribute represents a single Windows EA.
|
||||
type ExtendedAttribute struct {
|
||||
Name string
|
||||
Value []byte
|
||||
Flags uint8
|
||||
}
|
||||
|
||||
func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
|
||||
var info fileFullEaInformation
|
||||
err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
err = errInvalidEaBuffer
|
||||
return ea, nb, err
|
||||
}
|
||||
|
||||
nameOffset := fileFullEaInformationSize
|
||||
nameLen := int(info.NameLength)
|
||||
valueOffset := nameOffset + int(info.NameLength) + 1
|
||||
valueLen := int(info.ValueLength)
|
||||
nextOffset := int(info.NextEntryOffset)
|
||||
if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) {
|
||||
err = errInvalidEaBuffer
|
||||
return ea, nb, err
|
||||
}
|
||||
|
||||
ea.Name = string(b[nameOffset : nameOffset+nameLen])
|
||||
ea.Value = b[valueOffset : valueOffset+valueLen]
|
||||
ea.Flags = info.Flags
|
||||
if info.NextEntryOffset != 0 {
|
||||
nb = b[info.NextEntryOffset:]
|
||||
}
|
||||
return ea, nb, err
|
||||
}
|
||||
|
||||
// DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION
|
||||
// buffer retrieved from BackupRead, ZwQueryEaFile, etc.
|
||||
func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) {
|
||||
for len(b) != 0 {
|
||||
ea, nb, err := parseEa(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eas = append(eas, ea)
|
||||
b = nb
|
||||
}
|
||||
return eas, err
|
||||
}
|
||||
|
||||
func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error {
|
||||
if int(uint8(len(ea.Name))) != len(ea.Name) {
|
||||
return errEaNameTooLarge
|
||||
}
|
||||
if int(uint16(len(ea.Value))) != len(ea.Value) {
|
||||
return errEaValueTooLarge
|
||||
}
|
||||
entrySize := uint32(fileFullEaInformationSize + len(ea.Name) + 1 + len(ea.Value))
|
||||
withPadding := (entrySize + 3) &^ 3
|
||||
nextOffset := uint32(0)
|
||||
if !last {
|
||||
nextOffset = withPadding
|
||||
}
|
||||
info := fileFullEaInformation{
|
||||
NextEntryOffset: nextOffset,
|
||||
Flags: ea.Flags,
|
||||
NameLength: uint8(len(ea.Name)),
|
||||
ValueLength: uint16(len(ea.Value)),
|
||||
}
|
||||
|
||||
err := binary.Write(buf, binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte(ea.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = buf.WriteByte(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write(ea.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte{0, 0, 0}[0 : withPadding-entrySize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeExtendedAttributes encodes a list of EAs into a FILE_FULL_EA_INFORMATION
|
||||
// buffer for use with BackupWrite, ZwSetEaFile, etc.
|
||||
func EncodeExtendedAttributes(eas []ExtendedAttribute) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
for i := range eas {
|
||||
last := false
|
||||
if i == len(eas)-1 {
|
||||
last = true
|
||||
}
|
||||
|
||||
err := writeEa(&buf, &eas[i], last)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
320
vendor/github.com/Microsoft/go-winio/file.go
generated
vendored
Normal file
320
vendor/github.com/Microsoft/go-winio/file.go
generated
vendored
Normal file
@@ -0,0 +1,320 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
|
||||
//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
|
||||
//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
||||
//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
||||
//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
||||
|
||||
var (
|
||||
ErrFileClosed = errors.New("file has already been closed")
|
||||
ErrTimeout = &timeoutError{}
|
||||
)
|
||||
|
||||
type timeoutError struct{}
|
||||
|
||||
func (*timeoutError) Error() string { return "i/o timeout" }
|
||||
func (*timeoutError) Timeout() bool { return true }
|
||||
func (*timeoutError) Temporary() bool { return true }
|
||||
|
||||
type timeoutChan chan struct{}
|
||||
|
||||
var ioInitOnce sync.Once
|
||||
var ioCompletionPort windows.Handle
|
||||
|
||||
// ioResult contains the result of an asynchronous IO operation.
|
||||
type ioResult struct {
|
||||
bytes uint32
|
||||
err error
|
||||
}
|
||||
|
||||
// ioOperation represents an outstanding asynchronous Win32 IO.
|
||||
type ioOperation struct {
|
||||
o windows.Overlapped
|
||||
ch chan ioResult
|
||||
}
|
||||
|
||||
func initIO() {
|
||||
h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ioCompletionPort = h
|
||||
go ioCompletionProcessor(h)
|
||||
}
|
||||
|
||||
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
||||
// It takes ownership of this handle and will close it if it is garbage collected.
|
||||
type win32File struct {
|
||||
handle windows.Handle
|
||||
wg sync.WaitGroup
|
||||
wgLock sync.RWMutex
|
||||
closing atomic.Bool
|
||||
socket bool
|
||||
readDeadline deadlineHandler
|
||||
writeDeadline deadlineHandler
|
||||
}
|
||||
|
||||
type deadlineHandler struct {
|
||||
setLock sync.Mutex
|
||||
channel timeoutChan
|
||||
channelLock sync.RWMutex
|
||||
timer *time.Timer
|
||||
timedout atomic.Bool
|
||||
}
|
||||
|
||||
// makeWin32File makes a new win32File from an existing file handle.
|
||||
func makeWin32File(h windows.Handle) (*win32File, error) {
|
||||
f := &win32File{handle: h}
|
||||
ioInitOnce.Do(initIO)
|
||||
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = setFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.readDeadline.channel = make(timeoutChan)
|
||||
f.writeDeadline.channel = make(timeoutChan)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Deprecated: use NewOpenFile instead.
|
||||
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
|
||||
return NewOpenFile(windows.Handle(h))
|
||||
}
|
||||
|
||||
func NewOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
|
||||
// If we return the result of makeWin32File directly, it can result in an
|
||||
// interface-wrapped nil, rather than a nil interface value.
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// closeHandle closes the resources associated with a Win32 handle.
|
||||
func (f *win32File) closeHandle() {
|
||||
f.wgLock.Lock()
|
||||
// Atomically set that we are closing, releasing the resources only once.
|
||||
if !f.closing.Swap(true) {
|
||||
f.wgLock.Unlock()
|
||||
// cancel all IO and wait for it to complete
|
||||
_ = cancelIoEx(f.handle, nil)
|
||||
f.wg.Wait()
|
||||
// at this point, no new IO can start
|
||||
windows.Close(f.handle)
|
||||
f.handle = 0
|
||||
} else {
|
||||
f.wgLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes a win32File.
|
||||
func (f *win32File) Close() error {
|
||||
f.closeHandle()
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsClosed checks if the file has been closed.
|
||||
func (f *win32File) IsClosed() bool {
|
||||
return f.closing.Load()
|
||||
}
|
||||
|
||||
// prepareIO prepares for a new IO operation.
|
||||
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||
func (f *win32File) prepareIO() (*ioOperation, error) {
|
||||
f.wgLock.RLock()
|
||||
if f.closing.Load() {
|
||||
f.wgLock.RUnlock()
|
||||
return nil, ErrFileClosed
|
||||
}
|
||||
f.wg.Add(1)
|
||||
f.wgLock.RUnlock()
|
||||
c := &ioOperation{}
|
||||
c.ch = make(chan ioResult)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// ioCompletionProcessor processes completed async IOs forever.
|
||||
func ioCompletionProcessor(h windows.Handle) {
|
||||
for {
|
||||
var bytes uint32
|
||||
var key uintptr
|
||||
var op *ioOperation
|
||||
err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
|
||||
if op == nil {
|
||||
panic(err)
|
||||
}
|
||||
op.ch <- ioResult{bytes, err}
|
||||
}
|
||||
}
|
||||
|
||||
// todo: helsaawy - create an asyncIO version that takes a context
|
||||
|
||||
// asyncIO processes the return value from ReadFile or WriteFile, blocking until
|
||||
// the operation has actually completed.
|
||||
func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
||||
if err != windows.ERROR_IO_PENDING { //nolint:errorlint // err is Errno
|
||||
return int(bytes), err
|
||||
}
|
||||
|
||||
if f.closing.Load() {
|
||||
_ = cancelIoEx(f.handle, &c.o)
|
||||
}
|
||||
|
||||
var timeout timeoutChan
|
||||
if d != nil {
|
||||
d.channelLock.Lock()
|
||||
timeout = d.channel
|
||||
d.channelLock.Unlock()
|
||||
}
|
||||
|
||||
var r ioResult
|
||||
select {
|
||||
case r = <-c.ch:
|
||||
err = r.err
|
||||
if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
|
||||
if f.closing.Load() {
|
||||
err = ErrFileClosed
|
||||
}
|
||||
} else if err != nil && f.socket {
|
||||
// err is from Win32. Query the overlapped structure to get the winsock error.
|
||||
var bytes, flags uint32
|
||||
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
||||
}
|
||||
case <-timeout:
|
||||
_ = cancelIoEx(f.handle, &c.o)
|
||||
r = <-c.ch
|
||||
err = r.err
|
||||
if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
|
||||
err = ErrTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// runtime.KeepAlive is needed, as c is passed via native
|
||||
// code to ioCompletionProcessor, c must remain alive
|
||||
// until the channel read is complete.
|
||||
// todo: (de)allocate *ioOperation via win32 heap functions, instead of needing to KeepAlive?
|
||||
runtime.KeepAlive(c)
|
||||
return int(r.bytes), err
|
||||
}
|
||||
|
||||
// Read reads from a file handle.
|
||||
func (f *win32File) Read(b []byte) (int, error) {
|
||||
c, err := f.prepareIO()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.readDeadline.timedout.Load() {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
|
||||
var bytes uint32
|
||||
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
|
||||
n, err := f.asyncIO(c, &f.readDeadline, bytes, err)
|
||||
runtime.KeepAlive(b)
|
||||
|
||||
// Handle EOF conditions.
|
||||
if err == nil && n == 0 && len(b) != 0 {
|
||||
return 0, io.EOF
|
||||
} else if err == windows.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno
|
||||
return 0, io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write writes to a file handle.
|
||||
func (f *win32File) Write(b []byte) (int, error) {
|
||||
c, err := f.prepareIO()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.writeDeadline.timedout.Load() {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
|
||||
var bytes uint32
|
||||
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
|
||||
n, err := f.asyncIO(c, &f.writeDeadline, bytes, err)
|
||||
runtime.KeepAlive(b)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (f *win32File) SetReadDeadline(deadline time.Time) error {
|
||||
return f.readDeadline.set(deadline)
|
||||
}
|
||||
|
||||
func (f *win32File) SetWriteDeadline(deadline time.Time) error {
|
||||
return f.writeDeadline.set(deadline)
|
||||
}
|
||||
|
||||
func (f *win32File) Flush() error {
|
||||
return windows.FlushFileBuffers(f.handle)
|
||||
}
|
||||
|
||||
func (f *win32File) Fd() uintptr {
|
||||
return uintptr(f.handle)
|
||||
}
|
||||
|
||||
func (d *deadlineHandler) set(deadline time.Time) error {
|
||||
d.setLock.Lock()
|
||||
defer d.setLock.Unlock()
|
||||
|
||||
if d.timer != nil {
|
||||
if !d.timer.Stop() {
|
||||
<-d.channel
|
||||
}
|
||||
d.timer = nil
|
||||
}
|
||||
d.timedout.Store(false)
|
||||
|
||||
select {
|
||||
case <-d.channel:
|
||||
d.channelLock.Lock()
|
||||
d.channel = make(chan struct{})
|
||||
d.channelLock.Unlock()
|
||||
default:
|
||||
}
|
||||
|
||||
if deadline.IsZero() {
|
||||
return nil
|
||||
}
|
||||
|
||||
timeoutIO := func() {
|
||||
d.timedout.Store(true)
|
||||
close(d.channel)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
duration := deadline.Sub(now)
|
||||
if deadline.After(now) {
|
||||
// Deadline is in the future, set a timer to wait
|
||||
d.timer = time.AfterFunc(duration, timeoutIO)
|
||||
} else {
|
||||
// Deadline is in the past. Cancel all pending IO now.
|
||||
timeoutIO()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
106
vendor/github.com/Microsoft/go-winio/fileinfo.go
generated
vendored
Normal file
106
vendor/github.com/Microsoft/go-winio/fileinfo.go
generated
vendored
Normal file
@@ -0,0 +1,106 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// FileBasicInfo contains file access time and file attributes information.
|
||||
type FileBasicInfo struct {
|
||||
CreationTime, LastAccessTime, LastWriteTime, ChangeTime windows.Filetime
|
||||
FileAttributes uint32
|
||||
_ uint32 // padding
|
||||
}
|
||||
|
||||
// alignedFileBasicInfo is a FileBasicInfo, but aligned to uint64 by containing
|
||||
// uint64 rather than windows.Filetime. Filetime contains two uint32s. uint64
|
||||
// alignment is necessary to pass this as FILE_BASIC_INFO.
|
||||
type alignedFileBasicInfo struct {
|
||||
CreationTime, LastAccessTime, LastWriteTime, ChangeTime uint64
|
||||
FileAttributes uint32
|
||||
_ uint32 // padding
|
||||
}
|
||||
|
||||
// GetFileBasicInfo retrieves times and attributes for a file.
|
||||
func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
|
||||
bi := &alignedFileBasicInfo{}
|
||||
if err := windows.GetFileInformationByHandleEx(
|
||||
windows.Handle(f.Fd()),
|
||||
windows.FileBasicInfo,
|
||||
(*byte)(unsafe.Pointer(bi)),
|
||||
uint32(unsafe.Sizeof(*bi)),
|
||||
); err != nil {
|
||||
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
// Reinterpret the alignedFileBasicInfo as a FileBasicInfo so it matches the
|
||||
// public API of this module. The data may be unnecessarily aligned.
|
||||
return (*FileBasicInfo)(unsafe.Pointer(bi)), nil
|
||||
}
|
||||
|
||||
// SetFileBasicInfo sets times and attributes for a file.
|
||||
func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error {
|
||||
// Create an alignedFileBasicInfo based on a FileBasicInfo. The copy is
|
||||
// suitable to pass to GetFileInformationByHandleEx.
|
||||
biAligned := *(*alignedFileBasicInfo)(unsafe.Pointer(bi))
|
||||
if err := windows.SetFileInformationByHandle(
|
||||
windows.Handle(f.Fd()),
|
||||
windows.FileBasicInfo,
|
||||
(*byte)(unsafe.Pointer(&biAligned)),
|
||||
uint32(unsafe.Sizeof(biAligned)),
|
||||
); err != nil {
|
||||
return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileStandardInfo contains extended information for the file.
|
||||
// FILE_STANDARD_INFO in WinBase.h
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-file_standard_info
|
||||
type FileStandardInfo struct {
|
||||
AllocationSize, EndOfFile int64
|
||||
NumberOfLinks uint32
|
||||
DeletePending, Directory bool
|
||||
}
|
||||
|
||||
// GetFileStandardInfo retrieves ended information for the file.
|
||||
func GetFileStandardInfo(f *os.File) (*FileStandardInfo, error) {
|
||||
si := &FileStandardInfo{}
|
||||
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()),
|
||||
windows.FileStandardInfo,
|
||||
(*byte)(unsafe.Pointer(si)),
|
||||
uint32(unsafe.Sizeof(*si))); err != nil {
|
||||
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return si, nil
|
||||
}
|
||||
|
||||
// FileIDInfo contains the volume serial number and file ID for a file. This pair should be
|
||||
// unique on a system.
|
||||
type FileIDInfo struct {
|
||||
VolumeSerialNumber uint64
|
||||
FileID [16]byte
|
||||
}
|
||||
|
||||
// GetFileID retrieves the unique (volume, file ID) pair for a file.
|
||||
func GetFileID(f *os.File) (*FileIDInfo, error) {
|
||||
fileID := &FileIDInfo{}
|
||||
if err := windows.GetFileInformationByHandleEx(
|
||||
windows.Handle(f.Fd()),
|
||||
windows.FileIdInfo,
|
||||
(*byte)(unsafe.Pointer(fileID)),
|
||||
uint32(unsafe.Sizeof(*fileID)),
|
||||
); err != nil {
|
||||
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return fileID, nil
|
||||
}
|
||||
582
vendor/github.com/Microsoft/go-winio/hvsock.go
generated
vendored
Normal file
582
vendor/github.com/Microsoft/go-winio/hvsock.go
generated
vendored
Normal file
@@ -0,0 +1,582 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/Microsoft/go-winio/internal/socket"
|
||||
"github.com/Microsoft/go-winio/pkg/guid"
|
||||
)
|
||||
|
||||
const afHVSock = 34 // AF_HYPERV
|
||||
|
||||
// Well known Service and VM IDs
|
||||
// https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
|
||||
|
||||
// HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
|
||||
func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
|
||||
return guid.GUID{}
|
||||
}
|
||||
|
||||
// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions.
|
||||
func HvsockGUIDBroadcast() guid.GUID { // ffffffff-ffff-ffff-ffff-ffffffffffff
|
||||
return guid.GUID{
|
||||
Data1: 0xffffffff,
|
||||
Data2: 0xffff,
|
||||
Data3: 0xffff,
|
||||
Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
|
||||
func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
|
||||
return guid.GUID{
|
||||
Data1: 0xe0e16197,
|
||||
Data2: 0xdd56,
|
||||
Data3: 0x4a10,
|
||||
Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDSiloHost is the address of a silo's host partition:
|
||||
// - The silo host of a hosted silo is the utility VM.
|
||||
// - The silo host of a silo on a physical host is the physical host.
|
||||
func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
|
||||
return guid.GUID{
|
||||
Data1: 0x36bd0c5c,
|
||||
Data2: 0x7276,
|
||||
Data3: 0x4223,
|
||||
Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
|
||||
func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
|
||||
return guid.GUID{
|
||||
Data1: 0x90db8b89,
|
||||
Data2: 0xd35,
|
||||
Data3: 0x4f79,
|
||||
Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
|
||||
// Listening on this VmId accepts connection from:
|
||||
// - Inside silos: silo host partition.
|
||||
// - Inside hosted silo: host of the VM.
|
||||
// - Inside VM: VM host.
|
||||
// - Physical host: Not supported.
|
||||
func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
|
||||
return guid.GUID{
|
||||
Data1: 0xa42e7cda,
|
||||
Data2: 0xd03f,
|
||||
Data3: 0x480c,
|
||||
Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
|
||||
}
|
||||
}
|
||||
|
||||
// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol.
|
||||
func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
|
||||
return guid.GUID{
|
||||
Data2: 0xfacb,
|
||||
Data3: 0x11e6,
|
||||
Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
|
||||
}
|
||||
}
|
||||
|
||||
// An HvsockAddr is an address for a AF_HYPERV socket.
|
||||
type HvsockAddr struct {
|
||||
VMID guid.GUID
|
||||
ServiceID guid.GUID
|
||||
}
|
||||
|
||||
type rawHvsockAddr struct {
|
||||
Family uint16
|
||||
_ uint16
|
||||
VMID guid.GUID
|
||||
ServiceID guid.GUID
|
||||
}
|
||||
|
||||
var _ socket.RawSockaddr = &rawHvsockAddr{}
|
||||
|
||||
// Network returns the address's network name, "hvsock".
|
||||
func (*HvsockAddr) Network() string {
|
||||
return "hvsock"
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) String() string {
|
||||
return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
|
||||
}
|
||||
|
||||
// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
|
||||
func VsockServiceID(port uint32) guid.GUID {
|
||||
g := hvsockVsockServiceTemplate() // make a copy
|
||||
g.Data1 = port
|
||||
return g
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) raw() rawHvsockAddr {
|
||||
return rawHvsockAddr{
|
||||
Family: afHVSock,
|
||||
VMID: addr.VMID,
|
||||
ServiceID: addr.ServiceID,
|
||||
}
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
|
||||
addr.VMID = raw.VMID
|
||||
addr.ServiceID = raw.ServiceID
|
||||
}
|
||||
|
||||
// Sockaddr returns a pointer to and the size of this struct.
|
||||
//
|
||||
// Implements the [socket.RawSockaddr] interface, and allows use in
|
||||
// [socket.Bind] and [socket.ConnectEx].
|
||||
func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
|
||||
return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
|
||||
}
|
||||
|
||||
// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`.
|
||||
func (r *rawHvsockAddr) FromBytes(b []byte) error {
|
||||
n := int(unsafe.Sizeof(rawHvsockAddr{}))
|
||||
|
||||
if len(b) < n {
|
||||
return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
|
||||
}
|
||||
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
|
||||
if r.Family != afHVSock {
|
||||
return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HvsockListener is a socket listener for the AF_HYPERV address family.
|
||||
type HvsockListener struct {
|
||||
sock *win32File
|
||||
addr HvsockAddr
|
||||
}
|
||||
|
||||
var _ net.Listener = &HvsockListener{}
|
||||
|
||||
// HvsockConn is a connected socket of the AF_HYPERV address family.
|
||||
type HvsockConn struct {
|
||||
sock *win32File
|
||||
local, remote HvsockAddr
|
||||
}
|
||||
|
||||
var _ net.Conn = &HvsockConn{}
|
||||
|
||||
func newHVSocket() (*win32File, error) {
|
||||
fd, err := windows.Socket(afHVSock, windows.SOCK_STREAM, 1)
|
||||
if err != nil {
|
||||
return nil, os.NewSyscallError("socket", err)
|
||||
}
|
||||
f, err := makeWin32File(fd)
|
||||
if err != nil {
|
||||
windows.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
f.socket = true
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// ListenHvsock listens for connections on the specified hvsock address.
|
||||
func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
|
||||
l := &HvsockListener{addr: *addr}
|
||||
|
||||
var sock *win32File
|
||||
sock, err = newHVSocket()
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = sock.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
sa := addr.raw()
|
||||
err = socket.Bind(sock.handle, &sa)
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", os.NewSyscallError("socket", err))
|
||||
}
|
||||
err = windows.Listen(sock.handle, 16)
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", os.NewSyscallError("listen", err))
|
||||
}
|
||||
return &HvsockListener{sock: sock, addr: *addr}, nil
|
||||
}
|
||||
|
||||
func (l *HvsockListener) opErr(op string, err error) error {
|
||||
return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *HvsockListener) Addr() net.Addr {
|
||||
return &l.addr
|
||||
}
|
||||
|
||||
// Accept waits for the next connection and returns it.
|
||||
func (l *HvsockListener) Accept() (_ net.Conn, err error) {
|
||||
sock, err := newHVSocket()
|
||||
if err != nil {
|
||||
return nil, l.opErr("accept", err)
|
||||
}
|
||||
defer func() {
|
||||
if sock != nil {
|
||||
sock.Close()
|
||||
}
|
||||
}()
|
||||
c, err := l.sock.prepareIO()
|
||||
if err != nil {
|
||||
return nil, l.opErr("accept", err)
|
||||
}
|
||||
defer l.sock.wg.Done()
|
||||
|
||||
// AcceptEx, per documentation, requires an extra 16 bytes per address.
|
||||
//
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
|
||||
const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
|
||||
var addrbuf [addrlen * 2]byte
|
||||
|
||||
var bytes uint32
|
||||
err = windows.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o)
|
||||
if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
|
||||
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
|
||||
}
|
||||
|
||||
conn := &HvsockConn{
|
||||
sock: sock,
|
||||
}
|
||||
// The local address returned in the AcceptEx buffer is the same as the Listener socket's
|
||||
// address. However, the service GUID reported by GetSockName is different from the Listeners
|
||||
// socket, and is sometimes the same as the local address of the socket that dialed the
|
||||
// address, with the service GUID.Data1 incremented, but othertimes is different.
|
||||
// todo: does the local address matter? is the listener's address or the actual address appropriate?
|
||||
conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
|
||||
conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
|
||||
|
||||
// initialize the accepted socket and update its properties with those of the listening socket
|
||||
if err = windows.Setsockopt(sock.handle,
|
||||
windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
|
||||
(*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
|
||||
return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
|
||||
}
|
||||
|
||||
sock = nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close closes the listener, causing any pending Accept calls to fail.
|
||||
func (l *HvsockListener) Close() error {
|
||||
return l.sock.Close()
|
||||
}
|
||||
|
||||
// HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]).
|
||||
type HvsockDialer struct {
|
||||
// Deadline is the time the Dial operation must connect before erroring.
|
||||
Deadline time.Time
|
||||
|
||||
// Retries is the number of additional connects to try if the connection times out, is refused,
|
||||
// or the host is unreachable
|
||||
Retries uint
|
||||
|
||||
// RetryWait is the time to wait after a connection error to retry
|
||||
RetryWait time.Duration
|
||||
|
||||
rt *time.Timer // redial wait timer
|
||||
}
|
||||
|
||||
// Dial the Hyper-V socket at addr.
|
||||
//
|
||||
// See [HvsockDialer.Dial] for more information.
|
||||
func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
|
||||
return (&HvsockDialer{}).Dial(ctx, addr)
|
||||
}
|
||||
|
||||
// Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful.
|
||||
// Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between
|
||||
// retries.
|
||||
//
|
||||
// Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx.
|
||||
func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
|
||||
op := "dial"
|
||||
// create the conn early to use opErr()
|
||||
conn = &HvsockConn{
|
||||
remote: *addr,
|
||||
}
|
||||
|
||||
if !d.Deadline.IsZero() {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(ctx, d.Deadline)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// preemptive timeout/cancellation check
|
||||
if err = ctx.Err(); err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
|
||||
sock, err := newHVSocket()
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
defer func() {
|
||||
if sock != nil {
|
||||
sock.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
sa := addr.raw()
|
||||
err = socket.Bind(sock.handle, &sa)
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("bind", err))
|
||||
}
|
||||
|
||||
c, err := sock.prepareIO()
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
defer sock.wg.Done()
|
||||
var bytes uint32
|
||||
for i := uint(0); i <= d.Retries; i++ {
|
||||
err = socket.ConnectEx(
|
||||
sock.handle,
|
||||
&sa,
|
||||
nil, // sendBuf
|
||||
0, // sendDataLen
|
||||
&bytes,
|
||||
(*windows.Overlapped)(unsafe.Pointer(&c.o)))
|
||||
_, err = sock.asyncIO(c, nil, bytes, err)
|
||||
if i < d.Retries && canRedial(err) {
|
||||
if err = d.redialWait(ctx); err == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
|
||||
}
|
||||
|
||||
// update the connection properties, so shutdown can be used
|
||||
if err = windows.Setsockopt(
|
||||
sock.handle,
|
||||
windows.SOL_SOCKET,
|
||||
windows.SO_UPDATE_CONNECT_CONTEXT,
|
||||
nil, // optvalue
|
||||
0, // optlen
|
||||
); err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
|
||||
}
|
||||
|
||||
// get the local name
|
||||
var sal rawHvsockAddr
|
||||
err = socket.GetSockName(sock.handle, &sal)
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
|
||||
}
|
||||
conn.local.fromRaw(&sal)
|
||||
|
||||
// one last check for timeout, since asyncIO doesn't check the context
|
||||
if err = ctx.Err(); err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
|
||||
conn.sock = sock
|
||||
sock = nil
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// redialWait waits before attempting to redial, resetting the timer as appropriate.
|
||||
func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
|
||||
if d.RetryWait == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if d.rt == nil {
|
||||
d.rt = time.NewTimer(d.RetryWait)
|
||||
} else {
|
||||
// should already be stopped and drained
|
||||
d.rt.Reset(d.RetryWait)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-d.rt.C:
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop and drain the timer
|
||||
if !d.rt.Stop() {
|
||||
<-d.rt.C
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// assumes error is a plain, unwrapped windows.Errno provided by direct syscall.
|
||||
func canRedial(err error) bool {
|
||||
//nolint:errorlint // guaranteed to be an Errno
|
||||
switch err {
|
||||
case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
|
||||
windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) opErr(op string, err error) error {
|
||||
// translate from "file closed" to "socket closed"
|
||||
if errors.Is(err, ErrFileClosed) {
|
||||
err = socket.ErrSocketClosed
|
||||
}
|
||||
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) Read(b []byte) (int, error) {
|
||||
c, err := conn.sock.prepareIO()
|
||||
if err != nil {
|
||||
return 0, conn.opErr("read", err)
|
||||
}
|
||||
defer conn.sock.wg.Done()
|
||||
buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||
var flags, bytes uint32
|
||||
err = windows.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
|
||||
n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
|
||||
if err != nil {
|
||||
var eno windows.Errno
|
||||
if errors.As(err, &eno) {
|
||||
err = os.NewSyscallError("wsarecv", eno)
|
||||
}
|
||||
return 0, conn.opErr("read", err)
|
||||
} else if n == 0 {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) Write(b []byte) (int, error) {
|
||||
t := 0
|
||||
for len(b) != 0 {
|
||||
n, err := conn.write(b)
|
||||
if err != nil {
|
||||
return t + n, err
|
||||
}
|
||||
t += n
|
||||
b = b[n:]
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) write(b []byte) (int, error) {
|
||||
c, err := conn.sock.prepareIO()
|
||||
if err != nil {
|
||||
return 0, conn.opErr("write", err)
|
||||
}
|
||||
defer conn.sock.wg.Done()
|
||||
buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||
var bytes uint32
|
||||
err = windows.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
|
||||
n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
|
||||
if err != nil {
|
||||
var eno windows.Errno
|
||||
if errors.As(err, &eno) {
|
||||
err = os.NewSyscallError("wsasend", eno)
|
||||
}
|
||||
return 0, conn.opErr("write", err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close closes the socket connection, failing any pending read or write calls.
|
||||
func (conn *HvsockConn) Close() error {
|
||||
return conn.sock.Close()
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) IsClosed() bool {
|
||||
return conn.sock.IsClosed()
|
||||
}
|
||||
|
||||
// shutdown disables sending or receiving on a socket.
|
||||
func (conn *HvsockConn) shutdown(how int) error {
|
||||
if conn.IsClosed() {
|
||||
return socket.ErrSocketClosed
|
||||
}
|
||||
|
||||
err := windows.Shutdown(conn.sock.handle, how)
|
||||
if err != nil {
|
||||
// If the connection was closed, shutdowns fail with "not connected"
|
||||
if errors.Is(err, windows.WSAENOTCONN) ||
|
||||
errors.Is(err, windows.WSAESHUTDOWN) {
|
||||
err = socket.ErrSocketClosed
|
||||
}
|
||||
return os.NewSyscallError("shutdown", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseRead shuts down the read end of the socket, preventing future read operations.
|
||||
func (conn *HvsockConn) CloseRead() error {
|
||||
err := conn.shutdown(windows.SHUT_RD)
|
||||
if err != nil {
|
||||
return conn.opErr("closeread", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWrite shuts down the write end of the socket, preventing future write operations and
|
||||
// notifying the other endpoint that no more data will be written.
|
||||
func (conn *HvsockConn) CloseWrite() error {
|
||||
err := conn.shutdown(windows.SHUT_WR)
|
||||
if err != nil {
|
||||
return conn.opErr("closewrite", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address of the connection.
|
||||
func (conn *HvsockConn) LocalAddr() net.Addr {
|
||||
return &conn.local
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address of the connection.
|
||||
func (conn *HvsockConn) RemoteAddr() net.Addr {
|
||||
return &conn.remote
|
||||
}
|
||||
|
||||
// SetDeadline implements the net.Conn SetDeadline method.
|
||||
func (conn *HvsockConn) SetDeadline(t time.Time) error {
|
||||
// todo: implement `SetDeadline` for `win32File`
|
||||
if err := conn.SetReadDeadline(t); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
if err := conn.SetWriteDeadline(t); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements the net.Conn SetReadDeadline method.
|
||||
func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
|
||||
return conn.sock.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
|
||||
func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
|
||||
return conn.sock.SetWriteDeadline(t)
|
||||
}
|
||||
2
vendor/github.com/Microsoft/go-winio/internal/fs/doc.go
generated
vendored
Normal file
2
vendor/github.com/Microsoft/go-winio/internal/fs/doc.go
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
// This package contains Win32 filesystem functionality.
|
||||
package fs
|
||||
262
vendor/github.com/Microsoft/go-winio/internal/fs/fs.go
generated
vendored
Normal file
262
vendor/github.com/Microsoft/go-winio/internal/fs/fs.go
generated
vendored
Normal file
@@ -0,0 +1,262 @@
|
||||
//go:build windows
|
||||
|
||||
package fs
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/Microsoft/go-winio/internal/stringbuffer"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go fs.go
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew
|
||||
//sys CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
|
||||
|
||||
const NullHandle windows.Handle = 0
|
||||
|
||||
// AccessMask defines standard, specific, and generic rights.
|
||||
//
|
||||
// Used with CreateFile and NtCreateFile (and co.).
|
||||
//
|
||||
// Bitmask:
|
||||
// 3 3 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1
|
||||
// 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
|
||||
// +---------------+---------------+-------------------------------+
|
||||
// |G|G|G|G|Resvd|A| StandardRights| SpecificRights |
|
||||
// |R|W|E|A| |S| | |
|
||||
// +-+-------------+---------------+-------------------------------+
|
||||
//
|
||||
// GR Generic Read
|
||||
// GW Generic Write
|
||||
// GE Generic Exectue
|
||||
// GA Generic All
|
||||
// Resvd Reserved
|
||||
// AS Access Security System
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/secauthz/access-mask
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/fileio/file-access-rights-constants
|
||||
type AccessMask = windows.ACCESS_MASK
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// Not actually any.
|
||||
//
|
||||
// For CreateFile: "query certain metadata such as file, directory, or device attributes without accessing that file or device"
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew#parameters
|
||||
FILE_ANY_ACCESS AccessMask = 0
|
||||
|
||||
GENERIC_READ AccessMask = 0x8000_0000
|
||||
GENERIC_WRITE AccessMask = 0x4000_0000
|
||||
GENERIC_EXECUTE AccessMask = 0x2000_0000
|
||||
GENERIC_ALL AccessMask = 0x1000_0000
|
||||
ACCESS_SYSTEM_SECURITY AccessMask = 0x0100_0000
|
||||
|
||||
// Specific Object Access
|
||||
// from ntioapi.h
|
||||
|
||||
FILE_READ_DATA AccessMask = (0x0001) // file & pipe
|
||||
FILE_LIST_DIRECTORY AccessMask = (0x0001) // directory
|
||||
|
||||
FILE_WRITE_DATA AccessMask = (0x0002) // file & pipe
|
||||
FILE_ADD_FILE AccessMask = (0x0002) // directory
|
||||
|
||||
FILE_APPEND_DATA AccessMask = (0x0004) // file
|
||||
FILE_ADD_SUBDIRECTORY AccessMask = (0x0004) // directory
|
||||
FILE_CREATE_PIPE_INSTANCE AccessMask = (0x0004) // named pipe
|
||||
|
||||
FILE_READ_EA AccessMask = (0x0008) // file & directory
|
||||
FILE_READ_PROPERTIES AccessMask = FILE_READ_EA
|
||||
|
||||
FILE_WRITE_EA AccessMask = (0x0010) // file & directory
|
||||
FILE_WRITE_PROPERTIES AccessMask = FILE_WRITE_EA
|
||||
|
||||
FILE_EXECUTE AccessMask = (0x0020) // file
|
||||
FILE_TRAVERSE AccessMask = (0x0020) // directory
|
||||
|
||||
FILE_DELETE_CHILD AccessMask = (0x0040) // directory
|
||||
|
||||
FILE_READ_ATTRIBUTES AccessMask = (0x0080) // all
|
||||
|
||||
FILE_WRITE_ATTRIBUTES AccessMask = (0x0100) // all
|
||||
|
||||
FILE_ALL_ACCESS AccessMask = (STANDARD_RIGHTS_REQUIRED | SYNCHRONIZE | 0x1FF)
|
||||
FILE_GENERIC_READ AccessMask = (STANDARD_RIGHTS_READ | FILE_READ_DATA | FILE_READ_ATTRIBUTES | FILE_READ_EA | SYNCHRONIZE)
|
||||
FILE_GENERIC_WRITE AccessMask = (STANDARD_RIGHTS_WRITE | FILE_WRITE_DATA | FILE_WRITE_ATTRIBUTES | FILE_WRITE_EA | FILE_APPEND_DATA | SYNCHRONIZE)
|
||||
FILE_GENERIC_EXECUTE AccessMask = (STANDARD_RIGHTS_EXECUTE | FILE_READ_ATTRIBUTES | FILE_EXECUTE | SYNCHRONIZE)
|
||||
|
||||
SPECIFIC_RIGHTS_ALL AccessMask = 0x0000FFFF
|
||||
|
||||
// Standard Access
|
||||
// from ntseapi.h
|
||||
|
||||
DELETE AccessMask = 0x0001_0000
|
||||
READ_CONTROL AccessMask = 0x0002_0000
|
||||
WRITE_DAC AccessMask = 0x0004_0000
|
||||
WRITE_OWNER AccessMask = 0x0008_0000
|
||||
SYNCHRONIZE AccessMask = 0x0010_0000
|
||||
|
||||
STANDARD_RIGHTS_REQUIRED AccessMask = 0x000F_0000
|
||||
|
||||
STANDARD_RIGHTS_READ AccessMask = READ_CONTROL
|
||||
STANDARD_RIGHTS_WRITE AccessMask = READ_CONTROL
|
||||
STANDARD_RIGHTS_EXECUTE AccessMask = READ_CONTROL
|
||||
|
||||
STANDARD_RIGHTS_ALL AccessMask = 0x001F_0000
|
||||
)
|
||||
|
||||
type FileShareMode uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
FILE_SHARE_NONE FileShareMode = 0x00
|
||||
FILE_SHARE_READ FileShareMode = 0x01
|
||||
FILE_SHARE_WRITE FileShareMode = 0x02
|
||||
FILE_SHARE_DELETE FileShareMode = 0x04
|
||||
FILE_SHARE_VALID_FLAGS FileShareMode = 0x07
|
||||
)
|
||||
|
||||
type FileCreationDisposition uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// from winbase.h
|
||||
|
||||
CREATE_NEW FileCreationDisposition = 0x01
|
||||
CREATE_ALWAYS FileCreationDisposition = 0x02
|
||||
OPEN_EXISTING FileCreationDisposition = 0x03
|
||||
OPEN_ALWAYS FileCreationDisposition = 0x04
|
||||
TRUNCATE_EXISTING FileCreationDisposition = 0x05
|
||||
)
|
||||
|
||||
// Create disposition values for NtCreate*
|
||||
type NTFileCreationDisposition uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// From ntioapi.h
|
||||
|
||||
FILE_SUPERSEDE NTFileCreationDisposition = 0x00
|
||||
FILE_OPEN NTFileCreationDisposition = 0x01
|
||||
FILE_CREATE NTFileCreationDisposition = 0x02
|
||||
FILE_OPEN_IF NTFileCreationDisposition = 0x03
|
||||
FILE_OVERWRITE NTFileCreationDisposition = 0x04
|
||||
FILE_OVERWRITE_IF NTFileCreationDisposition = 0x05
|
||||
FILE_MAXIMUM_DISPOSITION NTFileCreationDisposition = 0x05
|
||||
)
|
||||
|
||||
// CreateFile and co. take flags or attributes together as one parameter.
|
||||
// Define alias until we can use generics to allow both
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants
|
||||
type FileFlagOrAttribute uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// from winnt.h
|
||||
|
||||
FILE_FLAG_WRITE_THROUGH FileFlagOrAttribute = 0x8000_0000
|
||||
FILE_FLAG_OVERLAPPED FileFlagOrAttribute = 0x4000_0000
|
||||
FILE_FLAG_NO_BUFFERING FileFlagOrAttribute = 0x2000_0000
|
||||
FILE_FLAG_RANDOM_ACCESS FileFlagOrAttribute = 0x1000_0000
|
||||
FILE_FLAG_SEQUENTIAL_SCAN FileFlagOrAttribute = 0x0800_0000
|
||||
FILE_FLAG_DELETE_ON_CLOSE FileFlagOrAttribute = 0x0400_0000
|
||||
FILE_FLAG_BACKUP_SEMANTICS FileFlagOrAttribute = 0x0200_0000
|
||||
FILE_FLAG_POSIX_SEMANTICS FileFlagOrAttribute = 0x0100_0000
|
||||
FILE_FLAG_OPEN_REPARSE_POINT FileFlagOrAttribute = 0x0020_0000
|
||||
FILE_FLAG_OPEN_NO_RECALL FileFlagOrAttribute = 0x0010_0000
|
||||
FILE_FLAG_FIRST_PIPE_INSTANCE FileFlagOrAttribute = 0x0008_0000
|
||||
)
|
||||
|
||||
// NtCreate* functions take a dedicated CreateOptions parameter.
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/Winternl/nf-winternl-ntcreatefile
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/devnotes/nt-create-named-pipe-file
|
||||
type NTCreateOptions uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// From ntioapi.h
|
||||
|
||||
FILE_DIRECTORY_FILE NTCreateOptions = 0x0000_0001
|
||||
FILE_WRITE_THROUGH NTCreateOptions = 0x0000_0002
|
||||
FILE_SEQUENTIAL_ONLY NTCreateOptions = 0x0000_0004
|
||||
FILE_NO_INTERMEDIATE_BUFFERING NTCreateOptions = 0x0000_0008
|
||||
|
||||
FILE_SYNCHRONOUS_IO_ALERT NTCreateOptions = 0x0000_0010
|
||||
FILE_SYNCHRONOUS_IO_NONALERT NTCreateOptions = 0x0000_0020
|
||||
FILE_NON_DIRECTORY_FILE NTCreateOptions = 0x0000_0040
|
||||
FILE_CREATE_TREE_CONNECTION NTCreateOptions = 0x0000_0080
|
||||
|
||||
FILE_COMPLETE_IF_OPLOCKED NTCreateOptions = 0x0000_0100
|
||||
FILE_NO_EA_KNOWLEDGE NTCreateOptions = 0x0000_0200
|
||||
FILE_DISABLE_TUNNELING NTCreateOptions = 0x0000_0400
|
||||
FILE_RANDOM_ACCESS NTCreateOptions = 0x0000_0800
|
||||
|
||||
FILE_DELETE_ON_CLOSE NTCreateOptions = 0x0000_1000
|
||||
FILE_OPEN_BY_FILE_ID NTCreateOptions = 0x0000_2000
|
||||
FILE_OPEN_FOR_BACKUP_INTENT NTCreateOptions = 0x0000_4000
|
||||
FILE_NO_COMPRESSION NTCreateOptions = 0x0000_8000
|
||||
)
|
||||
|
||||
type FileSQSFlag = FileFlagOrAttribute
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// from winbase.h
|
||||
|
||||
SECURITY_ANONYMOUS FileSQSFlag = FileSQSFlag(SecurityAnonymous << 16)
|
||||
SECURITY_IDENTIFICATION FileSQSFlag = FileSQSFlag(SecurityIdentification << 16)
|
||||
SECURITY_IMPERSONATION FileSQSFlag = FileSQSFlag(SecurityImpersonation << 16)
|
||||
SECURITY_DELEGATION FileSQSFlag = FileSQSFlag(SecurityDelegation << 16)
|
||||
|
||||
SECURITY_SQOS_PRESENT FileSQSFlag = 0x0010_0000
|
||||
SECURITY_VALID_SQOS_FLAGS FileSQSFlag = 0x001F_0000
|
||||
)
|
||||
|
||||
// GetFinalPathNameByHandle flags
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfinalpathnamebyhandlew#parameters
|
||||
type GetFinalPathFlag uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
GetFinalPathDefaultFlag GetFinalPathFlag = 0x0
|
||||
|
||||
FILE_NAME_NORMALIZED GetFinalPathFlag = 0x0
|
||||
FILE_NAME_OPENED GetFinalPathFlag = 0x8
|
||||
|
||||
VOLUME_NAME_DOS GetFinalPathFlag = 0x0
|
||||
VOLUME_NAME_GUID GetFinalPathFlag = 0x1
|
||||
VOLUME_NAME_NT GetFinalPathFlag = 0x2
|
||||
VOLUME_NAME_NONE GetFinalPathFlag = 0x4
|
||||
)
|
||||
|
||||
// getFinalPathNameByHandle facilitates calling the Windows API GetFinalPathNameByHandle
|
||||
// with the given handle and flags. It transparently takes care of creating a buffer of the
|
||||
// correct size for the call.
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfinalpathnamebyhandlew
|
||||
func GetFinalPathNameByHandle(h windows.Handle, flags GetFinalPathFlag) (string, error) {
|
||||
b := stringbuffer.NewWString()
|
||||
//TODO: can loop infinitely if Win32 keeps returning the same (or a larger) n?
|
||||
for {
|
||||
n, err := windows.GetFinalPathNameByHandle(h, b.Pointer(), b.Cap(), uint32(flags))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// If the buffer wasn't large enough, n will be the total size needed (including null terminator).
|
||||
// Resize and try again.
|
||||
if n > b.Cap() {
|
||||
b.ResizeTo(n)
|
||||
continue
|
||||
}
|
||||
// If the buffer is large enough, n will be the size not including the null terminator.
|
||||
// Convert to a Go string and return.
|
||||
return b.String(), nil
|
||||
}
|
||||
}
|
||||
12
vendor/github.com/Microsoft/go-winio/internal/fs/security.go
generated
vendored
Normal file
12
vendor/github.com/Microsoft/go-winio/internal/fs/security.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
package fs
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level
|
||||
type SecurityImpersonationLevel int32 // C default enums underlying type is `int`, which is Go `int32`
|
||||
|
||||
// Impersonation levels
|
||||
const (
|
||||
SecurityAnonymous SecurityImpersonationLevel = 0
|
||||
SecurityIdentification SecurityImpersonationLevel = 1
|
||||
SecurityImpersonation SecurityImpersonationLevel = 2
|
||||
SecurityDelegation SecurityImpersonationLevel = 3
|
||||
)
|
||||
61
vendor/github.com/Microsoft/go-winio/internal/fs/zsyscall_windows.go
generated
vendored
Normal file
61
vendor/github.com/Microsoft/go-winio/internal/fs/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
//go:build windows
|
||||
|
||||
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||
|
||||
package fs
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
|
||||
procCreateFileW = modkernel32.NewProc("CreateFileW")
|
||||
)
|
||||
|
||||
func CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _CreateFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
||||
}
|
||||
|
||||
func _CreateFile(name *uint16, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.SyscallN(procCreateFileW.Addr(), uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile))
|
||||
handle = windows.Handle(r0)
|
||||
if handle == windows.InvalidHandle {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
20
vendor/github.com/Microsoft/go-winio/internal/socket/rawaddr.go
generated
vendored
Normal file
20
vendor/github.com/Microsoft/go-winio/internal/socket/rawaddr.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
package socket
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// RawSockaddr allows structs to be used with [Bind] and [ConnectEx]. The
|
||||
// struct must meet the Win32 sockaddr requirements specified here:
|
||||
// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
|
||||
//
|
||||
// Specifically, the struct size must be least larger than an int16 (unsigned short)
|
||||
// for the address family.
|
||||
type RawSockaddr interface {
|
||||
// Sockaddr returns a pointer to the RawSockaddr and its struct size, allowing
|
||||
// for the RawSockaddr's data to be overwritten by syscalls (if necessary).
|
||||
//
|
||||
// It is the callers responsibility to validate that the values are valid; invalid
|
||||
// pointers or size can cause a panic.
|
||||
Sockaddr() (unsafe.Pointer, int32, error)
|
||||
}
|
||||
177
vendor/github.com/Microsoft/go-winio/internal/socket/socket.go
generated
vendored
Normal file
177
vendor/github.com/Microsoft/go-winio/internal/socket/socket.go
generated
vendored
Normal file
@@ -0,0 +1,177 @@
|
||||
//go:build windows
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/Microsoft/go-winio/pkg/guid"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
|
||||
|
||||
//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
|
||||
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
|
||||
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
|
||||
|
||||
const socketError = uintptr(^uint32(0))
|
||||
|
||||
var (
|
||||
// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
|
||||
|
||||
ErrBufferSize = errors.New("buffer size")
|
||||
ErrAddrFamily = errors.New("address family")
|
||||
ErrInvalidPointer = errors.New("invalid pointer")
|
||||
ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
|
||||
)
|
||||
|
||||
// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
|
||||
|
||||
// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
|
||||
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
|
||||
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
|
||||
ptr, l, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||
}
|
||||
|
||||
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
|
||||
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
|
||||
return getsockname(s, ptr, &l)
|
||||
}
|
||||
|
||||
// GetPeerName returns the remote address the socket is connected to.
|
||||
//
|
||||
// See [GetSockName] for more information.
|
||||
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
|
||||
ptr, l, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||
}
|
||||
|
||||
return getpeername(s, ptr, &l)
|
||||
}
|
||||
|
||||
func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
|
||||
ptr, l, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||
}
|
||||
|
||||
return bind(s, ptr, l)
|
||||
}
|
||||
|
||||
// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
|
||||
// their sockaddr interface, so they cannot be used with HvsockAddr
|
||||
// Replicate functionality here from
|
||||
// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
|
||||
|
||||
// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
|
||||
// runtime via a WSAIoctl call:
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
|
||||
|
||||
type runtimeFunc struct {
|
||||
id guid.GUID
|
||||
once sync.Once
|
||||
addr uintptr
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *runtimeFunc) Load() error {
|
||||
f.once.Do(func() {
|
||||
var s windows.Handle
|
||||
s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
|
||||
if f.err != nil {
|
||||
return
|
||||
}
|
||||
defer windows.CloseHandle(s) //nolint:errcheck
|
||||
|
||||
var n uint32
|
||||
f.err = windows.WSAIoctl(s,
|
||||
windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
|
||||
(*byte)(unsafe.Pointer(&f.id)),
|
||||
uint32(unsafe.Sizeof(f.id)),
|
||||
(*byte)(unsafe.Pointer(&f.addr)),
|
||||
uint32(unsafe.Sizeof(f.addr)),
|
||||
&n,
|
||||
nil, // overlapped
|
||||
0, // completionRoutine
|
||||
)
|
||||
})
|
||||
return f.err
|
||||
}
|
||||
|
||||
var (
|
||||
// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
|
||||
WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
|
||||
Data1: 0x25a207b9,
|
||||
Data2: 0xddf3,
|
||||
Data3: 0x4660,
|
||||
Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
|
||||
}
|
||||
|
||||
connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
|
||||
)
|
||||
|
||||
func ConnectEx(
|
||||
fd windows.Handle,
|
||||
rsa RawSockaddr,
|
||||
sendBuf *byte,
|
||||
sendDataLen uint32,
|
||||
bytesSent *uint32,
|
||||
overlapped *windows.Overlapped,
|
||||
) error {
|
||||
if err := connectExFunc.Load(); err != nil {
|
||||
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
|
||||
}
|
||||
ptr, n, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
|
||||
}
|
||||
|
||||
// BOOL LpfnConnectex(
|
||||
// [in] SOCKET s,
|
||||
// [in] const sockaddr *name,
|
||||
// [in] int namelen,
|
||||
// [in, optional] PVOID lpSendBuffer,
|
||||
// [in] DWORD dwSendDataLength,
|
||||
// [out] LPDWORD lpdwBytesSent,
|
||||
// [in] LPOVERLAPPED lpOverlapped
|
||||
// )
|
||||
|
||||
func connectEx(
|
||||
s windows.Handle,
|
||||
name unsafe.Pointer,
|
||||
namelen int32,
|
||||
sendBuf *byte,
|
||||
sendDataLen uint32,
|
||||
bytesSent *uint32,
|
||||
overlapped *windows.Overlapped,
|
||||
) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(connectExFunc.addr,
|
||||
uintptr(s),
|
||||
uintptr(name),
|
||||
uintptr(namelen),
|
||||
uintptr(unsafe.Pointer(sendBuf)),
|
||||
uintptr(sendDataLen),
|
||||
uintptr(unsafe.Pointer(bytesSent)),
|
||||
uintptr(unsafe.Pointer(overlapped)),
|
||||
)
|
||||
|
||||
if r1 == 0 {
|
||||
if e1 != 0 {
|
||||
err = error(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
69
vendor/github.com/Microsoft/go-winio/internal/socket/zsyscall_windows.go
generated
vendored
Normal file
69
vendor/github.com/Microsoft/go-winio/internal/socket/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
//go:build windows
|
||||
|
||||
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
procbind = modws2_32.NewProc("bind")
|
||||
procgetpeername = modws2_32.NewProc("getpeername")
|
||||
procgetsockname = modws2_32.NewProc("getsockname")
|
||||
)
|
||||
|
||||
func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procbind.Addr(), uintptr(s), uintptr(name), uintptr(namelen))
|
||||
if r1 == socketError {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procgetpeername.Addr(), uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
|
||||
if r1 == socketError {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procgetsockname.Addr(), uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
|
||||
if r1 == socketError {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
132
vendor/github.com/Microsoft/go-winio/internal/stringbuffer/wstring.go
generated
vendored
Normal file
132
vendor/github.com/Microsoft/go-winio/internal/stringbuffer/wstring.go
generated
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
package stringbuffer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unicode/utf16"
|
||||
)
|
||||
|
||||
// TODO: worth exporting and using in mkwinsyscall?
|
||||
|
||||
// Uint16BufferSize is the buffer size in the pool, chosen somewhat arbitrarily to accommodate
|
||||
// large path strings:
|
||||
// MAX_PATH (260) + size of volume GUID prefix (49) + null terminator = 310.
|
||||
const MinWStringCap = 310
|
||||
|
||||
// use *[]uint16 since []uint16 creates an extra allocation where the slice header
|
||||
// is copied to heap and then referenced via pointer in the interface header that sync.Pool
|
||||
// stores.
|
||||
var pathPool = sync.Pool{ // if go1.18+ adds Pool[T], use that to store []uint16 directly
|
||||
New: func() interface{} {
|
||||
b := make([]uint16, MinWStringCap)
|
||||
return &b
|
||||
},
|
||||
}
|
||||
|
||||
func newBuffer() []uint16 { return *(pathPool.Get().(*[]uint16)) }
|
||||
|
||||
// freeBuffer copies the slice header data, and puts a pointer to that in the pool.
|
||||
// This avoids taking a pointer to the slice header in WString, which can be set to nil.
|
||||
func freeBuffer(b []uint16) { pathPool.Put(&b) }
|
||||
|
||||
// WString is a wide string buffer ([]uint16) meant for storing UTF-16 encoded strings
|
||||
// for interacting with Win32 APIs.
|
||||
// Sizes are specified as uint32 and not int.
|
||||
//
|
||||
// It is not thread safe.
|
||||
type WString struct {
|
||||
// type-def allows casting to []uint16 directly, use struct to prevent that and allow adding fields in the future.
|
||||
|
||||
// raw buffer
|
||||
b []uint16
|
||||
}
|
||||
|
||||
// NewWString returns a [WString] allocated from a shared pool with an
|
||||
// initial capacity of at least [MinWStringCap].
|
||||
// Since the buffer may have been previously used, its contents are not guaranteed to be empty.
|
||||
//
|
||||
// The buffer should be freed via [WString.Free]
|
||||
func NewWString() *WString {
|
||||
return &WString{
|
||||
b: newBuffer(),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *WString) Free() {
|
||||
if b.empty() {
|
||||
return
|
||||
}
|
||||
freeBuffer(b.b)
|
||||
b.b = nil
|
||||
}
|
||||
|
||||
// ResizeTo grows the buffer to at least c and returns the new capacity, freeing the
|
||||
// previous buffer back into pool.
|
||||
func (b *WString) ResizeTo(c uint32) uint32 {
|
||||
// already sufficient (or n is 0)
|
||||
if c <= b.Cap() {
|
||||
return b.Cap()
|
||||
}
|
||||
|
||||
if c <= MinWStringCap {
|
||||
c = MinWStringCap
|
||||
}
|
||||
// allocate at-least double buffer size, as is done in [bytes.Buffer] and other places
|
||||
if c <= 2*b.Cap() {
|
||||
c = 2 * b.Cap()
|
||||
}
|
||||
|
||||
b2 := make([]uint16, c)
|
||||
if !b.empty() {
|
||||
copy(b2, b.b)
|
||||
freeBuffer(b.b)
|
||||
}
|
||||
b.b = b2
|
||||
return c
|
||||
}
|
||||
|
||||
// Buffer returns the underlying []uint16 buffer.
|
||||
func (b *WString) Buffer() []uint16 {
|
||||
if b.empty() {
|
||||
return nil
|
||||
}
|
||||
return b.b
|
||||
}
|
||||
|
||||
// Pointer returns a pointer to the first uint16 in the buffer.
|
||||
// If the [WString.Free] has already been called, the pointer will be nil.
|
||||
func (b *WString) Pointer() *uint16 {
|
||||
if b.empty() {
|
||||
return nil
|
||||
}
|
||||
return &b.b[0]
|
||||
}
|
||||
|
||||
// String returns the returns the UTF-8 encoding of the UTF-16 string in the buffer.
|
||||
//
|
||||
// It assumes that the data is null-terminated.
|
||||
func (b *WString) String() string {
|
||||
// Using [windows.UTF16ToString] would require importing "golang.org/x/sys/windows"
|
||||
// and would make this code Windows-only, which makes no sense.
|
||||
// So copy UTF16ToString code into here.
|
||||
// If other windows-specific code is added, switch to [windows.UTF16ToString]
|
||||
|
||||
s := b.b
|
||||
for i, v := range s {
|
||||
if v == 0 {
|
||||
s = s[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
return string(utf16.Decode(s))
|
||||
}
|
||||
|
||||
// Cap returns the underlying buffer capacity.
|
||||
func (b *WString) Cap() uint32 {
|
||||
if b.empty() {
|
||||
return 0
|
||||
}
|
||||
return b.cap()
|
||||
}
|
||||
|
||||
func (b *WString) cap() uint32 { return uint32(cap(b.b)) }
|
||||
func (b *WString) empty() bool { return b == nil || b.cap() == 0 }
|
||||
586
vendor/github.com/Microsoft/go-winio/pipe.go
generated
vendored
Normal file
586
vendor/github.com/Microsoft/go-winio/pipe.go
generated
vendored
Normal file
@@ -0,0 +1,586 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/Microsoft/go-winio/internal/fs"
|
||||
)
|
||||
|
||||
//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
|
||||
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
|
||||
//sys disconnectNamedPipe(pipe windows.Handle) (err error) = DisconnectNamedPipe
|
||||
//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
||||
//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
||||
//sys ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) = ntdll.NtCreateNamedPipeFile
|
||||
//sys rtlNtStatusToDosError(status ntStatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
||||
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U
|
||||
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl
|
||||
|
||||
type PipeConn interface {
|
||||
net.Conn
|
||||
Disconnect() error
|
||||
Flush() error
|
||||
}
|
||||
|
||||
// type aliases for mkwinsyscall code
|
||||
type (
|
||||
ntAccessMask = fs.AccessMask
|
||||
ntFileShareMode = fs.FileShareMode
|
||||
ntFileCreationDisposition = fs.NTFileCreationDisposition
|
||||
ntFileOptions = fs.NTCreateOptions
|
||||
)
|
||||
|
||||
type ioStatusBlock struct {
|
||||
Status, Information uintptr
|
||||
}
|
||||
|
||||
// typedef struct _OBJECT_ATTRIBUTES {
|
||||
// ULONG Length;
|
||||
// HANDLE RootDirectory;
|
||||
// PUNICODE_STRING ObjectName;
|
||||
// ULONG Attributes;
|
||||
// PVOID SecurityDescriptor;
|
||||
// PVOID SecurityQualityOfService;
|
||||
// } OBJECT_ATTRIBUTES;
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/ntdef/ns-ntdef-_object_attributes
|
||||
type objectAttributes struct {
|
||||
Length uintptr
|
||||
RootDirectory uintptr
|
||||
ObjectName *unicodeString
|
||||
Attributes uintptr
|
||||
SecurityDescriptor *securityDescriptor
|
||||
SecurityQoS uintptr
|
||||
}
|
||||
|
||||
type unicodeString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer uintptr
|
||||
}
|
||||
|
||||
// typedef struct _SECURITY_DESCRIPTOR {
|
||||
// BYTE Revision;
|
||||
// BYTE Sbz1;
|
||||
// SECURITY_DESCRIPTOR_CONTROL Control;
|
||||
// PSID Owner;
|
||||
// PSID Group;
|
||||
// PACL Sacl;
|
||||
// PACL Dacl;
|
||||
// } SECURITY_DESCRIPTOR, *PISECURITY_DESCRIPTOR;
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-security_descriptor
|
||||
type securityDescriptor struct {
|
||||
Revision byte
|
||||
Sbz1 byte
|
||||
Control uint16
|
||||
Owner uintptr
|
||||
Group uintptr
|
||||
Sacl uintptr //revive:disable-line:var-naming SACL, not Sacl
|
||||
Dacl uintptr //revive:disable-line:var-naming DACL, not Dacl
|
||||
}
|
||||
|
||||
type ntStatus int32
|
||||
|
||||
func (status ntStatus) Err() error {
|
||||
if status >= 0 {
|
||||
return nil
|
||||
}
|
||||
return rtlNtStatusToDosError(status)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
|
||||
ErrPipeListenerClosed = net.ErrClosed
|
||||
|
||||
errPipeWriteClosed = errors.New("pipe has been closed for write")
|
||||
)
|
||||
|
||||
type win32Pipe struct {
|
||||
*win32File
|
||||
path string
|
||||
}
|
||||
|
||||
var _ PipeConn = (*win32Pipe)(nil)
|
||||
|
||||
type win32MessageBytePipe struct {
|
||||
win32Pipe
|
||||
writeClosed bool
|
||||
readEOF bool
|
||||
}
|
||||
|
||||
type pipeAddress string
|
||||
|
||||
func (f *win32Pipe) LocalAddr() net.Addr {
|
||||
return pipeAddress(f.path)
|
||||
}
|
||||
|
||||
func (f *win32Pipe) RemoteAddr() net.Addr {
|
||||
return pipeAddress(f.path)
|
||||
}
|
||||
|
||||
func (f *win32Pipe) SetDeadline(t time.Time) error {
|
||||
if err := f.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (f *win32Pipe) Disconnect() error {
|
||||
return disconnectNamedPipe(f.win32File.handle)
|
||||
}
|
||||
|
||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||
func (f *win32MessageBytePipe) CloseWrite() error {
|
||||
if f.writeClosed {
|
||||
return errPipeWriteClosed
|
||||
}
|
||||
err := f.win32File.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = f.win32File.Write(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.writeClosed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||
// they are used to implement CloseWrite().
|
||||
func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
|
||||
if f.writeClosed {
|
||||
return 0, errPipeWriteClosed
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return f.win32File.Write(b)
|
||||
}
|
||||
|
||||
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
|
||||
// mode pipe will return io.EOF, as will all subsequent reads.
|
||||
func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
|
||||
if f.readEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err := f.win32File.Read(b)
|
||||
if err == io.EOF { //nolint:errorlint
|
||||
// If this was the result of a zero-byte read, then
|
||||
// it is possible that the read was due to a zero-size
|
||||
// message. Since we are simulating CloseWrite with a
|
||||
// zero-byte message, ensure that all future Read() calls
|
||||
// also return EOF.
|
||||
f.readEOF = true
|
||||
} else if err == windows.ERROR_MORE_DATA { //nolint:errorlint // err is Errno
|
||||
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||
// and the message still has more bytes. Treat this as a success, since
|
||||
// this package presents all named pipes as byte streams.
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (pipeAddress) Network() string {
|
||||
return "pipe"
|
||||
}
|
||||
|
||||
func (s pipeAddress) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
||||
func tryDialPipe(ctx context.Context, path *string, access fs.AccessMask, impLevel PipeImpLevel) (windows.Handle, error) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return windows.Handle(0), ctx.Err()
|
||||
default:
|
||||
h, err := fs.CreateFile(*path,
|
||||
access,
|
||||
0, // mode
|
||||
nil, // security attributes
|
||||
fs.OPEN_EXISTING,
|
||||
fs.FILE_FLAG_OVERLAPPED|fs.SECURITY_SQOS_PRESENT|fs.FileSQSFlag(impLevel),
|
||||
0, // template file handle
|
||||
)
|
||||
if err == nil {
|
||||
return h, nil
|
||||
}
|
||||
if err != windows.ERROR_PIPE_BUSY { //nolint:errorlint // err is Errno
|
||||
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||
}
|
||||
// Wait 10 msec and try again. This is a rather simplistic
|
||||
// view, as we always try each 10 milliseconds.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DialPipe connects to a named pipe by path, timing out if the connection
|
||||
// takes longer than the specified duration. If timeout is nil, then we use
|
||||
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
||||
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
||||
var absTimeout time.Time
|
||||
if timeout != nil {
|
||||
absTimeout = time.Now().Add(*timeout)
|
||||
} else {
|
||||
absTimeout = time.Now().Add(2 * time.Second)
|
||||
}
|
||||
ctx, cancel := context.WithDeadline(context.Background(), absTimeout)
|
||||
defer cancel()
|
||||
conn, err := DialPipeContext(ctx, path)
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
||||
// cancellation or timeout.
|
||||
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
|
||||
return DialPipeAccess(ctx, path, uint32(fs.GENERIC_READ|fs.GENERIC_WRITE))
|
||||
}
|
||||
|
||||
// PipeImpLevel is an enumeration of impersonation levels that may be set
|
||||
// when calling DialPipeAccessImpersonation.
|
||||
type PipeImpLevel uint32
|
||||
|
||||
const (
|
||||
PipeImpLevelAnonymous = PipeImpLevel(fs.SECURITY_ANONYMOUS)
|
||||
PipeImpLevelIdentification = PipeImpLevel(fs.SECURITY_IDENTIFICATION)
|
||||
PipeImpLevelImpersonation = PipeImpLevel(fs.SECURITY_IMPERSONATION)
|
||||
PipeImpLevelDelegation = PipeImpLevel(fs.SECURITY_DELEGATION)
|
||||
)
|
||||
|
||||
// DialPipeAccess attempts to connect to a named pipe by `path` with `access` until `ctx`
|
||||
// cancellation or timeout.
|
||||
func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn, error) {
|
||||
return DialPipeAccessImpLevel(ctx, path, access, PipeImpLevelAnonymous)
|
||||
}
|
||||
|
||||
// DialPipeAccessImpLevel attempts to connect to a named pipe by `path` with
|
||||
// `access` at `impLevel` until `ctx` cancellation or timeout. The other
|
||||
// DialPipe* implementations use PipeImpLevelAnonymous.
|
||||
func DialPipeAccessImpLevel(ctx context.Context, path string, access uint32, impLevel PipeImpLevel) (net.Conn, error) {
|
||||
var err error
|
||||
var h windows.Handle
|
||||
h, err = tryDialPipe(ctx, &path, fs.AccessMask(access), impLevel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var flags uint32
|
||||
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
windows.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the pipe is in message mode, return a message byte pipe, which
|
||||
// supports CloseWrite().
|
||||
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
|
||||
return &win32MessageBytePipe{
|
||||
win32Pipe: win32Pipe{win32File: f, path: path},
|
||||
}, nil
|
||||
}
|
||||
return &win32Pipe{win32File: f, path: path}, nil
|
||||
}
|
||||
|
||||
type acceptResponse struct {
|
||||
f *win32File
|
||||
err error
|
||||
}
|
||||
|
||||
type win32PipeListener struct {
|
||||
firstHandle windows.Handle
|
||||
path string
|
||||
config PipeConfig
|
||||
acceptCh chan (chan acceptResponse)
|
||||
closeCh chan int
|
||||
doneCh chan int
|
||||
}
|
||||
|
||||
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (windows.Handle, error) {
|
||||
path16, err := windows.UTF16FromString(path)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
var oa objectAttributes
|
||||
oa.Length = unsafe.Sizeof(oa)
|
||||
|
||||
var ntPath unicodeString
|
||||
if err := rtlDosPathNameToNtPathName(&path16[0],
|
||||
&ntPath,
|
||||
0,
|
||||
0,
|
||||
).Err(); err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(ntPath.Buffer)) //nolint:errcheck
|
||||
oa.ObjectName = &ntPath
|
||||
oa.Attributes = windows.OBJ_CASE_INSENSITIVE
|
||||
|
||||
// The security descriptor is only needed for the first pipe.
|
||||
if first {
|
||||
if sd != nil {
|
||||
//todo: does `sdb` need to be allocated on the heap, or can go allocate it?
|
||||
l := uint32(len(sd))
|
||||
sdb, err := windows.LocalAlloc(0, l)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("LocalAlloc for security descriptor with of length %d: %w", l, err)
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(sdb)) //nolint:errcheck
|
||||
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
|
||||
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
|
||||
} else {
|
||||
// Construct the default named pipe security descriptor.
|
||||
var dacl uintptr
|
||||
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
||||
return 0, fmt.Errorf("getting default named pipe ACL: %w", err)
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(dacl)) //nolint:errcheck
|
||||
|
||||
sdb := &securityDescriptor{
|
||||
Revision: 1,
|
||||
Control: windows.SE_DACL_PRESENT,
|
||||
Dacl: dacl,
|
||||
}
|
||||
oa.SecurityDescriptor = sdb
|
||||
}
|
||||
}
|
||||
|
||||
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
|
||||
if c.MessageMode {
|
||||
typ |= windows.FILE_PIPE_MESSAGE_TYPE
|
||||
}
|
||||
|
||||
disposition := fs.FILE_OPEN
|
||||
access := fs.GENERIC_READ | fs.GENERIC_WRITE | fs.SYNCHRONIZE
|
||||
if first {
|
||||
disposition = fs.FILE_CREATE
|
||||
// By not asking for read or write access, the named pipe file system
|
||||
// will put this pipe into an initially disconnected state, blocking
|
||||
// client connections until the next call with first == false.
|
||||
access = fs.SYNCHRONIZE
|
||||
}
|
||||
|
||||
timeout := int64(-50 * 10000) // 50ms
|
||||
|
||||
var (
|
||||
h windows.Handle
|
||||
iosb ioStatusBlock
|
||||
)
|
||||
err = ntCreateNamedPipeFile(&h,
|
||||
access,
|
||||
&oa,
|
||||
&iosb,
|
||||
fs.FILE_SHARE_READ|fs.FILE_SHARE_WRITE,
|
||||
disposition,
|
||||
0,
|
||||
typ,
|
||||
0,
|
||||
0,
|
||||
0xffffffff,
|
||||
uint32(c.InputBufferSize),
|
||||
uint32(c.OutputBufferSize),
|
||||
&timeout).Err()
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
runtime.KeepAlive(ntPath)
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
|
||||
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
windows.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
|
||||
p, err := l.makeServerPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait for the client to connect.
|
||||
ch := make(chan error)
|
||||
go func(p *win32File) {
|
||||
ch <- connectPipe(p)
|
||||
}(p)
|
||||
|
||||
select {
|
||||
case err = <-ch:
|
||||
if err != nil {
|
||||
p.Close()
|
||||
p = nil
|
||||
}
|
||||
case <-l.closeCh:
|
||||
// Abort the connect request by closing the handle.
|
||||
p.Close()
|
||||
p = nil
|
||||
err = <-ch
|
||||
if err == nil || err == ErrFileClosed { //nolint:errorlint // err is Errno
|
||||
err = ErrPipeListenerClosed
|
||||
}
|
||||
}
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) listenerRoutine() {
|
||||
closed := false
|
||||
for !closed {
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
closed = true
|
||||
case responseCh := <-l.acceptCh:
|
||||
var (
|
||||
p *win32File
|
||||
err error
|
||||
)
|
||||
for {
|
||||
p, err = l.makeConnectedServerPipe()
|
||||
// If the connection was immediately closed by the client, try
|
||||
// again.
|
||||
if err != windows.ERROR_NO_DATA { //nolint:errorlint // err is Errno
|
||||
break
|
||||
}
|
||||
}
|
||||
responseCh <- acceptResponse{p, err}
|
||||
closed = err == ErrPipeListenerClosed //nolint:errorlint // err is Errno
|
||||
}
|
||||
}
|
||||
windows.Close(l.firstHandle)
|
||||
l.firstHandle = 0
|
||||
// Notify Close() and Accept() callers that the handle has been closed.
|
||||
close(l.doneCh)
|
||||
}
|
||||
|
||||
// PipeConfig contain configuration for the pipe listener.
|
||||
type PipeConfig struct {
|
||||
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
|
||||
SecurityDescriptor string
|
||||
|
||||
// MessageMode determines whether the pipe is in byte or message mode. In either
|
||||
// case the pipe is read in byte mode by default. The only practical difference in
|
||||
// this implementation is that CloseWrite() is only supported for message mode pipes;
|
||||
// CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
|
||||
// transferred to the reader (and returned as io.EOF in this implementation)
|
||||
// when the pipe is in message mode.
|
||||
MessageMode bool
|
||||
|
||||
// InputBufferSize specifies the size of the input buffer, in bytes.
|
||||
InputBufferSize int32
|
||||
|
||||
// OutputBufferSize specifies the size of the output buffer, in bytes.
|
||||
OutputBufferSize int32
|
||||
}
|
||||
|
||||
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
||||
// The pipe must not already exist.
|
||||
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
||||
var (
|
||||
sd []byte
|
||||
err error
|
||||
)
|
||||
if c == nil {
|
||||
c = &PipeConfig{}
|
||||
}
|
||||
if c.SecurityDescriptor != "" {
|
||||
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
h, err := makeServerPipeHandle(path, sd, c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := &win32PipeListener{
|
||||
firstHandle: h,
|
||||
path: path,
|
||||
config: *c,
|
||||
acceptCh: make(chan (chan acceptResponse)),
|
||||
closeCh: make(chan int),
|
||||
doneCh: make(chan int),
|
||||
}
|
||||
go l.listenerRoutine()
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func connectPipe(p *win32File) error {
|
||||
c, err := p.prepareIO()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer p.wg.Done()
|
||||
|
||||
err = connectNamedPipe(p.handle, &c.o)
|
||||
_, err = p.asyncIO(c, nil, 0, err)
|
||||
if err != nil && err != windows.ERROR_PIPE_CONNECTED { //nolint:errorlint // err is Errno
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) Accept() (net.Conn, error) {
|
||||
ch := make(chan acceptResponse)
|
||||
select {
|
||||
case l.acceptCh <- ch:
|
||||
response := <-ch
|
||||
err := response.err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if l.config.MessageMode {
|
||||
return &win32MessageBytePipe{
|
||||
win32Pipe: win32Pipe{win32File: response.f, path: l.path},
|
||||
}, nil
|
||||
}
|
||||
return &win32Pipe{win32File: response.f, path: l.path}, nil
|
||||
case <-l.doneCh:
|
||||
return nil, ErrPipeListenerClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) Close() error {
|
||||
select {
|
||||
case l.closeCh <- 1:
|
||||
<-l.doneCh
|
||||
case <-l.doneCh:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) Addr() net.Addr {
|
||||
return pipeAddress(l.path)
|
||||
}
|
||||
232
vendor/github.com/Microsoft/go-winio/pkg/guid/guid.go
generated
vendored
Normal file
232
vendor/github.com/Microsoft/go-winio/pkg/guid/guid.go
generated
vendored
Normal file
@@ -0,0 +1,232 @@
|
||||
// Package guid provides a GUID type. The backing structure for a GUID is
|
||||
// identical to that used by the golang.org/x/sys/windows GUID type.
|
||||
// There are two main binary encodings used for a GUID, the big-endian encoding,
|
||||
// and the Windows (mixed-endian) encoding. See here for details:
|
||||
// https://en.wikipedia.org/wiki/Universally_unique_identifier#Encoding
|
||||
package guid
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1" //nolint:gosec // not used for secure application
|
||||
"encoding"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
//go:generate go run golang.org/x/tools/cmd/stringer -type=Variant -trimprefix=Variant -linecomment
|
||||
|
||||
// Variant specifies which GUID variant (or "type") of the GUID. It determines
|
||||
// how the entirety of the rest of the GUID is interpreted.
|
||||
type Variant uint8
|
||||
|
||||
// The variants specified by RFC 4122 section 4.1.1.
|
||||
const (
|
||||
// VariantUnknown specifies a GUID variant which does not conform to one of
|
||||
// the variant encodings specified in RFC 4122.
|
||||
VariantUnknown Variant = iota
|
||||
VariantNCS
|
||||
VariantRFC4122 // RFC 4122
|
||||
VariantMicrosoft
|
||||
VariantFuture
|
||||
)
|
||||
|
||||
// Version specifies how the bits in the GUID were generated. For instance, a
|
||||
// version 4 GUID is randomly generated, and a version 5 is generated from the
|
||||
// hash of an input string.
|
||||
type Version uint8
|
||||
|
||||
func (v Version) String() string {
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
}
|
||||
|
||||
var _ = (encoding.TextMarshaler)(GUID{})
|
||||
var _ = (encoding.TextUnmarshaler)(&GUID{})
|
||||
|
||||
// NewV4 returns a new version 4 (pseudorandom) GUID, as defined by RFC 4122.
|
||||
func NewV4() (GUID, error) {
|
||||
var b [16]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return GUID{}, err
|
||||
}
|
||||
|
||||
g := FromArray(b)
|
||||
g.setVersion(4) // Version 4 means randomly generated.
|
||||
g.setVariant(VariantRFC4122)
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// NewV5 returns a new version 5 (generated from a string via SHA-1 hashing)
|
||||
// GUID, as defined by RFC 4122. The RFC is unclear on the encoding of the name,
|
||||
// and the sample code treats it as a series of bytes, so we do the same here.
|
||||
//
|
||||
// Some implementations, such as those found on Windows, treat the name as a
|
||||
// big-endian UTF16 stream of bytes. If that is desired, the string can be
|
||||
// encoded as such before being passed to this function.
|
||||
func NewV5(namespace GUID, name []byte) (GUID, error) {
|
||||
b := sha1.New() //nolint:gosec // not used for secure application
|
||||
namespaceBytes := namespace.ToArray()
|
||||
b.Write(namespaceBytes[:])
|
||||
b.Write(name)
|
||||
|
||||
a := [16]byte{}
|
||||
copy(a[:], b.Sum(nil))
|
||||
|
||||
g := FromArray(a)
|
||||
g.setVersion(5) // Version 5 means generated from a string.
|
||||
g.setVariant(VariantRFC4122)
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func fromArray(b [16]byte, order binary.ByteOrder) GUID {
|
||||
var g GUID
|
||||
g.Data1 = order.Uint32(b[0:4])
|
||||
g.Data2 = order.Uint16(b[4:6])
|
||||
g.Data3 = order.Uint16(b[6:8])
|
||||
copy(g.Data4[:], b[8:16])
|
||||
return g
|
||||
}
|
||||
|
||||
func (g GUID) toArray(order binary.ByteOrder) [16]byte {
|
||||
b := [16]byte{}
|
||||
order.PutUint32(b[0:4], g.Data1)
|
||||
order.PutUint16(b[4:6], g.Data2)
|
||||
order.PutUint16(b[6:8], g.Data3)
|
||||
copy(b[8:16], g.Data4[:])
|
||||
return b
|
||||
}
|
||||
|
||||
// FromArray constructs a GUID from a big-endian encoding array of 16 bytes.
|
||||
func FromArray(b [16]byte) GUID {
|
||||
return fromArray(b, binary.BigEndian)
|
||||
}
|
||||
|
||||
// ToArray returns an array of 16 bytes representing the GUID in big-endian
|
||||
// encoding.
|
||||
func (g GUID) ToArray() [16]byte {
|
||||
return g.toArray(binary.BigEndian)
|
||||
}
|
||||
|
||||
// FromWindowsArray constructs a GUID from a Windows encoding array of bytes.
|
||||
func FromWindowsArray(b [16]byte) GUID {
|
||||
return fromArray(b, binary.LittleEndian)
|
||||
}
|
||||
|
||||
// ToWindowsArray returns an array of 16 bytes representing the GUID in Windows
|
||||
// encoding.
|
||||
func (g GUID) ToWindowsArray() [16]byte {
|
||||
return g.toArray(binary.LittleEndian)
|
||||
}
|
||||
|
||||
func (g GUID) String() string {
|
||||
return fmt.Sprintf(
|
||||
"%08x-%04x-%04x-%04x-%012x",
|
||||
g.Data1,
|
||||
g.Data2,
|
||||
g.Data3,
|
||||
g.Data4[:2],
|
||||
g.Data4[2:])
|
||||
}
|
||||
|
||||
// FromString parses a string containing a GUID and returns the GUID. The only
|
||||
// format currently supported is the `xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx`
|
||||
// format.
|
||||
func FromString(s string) (GUID, error) {
|
||||
if len(s) != 36 {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
|
||||
var g GUID
|
||||
|
||||
data1, err := strconv.ParseUint(s[0:8], 16, 32)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data1 = uint32(data1)
|
||||
|
||||
data2, err := strconv.ParseUint(s[9:13], 16, 16)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data2 = uint16(data2)
|
||||
|
||||
data3, err := strconv.ParseUint(s[14:18], 16, 16)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data3 = uint16(data3)
|
||||
|
||||
for i, x := range []int{19, 21, 24, 26, 28, 30, 32, 34} {
|
||||
v, err := strconv.ParseUint(s[x:x+2], 16, 8)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data4[i] = uint8(v)
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (g *GUID) setVariant(v Variant) {
|
||||
d := g.Data4[0]
|
||||
switch v {
|
||||
case VariantNCS:
|
||||
d = (d & 0x7f)
|
||||
case VariantRFC4122:
|
||||
d = (d & 0x3f) | 0x80
|
||||
case VariantMicrosoft:
|
||||
d = (d & 0x1f) | 0xc0
|
||||
case VariantFuture:
|
||||
d = (d & 0x0f) | 0xe0
|
||||
case VariantUnknown:
|
||||
fallthrough
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid variant: %d", v))
|
||||
}
|
||||
g.Data4[0] = d
|
||||
}
|
||||
|
||||
// Variant returns the GUID variant, as defined in RFC 4122.
|
||||
func (g GUID) Variant() Variant {
|
||||
b := g.Data4[0]
|
||||
if b&0x80 == 0 {
|
||||
return VariantNCS
|
||||
} else if b&0xc0 == 0x80 {
|
||||
return VariantRFC4122
|
||||
} else if b&0xe0 == 0xc0 {
|
||||
return VariantMicrosoft
|
||||
} else if b&0xe0 == 0xe0 {
|
||||
return VariantFuture
|
||||
}
|
||||
return VariantUnknown
|
||||
}
|
||||
|
||||
func (g *GUID) setVersion(v Version) {
|
||||
g.Data3 = (g.Data3 & 0x0fff) | (uint16(v) << 12)
|
||||
}
|
||||
|
||||
// Version returns the GUID version, as defined in RFC 4122.
|
||||
func (g GUID) Version() Version {
|
||||
return Version((g.Data3 & 0xF000) >> 12)
|
||||
}
|
||||
|
||||
// MarshalText returns the textual representation of the GUID.
|
||||
func (g GUID) MarshalText() ([]byte, error) {
|
||||
return []byte(g.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText takes the textual representation of a GUID, and unmarhals it
|
||||
// into this GUID.
|
||||
func (g *GUID) UnmarshalText(text []byte) error {
|
||||
g2, err := FromString(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*g = g2
|
||||
return nil
|
||||
}
|
||||
16
vendor/github.com/Microsoft/go-winio/pkg/guid/guid_nonwindows.go
generated
vendored
Normal file
16
vendor/github.com/Microsoft/go-winio/pkg/guid/guid_nonwindows.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package guid
|
||||
|
||||
// GUID represents a GUID/UUID. It has the same structure as
|
||||
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
|
||||
// that type. It is defined as its own type as that is only available to builds
|
||||
// targeted at `windows`. The representation matches that used by native Windows
|
||||
// code.
|
||||
type GUID struct {
|
||||
Data1 uint32
|
||||
Data2 uint16
|
||||
Data3 uint16
|
||||
Data4 [8]byte
|
||||
}
|
||||
13
vendor/github.com/Microsoft/go-winio/pkg/guid/guid_windows.go
generated
vendored
Normal file
13
vendor/github.com/Microsoft/go-winio/pkg/guid/guid_windows.go
generated
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package guid
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
// GUID represents a GUID/UUID. It has the same structure as
|
||||
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
|
||||
// that type. It is defined as its own type so that stringification and
|
||||
// marshaling can be supported. The representation matches that used by native
|
||||
// Windows code.
|
||||
type GUID windows.GUID
|
||||
27
vendor/github.com/Microsoft/go-winio/pkg/guid/variant_string.go
generated
vendored
Normal file
27
vendor/github.com/Microsoft/go-winio/pkg/guid/variant_string.go
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
// Code generated by "stringer -type=Variant -trimprefix=Variant -linecomment"; DO NOT EDIT.
|
||||
|
||||
package guid
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[VariantUnknown-0]
|
||||
_ = x[VariantNCS-1]
|
||||
_ = x[VariantRFC4122-2]
|
||||
_ = x[VariantMicrosoft-3]
|
||||
_ = x[VariantFuture-4]
|
||||
}
|
||||
|
||||
const _Variant_name = "UnknownNCSRFC 4122MicrosoftFuture"
|
||||
|
||||
var _Variant_index = [...]uint8{0, 7, 10, 18, 27, 33}
|
||||
|
||||
func (i Variant) String() string {
|
||||
if i >= Variant(len(_Variant_index)-1) {
|
||||
return "Variant(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _Variant_name[_Variant_index[i]:_Variant_index[i+1]]
|
||||
}
|
||||
196
vendor/github.com/Microsoft/go-winio/privilege.go
generated
vendored
Normal file
196
vendor/github.com/Microsoft/go-winio/privilege.go
generated
vendored
Normal file
@@ -0,0 +1,196 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"unicode/utf16"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) [true] = advapi32.AdjustTokenPrivileges
|
||||
//sys impersonateSelf(level uint32) (err error) = advapi32.ImpersonateSelf
|
||||
//sys revertToSelf() (err error) = advapi32.RevertToSelf
|
||||
//sys openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) = advapi32.OpenThreadToken
|
||||
//sys getCurrentThread() (h windows.Handle) = GetCurrentThread
|
||||
//sys lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) = advapi32.LookupPrivilegeValueW
|
||||
//sys lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) = advapi32.LookupPrivilegeNameW
|
||||
//sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW
|
||||
|
||||
const (
|
||||
//revive:disable-next-line:var-naming ALL_CAPS
|
||||
SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED
|
||||
|
||||
//revive:disable-next-line:var-naming ALL_CAPS
|
||||
ERROR_NOT_ALL_ASSIGNED windows.Errno = windows.ERROR_NOT_ALL_ASSIGNED
|
||||
|
||||
SeBackupPrivilege = "SeBackupPrivilege"
|
||||
SeRestorePrivilege = "SeRestorePrivilege"
|
||||
SeSecurityPrivilege = "SeSecurityPrivilege"
|
||||
)
|
||||
|
||||
var (
|
||||
privNames = make(map[string]uint64)
|
||||
privNameMutex sync.Mutex
|
||||
)
|
||||
|
||||
// PrivilegeError represents an error enabling privileges.
|
||||
type PrivilegeError struct {
|
||||
privileges []uint64
|
||||
}
|
||||
|
||||
func (e *PrivilegeError) Error() string {
|
||||
s := "Could not enable privilege "
|
||||
if len(e.privileges) > 1 {
|
||||
s = "Could not enable privileges "
|
||||
}
|
||||
for i, p := range e.privileges {
|
||||
if i != 0 {
|
||||
s += ", "
|
||||
}
|
||||
s += `"`
|
||||
s += getPrivilegeName(p)
|
||||
s += `"`
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// RunWithPrivilege enables a single privilege for a function call.
|
||||
func RunWithPrivilege(name string, fn func() error) error {
|
||||
return RunWithPrivileges([]string{name}, fn)
|
||||
}
|
||||
|
||||
// RunWithPrivileges enables privileges for a function call.
|
||||
func RunWithPrivileges(names []string, fn func() error) error {
|
||||
privileges, err := mapPrivileges(names)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
token, err := newThreadToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer releaseThreadToken(token)
|
||||
err = adjustPrivileges(token, privileges, SE_PRIVILEGE_ENABLED)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fn()
|
||||
}
|
||||
|
||||
func mapPrivileges(names []string) ([]uint64, error) {
|
||||
privileges := make([]uint64, 0, len(names))
|
||||
privNameMutex.Lock()
|
||||
defer privNameMutex.Unlock()
|
||||
for _, name := range names {
|
||||
p, ok := privNames[name]
|
||||
if !ok {
|
||||
err := lookupPrivilegeValue("", name, &p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
privNames[name] = p
|
||||
}
|
||||
privileges = append(privileges, p)
|
||||
}
|
||||
return privileges, nil
|
||||
}
|
||||
|
||||
// EnableProcessPrivileges enables privileges globally for the process.
|
||||
func EnableProcessPrivileges(names []string) error {
|
||||
return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED)
|
||||
}
|
||||
|
||||
// DisableProcessPrivileges disables privileges globally for the process.
|
||||
func DisableProcessPrivileges(names []string) error {
|
||||
return enableDisableProcessPrivilege(names, 0)
|
||||
}
|
||||
|
||||
func enableDisableProcessPrivilege(names []string, action uint32) error {
|
||||
privileges, err := mapPrivileges(names)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := windows.CurrentProcess()
|
||||
var token windows.Token
|
||||
err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer token.Close()
|
||||
return adjustPrivileges(token, privileges, action)
|
||||
}
|
||||
|
||||
func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error {
|
||||
var b bytes.Buffer
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint32(len(privileges)))
|
||||
for _, p := range privileges {
|
||||
_ = binary.Write(&b, binary.LittleEndian, p)
|
||||
_ = binary.Write(&b, binary.LittleEndian, action)
|
||||
}
|
||||
prevState := make([]byte, b.Len())
|
||||
reqSize := uint32(0)
|
||||
success, err := adjustTokenPrivileges(token, false, &b.Bytes()[0], uint32(len(prevState)), &prevState[0], &reqSize)
|
||||
if !success {
|
||||
return err
|
||||
}
|
||||
if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno
|
||||
return &PrivilegeError{privileges}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPrivilegeName(luid uint64) string {
|
||||
var nameBuffer [256]uint16
|
||||
bufSize := uint32(len(nameBuffer))
|
||||
err := lookupPrivilegeName("", &luid, &nameBuffer[0], &bufSize)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("<unknown privilege %d>", luid)
|
||||
}
|
||||
|
||||
var displayNameBuffer [256]uint16
|
||||
displayBufSize := uint32(len(displayNameBuffer))
|
||||
var langID uint32
|
||||
err = lookupPrivilegeDisplayName("", &nameBuffer[0], &displayNameBuffer[0], &displayBufSize, &langID)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("<unknown privilege %s>", string(utf16.Decode(nameBuffer[:bufSize])))
|
||||
}
|
||||
|
||||
return string(utf16.Decode(displayNameBuffer[:displayBufSize]))
|
||||
}
|
||||
|
||||
func newThreadToken() (windows.Token, error) {
|
||||
err := impersonateSelf(windows.SecurityImpersonation)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var token windows.Token
|
||||
err = openThreadToken(getCurrentThread(), windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, false, &token)
|
||||
if err != nil {
|
||||
rerr := revertToSelf()
|
||||
if rerr != nil {
|
||||
panic(rerr)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func releaseThreadToken(h windows.Token) {
|
||||
err := revertToSelf()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
h.Close()
|
||||
}
|
||||
131
vendor/github.com/Microsoft/go-winio/reparse.go
generated
vendored
Normal file
131
vendor/github.com/Microsoft/go-winio/reparse.go
generated
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf16"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
reparseTagMountPoint = 0xA0000003
|
||||
reparseTagSymlink = 0xA000000C
|
||||
)
|
||||
|
||||
type reparseDataBuffer struct {
|
||||
ReparseTag uint32
|
||||
ReparseDataLength uint16
|
||||
Reserved uint16
|
||||
SubstituteNameOffset uint16
|
||||
SubstituteNameLength uint16
|
||||
PrintNameOffset uint16
|
||||
PrintNameLength uint16
|
||||
}
|
||||
|
||||
// ReparsePoint describes a Win32 symlink or mount point.
|
||||
type ReparsePoint struct {
|
||||
Target string
|
||||
IsMountPoint bool
|
||||
}
|
||||
|
||||
// UnsupportedReparsePointError is returned when trying to decode a non-symlink or
|
||||
// mount point reparse point.
|
||||
type UnsupportedReparsePointError struct {
|
||||
Tag uint32
|
||||
}
|
||||
|
||||
func (e *UnsupportedReparsePointError) Error() string {
|
||||
return fmt.Sprintf("unsupported reparse point %x", e.Tag)
|
||||
}
|
||||
|
||||
// DecodeReparsePoint decodes a Win32 REPARSE_DATA_BUFFER structure containing either a symlink
|
||||
// or a mount point.
|
||||
func DecodeReparsePoint(b []byte) (*ReparsePoint, error) {
|
||||
tag := binary.LittleEndian.Uint32(b[0:4])
|
||||
return DecodeReparsePointData(tag, b[8:])
|
||||
}
|
||||
|
||||
func DecodeReparsePointData(tag uint32, b []byte) (*ReparsePoint, error) {
|
||||
isMountPoint := false
|
||||
switch tag {
|
||||
case reparseTagMountPoint:
|
||||
isMountPoint = true
|
||||
case reparseTagSymlink:
|
||||
default:
|
||||
return nil, &UnsupportedReparsePointError{tag}
|
||||
}
|
||||
nameOffset := 8 + binary.LittleEndian.Uint16(b[4:6])
|
||||
if !isMountPoint {
|
||||
nameOffset += 4
|
||||
}
|
||||
nameLength := binary.LittleEndian.Uint16(b[6:8])
|
||||
name := make([]uint16, nameLength/2)
|
||||
err := binary.Read(bytes.NewReader(b[nameOffset:nameOffset+nameLength]), binary.LittleEndian, &name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ReparsePoint{string(utf16.Decode(name)), isMountPoint}, nil
|
||||
}
|
||||
|
||||
func isDriveLetter(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
|
||||
}
|
||||
|
||||
// EncodeReparsePoint encodes a Win32 REPARSE_DATA_BUFFER structure describing a symlink or
|
||||
// mount point.
|
||||
func EncodeReparsePoint(rp *ReparsePoint) []byte {
|
||||
// Generate an NT path and determine if this is a relative path.
|
||||
var ntTarget string
|
||||
relative := false
|
||||
if strings.HasPrefix(rp.Target, `\\?\`) {
|
||||
ntTarget = `\??\` + rp.Target[4:]
|
||||
} else if strings.HasPrefix(rp.Target, `\\`) {
|
||||
ntTarget = `\??\UNC\` + rp.Target[2:]
|
||||
} else if len(rp.Target) >= 2 && isDriveLetter(rp.Target[0]) && rp.Target[1] == ':' {
|
||||
ntTarget = `\??\` + rp.Target
|
||||
} else {
|
||||
ntTarget = rp.Target
|
||||
relative = true
|
||||
}
|
||||
|
||||
// The paths must be NUL-terminated even though they are counted strings.
|
||||
target16 := utf16.Encode([]rune(rp.Target + "\x00"))
|
||||
ntTarget16 := utf16.Encode([]rune(ntTarget + "\x00"))
|
||||
|
||||
size := int(unsafe.Sizeof(reparseDataBuffer{})) - 8
|
||||
size += len(ntTarget16)*2 + len(target16)*2
|
||||
|
||||
tag := uint32(reparseTagMountPoint)
|
||||
if !rp.IsMountPoint {
|
||||
tag = reparseTagSymlink
|
||||
size += 4 // Add room for symlink flags
|
||||
}
|
||||
|
||||
data := reparseDataBuffer{
|
||||
ReparseTag: tag,
|
||||
ReparseDataLength: uint16(size),
|
||||
SubstituteNameOffset: 0,
|
||||
SubstituteNameLength: uint16((len(ntTarget16) - 1) * 2),
|
||||
PrintNameOffset: uint16(len(ntTarget16) * 2),
|
||||
PrintNameLength: uint16((len(target16) - 1) * 2),
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
_ = binary.Write(&b, binary.LittleEndian, &data)
|
||||
if !rp.IsMountPoint {
|
||||
flags := uint32(0)
|
||||
if relative {
|
||||
flags |= 1
|
||||
}
|
||||
_ = binary.Write(&b, binary.LittleEndian, flags)
|
||||
}
|
||||
|
||||
_ = binary.Write(&b, binary.LittleEndian, ntTarget16)
|
||||
_ = binary.Write(&b, binary.LittleEndian, target16)
|
||||
return b.Bytes()
|
||||
}
|
||||
133
vendor/github.com/Microsoft/go-winio/sd.go
generated
vendored
Normal file
133
vendor/github.com/Microsoft/go-winio/sd.go
generated
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountNameW
|
||||
//sys lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountSidW
|
||||
//sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW
|
||||
//sys convertStringSidToSid(str *uint16, sid **byte) (err error) = advapi32.ConvertStringSidToSidW
|
||||
|
||||
type AccountLookupError struct {
|
||||
Name string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *AccountLookupError) Error() string {
|
||||
if e.Name == "" {
|
||||
return "lookup account: empty account name specified"
|
||||
}
|
||||
var s string
|
||||
switch {
|
||||
case errors.Is(e.Err, windows.ERROR_INVALID_SID):
|
||||
s = "the security ID structure is invalid"
|
||||
case errors.Is(e.Err, windows.ERROR_NONE_MAPPED):
|
||||
s = "not found"
|
||||
default:
|
||||
s = e.Err.Error()
|
||||
}
|
||||
return "lookup account " + e.Name + ": " + s
|
||||
}
|
||||
|
||||
func (e *AccountLookupError) Unwrap() error { return e.Err }
|
||||
|
||||
type SddlConversionError struct {
|
||||
Sddl string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *SddlConversionError) Error() string {
|
||||
return "convert " + e.Sddl + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *SddlConversionError) Unwrap() error { return e.Err }
|
||||
|
||||
// LookupSidByName looks up the SID of an account by name
|
||||
//
|
||||
//revive:disable-next-line:var-naming SID, not Sid
|
||||
func LookupSidByName(name string) (sid string, err error) {
|
||||
if name == "" {
|
||||
return "", &AccountLookupError{name, windows.ERROR_NONE_MAPPED}
|
||||
}
|
||||
|
||||
var sidSize, sidNameUse, refDomainSize uint32
|
||||
err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse)
|
||||
if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
|
||||
return "", &AccountLookupError{name, err}
|
||||
}
|
||||
sidBuffer := make([]byte, sidSize)
|
||||
refDomainBuffer := make([]uint16, refDomainSize)
|
||||
err = lookupAccountName(nil, name, &sidBuffer[0], &sidSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{name, err}
|
||||
}
|
||||
var strBuffer *uint16
|
||||
err = convertSidToStringSid(&sidBuffer[0], &strBuffer)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{name, err}
|
||||
}
|
||||
sid = windows.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(strBuffer))[:])
|
||||
_, _ = windows.LocalFree(windows.Handle(unsafe.Pointer(strBuffer)))
|
||||
return sid, nil
|
||||
}
|
||||
|
||||
// LookupNameBySid looks up the name of an account by SID
|
||||
//
|
||||
//revive:disable-next-line:var-naming SID, not Sid
|
||||
func LookupNameBySid(sid string) (name string, err error) {
|
||||
if sid == "" {
|
||||
return "", &AccountLookupError{sid, windows.ERROR_NONE_MAPPED}
|
||||
}
|
||||
|
||||
sidBuffer, err := windows.UTF16PtrFromString(sid)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
|
||||
var sidPtr *byte
|
||||
if err = convertStringSidToSid(sidBuffer, &sidPtr); err != nil {
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(unsafe.Pointer(sidPtr))) //nolint:errcheck
|
||||
|
||||
var nameSize, refDomainSize, sidNameUse uint32
|
||||
err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse)
|
||||
if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
|
||||
nameBuffer := make([]uint16, nameSize)
|
||||
refDomainBuffer := make([]uint16, refDomainSize)
|
||||
err = lookupAccountSid(nil, sidPtr, &nameBuffer[0], &nameSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
|
||||
name = windows.UTF16ToString(nameBuffer)
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
||||
sd, err := windows.SecurityDescriptorFromString(sddl)
|
||||
if err != nil {
|
||||
return nil, &SddlConversionError{Sddl: sddl, Err: err}
|
||||
}
|
||||
b := unsafe.Slice((*byte)(unsafe.Pointer(sd)), sd.Length())
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func SecurityDescriptorToSddl(sd []byte) (string, error) {
|
||||
if l := int(unsafe.Sizeof(windows.SECURITY_DESCRIPTOR{})); len(sd) < l {
|
||||
return "", fmt.Errorf("SecurityDescriptor (%d) smaller than expected (%d): %w", len(sd), l, windows.ERROR_INCORRECT_SIZE)
|
||||
}
|
||||
s := (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(&sd[0]))
|
||||
return s.String(), nil
|
||||
}
|
||||
5
vendor/github.com/Microsoft/go-winio/syscall.go
generated
vendored
Normal file
5
vendor/github.com/Microsoft/go-winio/syscall.go
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build windows
|
||||
|
||||
package winio
|
||||
|
||||
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go
|
||||
378
vendor/github.com/Microsoft/go-winio/zsyscall_windows.go
generated
vendored
Normal file
378
vendor/github.com/Microsoft/go-winio/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,378 @@
|
||||
//go:build windows
|
||||
|
||||
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
|
||||
procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW")
|
||||
procConvertStringSidToSidW = modadvapi32.NewProc("ConvertStringSidToSidW")
|
||||
procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf")
|
||||
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
|
||||
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
|
||||
procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW")
|
||||
procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW")
|
||||
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
|
||||
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
|
||||
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
|
||||
procBackupRead = modkernel32.NewProc("BackupRead")
|
||||
procBackupWrite = modkernel32.NewProc("BackupWrite")
|
||||
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
||||
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
||||
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
||||
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
|
||||
procDisconnectNamedPipe = modkernel32.NewProc("DisconnectNamedPipe")
|
||||
procGetCurrentThread = modkernel32.NewProc("GetCurrentThread")
|
||||
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
|
||||
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
|
||||
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
||||
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
|
||||
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
|
||||
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
||||
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
||||
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
||||
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
||||
)
|
||||
|
||||
func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) {
|
||||
var _p0 uint32
|
||||
if releaseAll {
|
||||
_p0 = 1
|
||||
}
|
||||
r0, _, e1 := syscall.SyscallN(procAdjustTokenPrivileges.Addr(), uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize)))
|
||||
success = r0 != 0
|
||||
if true {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertSidToStringSid(sid *byte, str **uint16) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procConvertSidToStringSidW.Addr(), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(str)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertStringSidToSid(str *uint16, sid **byte) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procConvertStringSidToSidW.Addr(), uintptr(unsafe.Pointer(str)), uintptr(unsafe.Pointer(sid)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func impersonateSelf(level uint32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procImpersonateSelf.Addr(), uintptr(level))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(accountName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupAccountName(systemName, _p0, sid, sidSize, refDomain, refDomainSize, sidNameUse)
|
||||
}
|
||||
|
||||
func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procLookupAccountNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procLookupAccountSidW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupPrivilegeDisplayName(_p0, name, buffer, size, languageId)
|
||||
}
|
||||
|
||||
func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procLookupPrivilegeDisplayNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageId)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupPrivilegeName(_p0, luid, buffer, size)
|
||||
}
|
||||
|
||||
func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procLookupPrivilegeNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var _p1 *uint16
|
||||
_p1, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupPrivilegeValue(_p0, _p1, luid)
|
||||
}
|
||||
|
||||
func _lookupPrivilegeValue(systemName *uint16, name *uint16, luid *uint64) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procLookupPrivilegeValueW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(luid)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) {
|
||||
var _p0 uint32
|
||||
if openAsSelf {
|
||||
_p0 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.SyscallN(procOpenThreadToken.Addr(), uintptr(thread), uintptr(accessMask), uintptr(_p0), uintptr(unsafe.Pointer(token)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func revertToSelf() (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procRevertToSelf.Addr())
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
|
||||
var _p0 *byte
|
||||
if len(b) > 0 {
|
||||
_p0 = &b[0]
|
||||
}
|
||||
var _p1 uint32
|
||||
if abort {
|
||||
_p1 = 1
|
||||
}
|
||||
var _p2 uint32
|
||||
if processSecurity {
|
||||
_p2 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.SyscallN(procBackupRead.Addr(), uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesRead)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
|
||||
var _p0 *byte
|
||||
if len(b) > 0 {
|
||||
_p0 = &b[0]
|
||||
}
|
||||
var _p1 uint32
|
||||
if abort {
|
||||
_p1 = 1
|
||||
}
|
||||
var _p2 uint32
|
||||
if processSecurity {
|
||||
_p2 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.SyscallN(procBackupWrite.Addr(), uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesWritten)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procCancelIoEx.Addr(), uintptr(file), uintptr(unsafe.Pointer(o)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procConnectNamedPipe.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(o)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.SyscallN(procCreateIoCompletionPort.Addr(), uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount))
|
||||
newport = windows.Handle(r0)
|
||||
if newport == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
||||
}
|
||||
|
||||
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.SyscallN(procCreateNamedPipeW.Addr(), uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)))
|
||||
handle = windows.Handle(r0)
|
||||
if handle == windows.InvalidHandle {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func disconnectNamedPipe(pipe windows.Handle) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procDisconnectNamedPipe.Addr(), uintptr(pipe))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getCurrentThread() (h windows.Handle) {
|
||||
r0, _, _ := syscall.SyscallN(procGetCurrentThread.Addr())
|
||||
h = windows.Handle(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procGetNamedPipeHandleStateW.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procGetNamedPipeInfo.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procGetQueuedCompletionStatus.Addr(), uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
|
||||
r1, _, e1 := syscall.SyscallN(procSetFileCompletionNotificationModes.Addr(), uintptr(h), uintptr(flags))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) {
|
||||
r0, _, _ := syscall.SyscallN(procNtCreateNamedPipeFile.Addr(), uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)))
|
||||
status = ntStatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) {
|
||||
r0, _, _ := syscall.SyscallN(procRtlDefaultNpAcl.Addr(), uintptr(unsafe.Pointer(dacl)))
|
||||
status = ntStatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) {
|
||||
r0, _, _ := syscall.SyscallN(procRtlDosPathNameToNtPathName_U.Addr(), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved))
|
||||
status = ntStatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlNtStatusToDosError(status ntStatus) (winerr error) {
|
||||
r0, _, _ := syscall.SyscallN(procRtlNtStatusToDosErrorNoTeb.Addr(), uintptr(status))
|
||||
if r0 != 0 {
|
||||
winerr = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
||||
var _p0 uint32
|
||||
if wait {
|
||||
_p0 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
191
vendor/github.com/containerd/errdefs/LICENSE
generated
vendored
Normal file
191
vendor/github.com/containerd/errdefs/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,191 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
https://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Copyright The containerd Authors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user