Compare commits
18 Commits
ef4bf1efe0
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 660bf7ee48 | |||
|
|
17673c38a6 | ||
|
|
9dbd361caf | ||
|
|
859e5e1e02 | ||
|
|
f010a0c8a2 | ||
|
|
d0973b2adf | ||
|
|
8d9b62daf3 | ||
|
|
d1252ade69 | ||
|
|
9fc9a2e3a2 | ||
|
|
14b5125c12 | ||
|
|
ea04378962 | ||
| d69766c83c | |||
| 237e8699eb | |||
| 1de8695736 | |||
| c30c6dc480 | |||
|
|
e523c4b543 | ||
|
|
26e4ef7d8b | ||
|
|
eb2e05ff84 |
@@ -15,14 +15,16 @@ RUN addgroup -g 1000 chorus && \
|
|||||||
RUN mkdir -p /app/data && \
|
RUN mkdir -p /app/data && \
|
||||||
chown -R chorus:chorus /app
|
chown -R chorus:chorus /app
|
||||||
|
|
||||||
# Copy pre-built binary
|
# Copy pre-built binary from build directory (ensure it exists and is the correct one)
|
||||||
COPY chorus-agent /app/chorus-agent
|
COPY build/chorus-agent /app/chorus-agent
|
||||||
RUN chmod +x /app/chorus-agent && chown chorus:chorus /app/chorus-agent
|
RUN chmod +x /app/chorus-agent && chown chorus:chorus /app/chorus-agent
|
||||||
|
|
||||||
# Switch to non-root user
|
# Switch to non-root user
|
||||||
USER chorus
|
USER chorus
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Note: Using correct chorus-agent binary built with 'make build-agent'
|
||||||
|
|
||||||
# Expose ports
|
# Expose ports
|
||||||
EXPOSE 8080 8081 9000
|
EXPOSE 8080 8081 9000
|
||||||
|
|
||||||
|
|||||||
43
Dockerfile.ubuntu
Normal file
43
Dockerfile.ubuntu
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# CHORUS - Ubuntu-based Docker image for glibc compatibility
|
||||||
|
FROM ubuntu:22.04
|
||||||
|
|
||||||
|
# Install runtime dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
ca-certificates \
|
||||||
|
tzdata \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Create non-root user for security
|
||||||
|
RUN groupadd -g 1000 chorus && \
|
||||||
|
useradd -u 1000 -g chorus -s /bin/bash -d /home/chorus -m chorus
|
||||||
|
|
||||||
|
# Create application directories
|
||||||
|
RUN mkdir -p /app/data && \
|
||||||
|
chown -R chorus:chorus /app
|
||||||
|
|
||||||
|
# Copy pre-built binary from build directory
|
||||||
|
COPY build/chorus-agent /app/chorus-agent
|
||||||
|
RUN chmod +x /app/chorus-agent && chown chorus:chorus /app/chorus-agent
|
||||||
|
|
||||||
|
# Switch to non-root user
|
||||||
|
USER chorus
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Expose ports
|
||||||
|
EXPOSE 8080 8081 9000
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:8081/health || exit 1
|
||||||
|
|
||||||
|
# Set default environment variables
|
||||||
|
ENV LOG_LEVEL=info \
|
||||||
|
LOG_FORMAT=structured \
|
||||||
|
CHORUS_BIND_ADDRESS=0.0.0.0 \
|
||||||
|
CHORUS_API_PORT=8080 \
|
||||||
|
CHORUS_HEALTH_PORT=8081 \
|
||||||
|
CHORUS_P2P_PORT=9000
|
||||||
|
|
||||||
|
# Start CHORUS
|
||||||
|
ENTRYPOINT ["/app/chorus-agent"]
|
||||||
2
Makefile
2
Makefile
@@ -5,7 +5,7 @@
|
|||||||
BINARY_NAME_AGENT = chorus-agent
|
BINARY_NAME_AGENT = chorus-agent
|
||||||
BINARY_NAME_HAP = chorus-hap
|
BINARY_NAME_HAP = chorus-hap
|
||||||
BINARY_NAME_COMPAT = chorus
|
BINARY_NAME_COMPAT = chorus
|
||||||
VERSION ?= 0.1.0-dev
|
VERSION ?= 0.5.5
|
||||||
COMMIT_HASH ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
COMMIT_HASH ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||||
BUILD_DATE ?= $(shell date -u '+%Y-%m-%d_%H:%M:%S')
|
BUILD_DATE ?= $(shell date -u '+%Y-%m-%d_%H:%M:%S')
|
||||||
|
|
||||||
|
|||||||
35
README.md
35
README.md
@@ -8,7 +8,7 @@ CHORUS is the runtime that ties the CHORUS ecosystem together: libp2p mesh, DHT-
|
|||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| libp2p node + PubSub | ✅ Running | `internal/runtime/shared.go` spins up the mesh, hypercore logging, availability broadcasts. |
|
| libp2p node + PubSub | ✅ Running | `internal/runtime/shared.go` spins up the mesh, hypercore logging, availability broadcasts. |
|
||||||
| DHT + DecisionPublisher | ✅ Running | Encrypted storage wired through `pkg/dht`; decisions written via `ucxl.DecisionPublisher`. |
|
| DHT + DecisionPublisher | ✅ Running | Encrypted storage wired through `pkg/dht`; decisions written via `ucxl.DecisionPublisher`. |
|
||||||
| Election manager | ✅ Running | Admin election integrated with Backbeat; metrics exposed under `pkg/metrics`. |
|
| **Leader Election System** | ✅ **FULLY FUNCTIONAL** | **🎉 MILESTONE: Complete admin election with consensus, discovery protocol, heartbeats, and SLURP activation!** |
|
||||||
| SLURP (context intelligence) | 🚧 Stubbed | `pkg/slurp/slurp.go` contains TODOs for resolver, temporal graphs, intelligence. Leader integration scaffolding exists but uses placeholder IDs/request forwarding. |
|
| SLURP (context intelligence) | 🚧 Stubbed | `pkg/slurp/slurp.go` contains TODOs for resolver, temporal graphs, intelligence. Leader integration scaffolding exists but uses placeholder IDs/request forwarding. |
|
||||||
| SHHH (secrets sentinel) | 🚧 Sentinel live | `pkg/shhh` redacts hypercore + PubSub payloads with audit + metrics hooks (policy replay TBD). |
|
| SHHH (secrets sentinel) | 🚧 Sentinel live | `pkg/shhh` redacts hypercore + PubSub payloads with audit + metrics hooks (policy replay TBD). |
|
||||||
| HMMM routing | 🚧 Partial | PubSub topics join, but capability/role announcements and HMMM router wiring are placeholders (`internal/runtime/agent_support.go`). |
|
| HMMM routing | 🚧 Partial | PubSub topics join, but capability/role announcements and HMMM router wiring are placeholders (`internal/runtime/agent_support.go`). |
|
||||||
@@ -35,6 +35,39 @@ You’ll get a single agent container with:
|
|||||||
|
|
||||||
**Missing today:** SLURP context resolution, advanced SHHH policy replay, HMMM per-issue routing. Expect log warnings/TODOs for those paths.
|
**Missing today:** SLURP context resolution, advanced SHHH policy replay, HMMM per-issue routing. Expect log warnings/TODOs for those paths.
|
||||||
|
|
||||||
|
## 🎉 Leader Election System (NEW!)
|
||||||
|
|
||||||
|
CHORUS now features a complete, production-ready leader election system:
|
||||||
|
|
||||||
|
### Core Features
|
||||||
|
- **Consensus-based election** with weighted scoring (uptime, capabilities, resources)
|
||||||
|
- **Admin discovery protocol** for network-wide leader identification
|
||||||
|
- **Heartbeat system** with automatic failover (15-second intervals)
|
||||||
|
- **Concurrent election prevention** with randomized delays
|
||||||
|
- **SLURP activation** on elected admin nodes
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
1. **Bootstrap**: Nodes start in idle state, no admin known
|
||||||
|
2. **Discovery**: Nodes send discovery requests to find existing admin
|
||||||
|
3. **Election trigger**: If no admin found after grace period, trigger election
|
||||||
|
4. **Candidacy**: Eligible nodes announce themselves with capability scores
|
||||||
|
5. **Consensus**: Network selects winner based on highest score
|
||||||
|
6. **Leadership**: Winner starts heartbeats, activates SLURP functionality
|
||||||
|
7. **Monitoring**: Nodes continuously verify admin health via heartbeats
|
||||||
|
|
||||||
|
### Debugging
|
||||||
|
Use these log patterns to monitor election health:
|
||||||
|
```bash
|
||||||
|
# Monitor WHOAMI messages and leader identification
|
||||||
|
docker service logs CHORUS_chorus | grep "🤖 WHOAMI\|👑\|📡.*Discovered"
|
||||||
|
|
||||||
|
# Track election cycles
|
||||||
|
docker service logs CHORUS_chorus | grep "🗳️\|📢.*candidacy\|🏆.*winner"
|
||||||
|
|
||||||
|
# Watch discovery protocol
|
||||||
|
docker service logs CHORUS_chorus | grep "📩\|📤\|📥"
|
||||||
|
```
|
||||||
|
|
||||||
## Roadmap Highlights
|
## Roadmap Highlights
|
||||||
|
|
||||||
1. **Security substrate** – land SHHH sentinel, finish SLURP leader-only operations, validate COOEE enrolment (see roadmap Phase 1).
|
1. **Security substrate** – land SHHH sentinel, finish SLURP leader-only operations, validate COOEE enrolment (see roadmap Phase 1).
|
||||||
|
|||||||
@@ -9,10 +9,11 @@ import (
|
|||||||
|
|
||||||
"chorus/internal/logging"
|
"chorus/internal/logging"
|
||||||
"chorus/pubsub"
|
"chorus/pubsub"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPServer provides HTTP API endpoints for Bzzz
|
// HTTPServer provides HTTP API endpoints for CHORUS
|
||||||
type HTTPServer struct {
|
type HTTPServer struct {
|
||||||
port int
|
port int
|
||||||
hypercoreLog *logging.HypercoreLog
|
hypercoreLog *logging.HypercoreLog
|
||||||
@@ -20,7 +21,7 @@ type HTTPServer struct {
|
|||||||
server *http.Server
|
server *http.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPServer creates a new HTTP server for Bzzz API
|
// NewHTTPServer creates a new HTTP server for CHORUS API
|
||||||
func NewHTTPServer(port int, hlog *logging.HypercoreLog, ps *pubsub.PubSub) *HTTPServer {
|
func NewHTTPServer(port int, hlog *logging.HypercoreLog, ps *pubsub.PubSub) *HTTPServer {
|
||||||
return &HTTPServer{
|
return &HTTPServer{
|
||||||
port: port,
|
port: port,
|
||||||
@@ -32,38 +33,38 @@ func NewHTTPServer(port int, hlog *logging.HypercoreLog, ps *pubsub.PubSub) *HTT
|
|||||||
// Start starts the HTTP server
|
// Start starts the HTTP server
|
||||||
func (h *HTTPServer) Start() error {
|
func (h *HTTPServer) Start() error {
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
|
|
||||||
// Enable CORS for all routes
|
// Enable CORS for all routes
|
||||||
router.Use(func(next http.Handler) http.Handler {
|
router.Use(func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||||
|
|
||||||
if r.Method == "OPTIONS" {
|
if r.Method == "OPTIONS" {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// API routes
|
// API routes
|
||||||
api := router.PathPrefix("/api").Subrouter()
|
api := router.PathPrefix("/api").Subrouter()
|
||||||
|
|
||||||
// Hypercore log endpoints
|
// Hypercore log endpoints
|
||||||
api.HandleFunc("/hypercore/logs", h.handleGetLogs).Methods("GET")
|
api.HandleFunc("/hypercore/logs", h.handleGetLogs).Methods("GET")
|
||||||
api.HandleFunc("/hypercore/logs/recent", h.handleGetRecentLogs).Methods("GET")
|
api.HandleFunc("/hypercore/logs/recent", h.handleGetRecentLogs).Methods("GET")
|
||||||
api.HandleFunc("/hypercore/logs/stats", h.handleGetLogStats).Methods("GET")
|
api.HandleFunc("/hypercore/logs/stats", h.handleGetLogStats).Methods("GET")
|
||||||
api.HandleFunc("/hypercore/logs/since/{index}", h.handleGetLogsSince).Methods("GET")
|
api.HandleFunc("/hypercore/logs/since/{index}", h.handleGetLogsSince).Methods("GET")
|
||||||
|
|
||||||
// Health check
|
// Health check
|
||||||
api.HandleFunc("/health", h.handleHealth).Methods("GET")
|
api.HandleFunc("/health", h.handleHealth).Methods("GET")
|
||||||
|
|
||||||
// Status endpoint
|
// Status endpoint
|
||||||
api.HandleFunc("/status", h.handleStatus).Methods("GET")
|
api.HandleFunc("/status", h.handleStatus).Methods("GET")
|
||||||
|
|
||||||
h.server = &http.Server{
|
h.server = &http.Server{
|
||||||
Addr: fmt.Sprintf(":%d", h.port),
|
Addr: fmt.Sprintf(":%d", h.port),
|
||||||
Handler: router,
|
Handler: router,
|
||||||
@@ -71,7 +72,7 @@ func (h *HTTPServer) Start() error {
|
|||||||
WriteTimeout: 15 * time.Second,
|
WriteTimeout: 15 * time.Second,
|
||||||
IdleTimeout: 60 * time.Second,
|
IdleTimeout: 60 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("🌐 Starting HTTP API server on port %d\n", h.port)
|
fmt.Printf("🌐 Starting HTTP API server on port %d\n", h.port)
|
||||||
return h.server.ListenAndServe()
|
return h.server.ListenAndServe()
|
||||||
}
|
}
|
||||||
@@ -87,16 +88,16 @@ func (h *HTTPServer) Stop() error {
|
|||||||
// handleGetLogs returns hypercore log entries
|
// handleGetLogs returns hypercore log entries
|
||||||
func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
// Parse query parameters
|
// Parse query parameters
|
||||||
query := r.URL.Query()
|
query := r.URL.Query()
|
||||||
startStr := query.Get("start")
|
startStr := query.Get("start")
|
||||||
endStr := query.Get("end")
|
endStr := query.Get("end")
|
||||||
limitStr := query.Get("limit")
|
limitStr := query.Get("limit")
|
||||||
|
|
||||||
var start, end uint64
|
var start, end uint64
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if startStr != "" {
|
if startStr != "" {
|
||||||
start, err = strconv.ParseUint(startStr, 10, 64)
|
start, err = strconv.ParseUint(startStr, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -104,7 +105,7 @@ func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if endStr != "" {
|
if endStr != "" {
|
||||||
end, err = strconv.ParseUint(endStr, 10, 64)
|
end, err = strconv.ParseUint(endStr, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -114,7 +115,7 @@ func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
|||||||
} else {
|
} else {
|
||||||
end = h.hypercoreLog.Length()
|
end = h.hypercoreLog.Length()
|
||||||
}
|
}
|
||||||
|
|
||||||
var limit int = 100 // Default limit
|
var limit int = 100 // Default limit
|
||||||
if limitStr != "" {
|
if limitStr != "" {
|
||||||
limit, err = strconv.Atoi(limitStr)
|
limit, err = strconv.Atoi(limitStr)
|
||||||
@@ -122,7 +123,7 @@ func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
|||||||
limit = 100
|
limit = 100
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get log entries
|
// Get log entries
|
||||||
var entries []logging.LogEntry
|
var entries []logging.LogEntry
|
||||||
if endStr != "" || startStr != "" {
|
if endStr != "" || startStr != "" {
|
||||||
@@ -130,87 +131,87 @@ func (h *HTTPServer) handleGetLogs(w http.ResponseWriter, r *http.Request) {
|
|||||||
} else {
|
} else {
|
||||||
entries, err = h.hypercoreLog.GetRecentEntries(limit)
|
entries, err = h.hypercoreLog.GetRecentEntries(limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Failed to get log entries: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to get log entries: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"entries": entries,
|
"entries": entries,
|
||||||
"count": len(entries),
|
"count": len(entries),
|
||||||
"timestamp": time.Now().Unix(),
|
"timestamp": time.Now().Unix(),
|
||||||
"total": h.hypercoreLog.Length(),
|
"total": h.hypercoreLog.Length(),
|
||||||
}
|
}
|
||||||
|
|
||||||
json.NewEncoder(w).Encode(response)
|
json.NewEncoder(w).Encode(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleGetRecentLogs returns the most recent log entries
|
// handleGetRecentLogs returns the most recent log entries
|
||||||
func (h *HTTPServer) handleGetRecentLogs(w http.ResponseWriter, r *http.Request) {
|
func (h *HTTPServer) handleGetRecentLogs(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
// Parse limit parameter
|
// Parse limit parameter
|
||||||
query := r.URL.Query()
|
query := r.URL.Query()
|
||||||
limitStr := query.Get("limit")
|
limitStr := query.Get("limit")
|
||||||
|
|
||||||
limit := 50 // Default
|
limit := 50 // Default
|
||||||
if limitStr != "" {
|
if limitStr != "" {
|
||||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 1000 {
|
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 1000 {
|
||||||
limit = l
|
limit = l
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, err := h.hypercoreLog.GetRecentEntries(limit)
|
entries, err := h.hypercoreLog.GetRecentEntries(limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Failed to get recent entries: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to get recent entries: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"entries": entries,
|
"entries": entries,
|
||||||
"count": len(entries),
|
"count": len(entries),
|
||||||
"timestamp": time.Now().Unix(),
|
"timestamp": time.Now().Unix(),
|
||||||
"total": h.hypercoreLog.Length(),
|
"total": h.hypercoreLog.Length(),
|
||||||
}
|
}
|
||||||
|
|
||||||
json.NewEncoder(w).Encode(response)
|
json.NewEncoder(w).Encode(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleGetLogsSince returns log entries since a given index
|
// handleGetLogsSince returns log entries since a given index
|
||||||
func (h *HTTPServer) handleGetLogsSince(w http.ResponseWriter, r *http.Request) {
|
func (h *HTTPServer) handleGetLogsSince(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
indexStr := vars["index"]
|
indexStr := vars["index"]
|
||||||
|
|
||||||
index, err := strconv.ParseUint(indexStr, 10, 64)
|
index, err := strconv.ParseUint(indexStr, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Invalid index parameter", http.StatusBadRequest)
|
http.Error(w, "Invalid index parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, err := h.hypercoreLog.GetEntriesSince(index)
|
entries, err := h.hypercoreLog.GetEntriesSince(index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Failed to get entries since index: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to get entries since index: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"entries": entries,
|
"entries": entries,
|
||||||
"count": len(entries),
|
"count": len(entries),
|
||||||
"since_index": index,
|
"since_index": index,
|
||||||
"timestamp": time.Now().Unix(),
|
"timestamp": time.Now().Unix(),
|
||||||
"total": h.hypercoreLog.Length(),
|
"total": h.hypercoreLog.Length(),
|
||||||
}
|
}
|
||||||
|
|
||||||
json.NewEncoder(w).Encode(response)
|
json.NewEncoder(w).Encode(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleGetLogStats returns statistics about the hypercore log
|
// handleGetLogStats returns statistics about the hypercore log
|
||||||
func (h *HTTPServer) handleGetLogStats(w http.ResponseWriter, r *http.Request) {
|
func (h *HTTPServer) handleGetLogStats(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
stats := h.hypercoreLog.GetStats()
|
stats := h.hypercoreLog.GetStats()
|
||||||
json.NewEncoder(w).Encode(stats)
|
json.NewEncoder(w).Encode(stats)
|
||||||
}
|
}
|
||||||
@@ -218,26 +219,26 @@ func (h *HTTPServer) handleGetLogStats(w http.ResponseWriter, r *http.Request) {
|
|||||||
// handleHealth returns health status
|
// handleHealth returns health status
|
||||||
func (h *HTTPServer) handleHealth(w http.ResponseWriter, r *http.Request) {
|
func (h *HTTPServer) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
health := map[string]interface{}{
|
health := map[string]interface{}{
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"timestamp": time.Now().Unix(),
|
"timestamp": time.Now().Unix(),
|
||||||
"log_entries": h.hypercoreLog.Length(),
|
"log_entries": h.hypercoreLog.Length(),
|
||||||
}
|
}
|
||||||
|
|
||||||
json.NewEncoder(w).Encode(health)
|
json.NewEncoder(w).Encode(health)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleStatus returns detailed status information
|
// handleStatus returns detailed status information
|
||||||
func (h *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
|
func (h *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
status := map[string]interface{}{
|
status := map[string]interface{}{
|
||||||
"status": "running",
|
"status": "running",
|
||||||
"timestamp": time.Now().Unix(),
|
"timestamp": time.Now().Unix(),
|
||||||
"hypercore": h.hypercoreLog.GetStats(),
|
"hypercore": h.hypercoreLog.GetStats(),
|
||||||
"api_version": "1.0.0",
|
"api_version": "1.0.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
json.NewEncoder(w).Encode(status)
|
json.NewEncoder(w).Encode(status)
|
||||||
}
|
}
|
||||||
|
|||||||
BIN
chorus-agent
Executable file
BIN
chorus-agent
Executable file
Binary file not shown.
@@ -8,12 +8,19 @@ import (
|
|||||||
"chorus/internal/runtime"
|
"chorus/internal/runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Build-time variables set by ldflags
|
||||||
|
var (
|
||||||
|
version = "0.5.0-dev"
|
||||||
|
commitHash = "unknown"
|
||||||
|
buildDate = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Early CLI handling: print help/version without requiring env/config
|
// Early CLI handling: print help/version without requiring env/config
|
||||||
for _, a := range os.Args[1:] {
|
for _, a := range os.Args[1:] {
|
||||||
switch a {
|
switch a {
|
||||||
case "--help", "-h", "help":
|
case "--help", "-h", "help":
|
||||||
fmt.Printf("%s-agent %s\n\n", runtime.AppName, runtime.AppVersion)
|
fmt.Printf("%s-agent %s (build: %s, %s)\n\n", runtime.AppName, version, commitHash, buildDate)
|
||||||
fmt.Println("Usage:")
|
fmt.Println("Usage:")
|
||||||
fmt.Printf(" %s [--help] [--version]\n\n", filepath.Base(os.Args[0]))
|
fmt.Printf(" %s [--help] [--version]\n\n", filepath.Base(os.Args[0]))
|
||||||
fmt.Println("CHORUS Autonomous Agent - P2P Task Coordination")
|
fmt.Println("CHORUS Autonomous Agent - P2P Task Coordination")
|
||||||
@@ -46,11 +53,16 @@ func main() {
|
|||||||
fmt.Println(" - Health monitoring")
|
fmt.Println(" - Health monitoring")
|
||||||
return
|
return
|
||||||
case "--version", "-v":
|
case "--version", "-v":
|
||||||
fmt.Printf("%s-agent %s\n", runtime.AppName, runtime.AppVersion)
|
fmt.Printf("%s-agent %s (build: %s, %s)\n", runtime.AppName, version, commitHash, buildDate)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set dynamic build information
|
||||||
|
runtime.AppVersion = version
|
||||||
|
runtime.AppCommitHash = commitHash
|
||||||
|
runtime.AppBuildDate = buildDate
|
||||||
|
|
||||||
// Initialize shared P2P runtime
|
// Initialize shared P2P runtime
|
||||||
sharedRuntime, err := runtime.Initialize("agent")
|
sharedRuntime, err := runtime.Initialize("agent")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
372
configs/models.yaml
Normal file
372
configs/models.yaml
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
# CHORUS AI Provider and Model Configuration
|
||||||
|
# This file defines how different agent roles map to AI models and providers
|
||||||
|
|
||||||
|
# Global provider settings
|
||||||
|
providers:
|
||||||
|
# Local Ollama instance (default for most roles)
|
||||||
|
ollama:
|
||||||
|
type: ollama
|
||||||
|
endpoint: http://localhost:11434
|
||||||
|
default_model: llama3.1:8b
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 4096
|
||||||
|
timeout: 300s
|
||||||
|
retry_attempts: 3
|
||||||
|
retry_delay: 2s
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
mcp_servers: []
|
||||||
|
|
||||||
|
# Ollama cluster nodes (for load balancing)
|
||||||
|
ollama_cluster:
|
||||||
|
type: ollama
|
||||||
|
endpoint: http://192.168.1.72:11434 # Primary node
|
||||||
|
default_model: llama3.1:8b
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 4096
|
||||||
|
timeout: 300s
|
||||||
|
retry_attempts: 3
|
||||||
|
retry_delay: 2s
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
|
||||||
|
# OpenAI API (for advanced models)
|
||||||
|
openai:
|
||||||
|
type: openai
|
||||||
|
endpoint: https://api.openai.com/v1
|
||||||
|
api_key: ${OPENAI_API_KEY}
|
||||||
|
default_model: gpt-4o
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 4096
|
||||||
|
timeout: 120s
|
||||||
|
retry_attempts: 3
|
||||||
|
retry_delay: 5s
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
|
||||||
|
# ResetData LaaS (fallback/testing)
|
||||||
|
resetdata:
|
||||||
|
type: resetdata
|
||||||
|
endpoint: ${RESETDATA_ENDPOINT}
|
||||||
|
api_key: ${RESETDATA_API_KEY}
|
||||||
|
default_model: llama3.1:8b
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 4096
|
||||||
|
timeout: 300s
|
||||||
|
retry_attempts: 3
|
||||||
|
retry_delay: 2s
|
||||||
|
enable_tools: false
|
||||||
|
enable_mcp: false
|
||||||
|
|
||||||
|
# Global fallback settings
|
||||||
|
default_provider: ollama
|
||||||
|
fallback_provider: resetdata
|
||||||
|
|
||||||
|
# Role-based model mappings
|
||||||
|
roles:
|
||||||
|
# Software Developer Agent
|
||||||
|
developer:
|
||||||
|
provider: ollama
|
||||||
|
model: codellama:13b
|
||||||
|
temperature: 0.3 # Lower temperature for more consistent code
|
||||||
|
max_tokens: 8192 # Larger context for code generation
|
||||||
|
system_prompt: |
|
||||||
|
You are an expert software developer agent in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your expertise includes:
|
||||||
|
- Writing clean, maintainable, and well-documented code
|
||||||
|
- Following language-specific best practices and conventions
|
||||||
|
- Implementing proper error handling and validation
|
||||||
|
- Creating comprehensive tests for your code
|
||||||
|
- Considering performance, security, and scalability
|
||||||
|
|
||||||
|
Always provide specific, actionable implementation steps with code examples.
|
||||||
|
Focus on delivering production-ready solutions that follow industry best practices.
|
||||||
|
fallback_provider: resetdata
|
||||||
|
fallback_model: codellama:7b
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
allowed_tools:
|
||||||
|
- file_operation
|
||||||
|
- execute_command
|
||||||
|
- git_operations
|
||||||
|
- code_analysis
|
||||||
|
mcp_servers:
|
||||||
|
- file-server
|
||||||
|
- git-server
|
||||||
|
- code-tools
|
||||||
|
|
||||||
|
# Code Reviewer Agent
|
||||||
|
reviewer:
|
||||||
|
provider: ollama
|
||||||
|
model: llama3.1:8b
|
||||||
|
temperature: 0.2 # Very low temperature for consistent analysis
|
||||||
|
max_tokens: 6144
|
||||||
|
system_prompt: |
|
||||||
|
You are a thorough code reviewer agent in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your responsibilities include:
|
||||||
|
- Analyzing code quality, readability, and maintainability
|
||||||
|
- Identifying bugs, security vulnerabilities, and performance issues
|
||||||
|
- Checking test coverage and test quality
|
||||||
|
- Verifying documentation completeness and accuracy
|
||||||
|
- Suggesting improvements and refactoring opportunities
|
||||||
|
- Ensuring compliance with coding standards and best practices
|
||||||
|
|
||||||
|
Always provide constructive feedback with specific examples and suggestions for improvement.
|
||||||
|
Focus on both technical correctness and long-term maintainability.
|
||||||
|
fallback_provider: resetdata
|
||||||
|
fallback_model: llama3.1:8b
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
allowed_tools:
|
||||||
|
- code_analysis
|
||||||
|
- security_scan
|
||||||
|
- test_coverage
|
||||||
|
- documentation_check
|
||||||
|
mcp_servers:
|
||||||
|
- code-analysis-server
|
||||||
|
- security-tools
|
||||||
|
|
||||||
|
# Software Architect Agent
|
||||||
|
architect:
|
||||||
|
provider: openai # Use OpenAI for complex architectural decisions
|
||||||
|
model: gpt-4o
|
||||||
|
temperature: 0.5 # Balanced creativity and consistency
|
||||||
|
max_tokens: 8192 # Large context for architectural discussions
|
||||||
|
system_prompt: |
|
||||||
|
You are a senior software architect agent in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your expertise includes:
|
||||||
|
- Designing scalable and maintainable system architectures
|
||||||
|
- Making informed decisions about technologies and frameworks
|
||||||
|
- Defining clear interfaces and API contracts
|
||||||
|
- Considering scalability, performance, and security requirements
|
||||||
|
- Creating architectural documentation and diagrams
|
||||||
|
- Evaluating trade-offs between different architectural approaches
|
||||||
|
|
||||||
|
Always provide well-reasoned architectural decisions with clear justifications.
|
||||||
|
Consider both immediate requirements and long-term evolution of the system.
|
||||||
|
fallback_provider: ollama
|
||||||
|
fallback_model: llama3.1:13b
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
allowed_tools:
|
||||||
|
- architecture_analysis
|
||||||
|
- diagram_generation
|
||||||
|
- technology_research
|
||||||
|
- api_design
|
||||||
|
mcp_servers:
|
||||||
|
- architecture-tools
|
||||||
|
- diagram-server
|
||||||
|
|
||||||
|
# QA/Testing Agent
|
||||||
|
tester:
|
||||||
|
provider: ollama
|
||||||
|
model: codellama:7b # Smaller model, focused on test generation
|
||||||
|
temperature: 0.3
|
||||||
|
max_tokens: 6144
|
||||||
|
system_prompt: |
|
||||||
|
You are a quality assurance engineer agent in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your responsibilities include:
|
||||||
|
- Creating comprehensive test plans and test cases
|
||||||
|
- Implementing unit, integration, and end-to-end tests
|
||||||
|
- Identifying edge cases and potential failure scenarios
|
||||||
|
- Setting up test automation and continuous integration
|
||||||
|
- Validating functionality against requirements
|
||||||
|
- Performing security and performance testing
|
||||||
|
|
||||||
|
Always focus on thorough test coverage and quality assurance practices.
|
||||||
|
Ensure tests are maintainable, reliable, and provide meaningful feedback.
|
||||||
|
fallback_provider: resetdata
|
||||||
|
fallback_model: llama3.1:8b
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
allowed_tools:
|
||||||
|
- test_generation
|
||||||
|
- test_execution
|
||||||
|
- coverage_analysis
|
||||||
|
- performance_testing
|
||||||
|
mcp_servers:
|
||||||
|
- testing-framework
|
||||||
|
- coverage-tools
|
||||||
|
|
||||||
|
# DevOps/Infrastructure Agent
|
||||||
|
devops:
|
||||||
|
provider: ollama_cluster
|
||||||
|
model: llama3.1:8b
|
||||||
|
temperature: 0.4
|
||||||
|
max_tokens: 6144
|
||||||
|
system_prompt: |
|
||||||
|
You are a DevOps engineer agent in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your expertise includes:
|
||||||
|
- Automating deployment processes and CI/CD pipelines
|
||||||
|
- Managing containerization with Docker and orchestration with Kubernetes
|
||||||
|
- Implementing infrastructure as code (IaC)
|
||||||
|
- Monitoring, logging, and observability setup
|
||||||
|
- Security hardening and compliance management
|
||||||
|
- Performance optimization and scaling strategies
|
||||||
|
|
||||||
|
Always focus on automation, reliability, and security in your solutions.
|
||||||
|
Ensure infrastructure is scalable, maintainable, and follows best practices.
|
||||||
|
fallback_provider: resetdata
|
||||||
|
fallback_model: llama3.1:8b
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
allowed_tools:
|
||||||
|
- docker_operations
|
||||||
|
- kubernetes_management
|
||||||
|
- ci_cd_tools
|
||||||
|
- monitoring_setup
|
||||||
|
- security_hardening
|
||||||
|
mcp_servers:
|
||||||
|
- docker-server
|
||||||
|
- k8s-tools
|
||||||
|
- monitoring-server
|
||||||
|
|
||||||
|
# Security Specialist Agent
|
||||||
|
security:
|
||||||
|
provider: openai
|
||||||
|
model: gpt-4o # Use advanced model for security analysis
|
||||||
|
temperature: 0.1 # Very conservative for security
|
||||||
|
max_tokens: 8192
|
||||||
|
system_prompt: |
|
||||||
|
You are a security specialist agent in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your expertise includes:
|
||||||
|
- Conducting security audits and vulnerability assessments
|
||||||
|
- Implementing security best practices and controls
|
||||||
|
- Analyzing code for security vulnerabilities
|
||||||
|
- Setting up security monitoring and incident response
|
||||||
|
- Ensuring compliance with security standards
|
||||||
|
- Designing secure architectures and data flows
|
||||||
|
|
||||||
|
Always prioritize security over convenience and thoroughly analyze potential threats.
|
||||||
|
Provide specific, actionable security recommendations with risk assessments.
|
||||||
|
fallback_provider: ollama
|
||||||
|
fallback_model: llama3.1:8b
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
allowed_tools:
|
||||||
|
- security_scan
|
||||||
|
- vulnerability_assessment
|
||||||
|
- compliance_check
|
||||||
|
- threat_modeling
|
||||||
|
mcp_servers:
|
||||||
|
- security-tools
|
||||||
|
- compliance-server
|
||||||
|
|
||||||
|
# Documentation Agent
|
||||||
|
documentation:
|
||||||
|
provider: ollama
|
||||||
|
model: llama3.1:8b
|
||||||
|
temperature: 0.6 # Slightly higher for creative writing
|
||||||
|
max_tokens: 8192
|
||||||
|
system_prompt: |
|
||||||
|
You are a technical documentation specialist agent in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your expertise includes:
|
||||||
|
- Creating clear, comprehensive technical documentation
|
||||||
|
- Writing user guides, API documentation, and tutorials
|
||||||
|
- Maintaining README files and project wikis
|
||||||
|
- Creating architectural decision records (ADRs)
|
||||||
|
- Developing onboarding materials and runbooks
|
||||||
|
- Ensuring documentation accuracy and completeness
|
||||||
|
|
||||||
|
Always write documentation that is clear, actionable, and accessible to your target audience.
|
||||||
|
Focus on providing practical information that helps users accomplish their goals.
|
||||||
|
fallback_provider: resetdata
|
||||||
|
fallback_model: llama3.1:8b
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
allowed_tools:
|
||||||
|
- documentation_generation
|
||||||
|
- markdown_processing
|
||||||
|
- diagram_creation
|
||||||
|
- content_validation
|
||||||
|
mcp_servers:
|
||||||
|
- docs-server
|
||||||
|
- markdown-tools
|
||||||
|
|
||||||
|
# General Purpose Agent (fallback)
|
||||||
|
general:
|
||||||
|
provider: ollama
|
||||||
|
model: llama3.1:8b
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 4096
|
||||||
|
system_prompt: |
|
||||||
|
You are a general-purpose AI agent in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your capabilities include:
|
||||||
|
- Analyzing and understanding various types of development tasks
|
||||||
|
- Providing guidance on software development best practices
|
||||||
|
- Assisting with problem-solving and decision-making
|
||||||
|
- Coordinating with other specialized agents when needed
|
||||||
|
|
||||||
|
Always provide helpful, accurate information and know when to defer to specialized agents.
|
||||||
|
Focus on understanding the task requirements and providing appropriate guidance.
|
||||||
|
fallback_provider: resetdata
|
||||||
|
fallback_model: llama3.1:8b
|
||||||
|
enable_tools: true
|
||||||
|
enable_mcp: true
|
||||||
|
|
||||||
|
# Environment-specific overrides
|
||||||
|
environments:
|
||||||
|
development:
|
||||||
|
# Use local models for development to reduce costs
|
||||||
|
default_provider: ollama
|
||||||
|
fallback_provider: resetdata
|
||||||
|
|
||||||
|
staging:
|
||||||
|
# Mix of local and cloud models for realistic testing
|
||||||
|
default_provider: ollama_cluster
|
||||||
|
fallback_provider: openai
|
||||||
|
|
||||||
|
production:
|
||||||
|
# Prefer reliable cloud providers with fallback to local
|
||||||
|
default_provider: openai
|
||||||
|
fallback_provider: ollama_cluster
|
||||||
|
|
||||||
|
# Model performance preferences (for auto-selection)
|
||||||
|
model_preferences:
|
||||||
|
# Code generation tasks
|
||||||
|
code_generation:
|
||||||
|
preferred_models:
|
||||||
|
- codellama:13b
|
||||||
|
- gpt-4o
|
||||||
|
- codellama:34b
|
||||||
|
min_context_tokens: 8192
|
||||||
|
|
||||||
|
# Code review tasks
|
||||||
|
code_review:
|
||||||
|
preferred_models:
|
||||||
|
- llama3.1:8b
|
||||||
|
- gpt-4o
|
||||||
|
- llama3.1:13b
|
||||||
|
min_context_tokens: 6144
|
||||||
|
|
||||||
|
# Architecture and design
|
||||||
|
architecture:
|
||||||
|
preferred_models:
|
||||||
|
- gpt-4o
|
||||||
|
- llama3.1:13b
|
||||||
|
- llama3.1:70b
|
||||||
|
min_context_tokens: 8192
|
||||||
|
|
||||||
|
# Testing and QA
|
||||||
|
testing:
|
||||||
|
preferred_models:
|
||||||
|
- codellama:7b
|
||||||
|
- llama3.1:8b
|
||||||
|
- codellama:13b
|
||||||
|
min_context_tokens: 6144
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
documentation:
|
||||||
|
preferred_models:
|
||||||
|
- llama3.1:8b
|
||||||
|
- gpt-4o
|
||||||
|
- mistral:7b
|
||||||
|
min_context_tokens: 8192
|
||||||
@@ -8,7 +8,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"chorus/internal/logging"
|
"chorus/internal/logging"
|
||||||
|
"chorus/pkg/ai"
|
||||||
"chorus/pkg/config"
|
"chorus/pkg/config"
|
||||||
|
"chorus/pkg/execution"
|
||||||
"chorus/pkg/hmmm"
|
"chorus/pkg/hmmm"
|
||||||
"chorus/pkg/repository"
|
"chorus/pkg/repository"
|
||||||
"chorus/pubsub"
|
"chorus/pubsub"
|
||||||
@@ -41,6 +43,9 @@ type TaskCoordinator struct {
|
|||||||
taskMatcher repository.TaskMatcher
|
taskMatcher repository.TaskMatcher
|
||||||
taskTracker TaskProgressTracker
|
taskTracker TaskProgressTracker
|
||||||
|
|
||||||
|
// Task execution
|
||||||
|
executionEngine execution.TaskExecutionEngine
|
||||||
|
|
||||||
// Agent tracking
|
// Agent tracking
|
||||||
nodeID string
|
nodeID string
|
||||||
agentInfo *repository.AgentInfo
|
agentInfo *repository.AgentInfo
|
||||||
@@ -109,6 +114,13 @@ func NewTaskCoordinator(
|
|||||||
func (tc *TaskCoordinator) Start() {
|
func (tc *TaskCoordinator) Start() {
|
||||||
fmt.Printf("🎯 Starting task coordinator for agent %s (%s)\n", tc.agentInfo.ID, tc.agentInfo.Role)
|
fmt.Printf("🎯 Starting task coordinator for agent %s (%s)\n", tc.agentInfo.ID, tc.agentInfo.Role)
|
||||||
|
|
||||||
|
// Initialize task execution engine
|
||||||
|
err := tc.initializeExecutionEngine()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("⚠️ Failed to initialize task execution engine: %v\n", err)
|
||||||
|
fmt.Println("Task execution will fall back to mock implementation")
|
||||||
|
}
|
||||||
|
|
||||||
// Announce role and capabilities
|
// Announce role and capabilities
|
||||||
tc.announceAgentRole()
|
tc.announceAgentRole()
|
||||||
|
|
||||||
@@ -299,6 +311,65 @@ func (tc *TaskCoordinator) requestTaskCollaboration(task *repository.Task) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initializeExecutionEngine sets up the AI-powered task execution engine
|
||||||
|
func (tc *TaskCoordinator) initializeExecutionEngine() error {
|
||||||
|
// Create AI provider factory
|
||||||
|
aiFactory := ai.NewProviderFactory()
|
||||||
|
|
||||||
|
// Load AI configuration from config file
|
||||||
|
configPath := "configs/models.yaml"
|
||||||
|
configLoader := ai.NewConfigLoader(configPath, "production")
|
||||||
|
_, err := configLoader.LoadConfig()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load AI config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the factory with the loaded configuration
|
||||||
|
// For now, we'll use a simplified initialization
|
||||||
|
// In a complete implementation, the factory would have an Initialize method
|
||||||
|
|
||||||
|
// Create task execution engine
|
||||||
|
tc.executionEngine = execution.NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
// Configure execution engine
|
||||||
|
engineConfig := &execution.EngineConfig{
|
||||||
|
AIProviderFactory: aiFactory,
|
||||||
|
DefaultTimeout: 5 * time.Minute,
|
||||||
|
MaxConcurrentTasks: tc.agentInfo.MaxTasks,
|
||||||
|
EnableMetrics: true,
|
||||||
|
LogLevel: "info",
|
||||||
|
SandboxDefaults: &execution.SandboxConfig{
|
||||||
|
Type: "docker",
|
||||||
|
Image: "alpine:latest",
|
||||||
|
Architecture: "amd64",
|
||||||
|
Resources: execution.ResourceLimits{
|
||||||
|
MemoryLimit: 512 * 1024 * 1024, // 512MB
|
||||||
|
CPULimit: 1.0,
|
||||||
|
ProcessLimit: 50,
|
||||||
|
FileLimit: 1024,
|
||||||
|
},
|
||||||
|
Security: execution.SecurityPolicy{
|
||||||
|
ReadOnlyRoot: false,
|
||||||
|
NoNewPrivileges: true,
|
||||||
|
AllowNetworking: true,
|
||||||
|
IsolateNetwork: false,
|
||||||
|
IsolateProcess: true,
|
||||||
|
DropCapabilities: []string{"NET_ADMIN", "SYS_ADMIN"},
|
||||||
|
},
|
||||||
|
WorkingDir: "/workspace",
|
||||||
|
Timeout: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tc.executionEngine.Initialize(tc.ctx, engineConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize execution engine: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("✅ Task execution engine initialized successfully\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// executeTask executes a claimed task
|
// executeTask executes a claimed task
|
||||||
func (tc *TaskCoordinator) executeTask(activeTask *ActiveTask) {
|
func (tc *TaskCoordinator) executeTask(activeTask *ActiveTask) {
|
||||||
taskKey := fmt.Sprintf("%s:%d", activeTask.Task.Repository, activeTask.Task.Number)
|
taskKey := fmt.Sprintf("%s:%d", activeTask.Task.Repository, activeTask.Task.Number)
|
||||||
@@ -311,21 +382,27 @@ func (tc *TaskCoordinator) executeTask(activeTask *ActiveTask) {
|
|||||||
// Announce work start
|
// Announce work start
|
||||||
tc.announceTaskProgress(activeTask.Task, "started")
|
tc.announceTaskProgress(activeTask.Task, "started")
|
||||||
|
|
||||||
// Simulate task execution (in real implementation, this would call actual execution logic)
|
// Execute task using AI-powered execution engine
|
||||||
time.Sleep(10 * time.Second) // Simulate work
|
var taskResult *repository.TaskResult
|
||||||
|
|
||||||
// Complete the task
|
if tc.executionEngine != nil {
|
||||||
results := map[string]interface{}{
|
// Use real AI-powered execution
|
||||||
"status": "completed",
|
executionResult, err := tc.executeTaskWithAI(activeTask)
|
||||||
"completion_time": time.Now().Format(time.RFC3339),
|
if err != nil {
|
||||||
"agent_id": tc.agentInfo.ID,
|
fmt.Printf("⚠️ AI execution failed for task %s #%d: %v\n",
|
||||||
"agent_role": tc.agentInfo.Role,
|
activeTask.Task.Repository, activeTask.Task.Number, err)
|
||||||
}
|
|
||||||
|
|
||||||
taskResult := &repository.TaskResult{
|
// Fall back to mock execution
|
||||||
Success: true,
|
taskResult = tc.executeMockTask(activeTask)
|
||||||
Message: "Task completed successfully",
|
} else {
|
||||||
Metadata: results,
|
// Convert execution result to task result
|
||||||
|
taskResult = tc.convertExecutionResult(activeTask, executionResult)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fall back to mock execution
|
||||||
|
fmt.Printf("📝 Using mock execution for task %s #%d (engine not available)\n",
|
||||||
|
activeTask.Task.Repository, activeTask.Task.Number)
|
||||||
|
taskResult = tc.executeMockTask(activeTask)
|
||||||
}
|
}
|
||||||
err := activeTask.Provider.CompleteTask(activeTask.Task, taskResult)
|
err := activeTask.Provider.CompleteTask(activeTask.Task, taskResult)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -343,7 +420,7 @@ func (tc *TaskCoordinator) executeTask(activeTask *ActiveTask) {
|
|||||||
// Update status and remove from active tasks
|
// Update status and remove from active tasks
|
||||||
tc.taskLock.Lock()
|
tc.taskLock.Lock()
|
||||||
activeTask.Status = "completed"
|
activeTask.Status = "completed"
|
||||||
activeTask.Results = results
|
activeTask.Results = taskResult.Metadata
|
||||||
delete(tc.activeTasks, taskKey)
|
delete(tc.activeTasks, taskKey)
|
||||||
tc.agentInfo.CurrentTasks = len(tc.activeTasks)
|
tc.agentInfo.CurrentTasks = len(tc.activeTasks)
|
||||||
tc.taskLock.Unlock()
|
tc.taskLock.Unlock()
|
||||||
@@ -357,7 +434,7 @@ func (tc *TaskCoordinator) executeTask(activeTask *ActiveTask) {
|
|||||||
"task_number": activeTask.Task.Number,
|
"task_number": activeTask.Task.Number,
|
||||||
"repository": activeTask.Task.Repository,
|
"repository": activeTask.Task.Repository,
|
||||||
"duration": time.Since(activeTask.ClaimedAt).Seconds(),
|
"duration": time.Since(activeTask.ClaimedAt).Seconds(),
|
||||||
"results": results,
|
"results": taskResult.Metadata,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Announce completion
|
// Announce completion
|
||||||
@@ -366,6 +443,200 @@ func (tc *TaskCoordinator) executeTask(activeTask *ActiveTask) {
|
|||||||
fmt.Printf("✅ Completed task %s #%d\n", activeTask.Task.Repository, activeTask.Task.Number)
|
fmt.Printf("✅ Completed task %s #%d\n", activeTask.Task.Repository, activeTask.Task.Number)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// executeTaskWithAI executes a task using the AI-powered execution engine
|
||||||
|
func (tc *TaskCoordinator) executeTaskWithAI(activeTask *ActiveTask) (*execution.TaskExecutionResult, error) {
|
||||||
|
// Convert repository task to execution request
|
||||||
|
executionRequest := &execution.TaskExecutionRequest{
|
||||||
|
ID: fmt.Sprintf("%s:%d", activeTask.Task.Repository, activeTask.Task.Number),
|
||||||
|
Type: tc.determineTaskType(activeTask.Task),
|
||||||
|
Description: tc.buildTaskDescription(activeTask.Task),
|
||||||
|
Context: tc.buildTaskContext(activeTask.Task),
|
||||||
|
Requirements: &execution.TaskRequirements{
|
||||||
|
AIModel: "", // Let the engine choose based on role
|
||||||
|
SandboxType: "docker",
|
||||||
|
RequiredTools: []string{"git", "curl"},
|
||||||
|
EnvironmentVars: map[string]string{
|
||||||
|
"TASK_ID": fmt.Sprintf("%d", activeTask.Task.Number),
|
||||||
|
"REPOSITORY": activeTask.Task.Repository,
|
||||||
|
"AGENT_ID": tc.agentInfo.ID,
|
||||||
|
"AGENT_ROLE": tc.agentInfo.Role,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Timeout: 10 * time.Minute, // Allow longer timeout for complex tasks
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the task
|
||||||
|
return tc.executionEngine.ExecuteTask(tc.ctx, executionRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeMockTask provides fallback mock execution
|
||||||
|
func (tc *TaskCoordinator) executeMockTask(activeTask *ActiveTask) *repository.TaskResult {
|
||||||
|
// Simulate work time based on task complexity
|
||||||
|
workTime := 5 * time.Second
|
||||||
|
if strings.Contains(strings.ToLower(activeTask.Task.Title), "complex") {
|
||||||
|
workTime = 15 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("🕐 Mock execution for task %s #%d (simulating %v)\n",
|
||||||
|
activeTask.Task.Repository, activeTask.Task.Number, workTime)
|
||||||
|
|
||||||
|
time.Sleep(workTime)
|
||||||
|
|
||||||
|
results := map[string]interface{}{
|
||||||
|
"status": "completed",
|
||||||
|
"execution_type": "mock",
|
||||||
|
"completion_time": time.Now().Format(time.RFC3339),
|
||||||
|
"agent_id": tc.agentInfo.ID,
|
||||||
|
"agent_role": tc.agentInfo.Role,
|
||||||
|
"simulated_work": workTime.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &repository.TaskResult{
|
||||||
|
Success: true,
|
||||||
|
Message: "Task completed successfully (mock execution)",
|
||||||
|
Metadata: results,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertExecutionResult converts an execution result to a task result
|
||||||
|
func (tc *TaskCoordinator) convertExecutionResult(activeTask *ActiveTask, result *execution.TaskExecutionResult) *repository.TaskResult {
|
||||||
|
// Build result metadata
|
||||||
|
metadata := map[string]interface{}{
|
||||||
|
"status": "completed",
|
||||||
|
"execution_type": "ai_powered",
|
||||||
|
"completion_time": time.Now().Format(time.RFC3339),
|
||||||
|
"agent_id": tc.agentInfo.ID,
|
||||||
|
"agent_role": tc.agentInfo.Role,
|
||||||
|
"task_id": result.TaskID,
|
||||||
|
"duration": result.Metrics.Duration.String(),
|
||||||
|
"ai_provider_time": result.Metrics.AIProviderTime.String(),
|
||||||
|
"sandbox_time": result.Metrics.SandboxTime.String(),
|
||||||
|
"commands_executed": result.Metrics.CommandsExecuted,
|
||||||
|
"files_generated": result.Metrics.FilesGenerated,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add execution metadata if available
|
||||||
|
if result.Metadata != nil {
|
||||||
|
metadata["ai_metadata"] = result.Metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add resource usage if available
|
||||||
|
if result.Metrics.ResourceUsage != nil {
|
||||||
|
metadata["resource_usage"] = map[string]interface{}{
|
||||||
|
"cpu_usage": result.Metrics.ResourceUsage.CPUUsage,
|
||||||
|
"memory_usage": result.Metrics.ResourceUsage.MemoryUsage,
|
||||||
|
"memory_percent": result.Metrics.ResourceUsage.MemoryPercent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle artifacts
|
||||||
|
if len(result.Artifacts) > 0 {
|
||||||
|
artifactsList := make([]map[string]interface{}, len(result.Artifacts))
|
||||||
|
for i, artifact := range result.Artifacts {
|
||||||
|
artifactsList[i] = map[string]interface{}{
|
||||||
|
"name": artifact.Name,
|
||||||
|
"type": artifact.Type,
|
||||||
|
"size": artifact.Size,
|
||||||
|
"created_at": artifact.CreatedAt.Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
metadata["artifacts"] = artifactsList
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine success based on execution result
|
||||||
|
success := result.Success
|
||||||
|
message := "Task completed successfully with AI execution"
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
message = fmt.Sprintf("Task failed: %s", result.ErrorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &repository.TaskResult{
|
||||||
|
Success: success,
|
||||||
|
Message: message,
|
||||||
|
Metadata: metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// determineTaskType analyzes a task to determine its execution type
|
||||||
|
func (tc *TaskCoordinator) determineTaskType(task *repository.Task) string {
|
||||||
|
title := strings.ToLower(task.Title)
|
||||||
|
description := strings.ToLower(task.Body)
|
||||||
|
|
||||||
|
// Check for common task type keywords
|
||||||
|
if strings.Contains(title, "bug") || strings.Contains(title, "fix") {
|
||||||
|
return "bug_fix"
|
||||||
|
}
|
||||||
|
if strings.Contains(title, "feature") || strings.Contains(title, "implement") {
|
||||||
|
return "feature_development"
|
||||||
|
}
|
||||||
|
if strings.Contains(title, "test") || strings.Contains(description, "test") {
|
||||||
|
return "testing"
|
||||||
|
}
|
||||||
|
if strings.Contains(title, "doc") || strings.Contains(description, "documentation") {
|
||||||
|
return "documentation"
|
||||||
|
}
|
||||||
|
if strings.Contains(title, "refactor") || strings.Contains(description, "refactor") {
|
||||||
|
return "refactoring"
|
||||||
|
}
|
||||||
|
if strings.Contains(title, "review") || strings.Contains(description, "review") {
|
||||||
|
return "code_review"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to general development task
|
||||||
|
return "development"
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildTaskDescription creates a comprehensive description for AI execution
|
||||||
|
func (tc *TaskCoordinator) buildTaskDescription(task *repository.Task) string {
|
||||||
|
var description strings.Builder
|
||||||
|
|
||||||
|
description.WriteString(fmt.Sprintf("Task: %s\n\n", task.Title))
|
||||||
|
|
||||||
|
if task.Body != "" {
|
||||||
|
description.WriteString(fmt.Sprintf("Description:\n%s\n\n", task.Body))
|
||||||
|
}
|
||||||
|
|
||||||
|
description.WriteString(fmt.Sprintf("Repository: %s\n", task.Repository))
|
||||||
|
description.WriteString(fmt.Sprintf("Task Number: %d\n", task.Number))
|
||||||
|
|
||||||
|
if len(task.RequiredExpertise) > 0 {
|
||||||
|
description.WriteString(fmt.Sprintf("Required Expertise: %v\n", task.RequiredExpertise))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(task.Labels) > 0 {
|
||||||
|
description.WriteString(fmt.Sprintf("Labels: %v\n", task.Labels))
|
||||||
|
}
|
||||||
|
|
||||||
|
description.WriteString("\nPlease analyze this task and provide appropriate commands or code to complete it.")
|
||||||
|
|
||||||
|
return description.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildTaskContext creates context information for AI execution
|
||||||
|
func (tc *TaskCoordinator) buildTaskContext(task *repository.Task) map[string]interface{} {
|
||||||
|
context := map[string]interface{}{
|
||||||
|
"repository": task.Repository,
|
||||||
|
"task_number": task.Number,
|
||||||
|
"task_title": task.Title,
|
||||||
|
"required_role": task.RequiredRole,
|
||||||
|
"required_expertise": task.RequiredExpertise,
|
||||||
|
"labels": task.Labels,
|
||||||
|
"agent_info": map[string]interface{}{
|
||||||
|
"id": tc.agentInfo.ID,
|
||||||
|
"role": tc.agentInfo.Role,
|
||||||
|
"expertise": tc.agentInfo.Expertise,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add any additional metadata from the task
|
||||||
|
if task.Metadata != nil {
|
||||||
|
context["task_metadata"] = task.Metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
return context
|
||||||
|
}
|
||||||
|
|
||||||
// announceAgentRole announces this agent's role and capabilities
|
// announceAgentRole announces this agent's role and capabilities
|
||||||
func (tc *TaskCoordinator) announceAgentRole() {
|
func (tc *TaskCoordinator) announceAgentRole() {
|
||||||
data := map[string]interface{}{
|
data := map[string]interface{}{
|
||||||
|
|||||||
@@ -11,15 +11,15 @@ WORKDIR /build
|
|||||||
# Copy go mod files first (for better caching)
|
# Copy go mod files first (for better caching)
|
||||||
COPY go.mod go.sum ./
|
COPY go.mod go.sum ./
|
||||||
|
|
||||||
# Copy vendor directory for local dependencies
|
# Download dependencies
|
||||||
COPY vendor/ vendor/
|
RUN go mod download
|
||||||
|
|
||||||
# Copy source code
|
# Copy source code
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Build the CHORUS binary with vendor mode
|
# Build the CHORUS binary with mod mode
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||||
-mod=vendor \
|
-mod=mod \
|
||||||
-ldflags='-w -s -extldflags "-static"' \
|
-ldflags='-w -s -extldflags "-static"' \
|
||||||
-o chorus \
|
-o chorus \
|
||||||
./cmd/chorus
|
./cmd/chorus
|
||||||
|
|||||||
38
docker/bootstrap.json
Normal file
38
docker/bootstrap.json
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
{
|
||||||
|
"metadata": {
|
||||||
|
"generated_at": "2024-12-19T10:00:00Z",
|
||||||
|
"cluster_id": "production-cluster",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"notes": "Bootstrap configuration for CHORUS scaling - managed by WHOOSH"
|
||||||
|
},
|
||||||
|
"peers": [
|
||||||
|
{
|
||||||
|
"address": "/ip4/10.0.1.10/tcp/9000/p2p/12D3KooWExample1234567890abcdef",
|
||||||
|
"priority": 100,
|
||||||
|
"region": "us-east-1",
|
||||||
|
"roles": ["admin", "stable"],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"address": "/ip4/10.0.1.11/tcp/9000/p2p/12D3KooWExample1234567890abcde2",
|
||||||
|
"priority": 90,
|
||||||
|
"region": "us-east-1",
|
||||||
|
"roles": ["worker", "stable"],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"address": "/ip4/10.0.2.10/tcp/9000/p2p/12D3KooWExample1234567890abcde3",
|
||||||
|
"priority": 80,
|
||||||
|
"region": "us-west-2",
|
||||||
|
"roles": ["worker", "stable"],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"address": "/ip4/10.0.3.10/tcp/9000/p2p/12D3KooWExample1234567890abcde4",
|
||||||
|
"priority": 70,
|
||||||
|
"region": "eu-central-1",
|
||||||
|
"roles": ["worker"],
|
||||||
|
"enabled": false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -2,7 +2,7 @@ version: "3.9"
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
chorus:
|
chorus:
|
||||||
image: anthonyrawlins/chorus:resetdata-secrets-v1.0.5
|
image: anthonyrawlins/chorus:latest
|
||||||
|
|
||||||
# REQUIRED: License configuration (CHORUS will not start without this)
|
# REQUIRED: License configuration (CHORUS will not start without this)
|
||||||
environment:
|
environment:
|
||||||
@@ -15,13 +15,32 @@ services:
|
|||||||
- CHORUS_AGENT_ID=${CHORUS_AGENT_ID:-} # Auto-generated if not provided
|
- CHORUS_AGENT_ID=${CHORUS_AGENT_ID:-} # Auto-generated if not provided
|
||||||
- CHORUS_SPECIALIZATION=${CHORUS_SPECIALIZATION:-general_developer}
|
- CHORUS_SPECIALIZATION=${CHORUS_SPECIALIZATION:-general_developer}
|
||||||
- CHORUS_MAX_TASKS=${CHORUS_MAX_TASKS:-3}
|
- CHORUS_MAX_TASKS=${CHORUS_MAX_TASKS:-3}
|
||||||
- CHORUS_CAPABILITIES=${CHORUS_CAPABILITIES:-general_development,task_coordination}
|
- CHORUS_CAPABILITIES=general_development,task_coordination,admin_election
|
||||||
|
|
||||||
# Network configuration
|
# Network configuration
|
||||||
- CHORUS_API_PORT=8080
|
- CHORUS_API_PORT=8080
|
||||||
- CHORUS_HEALTH_PORT=8081
|
- CHORUS_HEALTH_PORT=8081
|
||||||
- CHORUS_P2P_PORT=9000
|
- CHORUS_P2P_PORT=9000
|
||||||
- CHORUS_BIND_ADDRESS=0.0.0.0
|
- CHORUS_BIND_ADDRESS=0.0.0.0
|
||||||
|
|
||||||
|
# Scaling optimizations (as per WHOOSH issue #7)
|
||||||
|
- CHORUS_MDNS_ENABLED=false # Disabled for container/swarm environments
|
||||||
|
- CHORUS_DIALS_PER_SEC=5 # Rate limit outbound connections to prevent storms
|
||||||
|
- CHORUS_MAX_CONCURRENT_DHT=16 # Limit concurrent DHT queries
|
||||||
|
|
||||||
|
# Election stability windows (Medium-risk fix 2.1)
|
||||||
|
- CHORUS_ELECTION_MIN_TERM=30s # Minimum time between elections to prevent churn
|
||||||
|
- CHORUS_LEADER_MIN_TERM=45s # Minimum time before challenging healthy leader
|
||||||
|
|
||||||
|
# Assignment system for runtime configuration (Medium-risk fix 2.2)
|
||||||
|
- ASSIGN_URL=${ASSIGN_URL:-} # Optional: WHOOSH assignment endpoint
|
||||||
|
- TASK_SLOT=${TASK_SLOT:-} # Optional: Task slot identifier
|
||||||
|
- TASK_ID=${TASK_ID:-} # Optional: Task identifier
|
||||||
|
- NODE_ID=${NODE_ID:-} # Optional: Node identifier
|
||||||
|
|
||||||
|
# Bootstrap pool configuration (supports JSON and CSV)
|
||||||
|
- BOOTSTRAP_JSON=/config/bootstrap.json # Optional: JSON bootstrap config
|
||||||
|
- CHORUS_BOOTSTRAP_PEERS=${CHORUS_BOOTSTRAP_PEERS:-} # CSV fallback
|
||||||
|
|
||||||
# AI configuration - Provider selection
|
# AI configuration - Provider selection
|
||||||
- CHORUS_AI_PROVIDER=${CHORUS_AI_PROVIDER:-resetdata}
|
- CHORUS_AI_PROVIDER=${CHORUS_AI_PROVIDER:-resetdata}
|
||||||
@@ -57,6 +76,11 @@ services:
|
|||||||
secrets:
|
secrets:
|
||||||
- chorus_license_id
|
- chorus_license_id
|
||||||
- resetdata_api_key
|
- resetdata_api_key
|
||||||
|
|
||||||
|
# Configuration files
|
||||||
|
configs:
|
||||||
|
- source: chorus_bootstrap
|
||||||
|
target: /config/bootstrap.json
|
||||||
|
|
||||||
# Persistent data storage
|
# Persistent data storage
|
||||||
volumes:
|
volumes:
|
||||||
@@ -71,7 +95,7 @@ services:
|
|||||||
# Container resource limits
|
# Container resource limits
|
||||||
deploy:
|
deploy:
|
||||||
mode: replicated
|
mode: replicated
|
||||||
replicas: ${CHORUS_REPLICAS:-1}
|
replicas: ${CHORUS_REPLICAS:-9}
|
||||||
update_config:
|
update_config:
|
||||||
parallelism: 1
|
parallelism: 1
|
||||||
delay: 10s
|
delay: 10s
|
||||||
@@ -91,7 +115,6 @@ services:
|
|||||||
memory: 128M
|
memory: 128M
|
||||||
placement:
|
placement:
|
||||||
constraints:
|
constraints:
|
||||||
- node.hostname != rosewood
|
|
||||||
- node.hostname != acacia
|
- node.hostname != acacia
|
||||||
preferences:
|
preferences:
|
||||||
- spread: node.hostname
|
- spread: node.hostname
|
||||||
@@ -169,7 +192,14 @@ services:
|
|||||||
# Scaling system configuration
|
# Scaling system configuration
|
||||||
WHOOSH_SCALING_KACHING_URL: "https://kaching.chorus.services"
|
WHOOSH_SCALING_KACHING_URL: "https://kaching.chorus.services"
|
||||||
WHOOSH_SCALING_BACKBEAT_URL: "http://backbeat-pulse:8080"
|
WHOOSH_SCALING_BACKBEAT_URL: "http://backbeat-pulse:8080"
|
||||||
WHOOSH_SCALING_CHORUS_URL: "http://chorus:8080"
|
WHOOSH_SCALING_CHORUS_URL: "http://chorus:9000"
|
||||||
|
|
||||||
|
# BACKBEAT integration configuration (temporarily disabled)
|
||||||
|
WHOOSH_BACKBEAT_ENABLED: "false"
|
||||||
|
WHOOSH_BACKBEAT_CLUSTER_ID: "chorus-production"
|
||||||
|
WHOOSH_BACKBEAT_AGENT_ID: "whoosh"
|
||||||
|
WHOOSH_BACKBEAT_NATS_URL: "nats://backbeat-nats:4222"
|
||||||
|
|
||||||
secrets:
|
secrets:
|
||||||
- whoosh_db_password
|
- whoosh_db_password
|
||||||
- gitea_token
|
- gitea_token
|
||||||
@@ -212,14 +242,16 @@ services:
|
|||||||
cpus: '0.25'
|
cpus: '0.25'
|
||||||
labels:
|
labels:
|
||||||
- traefik.enable=true
|
- traefik.enable=true
|
||||||
|
- traefik.docker.network=tengig
|
||||||
- traefik.http.routers.whoosh.rule=Host(`whoosh.chorus.services`)
|
- traefik.http.routers.whoosh.rule=Host(`whoosh.chorus.services`)
|
||||||
- traefik.http.routers.whoosh.tls=true
|
- traefik.http.routers.whoosh.tls=true
|
||||||
- traefik.http.routers.whoosh.tls.certresolver=letsencrypt
|
- traefik.http.routers.whoosh.tls.certresolver=letsencryptresolver
|
||||||
|
- traefik.http.routers.photoprism.entrypoints=web,web-secured
|
||||||
- traefik.http.services.whoosh.loadbalancer.server.port=8080
|
- traefik.http.services.whoosh.loadbalancer.server.port=8080
|
||||||
- traefik.http.middlewares.whoosh-auth.basicauth.users=admin:$$2y$$10$$example_hash
|
- traefik.http.services.photoprism.loadbalancer.passhostheader=true
|
||||||
|
- traefik.http.middlewares.whoosh-auth.basicauth.users=admin:$2y$10$example_hash
|
||||||
networks:
|
networks:
|
||||||
- tengig
|
- tengig
|
||||||
- whoosh-backend
|
|
||||||
- chorus_net
|
- chorus_net
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "/app/whoosh", "--health-check"]
|
test: ["CMD", "/app/whoosh", "--health-check"]
|
||||||
@@ -257,14 +289,13 @@ services:
|
|||||||
memory: 256M
|
memory: 256M
|
||||||
cpus: '0.5'
|
cpus: '0.5'
|
||||||
networks:
|
networks:
|
||||||
- whoosh-backend
|
|
||||||
- chorus_net
|
- chorus_net
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U whoosh"]
|
test: ["CMD-SHELL", "pg_isready -h localhost -p 5432 -U whoosh -d whoosh"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 5
|
retries: 5
|
||||||
start_period: 30s
|
start_period: 40s
|
||||||
|
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
@@ -292,7 +323,6 @@ services:
|
|||||||
memory: 64M
|
memory: 64M
|
||||||
cpus: '0.1'
|
cpus: '0.1'
|
||||||
networks:
|
networks:
|
||||||
- whoosh-backend
|
|
||||||
- chorus_net
|
- chorus_net
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "sh", "-c", "redis-cli --no-auth-warning -a $$(cat /run/secrets/redis_password) ping"]
|
test: ["CMD", "sh", "-c", "redis-cli --no-auth-warning -a $$(cat /run/secrets/redis_password) ping"]
|
||||||
@@ -310,6 +340,66 @@ services:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
prometheus:
|
||||||
|
image: prom/prometheus:latest
|
||||||
|
command:
|
||||||
|
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||||
|
- '--storage.tsdb.path=/prometheus'
|
||||||
|
- '--web.console.libraries=/usr/share/prometheus/console_libraries'
|
||||||
|
- '--web.console.templates=/usr/share/prometheus/consoles'
|
||||||
|
volumes:
|
||||||
|
- /rust/containers/CHORUS/monitoring/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml:ro
|
||||||
|
- /rust/containers/CHORUS/monitoring/prometheus:/prometheus
|
||||||
|
ports:
|
||||||
|
- "9099:9090" # Expose Prometheus UI
|
||||||
|
deploy:
|
||||||
|
replicas: 1
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.http.routers.prometheus.rule=Host(`prometheus.chorus.services`)
|
||||||
|
- traefik.http.routers.prometheus.entrypoints=web,web-secured
|
||||||
|
- traefik.http.routers.prometheus.tls=true
|
||||||
|
- traefik.http.routers.prometheus.tls.certresolver=letsencryptresolver
|
||||||
|
- traefik.http.services.prometheus.loadbalancer.server.port=9090
|
||||||
|
networks:
|
||||||
|
- chorus_net
|
||||||
|
- tengig
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:9090/-/ready"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
|
||||||
|
grafana:
|
||||||
|
image: grafana/grafana:latest
|
||||||
|
user: "1000:1000"
|
||||||
|
environment:
|
||||||
|
- GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD:-admin} # Use a strong password in production
|
||||||
|
- GF_SERVER_ROOT_URL=https://grafana.chorus.services
|
||||||
|
volumes:
|
||||||
|
- /rust/containers/CHORUS/monitoring/grafana:/var/lib/grafana
|
||||||
|
ports:
|
||||||
|
- "3300:3000" # Expose Grafana UI
|
||||||
|
deploy:
|
||||||
|
replicas: 1
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.http.routers.grafana.rule=Host(`grafana.chorus.services`)
|
||||||
|
- traefik.http.routers.grafana.entrypoints=web,web-secured
|
||||||
|
- traefik.http.routers.grafana.tls=true
|
||||||
|
- traefik.http.routers.grafana.tls.certresolver=letsencryptresolver
|
||||||
|
- traefik.http.services.grafana.loadbalancer.server.port=3000
|
||||||
|
networks:
|
||||||
|
- chorus_net
|
||||||
|
- tengig
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:3000/api/health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
|
||||||
# BACKBEAT Pulse Service - Leader-elected tempo broadcaster
|
# BACKBEAT Pulse Service - Leader-elected tempo broadcaster
|
||||||
# REQ: BACKBEAT-REQ-001 - Single BeatFrame publisher per cluster
|
# REQ: BACKBEAT-REQ-001 - Single BeatFrame publisher per cluster
|
||||||
# REQ: BACKBEAT-OPS-001 - One replica prefers leadership
|
# REQ: BACKBEAT-OPS-001 - One replica prefers leadership
|
||||||
@@ -355,8 +445,6 @@ services:
|
|||||||
placement:
|
placement:
|
||||||
preferences:
|
preferences:
|
||||||
- spread: node.hostname
|
- spread: node.hostname
|
||||||
constraints:
|
|
||||||
- node.hostname != rosewood # Avoid intermittent gaming PC
|
|
||||||
resources:
|
resources:
|
||||||
limits:
|
limits:
|
||||||
memory: 256M
|
memory: 256M
|
||||||
@@ -424,8 +512,6 @@ services:
|
|||||||
placement:
|
placement:
|
||||||
preferences:
|
preferences:
|
||||||
- spread: node.hostname
|
- spread: node.hostname
|
||||||
constraints:
|
|
||||||
- node.hostname != rosewood
|
|
||||||
resources:
|
resources:
|
||||||
limits:
|
limits:
|
||||||
memory: 512M # Larger for window aggregation
|
memory: 512M # Larger for window aggregation
|
||||||
@@ -458,7 +544,6 @@ services:
|
|||||||
backbeat-nats:
|
backbeat-nats:
|
||||||
image: nats:2.9-alpine
|
image: nats:2.9-alpine
|
||||||
command: ["--jetstream"]
|
command: ["--jetstream"]
|
||||||
|
|
||||||
deploy:
|
deploy:
|
||||||
replicas: 1
|
replicas: 1
|
||||||
restart_policy:
|
restart_policy:
|
||||||
@@ -469,8 +554,6 @@ services:
|
|||||||
placement:
|
placement:
|
||||||
preferences:
|
preferences:
|
||||||
- spread: node.hostname
|
- spread: node.hostname
|
||||||
constraints:
|
|
||||||
- node.hostname != rosewood
|
|
||||||
resources:
|
resources:
|
||||||
limits:
|
limits:
|
||||||
memory: 256M
|
memory: 256M
|
||||||
@@ -478,10 +561,8 @@ services:
|
|||||||
reservations:
|
reservations:
|
||||||
memory: 128M
|
memory: 128M
|
||||||
cpus: '0.25'
|
cpus: '0.25'
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
- chorus_net
|
- chorus_net
|
||||||
|
|
||||||
# Container logging
|
# Container logging
|
||||||
logging:
|
logging:
|
||||||
driver: "json-file"
|
driver: "json-file"
|
||||||
@@ -495,6 +576,24 @@ services:
|
|||||||
|
|
||||||
# Persistent volumes
|
# Persistent volumes
|
||||||
volumes:
|
volumes:
|
||||||
|
prometheus_data:
|
||||||
|
driver: local
|
||||||
|
driver_opts:
|
||||||
|
type: none
|
||||||
|
o: bind
|
||||||
|
device: /rust/containers/CHORUS/monitoring/prometheus
|
||||||
|
prometheus_config:
|
||||||
|
driver: local
|
||||||
|
driver_opts:
|
||||||
|
type: none
|
||||||
|
o: bind
|
||||||
|
device: /rust/containers/CHORUS/monitoring/prometheus
|
||||||
|
grafana_data:
|
||||||
|
driver: local
|
||||||
|
driver_opts:
|
||||||
|
type: none
|
||||||
|
o: bind
|
||||||
|
device: /rust/containers/CHORUS/monitoring/grafana
|
||||||
chorus_data:
|
chorus_data:
|
||||||
driver: local
|
driver: local
|
||||||
whoosh_postgres_data:
|
whoosh_postgres_data:
|
||||||
@@ -516,18 +615,14 @@ networks:
|
|||||||
tengig:
|
tengig:
|
||||||
external: true
|
external: true
|
||||||
|
|
||||||
whoosh-backend:
|
|
||||||
driver: overlay
|
|
||||||
attachable: false
|
|
||||||
|
|
||||||
chorus_net:
|
chorus_net:
|
||||||
driver: overlay
|
driver: overlay
|
||||||
attachable: true
|
attachable: true
|
||||||
ipam:
|
|
||||||
config:
|
|
||||||
- subnet: 10.201.0.0/24
|
|
||||||
|
|
||||||
|
|
||||||
|
configs:
|
||||||
|
chorus_bootstrap:
|
||||||
|
file: ./bootstrap.json
|
||||||
|
|
||||||
secrets:
|
secrets:
|
||||||
chorus_license_id:
|
chorus_license_id:
|
||||||
|
|||||||
435
docs/development/task-execution-engine-plan.md
Normal file
435
docs/development/task-execution-engine-plan.md
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
# CHORUS Task Execution Engine Development Plan
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
This plan outlines the development of a comprehensive task execution engine for CHORUS agents, replacing the current mock implementation with a fully functional system that can execute real work according to agent roles and specializations.
|
||||||
|
|
||||||
|
## Current State Analysis
|
||||||
|
|
||||||
|
### What's Implemented ✅
|
||||||
|
- **Task Coordinator Framework** (`coordinator/task_coordinator.go`): Full task management lifecycle with role-based assignment, collaboration requests, and HMMM integration
|
||||||
|
- **Agent Role System**: Role announcements, capability broadcasting, and expertise matching
|
||||||
|
- **P2P Infrastructure**: Nodes can discover each other and communicate via pubsub
|
||||||
|
- **Health Monitoring**: Comprehensive health checks and graceful shutdown
|
||||||
|
|
||||||
|
### Critical Gaps Identified ❌
|
||||||
|
- **Task Execution Engine**: `executeTask()` only has a 10-second sleep simulation - no actual work performed
|
||||||
|
- **Repository Integration**: Mock providers only - no real GitHub/GitLab task pulling
|
||||||
|
- **Agent-to-Task Binding**: Task discovery relies on WHOOSH but agents don't connect to real work
|
||||||
|
- **Role-Based Execution**: Agents announce roles but don't execute tasks according to their specialization
|
||||||
|
- **AI Integration**: No LLM/reasoning integration for task completion
|
||||||
|
|
||||||
|
## Architecture Requirements
|
||||||
|
|
||||||
|
### Model and Provider Abstraction
|
||||||
|
The execution engine must support multiple AI model providers and execution environments:
|
||||||
|
|
||||||
|
**Model Provider Types:**
|
||||||
|
- **Local Ollama**: Default for most roles (llama3.1:8b, codellama, etc.)
|
||||||
|
- **OpenAI API**: For specialized models (chatgpt-5, gpt-4o, etc.)
|
||||||
|
- **ResetData API**: For testing and fallback (llama3.1:8b via LaaS)
|
||||||
|
- **Custom Endpoints**: Support for other provider APIs
|
||||||
|
|
||||||
|
**Role-Model Mapping:**
|
||||||
|
- Each role has a default model configuration
|
||||||
|
- Specialized roles may require specific models/providers
|
||||||
|
- Model selection transparent to execution logic
|
||||||
|
- Support for MCP calls and tool usage regardless of provider
|
||||||
|
|
||||||
|
### Execution Environment Abstraction
|
||||||
|
Tasks must execute in secure, isolated environments while maintaining transparency:
|
||||||
|
|
||||||
|
**Sandbox Types:**
|
||||||
|
- **Docker Containers**: Isolated execution environment per task
|
||||||
|
- **Specialized VMs**: For tasks requiring full OS isolation
|
||||||
|
- **Process Sandboxing**: Lightweight isolation for simple tasks
|
||||||
|
|
||||||
|
**Transparency Requirements:**
|
||||||
|
- Model perceives it's working on a local repository
|
||||||
|
- Development tools available within sandbox
|
||||||
|
- File system operations work normally from model's perspective
|
||||||
|
- Network access controlled but transparent
|
||||||
|
- Resource limits enforced but invisible
|
||||||
|
|
||||||
|
## Development Plan
|
||||||
|
|
||||||
|
### Phase 1: Model Provider Abstraction Layer
|
||||||
|
|
||||||
|
#### 1.1 Create Provider Interface
|
||||||
|
```go
|
||||||
|
// pkg/ai/provider.go
|
||||||
|
type ModelProvider interface {
|
||||||
|
ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error)
|
||||||
|
SupportsMCP() bool
|
||||||
|
SupportsTools() bool
|
||||||
|
GetCapabilities() []string
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 1.2 Implement Provider Types
|
||||||
|
- **OllamaProvider**: Local model execution
|
||||||
|
- **OpenAIProvider**: OpenAI API integration
|
||||||
|
- **ResetDataProvider**: ResetData LaaS integration
|
||||||
|
- **ProviderFactory**: Creates appropriate provider based on model config
|
||||||
|
|
||||||
|
#### 1.3 Role-Model Configuration
|
||||||
|
```yaml
|
||||||
|
# Config structure for role-model mapping
|
||||||
|
roles:
|
||||||
|
developer:
|
||||||
|
default_model: "codellama:13b"
|
||||||
|
provider: "ollama"
|
||||||
|
fallback_model: "llama3.1:8b"
|
||||||
|
fallback_provider: "resetdata"
|
||||||
|
|
||||||
|
architect:
|
||||||
|
default_model: "gpt-4o"
|
||||||
|
provider: "openai"
|
||||||
|
fallback_model: "llama3.1:8b"
|
||||||
|
fallback_provider: "ollama"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 2: Execution Environment Abstraction
|
||||||
|
|
||||||
|
#### 2.1 Create Sandbox Interface
|
||||||
|
```go
|
||||||
|
// pkg/execution/sandbox.go
|
||||||
|
type ExecutionSandbox interface {
|
||||||
|
Initialize(ctx context.Context, config *SandboxConfig) error
|
||||||
|
ExecuteCommand(ctx context.Context, cmd *Command) (*CommandResult, error)
|
||||||
|
CopyFiles(ctx context.Context, source, dest string) error
|
||||||
|
Cleanup() error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2.2 Implement Sandbox Types
|
||||||
|
- **DockerSandbox**: Container-based isolation
|
||||||
|
- **VMSandbox**: Full VM isolation for sensitive tasks
|
||||||
|
- **ProcessSandbox**: Lightweight process-based isolation
|
||||||
|
|
||||||
|
#### 2.3 Repository Mounting
|
||||||
|
- Clone repository into sandbox environment
|
||||||
|
- Mount as local filesystem from model's perspective
|
||||||
|
- Implement secure file I/O operations
|
||||||
|
- Handle git operations within sandbox
|
||||||
|
|
||||||
|
### Phase 3: Core Task Execution Engine
|
||||||
|
|
||||||
|
#### 3.1 Replace Mock Implementation
|
||||||
|
Replace the current simulation in `coordinator/task_coordinator.go:314`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Current mock implementation
|
||||||
|
time.Sleep(10 * time.Second) // Simulate work
|
||||||
|
|
||||||
|
// New implementation
|
||||||
|
result, err := tc.executionEngine.ExecuteTask(ctx, &TaskExecutionRequest{
|
||||||
|
Task: activeTask.Task,
|
||||||
|
Agent: tc.agentInfo,
|
||||||
|
Sandbox: sandboxConfig,
|
||||||
|
ModelProvider: providerConfig,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3.2 Task Execution Strategies
|
||||||
|
Create role-specific execution patterns:
|
||||||
|
|
||||||
|
- **DeveloperStrategy**: Code implementation, bug fixes, feature development
|
||||||
|
- **ReviewerStrategy**: Code review, quality analysis, test coverage assessment
|
||||||
|
- **ArchitectStrategy**: System design, technical decision making
|
||||||
|
- **TesterStrategy**: Test creation, validation, quality assurance
|
||||||
|
|
||||||
|
#### 3.3 Execution Workflow
|
||||||
|
1. **Task Analysis**: Parse task requirements and complexity
|
||||||
|
2. **Environment Setup**: Initialize appropriate sandbox
|
||||||
|
3. **Repository Preparation**: Clone and mount repository
|
||||||
|
4. **Model Selection**: Choose appropriate model/provider
|
||||||
|
5. **Task Execution**: Run role-specific execution strategy
|
||||||
|
6. **Result Validation**: Verify output quality and completeness
|
||||||
|
7. **Cleanup**: Teardown sandbox and collect artifacts
|
||||||
|
|
||||||
|
### Phase 4: Repository Provider Implementation
|
||||||
|
|
||||||
|
#### 4.1 Real Repository Integration
|
||||||
|
Replace `MockTaskProvider` with actual implementations:
|
||||||
|
- **GiteaProvider**: Integration with GITEA API
|
||||||
|
- **GitHubProvider**: GitHub API integration
|
||||||
|
- **GitLabProvider**: GitLab API integration
|
||||||
|
|
||||||
|
#### 4.2 Task Lifecycle Management
|
||||||
|
- Task claiming and status updates
|
||||||
|
- Progress reporting back to repositories
|
||||||
|
- Artifact attachment (patches, documentation, etc.)
|
||||||
|
- Automated PR/MR creation for completed tasks
|
||||||
|
|
||||||
|
### Phase 5: AI Integration and Tool Support
|
||||||
|
|
||||||
|
#### 5.1 LLM Integration
|
||||||
|
- Context-aware task analysis based on repository content
|
||||||
|
- Code generation and problem-solving capabilities
|
||||||
|
- Natural language processing for task descriptions
|
||||||
|
- Multi-step reasoning for complex tasks
|
||||||
|
|
||||||
|
#### 5.2 Tool Integration
|
||||||
|
- MCP server connectivity within sandbox
|
||||||
|
- Development tool access (compilers, linters, formatters)
|
||||||
|
- Testing framework integration
|
||||||
|
- Documentation generation tools
|
||||||
|
|
||||||
|
#### 5.3 Quality Assurance
|
||||||
|
- Automated testing of generated code
|
||||||
|
- Code quality metrics and analysis
|
||||||
|
- Security vulnerability scanning
|
||||||
|
- Performance impact assessment
|
||||||
|
|
||||||
|
### Phase 6: Testing and Validation
|
||||||
|
|
||||||
|
#### 6.1 Unit Testing
|
||||||
|
- Provider abstraction layer testing
|
||||||
|
- Sandbox isolation verification
|
||||||
|
- Task execution strategy validation
|
||||||
|
- Error handling and recovery testing
|
||||||
|
|
||||||
|
#### 6.2 Integration Testing
|
||||||
|
- End-to-end task execution workflows
|
||||||
|
- Agent-to-WHOOSH communication testing
|
||||||
|
- Multi-provider failover scenarios
|
||||||
|
- Concurrent task execution testing
|
||||||
|
|
||||||
|
#### 6.3 Security Testing
|
||||||
|
- Sandbox escape prevention
|
||||||
|
- Resource limit enforcement
|
||||||
|
- Network isolation validation
|
||||||
|
- Secrets and credential protection
|
||||||
|
|
||||||
|
### Phase 7: Production Deployment
|
||||||
|
|
||||||
|
#### 7.1 Configuration Management
|
||||||
|
- Environment-specific model configurations
|
||||||
|
- Sandbox resource limit definitions
|
||||||
|
- Provider API key management
|
||||||
|
- Monitoring and logging setup
|
||||||
|
|
||||||
|
#### 7.2 Monitoring and Observability
|
||||||
|
- Task execution metrics and dashboards
|
||||||
|
- Performance monitoring and alerting
|
||||||
|
- Resource utilization tracking
|
||||||
|
- Error rate and success metrics
|
||||||
|
|
||||||
|
## Implementation Priorities
|
||||||
|
|
||||||
|
### Critical Path (Week 1-2)
|
||||||
|
1. Model Provider Abstraction Layer
|
||||||
|
2. Basic Docker Sandbox Implementation
|
||||||
|
3. Replace Mock Task Execution
|
||||||
|
4. Role-Based Execution Strategies
|
||||||
|
|
||||||
|
### High Priority (Week 3-4)
|
||||||
|
5. Real Repository Provider Implementation
|
||||||
|
6. AI Integration with Ollama/OpenAI
|
||||||
|
7. MCP Tool Integration
|
||||||
|
8. Basic Testing Framework
|
||||||
|
|
||||||
|
### Medium Priority (Week 5-6)
|
||||||
|
9. Advanced Sandbox Types (VM, Process)
|
||||||
|
10. Quality Assurance Pipeline
|
||||||
|
11. Comprehensive Testing Suite
|
||||||
|
12. Performance Optimization
|
||||||
|
|
||||||
|
### Future Enhancements
|
||||||
|
- Multi-language model support
|
||||||
|
- Advanced reasoning capabilities
|
||||||
|
- Distributed task execution
|
||||||
|
- Machine learning model fine-tuning
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
- **Task Completion Rate**: >90% of assigned tasks successfully completed
|
||||||
|
- **Code Quality**: Generated code passes all existing tests and linting
|
||||||
|
- **Security**: Zero sandbox escapes or security violations
|
||||||
|
- **Performance**: Task execution time within acceptable bounds
|
||||||
|
- **Reliability**: <5% execution failure rate due to engine issues
|
||||||
|
|
||||||
|
## Risk Mitigation
|
||||||
|
|
||||||
|
### Security Risks
|
||||||
|
- Sandbox escape → Multiple isolation layers, security audits
|
||||||
|
- Credential exposure → Secure credential management, rotation
|
||||||
|
- Resource exhaustion → Resource limits, monitoring, auto-scaling
|
||||||
|
|
||||||
|
### Technical Risks
|
||||||
|
- Model provider outages → Multi-provider failover, local fallbacks
|
||||||
|
- Execution failures → Robust error handling, retry mechanisms
|
||||||
|
- Performance bottlenecks → Profiling, optimization, horizontal scaling
|
||||||
|
|
||||||
|
### Integration Risks
|
||||||
|
- WHOOSH compatibility → Extensive integration testing, versioning
|
||||||
|
- Repository provider changes → Provider abstraction, API versioning
|
||||||
|
- Model compatibility → Provider abstraction, capability detection
|
||||||
|
|
||||||
|
This comprehensive plan addresses the core limitation that CHORUS agents currently lack real task execution capabilities while building a robust, secure, and scalable execution engine suitable for production deployment.
|
||||||
|
|
||||||
|
## Implementation Roadmap
|
||||||
|
|
||||||
|
### Development Standards & Workflow
|
||||||
|
|
||||||
|
**Semantic Versioning Strategy:**
|
||||||
|
- **Patch (0.N.X)**: Bug fixes, small improvements, documentation updates
|
||||||
|
- **Minor (0.N.0)**: New features, phase completions, non-breaking changes
|
||||||
|
- **Major (N.0.0)**: Breaking changes, major architectural shifts
|
||||||
|
|
||||||
|
**Git Workflow:**
|
||||||
|
1. **Branch Creation**: `git checkout -b feature/phase-N-description`
|
||||||
|
2. **Development**: Implement with frequent commits using conventional commit format
|
||||||
|
3. **Testing**: Run full test suite with `make test` before PR
|
||||||
|
4. **Code Review**: Create PR with detailed description and test results
|
||||||
|
5. **Integration**: Squash merge to main after approval
|
||||||
|
6. **Release**: Tag with `git tag v0.N.0` and update Makefile version
|
||||||
|
|
||||||
|
**Quality Gates:**
|
||||||
|
Each phase must meet these criteria before merge:
|
||||||
|
- ✅ Unit tests with >80% coverage
|
||||||
|
- ✅ Integration tests for external dependencies
|
||||||
|
- ✅ Security review for new attack surfaces
|
||||||
|
- ✅ Performance benchmarks within acceptable bounds
|
||||||
|
- ✅ Documentation updates (code comments + README)
|
||||||
|
- ✅ Backward compatibility verification
|
||||||
|
|
||||||
|
### Phase-by-Phase Implementation
|
||||||
|
|
||||||
|
#### Phase 1: Model Provider Abstraction (v0.2.0)
|
||||||
|
**Branch:** `feature/phase-1-model-providers`
|
||||||
|
**Duration:** 3-5 days
|
||||||
|
**Deliverables:**
|
||||||
|
```
|
||||||
|
pkg/ai/
|
||||||
|
├── provider.go # Core provider interface & request/response types
|
||||||
|
├── ollama.go # Local Ollama model integration
|
||||||
|
├── openai.go # OpenAI API client wrapper
|
||||||
|
├── resetdata.go # ResetData LaaS integration
|
||||||
|
├── factory.go # Provider factory with auto-selection
|
||||||
|
└── provider_test.go # Comprehensive provider tests
|
||||||
|
|
||||||
|
configs/
|
||||||
|
└── models.yaml # Role-model mapping configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- Abstract AI providers behind unified interface
|
||||||
|
- Support multiple providers with automatic failover
|
||||||
|
- Configuration-driven model selection per agent role
|
||||||
|
- Proper error handling and retry logic
|
||||||
|
|
||||||
|
#### Phase 2: Execution Environment Abstraction (v0.3.0)
|
||||||
|
**Branch:** `feature/phase-2-execution-sandbox`
|
||||||
|
**Duration:** 5-7 days
|
||||||
|
**Deliverables:**
|
||||||
|
```
|
||||||
|
pkg/execution/
|
||||||
|
├── sandbox.go # Core sandbox interface & types
|
||||||
|
├── docker.go # Docker container implementation
|
||||||
|
├── security.go # Security policies & enforcement
|
||||||
|
├── resources.go # Resource monitoring & limits
|
||||||
|
└── sandbox_test.go # Sandbox security & isolation tests
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- Docker-based task isolation with transparent repository access
|
||||||
|
- Resource limits (CPU, memory, network, disk) with monitoring
|
||||||
|
- Security boundary enforcement and escape prevention
|
||||||
|
- Clean teardown and artifact collection
|
||||||
|
|
||||||
|
#### Phase 3: Core Task Execution Engine (v0.4.0)
|
||||||
|
**Branch:** `feature/phase-3-task-execution`
|
||||||
|
**Duration:** 7-10 days
|
||||||
|
**Modified Files:**
|
||||||
|
- `coordinator/task_coordinator.go:314` - Replace mock with real execution
|
||||||
|
- `pkg/repository/types.go` - Extend interfaces for execution context
|
||||||
|
|
||||||
|
**New Files:**
|
||||||
|
```
|
||||||
|
pkg/strategies/
|
||||||
|
├── developer.go # Code implementation & bug fixes
|
||||||
|
├── reviewer.go # Code review & quality analysis
|
||||||
|
├── architect.go # System design & tech decisions
|
||||||
|
└── tester.go # Test creation & validation
|
||||||
|
|
||||||
|
pkg/engine/
|
||||||
|
├── executor.go # Main execution orchestrator
|
||||||
|
├── workflow.go # 7-step execution workflow
|
||||||
|
└── validation.go # Result quality verification
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- Real task execution replacing 10-second sleep simulation
|
||||||
|
- Role-specific execution strategies with appropriate tooling
|
||||||
|
- Integration between AI providers, sandboxes, and task lifecycle
|
||||||
|
- Comprehensive result validation and quality metrics
|
||||||
|
|
||||||
|
#### Phase 4: Repository Provider Implementation (v0.5.0)
|
||||||
|
**Branch:** `feature/phase-4-real-providers`
|
||||||
|
**Duration:** 10-14 days
|
||||||
|
**Deliverables:**
|
||||||
|
```
|
||||||
|
pkg/providers/
|
||||||
|
├── gitea.go # Gitea API integration (primary)
|
||||||
|
├── github.go # GitHub API integration
|
||||||
|
├── gitlab.go # GitLab API integration
|
||||||
|
└── provider_test.go # API integration tests
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- Replace MockTaskProvider with production implementations
|
||||||
|
- Task claiming, status updates, and progress reporting via APIs
|
||||||
|
- Automated PR/MR creation with proper branch management
|
||||||
|
- Repository-specific configuration and credential management
|
||||||
|
|
||||||
|
### Testing Strategy
|
||||||
|
|
||||||
|
**Unit Testing:**
|
||||||
|
- Each provider/sandbox implementation has dedicated test suite
|
||||||
|
- Mock external dependencies (APIs, Docker, etc.) for isolated testing
|
||||||
|
- Property-based testing for core interfaces
|
||||||
|
- Error condition and edge case coverage
|
||||||
|
|
||||||
|
**Integration Testing:**
|
||||||
|
- End-to-end task execution workflows
|
||||||
|
- Multi-provider failover scenarios
|
||||||
|
- Agent-to-WHOOSH communication validation
|
||||||
|
- Concurrent task execution under load
|
||||||
|
|
||||||
|
**Security Testing:**
|
||||||
|
- Sandbox escape prevention validation
|
||||||
|
- Resource exhaustion protection
|
||||||
|
- Network isolation verification
|
||||||
|
- Secrets and credential protection audits
|
||||||
|
|
||||||
|
### Deployment & Monitoring
|
||||||
|
|
||||||
|
**Configuration Management:**
|
||||||
|
- Environment-specific model configurations
|
||||||
|
- Sandbox resource limits per environment
|
||||||
|
- Provider API credentials via secure secret management
|
||||||
|
- Feature flags for gradual rollout
|
||||||
|
|
||||||
|
**Observability:**
|
||||||
|
- Task execution metrics (completion rate, duration, success/failure)
|
||||||
|
- Resource utilization tracking (CPU, memory, network per task)
|
||||||
|
- Error rate monitoring with alerting thresholds
|
||||||
|
- Performance dashboards for capacity planning
|
||||||
|
|
||||||
|
### Risk Mitigation
|
||||||
|
|
||||||
|
**Technical Risks:**
|
||||||
|
- **Provider Outages**: Multi-provider failover with health checks
|
||||||
|
- **Resource Exhaustion**: Strict limits with monitoring and auto-scaling
|
||||||
|
- **Execution Failures**: Retry mechanisms with exponential backoff
|
||||||
|
|
||||||
|
**Security Risks:**
|
||||||
|
- **Sandbox Escapes**: Multiple isolation layers and regular security audits
|
||||||
|
- **Credential Exposure**: Secure rotation and least-privilege access
|
||||||
|
- **Data Exfiltration**: Network isolation and egress monitoring
|
||||||
|
|
||||||
|
**Integration Risks:**
|
||||||
|
- **API Changes**: Provider abstraction with versioning support
|
||||||
|
- **Performance Degradation**: Comprehensive benchmarking at each phase
|
||||||
|
- **Compatibility Issues**: Extensive integration testing with existing systems
|
||||||
33
go.mod
33
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module chorus
|
module chorus
|
||||||
|
|
||||||
go 1.23
|
go 1.23.0
|
||||||
|
|
||||||
toolchain go1.24.5
|
toolchain go1.24.5
|
||||||
|
|
||||||
@@ -8,6 +8,9 @@ require (
|
|||||||
filippo.io/age v1.2.1
|
filippo.io/age v1.2.1
|
||||||
github.com/blevesearch/bleve/v2 v2.5.3
|
github.com/blevesearch/bleve/v2 v2.5.3
|
||||||
github.com/chorus-services/backbeat v0.0.0-00010101000000-000000000000
|
github.com/chorus-services/backbeat v0.0.0-00010101000000-000000000000
|
||||||
|
github.com/docker/docker v28.4.0+incompatible
|
||||||
|
github.com/docker/go-connections v0.6.0
|
||||||
|
github.com/docker/go-units v0.5.0
|
||||||
github.com/go-redis/redis/v8 v8.11.5
|
github.com/go-redis/redis/v8 v8.11.5
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/mux v1.8.1
|
github.com/gorilla/mux v1.8.1
|
||||||
@@ -21,12 +24,15 @@ require (
|
|||||||
github.com/prometheus/client_golang v1.19.1
|
github.com/prometheus/client_golang v1.19.1
|
||||||
github.com/robfig/cron/v3 v3.0.1
|
github.com/robfig/cron/v3 v3.0.1
|
||||||
github.com/sashabaranov/go-openai v1.41.1
|
github.com/sashabaranov/go-openai v1.41.1
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/sony/gobreaker v0.5.0
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/syndtr/goleveldb v1.0.0
|
github.com/syndtr/goleveldb v1.0.0
|
||||||
golang.org/x/crypto v0.24.0
|
golang.org/x/crypto v0.24.0
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
github.com/RoaringBitmap/roaring/v2 v2.4.5 // indirect
|
github.com/RoaringBitmap/roaring/v2 v2.4.5 // indirect
|
||||||
github.com/benbjohnson/clock v1.3.5 // indirect
|
github.com/benbjohnson/clock v1.3.5 // indirect
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
@@ -50,16 +56,19 @@ require (
|
|||||||
github.com/blevesearch/zapx/v16 v16.2.4 // indirect
|
github.com/blevesearch/zapx/v16 v16.2.4 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||||
github.com/containerd/cgroups v1.1.0 // indirect
|
github.com/containerd/cgroups v1.1.0 // indirect
|
||||||
|
github.com/containerd/errdefs v1.0.0 // indirect
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||||
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
|
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect
|
github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect
|
||||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
|
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/elastic/gosigar v0.14.2 // indirect
|
github.com/elastic/gosigar v0.14.2 // indirect
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/flynn/noise v1.0.0 // indirect
|
github.com/flynn/noise v1.0.0 // indirect
|
||||||
github.com/francoispqt/gojay v1.2.13 // indirect
|
github.com/francoispqt/gojay v1.2.13 // indirect
|
||||||
github.com/go-logr/logr v1.2.4 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||||
@@ -104,6 +113,7 @@ require (
|
|||||||
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b // indirect
|
github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b // indirect
|
||||||
github.com/mikioh/tcpopt v0.0.0-20190314235656-172688c1accc // indirect
|
github.com/mikioh/tcpopt v0.0.0-20190314235656-172688c1accc // indirect
|
||||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/mr-tron/base58 v1.2.0 // indirect
|
github.com/mr-tron/base58 v1.2.0 // indirect
|
||||||
@@ -120,6 +130,8 @@ require (
|
|||||||
github.com/nats-io/nkeys v0.4.7 // indirect
|
github.com/nats-io/nkeys v0.4.7 // indirect
|
||||||
github.com/nats-io/nuid v1.0.1 // indirect
|
github.com/nats-io/nuid v1.0.1 // indirect
|
||||||
github.com/onsi/ginkgo/v2 v2.13.0 // indirect
|
github.com/onsi/ginkgo/v2 v2.13.0 // indirect
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
github.com/opencontainers/runtime-spec v1.1.0 // indirect
|
github.com/opencontainers/runtime-spec v1.1.0 // indirect
|
||||||
github.com/opentracing/opentracing-go v1.2.0 // indirect
|
github.com/opentracing/opentracing-go v1.2.0 // indirect
|
||||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
|
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
|
||||||
@@ -138,9 +150,11 @@ require (
|
|||||||
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 // indirect
|
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 // indirect
|
||||||
go.etcd.io/bbolt v1.4.0 // indirect
|
go.etcd.io/bbolt v1.4.0 // indirect
|
||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
go.opentelemetry.io/otel v1.16.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.16.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.16.0 // indirect
|
go.opentelemetry.io/otel v1.38.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||||
go.uber.org/dig v1.17.1 // indirect
|
go.uber.org/dig v1.17.1 // indirect
|
||||||
go.uber.org/fx v1.20.1 // indirect
|
go.uber.org/fx v1.20.1 // indirect
|
||||||
go.uber.org/mock v0.3.0 // indirect
|
go.uber.org/mock v0.3.0 // indirect
|
||||||
@@ -150,12 +164,11 @@ require (
|
|||||||
golang.org/x/mod v0.18.0 // indirect
|
golang.org/x/mod v0.18.0 // indirect
|
||||||
golang.org/x/net v0.26.0 // indirect
|
golang.org/x/net v0.26.0 // indirect
|
||||||
golang.org/x/sync v0.10.0 // indirect
|
golang.org/x/sync v0.10.0 // indirect
|
||||||
golang.org/x/sys v0.29.0 // indirect
|
golang.org/x/sys v0.35.0 // indirect
|
||||||
golang.org/x/text v0.16.0 // indirect
|
golang.org/x/text v0.16.0 // indirect
|
||||||
golang.org/x/tools v0.22.0 // indirect
|
golang.org/x/tools v0.22.0 // indirect
|
||||||
gonum.org/v1/gonum v0.13.0 // indirect
|
gonum.org/v1/gonum v0.13.0 // indirect
|
||||||
google.golang.org/protobuf v1.33.0 // indirect
|
google.golang.org/protobuf v1.34.2 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
|
||||||
lukechampine.com/blake3 v1.2.1 // indirect
|
lukechampine.com/blake3 v1.2.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
40
go.sum
40
go.sum
@@ -12,6 +12,8 @@ filippo.io/age v1.2.1 h1:X0TZjehAZylOIj4DubWYU1vWQxv9bJpo+Uu2/LGhi1o=
|
|||||||
filippo.io/age v1.2.1/go.mod h1:JL9ew2lTN+Pyft4RiNGguFfOpewKwSHm5ayKD/A4004=
|
filippo.io/age v1.2.1/go.mod h1:JL9ew2lTN+Pyft4RiNGguFfOpewKwSHm5ayKD/A4004=
|
||||||
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
|
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
github.com/RoaringBitmap/roaring/v2 v2.4.5 h1:uGrrMreGjvAtTBobc0g5IrW1D5ldxDQYe2JW2gggRdg=
|
github.com/RoaringBitmap/roaring/v2 v2.4.5 h1:uGrrMreGjvAtTBobc0g5IrW1D5ldxDQYe2JW2gggRdg=
|
||||||
github.com/RoaringBitmap/roaring/v2 v2.4.5/go.mod h1:FiJcsfkGje/nZBZgCu0ZxCPOKD/hVXDS2dXi7/eUFE0=
|
github.com/RoaringBitmap/roaring/v2 v2.4.5/go.mod h1:FiJcsfkGje/nZBZgCu0ZxCPOKD/hVXDS2dXi7/eUFE0=
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
|
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
|
||||||
@@ -72,6 +74,10 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX
|
|||||||
github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE=
|
github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE=
|
||||||
github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM=
|
github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM=
|
||||||
github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw=
|
github.com/containerd/cgroups v1.1.0/go.mod h1:6ppBcbh/NOOUU+dMKrykgaBnK9lCIBxHqJDGwsa1mIw=
|
||||||
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
|
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||||
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||||
github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
|
github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
|
||||||
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
|
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
|
||||||
@@ -89,6 +95,12 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etly
|
|||||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
|
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
|
github.com/docker/docker v28.4.0+incompatible h1:KVC7bz5zJY/4AZe/78BIvCnPsLaC9T/zh72xnlrTTOk=
|
||||||
|
github.com/docker/docker v28.4.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
|
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||||
|
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||||
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
@@ -100,6 +112,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF
|
|||||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
|
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
|
||||||
github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ=
|
github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ=
|
||||||
github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||||
@@ -116,6 +130,8 @@ github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm
|
|||||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||||
@@ -307,6 +323,8 @@ github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8Rv
|
|||||||
github.com/minio/sha256-simd v0.1.1-0.20190913151208-6de447530771/go.mod h1:B5e1o+1/KgNmWrSQK08Y6Z1Vb5pwIktudl0J58iy0KM=
|
github.com/minio/sha256-simd v0.1.1-0.20190913151208-6de447530771/go.mod h1:B5e1o+1/KgNmWrSQK08Y6Z1Vb5pwIktudl0J58iy0KM=
|
||||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@@ -361,6 +379,10 @@ github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xl
|
|||||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||||
github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI=
|
github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI=
|
||||||
github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M=
|
github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
|
github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
|
||||||
github.com/opencontainers/runtime-spec v1.1.0 h1:HHUyrt9mwHUjtasSbXSMvs4cyFxh+Bll4AjJ9odEGpg=
|
github.com/opencontainers/runtime-spec v1.1.0 h1:HHUyrt9mwHUjtasSbXSMvs4cyFxh+Bll4AjJ9odEGpg=
|
||||||
github.com/opencontainers/runtime-spec v1.1.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
|
github.com/opencontainers/runtime-spec v1.1.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
|
||||||
@@ -437,6 +459,8 @@ github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N
|
|||||||
github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
|
github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
|
||||||
github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
|
github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
|
||||||
github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM=
|
github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM=
|
||||||
|
github.com/sony/gobreaker v0.5.0 h1:dRCvqm0P490vZPmy7ppEk2qCnCieBooFJ+YoXGYB+yg=
|
||||||
|
github.com/sony/gobreaker v0.5.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
|
||||||
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE=
|
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE=
|
||||||
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
|
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
|
||||||
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
|
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
|
||||||
@@ -454,6 +478,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
|||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
|
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
|
||||||
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
|
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
|
||||||
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
||||||
@@ -473,12 +499,22 @@ go.etcd.io/bbolt v1.4.0/go.mod h1:AsD+OCi/qPN1giOX1aiLAha3o1U8rAz65bvN4j0sRuk=
|
|||||||
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
|
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
|
||||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg=
|
||||||
go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s=
|
go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s=
|
||||||
go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4=
|
go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4=
|
||||||
|
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||||
|
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||||
go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo=
|
go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo=
|
||||||
go.opentelemetry.io/otel/metric v1.16.0/go.mod h1:QE47cpOmkwipPiefDwo2wDzwJrlfxxNYodqc4xnGCo4=
|
go.opentelemetry.io/otel/metric v1.16.0/go.mod h1:QE47cpOmkwipPiefDwo2wDzwJrlfxxNYodqc4xnGCo4=
|
||||||
|
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||||
|
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||||
go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs=
|
go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs=
|
||||||
go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0=
|
go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0=
|
||||||
|
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||||
|
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||||
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||||
@@ -588,6 +624,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||||
|
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
|
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
|
||||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||||
@@ -659,6 +697,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
|||||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||||
|
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||||
|
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
|||||||
340
internal/licensing/license_gate.go
Normal file
340
internal/licensing/license_gate.go
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
package licensing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sony/gobreaker"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LicenseGate provides burst-proof license validation with caching and circuit breaker
|
||||||
|
type LicenseGate struct {
|
||||||
|
config LicenseConfig
|
||||||
|
cache atomic.Value // stores cachedLease
|
||||||
|
breaker *gobreaker.CircuitBreaker
|
||||||
|
graceUntil atomic.Value // stores time.Time
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// cachedLease represents a cached license lease with expiry
|
||||||
|
type cachedLease struct {
|
||||||
|
LeaseToken string `json:"lease_token"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
ClusterID string `json:"cluster_id"`
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
CachedAt time.Time `json:"cached_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LeaseRequest represents a cluster lease request
|
||||||
|
type LeaseRequest struct {
|
||||||
|
ClusterID string `json:"cluster_id"`
|
||||||
|
RequestedReplicas int `json:"requested_replicas"`
|
||||||
|
DurationMinutes int `json:"duration_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LeaseResponse represents a cluster lease response
|
||||||
|
type LeaseResponse struct {
|
||||||
|
LeaseToken string `json:"lease_token"`
|
||||||
|
MaxReplicas int `json:"max_replicas"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
ClusterID string `json:"cluster_id"`
|
||||||
|
LeaseID string `json:"lease_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LeaseValidationRequest represents a lease validation request
|
||||||
|
type LeaseValidationRequest struct {
|
||||||
|
LeaseToken string `json:"lease_token"`
|
||||||
|
ClusterID string `json:"cluster_id"`
|
||||||
|
AgentID string `json:"agent_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LeaseValidationResponse represents a lease validation response
|
||||||
|
type LeaseValidationResponse struct {
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
RemainingReplicas int `json:"remaining_replicas"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLicenseGate creates a new license gate with circuit breaker and caching
|
||||||
|
func NewLicenseGate(config LicenseConfig) *LicenseGate {
|
||||||
|
// Circuit breaker settings optimized for license validation
|
||||||
|
breakerSettings := gobreaker.Settings{
|
||||||
|
Name: "license-validation",
|
||||||
|
MaxRequests: 3, // Allow 3 requests in half-open state
|
||||||
|
Interval: 60 * time.Second, // Reset failure count every minute
|
||||||
|
Timeout: 30 * time.Second, // Stay open for 30 seconds
|
||||||
|
ReadyToTrip: func(counts gobreaker.Counts) bool {
|
||||||
|
// Trip after 3 consecutive failures
|
||||||
|
return counts.ConsecutiveFailures >= 3
|
||||||
|
},
|
||||||
|
OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) {
|
||||||
|
fmt.Printf("🔌 License validation circuit breaker: %s -> %s\n", from, to)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gate := &LicenseGate{
|
||||||
|
config: config,
|
||||||
|
breaker: gobreaker.NewCircuitBreaker(breakerSettings),
|
||||||
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize grace period
|
||||||
|
gate.graceUntil.Store(time.Now().Add(90 * time.Second))
|
||||||
|
|
||||||
|
return gate
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidNow checks if the cached lease is currently valid
|
||||||
|
func (c *cachedLease) ValidNow() bool {
|
||||||
|
if !c.Valid {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Consider lease invalid 2 minutes before actual expiry for safety margin
|
||||||
|
return time.Now().Before(c.ExpiresAt.Add(-2 * time.Minute))
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadCachedLease safely loads the cached lease
|
||||||
|
func (g *LicenseGate) loadCachedLease() *cachedLease {
|
||||||
|
if cached := g.cache.Load(); cached != nil {
|
||||||
|
if lease, ok := cached.(*cachedLease); ok {
|
||||||
|
return lease
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &cachedLease{Valid: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeLease safely stores a lease in the cache
|
||||||
|
func (g *LicenseGate) storeLease(lease *cachedLease) {
|
||||||
|
lease.CachedAt = time.Now()
|
||||||
|
g.cache.Store(lease)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isInGracePeriod checks if we're still in the grace period
|
||||||
|
func (g *LicenseGate) isInGracePeriod() bool {
|
||||||
|
if graceUntil := g.graceUntil.Load(); graceUntil != nil {
|
||||||
|
if grace, ok := graceUntil.(time.Time); ok {
|
||||||
|
return time.Now().Before(grace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// extendGracePeriod extends the grace period on successful validation
|
||||||
|
func (g *LicenseGate) extendGracePeriod() {
|
||||||
|
g.graceUntil.Store(time.Now().Add(90 * time.Second))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the license using cache, lease system, and circuit breaker
|
||||||
|
func (g *LicenseGate) Validate(ctx context.Context, agentID string) error {
|
||||||
|
// Check cached lease first
|
||||||
|
if lease := g.loadCachedLease(); lease.ValidNow() {
|
||||||
|
return g.validateCachedLease(ctx, lease, agentID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get/renew lease through circuit breaker
|
||||||
|
_, err := g.breaker.Execute(func() (interface{}, error) {
|
||||||
|
lease, err := g.requestOrRenewLease(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the new lease
|
||||||
|
if err := g.validateLease(ctx, lease, agentID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store successful lease
|
||||||
|
g.storeLease(&cachedLease{
|
||||||
|
LeaseToken: lease.LeaseToken,
|
||||||
|
ExpiresAt: lease.ExpiresAt,
|
||||||
|
ClusterID: lease.ClusterID,
|
||||||
|
Valid: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// If we're in grace period, allow startup but log warning
|
||||||
|
if g.isInGracePeriod() {
|
||||||
|
fmt.Printf("⚠️ License validation failed but in grace period: %v\n", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("license validation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend grace period on successful validation
|
||||||
|
g.extendGracePeriod()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateCachedLease validates using cached lease token
|
||||||
|
func (g *LicenseGate) validateCachedLease(ctx context.Context, lease *cachedLease, agentID string) error {
|
||||||
|
validation := LeaseValidationRequest{
|
||||||
|
LeaseToken: lease.LeaseToken,
|
||||||
|
ClusterID: g.config.ClusterID,
|
||||||
|
AgentID: agentID,
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/api/v1/licenses/validate-lease", strings.TrimSuffix(g.config.KachingURL, "/"))
|
||||||
|
|
||||||
|
reqBody, err := json.Marshal(validation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal lease validation request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(reqBody)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create lease validation request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := g.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("lease validation request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
// If validation fails, invalidate cache
|
||||||
|
lease.Valid = false
|
||||||
|
g.storeLease(lease)
|
||||||
|
return fmt.Errorf("lease validation failed with status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var validationResp LeaseValidationResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&validationResp); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode lease validation response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !validationResp.Valid {
|
||||||
|
// If validation fails, invalidate cache
|
||||||
|
lease.Valid = false
|
||||||
|
g.storeLease(lease)
|
||||||
|
return fmt.Errorf("lease token is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestOrRenewLease requests a new cluster lease or renews existing one
|
||||||
|
func (g *LicenseGate) requestOrRenewLease(ctx context.Context) (*LeaseResponse, error) {
|
||||||
|
// For now, request a new lease (TODO: implement renewal logic)
|
||||||
|
leaseReq := LeaseRequest{
|
||||||
|
ClusterID: g.config.ClusterID,
|
||||||
|
RequestedReplicas: 1, // Start with single replica
|
||||||
|
DurationMinutes: 60, // 1 hour lease
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/api/v1/licenses/%s/cluster-lease",
|
||||||
|
strings.TrimSuffix(g.config.KachingURL, "/"), g.config.LicenseID)
|
||||||
|
|
||||||
|
reqBody, err := json.Marshal(leaseReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal lease request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(reqBody)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create lease request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := g.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("lease request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
return nil, fmt.Errorf("rate limited by KACHING, retry after: %s", resp.Header.Get("Retry-After"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("lease request failed with status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var leaseResp LeaseResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&leaseResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode lease response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &leaseResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateLease validates a lease token
|
||||||
|
func (g *LicenseGate) validateLease(ctx context.Context, lease *LeaseResponse, agentID string) error {
|
||||||
|
validation := LeaseValidationRequest{
|
||||||
|
LeaseToken: lease.LeaseToken,
|
||||||
|
ClusterID: lease.ClusterID,
|
||||||
|
AgentID: agentID,
|
||||||
|
}
|
||||||
|
|
||||||
|
return g.validateLeaseRequest(ctx, validation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateLeaseRequest performs the actual lease validation HTTP request
|
||||||
|
func (g *LicenseGate) validateLeaseRequest(ctx context.Context, validation LeaseValidationRequest) error {
|
||||||
|
url := fmt.Sprintf("%s/api/v1/licenses/validate-lease", strings.TrimSuffix(g.config.KachingURL, "/"))
|
||||||
|
|
||||||
|
reqBody, err := json.Marshal(validation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal lease validation request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(reqBody)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create lease validation request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := g.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("lease validation request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("lease validation failed with status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var validationResp LeaseValidationResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&validationResp); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode lease validation response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !validationResp.Valid {
|
||||||
|
return fmt.Errorf("lease token is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCacheStats returns cache statistics for monitoring
|
||||||
|
func (g *LicenseGate) GetCacheStats() map[string]interface{} {
|
||||||
|
lease := g.loadCachedLease()
|
||||||
|
stats := map[string]interface{}{
|
||||||
|
"cache_valid": lease.Valid,
|
||||||
|
"cache_hit": lease.ValidNow(),
|
||||||
|
"expires_at": lease.ExpiresAt,
|
||||||
|
"cached_at": lease.CachedAt,
|
||||||
|
"in_grace_period": g.isInGracePeriod(),
|
||||||
|
"breaker_state": g.breaker.State().String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if grace := g.graceUntil.Load(); grace != nil {
|
||||||
|
if graceTime, ok := grace.(time.Time); ok {
|
||||||
|
stats["grace_until"] = graceTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package licensing
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -21,35 +22,60 @@ type LicenseConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validator handles license validation with KACHING
|
// Validator handles license validation with KACHING
|
||||||
|
// Enhanced with license gate for burst-proof validation
|
||||||
type Validator struct {
|
type Validator struct {
|
||||||
config LicenseConfig
|
config LicenseConfig
|
||||||
kachingURL string
|
kachingURL string
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
gate *LicenseGate // New: License gate for scaling support
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewValidator creates a new license validator
|
// NewValidator creates a new license validator with enhanced scaling support
|
||||||
func NewValidator(config LicenseConfig) *Validator {
|
func NewValidator(config LicenseConfig) *Validator {
|
||||||
kachingURL := config.KachingURL
|
kachingURL := config.KachingURL
|
||||||
if kachingURL == "" {
|
if kachingURL == "" {
|
||||||
kachingURL = DefaultKachingURL
|
kachingURL = DefaultKachingURL
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Validator{
|
validator := &Validator{
|
||||||
config: config,
|
config: config,
|
||||||
kachingURL: kachingURL,
|
kachingURL: kachingURL,
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
Timeout: LicenseTimeout,
|
Timeout: LicenseTimeout,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize license gate for scaling support
|
||||||
|
validator.gate = NewLicenseGate(config)
|
||||||
|
|
||||||
|
return validator
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate performs license validation with KACHING license authority
|
// Validate performs license validation with KACHING license authority
|
||||||
// CRITICAL: CHORUS will not start without valid license validation
|
// Enhanced with caching, circuit breaker, and lease token support
|
||||||
func (v *Validator) Validate() error {
|
func (v *Validator) Validate() error {
|
||||||
|
return v.ValidateWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateWithContext performs license validation with context and agent ID
|
||||||
|
func (v *Validator) ValidateWithContext(ctx context.Context) error {
|
||||||
if v.config.LicenseID == "" || v.config.ClusterID == "" {
|
if v.config.LicenseID == "" || v.config.ClusterID == "" {
|
||||||
return fmt.Errorf("license ID and cluster ID are required")
|
return fmt.Errorf("license ID and cluster ID are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use enhanced license gate for validation
|
||||||
|
agentID := "default-agent" // TODO: Get from config/environment
|
||||||
|
if err := v.gate.Validate(ctx, agentID); err != nil {
|
||||||
|
// Fallback to legacy validation for backward compatibility
|
||||||
|
fmt.Printf("⚠️ License gate validation failed, trying legacy validation: %v\n", err)
|
||||||
|
return v.validateLegacy()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateLegacy performs the original license validation (for fallback)
|
||||||
|
func (v *Validator) validateLegacy() error {
|
||||||
// Prepare validation request
|
// Prepare validation request
|
||||||
request := map[string]interface{}{
|
request := map[string]interface{}{
|
||||||
"license_id": v.config.LicenseID,
|
"license_id": v.config.LicenseID,
|
||||||
@@ -66,7 +92,7 @@ func (v *Validator) Validate() error {
|
|||||||
return fmt.Errorf("failed to marshal license request: %w", err)
|
return fmt.Errorf("failed to marshal license request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call KACHING license authority
|
// Call KACHING license authority
|
||||||
licenseURL := fmt.Sprintf("%s/v1/license/activate", v.kachingURL)
|
licenseURL := fmt.Sprintf("%s/v1/license/activate", v.kachingURL)
|
||||||
resp, err := v.client.Post(licenseURL, "application/json", bytes.NewReader(requestBody))
|
resp, err := v.client.Post(licenseURL, "application/json", bytes.NewReader(requestBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -33,9 +33,12 @@ import (
|
|||||||
"github.com/multiformats/go-multiaddr"
|
"github.com/multiformats/go-multiaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// Build information - set by main package
|
||||||
AppName = "CHORUS"
|
var (
|
||||||
AppVersion = "0.1.0-dev"
|
AppName = "CHORUS"
|
||||||
|
AppVersion = "0.1.0-dev"
|
||||||
|
AppCommitHash = "unknown"
|
||||||
|
AppBuildDate = "unknown"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SimpleLogger provides basic logging implementation
|
// SimpleLogger provides basic logging implementation
|
||||||
@@ -105,6 +108,7 @@ func (t *SimpleTaskTracker) publishTaskCompletion(taskID string, success bool, s
|
|||||||
// SharedRuntime contains all the shared P2P infrastructure components
|
// SharedRuntime contains all the shared P2P infrastructure components
|
||||||
type SharedRuntime struct {
|
type SharedRuntime struct {
|
||||||
Config *config.Config
|
Config *config.Config
|
||||||
|
RuntimeConfig *config.RuntimeConfig
|
||||||
Logger *SimpleLogger
|
Logger *SimpleLogger
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Cancel context.CancelFunc
|
Cancel context.CancelFunc
|
||||||
@@ -137,7 +141,7 @@ func Initialize(appMode string) (*SharedRuntime, error) {
|
|||||||
runtime.Context = ctx
|
runtime.Context = ctx
|
||||||
runtime.Cancel = cancel
|
runtime.Cancel = cancel
|
||||||
|
|
||||||
runtime.Logger.Info("🎭 Starting CHORUS v%s - Container-First P2P Task Coordination", AppVersion)
|
runtime.Logger.Info("🎭 Starting CHORUS v%s (build: %s, %s) - Container-First P2P Task Coordination", AppVersion, AppCommitHash, AppBuildDate)
|
||||||
runtime.Logger.Info("📦 Container deployment - Mode: %s", appMode)
|
runtime.Logger.Info("📦 Container deployment - Mode: %s", appMode)
|
||||||
|
|
||||||
// Load configuration from environment (no config files in containers)
|
// Load configuration from environment (no config files in containers)
|
||||||
@@ -149,6 +153,28 @@ func Initialize(appMode string) (*SharedRuntime, error) {
|
|||||||
runtime.Config = cfg
|
runtime.Config = cfg
|
||||||
|
|
||||||
runtime.Logger.Info("✅ Configuration loaded successfully")
|
runtime.Logger.Info("✅ Configuration loaded successfully")
|
||||||
|
|
||||||
|
// Initialize runtime configuration with assignment support
|
||||||
|
runtime.RuntimeConfig = config.NewRuntimeConfig(cfg)
|
||||||
|
|
||||||
|
// Load assignment if ASSIGN_URL is configured
|
||||||
|
if assignURL := os.Getenv("ASSIGN_URL"); assignURL != "" {
|
||||||
|
runtime.Logger.Info("📡 Loading assignment from WHOOSH: %s", assignURL)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(runtime.Context, 10*time.Second)
|
||||||
|
if err := runtime.RuntimeConfig.LoadAssignment(ctx, assignURL); err != nil {
|
||||||
|
runtime.Logger.Warn("⚠️ Failed to load assignment (continuing with base config): %v", err)
|
||||||
|
} else {
|
||||||
|
runtime.Logger.Info("✅ Assignment loaded successfully")
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Start reload handler for SIGHUP
|
||||||
|
runtime.RuntimeConfig.StartReloadHandler(runtime.Context, assignURL)
|
||||||
|
runtime.Logger.Info("📡 SIGHUP reload handler started for assignment updates")
|
||||||
|
} else {
|
||||||
|
runtime.Logger.Info("⚪ No ASSIGN_URL configured, using static configuration")
|
||||||
|
}
|
||||||
runtime.Logger.Info("🤖 Agent ID: %s", cfg.Agent.ID)
|
runtime.Logger.Info("🤖 Agent ID: %s", cfg.Agent.ID)
|
||||||
runtime.Logger.Info("🎯 Specialization: %s", cfg.Agent.Specialization)
|
runtime.Logger.Info("🎯 Specialization: %s", cfg.Agent.Specialization)
|
||||||
|
|
||||||
@@ -283,6 +309,7 @@ func (r *SharedRuntime) Cleanup() {
|
|||||||
|
|
||||||
if r.MDNSDiscovery != nil {
|
if r.MDNSDiscovery != nil {
|
||||||
r.MDNSDiscovery.Close()
|
r.MDNSDiscovery.Close()
|
||||||
|
r.Logger.Info("🔍 mDNS discovery closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.PubSub != nil {
|
if r.PubSub != nil {
|
||||||
@@ -407,8 +434,20 @@ func (r *SharedRuntime) initializeDHTStorage() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to bootstrap peers if configured
|
// Connect to bootstrap peers (with assignment override support)
|
||||||
for _, addrStr := range r.Config.V2.DHT.BootstrapPeers {
|
bootstrapPeers := r.RuntimeConfig.GetBootstrapPeers()
|
||||||
|
if len(bootstrapPeers) == 0 {
|
||||||
|
bootstrapPeers = r.Config.V2.DHT.BootstrapPeers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply join stagger if configured
|
||||||
|
joinStagger := r.RuntimeConfig.GetJoinStagger()
|
||||||
|
if joinStagger > 0 {
|
||||||
|
r.Logger.Info("⏱️ Applying join stagger delay: %v", joinStagger)
|
||||||
|
time.Sleep(joinStagger)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addrStr := range bootstrapPeers {
|
||||||
addr, err := multiaddr.NewMultiaddr(addrStr)
|
addr, err := multiaddr.NewMultiaddr(addrStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.Logger.Warn("⚠️ Invalid bootstrap address %s: %v", addrStr, err)
|
r.Logger.Warn("⚠️ Invalid bootstrap address %s: %v", addrStr, err)
|
||||||
|
|||||||
@@ -9,25 +9,31 @@ type Config struct {
|
|||||||
// Network configuration
|
// Network configuration
|
||||||
ListenAddresses []string
|
ListenAddresses []string
|
||||||
NetworkID string
|
NetworkID string
|
||||||
|
|
||||||
// Discovery configuration
|
// Discovery configuration
|
||||||
EnableMDNS bool
|
EnableMDNS bool
|
||||||
MDNSServiceTag string
|
MDNSServiceTag string
|
||||||
|
|
||||||
// DHT configuration
|
// DHT configuration
|
||||||
EnableDHT bool
|
EnableDHT bool
|
||||||
DHTBootstrapPeers []string
|
DHTBootstrapPeers []string
|
||||||
DHTMode string // "client", "server", "auto"
|
DHTMode string // "client", "server", "auto"
|
||||||
DHTProtocolPrefix string
|
DHTProtocolPrefix string
|
||||||
|
|
||||||
// Connection limits
|
// Connection limits and rate limiting
|
||||||
MaxConnections int
|
MaxConnections int
|
||||||
MaxPeersPerIP int
|
MaxPeersPerIP int
|
||||||
ConnectionTimeout time.Duration
|
ConnectionTimeout time.Duration
|
||||||
|
LowWatermark int // Connection manager low watermark
|
||||||
|
HighWatermark int // Connection manager high watermark
|
||||||
|
DialsPerSecond int // Dial rate limiting
|
||||||
|
MaxConcurrentDials int // Maximum concurrent outbound dials
|
||||||
|
MaxConcurrentDHT int // Maximum concurrent DHT queries
|
||||||
|
JoinStaggerMS int // Join stagger delay in milliseconds
|
||||||
|
|
||||||
// Security configuration
|
// Security configuration
|
||||||
EnableSecurity bool
|
EnableSecurity bool
|
||||||
|
|
||||||
// Pubsub configuration
|
// Pubsub configuration
|
||||||
EnablePubsub bool
|
EnablePubsub bool
|
||||||
BzzzTopic string // Task coordination topic
|
BzzzTopic string // Task coordination topic
|
||||||
@@ -47,25 +53,31 @@ func DefaultConfig() *Config {
|
|||||||
"/ip6/::/tcp/3333",
|
"/ip6/::/tcp/3333",
|
||||||
},
|
},
|
||||||
NetworkID: "CHORUS-network",
|
NetworkID: "CHORUS-network",
|
||||||
|
|
||||||
// Discovery settings
|
// Discovery settings - mDNS disabled for Swarm by default
|
||||||
EnableMDNS: true,
|
EnableMDNS: false, // Disabled for container environments
|
||||||
MDNSServiceTag: "CHORUS-peer-discovery",
|
MDNSServiceTag: "CHORUS-peer-discovery",
|
||||||
|
|
||||||
// DHT settings (disabled by default for local development)
|
// DHT settings (disabled by default for local development)
|
||||||
EnableDHT: false,
|
EnableDHT: false,
|
||||||
DHTBootstrapPeers: []string{},
|
DHTBootstrapPeers: []string{},
|
||||||
DHTMode: "auto",
|
DHTMode: "auto",
|
||||||
DHTProtocolPrefix: "/CHORUS",
|
DHTProtocolPrefix: "/CHORUS",
|
||||||
|
|
||||||
// Connection limits for local network
|
// Connection limits and rate limiting for scaling
|
||||||
MaxConnections: 50,
|
MaxConnections: 50,
|
||||||
MaxPeersPerIP: 3,
|
MaxPeersPerIP: 3,
|
||||||
ConnectionTimeout: 30 * time.Second,
|
ConnectionTimeout: 30 * time.Second,
|
||||||
|
LowWatermark: 32, // Keep at least 32 connections
|
||||||
|
HighWatermark: 128, // Trim above 128 connections
|
||||||
|
DialsPerSecond: 5, // Limit outbound dials to prevent storms
|
||||||
|
MaxConcurrentDials: 10, // Maximum concurrent outbound dials
|
||||||
|
MaxConcurrentDHT: 16, // Maximum concurrent DHT queries
|
||||||
|
JoinStaggerMS: 0, // No stagger by default (set by assignment)
|
||||||
|
|
||||||
// Security enabled by default
|
// Security enabled by default
|
||||||
EnableSecurity: true,
|
EnableSecurity: true,
|
||||||
|
|
||||||
// Pubsub for coordination and meta-discussion
|
// Pubsub for coordination and meta-discussion
|
||||||
EnablePubsub: true,
|
EnablePubsub: true,
|
||||||
BzzzTopic: "CHORUS/coordination/v1",
|
BzzzTopic: "CHORUS/coordination/v1",
|
||||||
@@ -164,4 +176,34 @@ func WithDHTProtocolPrefix(prefix string) Option {
|
|||||||
return func(c *Config) {
|
return func(c *Config) {
|
||||||
c.DHTProtocolPrefix = prefix
|
c.DHTProtocolPrefix = prefix
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithConnectionManager sets connection manager watermarks
|
||||||
|
func WithConnectionManager(low, high int) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.LowWatermark = low
|
||||||
|
c.HighWatermark = high
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDialRateLimit sets the dial rate limiting
|
||||||
|
func WithDialRateLimit(dialsPerSecond, maxConcurrent int) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.DialsPerSecond = dialsPerSecond
|
||||||
|
c.MaxConcurrentDials = maxConcurrent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDHTRateLimit sets the DHT query rate limiting
|
||||||
|
func WithDHTRateLimit(maxConcurrentDHT int) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.MaxConcurrentDHT = maxConcurrentDHT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithJoinStagger sets the join stagger delay in milliseconds
|
||||||
|
func WithJoinStagger(delayMS int) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.JoinStaggerMS = delayMS
|
||||||
|
}
|
||||||
}
|
}
|
||||||
11
p2p/node.go
11
p2p/node.go
@@ -6,16 +6,17 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"chorus/pkg/dht"
|
"chorus/pkg/dht"
|
||||||
|
|
||||||
"github.com/libp2p/go-libp2p"
|
"github.com/libp2p/go-libp2p"
|
||||||
|
kaddht "github.com/libp2p/go-libp2p-kad-dht"
|
||||||
"github.com/libp2p/go-libp2p/core/host"
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
"github.com/libp2p/go-libp2p/core/peer"
|
"github.com/libp2p/go-libp2p/core/peer"
|
||||||
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
||||||
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
||||||
kaddht "github.com/libp2p/go-libp2p-kad-dht"
|
|
||||||
"github.com/multiformats/go-multiaddr"
|
"github.com/multiformats/go-multiaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Node represents a Bzzz P2P node
|
// Node represents a CHORUS P2P node
|
||||||
type Node struct {
|
type Node struct {
|
||||||
host host.Host
|
host host.Host
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -157,9 +158,9 @@ func (n *Node) startBackgroundTasks() {
|
|||||||
// logConnectionStatus logs the current connection status
|
// logConnectionStatus logs the current connection status
|
||||||
func (n *Node) logConnectionStatus() {
|
func (n *Node) logConnectionStatus() {
|
||||||
peers := n.Peers()
|
peers := n.Peers()
|
||||||
fmt.Printf("🐝 Bzzz Node Status - ID: %s, Connected Peers: %d\n",
|
fmt.Printf("🐝 Bzzz Node Status - ID: %s, Connected Peers: %d\n",
|
||||||
n.ID().ShortString(), len(peers))
|
n.ID().ShortString(), len(peers))
|
||||||
|
|
||||||
if len(peers) > 0 {
|
if len(peers) > 0 {
|
||||||
fmt.Printf(" Connected to: ")
|
fmt.Printf(" Connected to: ")
|
||||||
for i, p := range peers {
|
for i, p := range peers {
|
||||||
@@ -197,4 +198,4 @@ func (n *Node) Close() error {
|
|||||||
}
|
}
|
||||||
n.cancel()
|
n.cancel()
|
||||||
return n.host.Close()
|
return n.host.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
329
pkg/ai/config.go
Normal file
329
pkg/ai/config.go
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelConfig represents the complete model configuration loaded from YAML
|
||||||
|
type ModelConfig struct {
|
||||||
|
Providers map[string]ProviderConfig `yaml:"providers" json:"providers"`
|
||||||
|
DefaultProvider string `yaml:"default_provider" json:"default_provider"`
|
||||||
|
FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"`
|
||||||
|
Roles map[string]RoleConfig `yaml:"roles" json:"roles"`
|
||||||
|
Environments map[string]EnvConfig `yaml:"environments" json:"environments"`
|
||||||
|
ModelPreferences map[string]TaskPreference `yaml:"model_preferences" json:"model_preferences"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnvConfig represents environment-specific configuration overrides
|
||||||
|
type EnvConfig struct {
|
||||||
|
DefaultProvider string `yaml:"default_provider" json:"default_provider"`
|
||||||
|
FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskPreference represents preferred models for specific task types
|
||||||
|
type TaskPreference struct {
|
||||||
|
PreferredModels []string `yaml:"preferred_models" json:"preferred_models"`
|
||||||
|
MinContextTokens int `yaml:"min_context_tokens" json:"min_context_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigLoader loads and manages AI provider configurations
|
||||||
|
type ConfigLoader struct {
|
||||||
|
configPath string
|
||||||
|
environment string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfigLoader creates a new configuration loader
|
||||||
|
func NewConfigLoader(configPath, environment string) *ConfigLoader {
|
||||||
|
return &ConfigLoader{
|
||||||
|
configPath: configPath,
|
||||||
|
environment: environment,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig loads the complete configuration from the YAML file
|
||||||
|
func (c *ConfigLoader) LoadConfig() (*ModelConfig, error) {
|
||||||
|
data, err := os.ReadFile(c.configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read config file %s: %w", c.configPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand environment variables in the config
|
||||||
|
configData := c.expandEnvVars(string(data))
|
||||||
|
|
||||||
|
var config ModelConfig
|
||||||
|
if err := yaml.Unmarshal([]byte(configData), &config); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config file %s: %w", c.configPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply environment-specific overrides
|
||||||
|
if c.environment != "" {
|
||||||
|
c.applyEnvironmentOverrides(&config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the configuration
|
||||||
|
if err := c.validateConfig(&config); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadProviderFactory creates a provider factory from the configuration
|
||||||
|
func (c *ConfigLoader) LoadProviderFactory() (*ProviderFactory, error) {
|
||||||
|
config, err := c.LoadConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Register all providers
|
||||||
|
for name, providerConfig := range config.Providers {
|
||||||
|
if err := factory.RegisterProvider(name, providerConfig); err != nil {
|
||||||
|
// Log warning but continue with other providers
|
||||||
|
fmt.Printf("Warning: Failed to register provider %s: %v\n", name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up role mapping
|
||||||
|
roleMapping := RoleModelMapping{
|
||||||
|
DefaultProvider: config.DefaultProvider,
|
||||||
|
FallbackProvider: config.FallbackProvider,
|
||||||
|
Roles: config.Roles,
|
||||||
|
}
|
||||||
|
factory.SetRoleMapping(roleMapping)
|
||||||
|
|
||||||
|
return factory, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// expandEnvVars expands environment variables in the configuration
|
||||||
|
func (c *ConfigLoader) expandEnvVars(config string) string {
|
||||||
|
// Replace ${VAR} and $VAR patterns with environment variable values
|
||||||
|
expanded := config
|
||||||
|
|
||||||
|
// Handle ${VAR} pattern
|
||||||
|
for {
|
||||||
|
start := strings.Index(expanded, "${")
|
||||||
|
if start == -1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
end := strings.Index(expanded[start:], "}")
|
||||||
|
if end == -1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
end += start
|
||||||
|
|
||||||
|
varName := expanded[start+2 : end]
|
||||||
|
varValue := os.Getenv(varName)
|
||||||
|
expanded = expanded[:start] + varValue + expanded[end+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return expanded
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyEnvironmentOverrides applies environment-specific configuration overrides
|
||||||
|
func (c *ConfigLoader) applyEnvironmentOverrides(config *ModelConfig) {
|
||||||
|
envConfig, exists := config.Environments[c.environment]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override default and fallback providers if specified
|
||||||
|
if envConfig.DefaultProvider != "" {
|
||||||
|
config.DefaultProvider = envConfig.DefaultProvider
|
||||||
|
}
|
||||||
|
if envConfig.FallbackProvider != "" {
|
||||||
|
config.FallbackProvider = envConfig.FallbackProvider
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateConfig validates the loaded configuration
|
||||||
|
func (c *ConfigLoader) validateConfig(config *ModelConfig) error {
|
||||||
|
// Check that default provider exists
|
||||||
|
if config.DefaultProvider != "" {
|
||||||
|
if _, exists := config.Providers[config.DefaultProvider]; !exists {
|
||||||
|
return fmt.Errorf("default_provider '%s' not found in providers", config.DefaultProvider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that fallback provider exists
|
||||||
|
if config.FallbackProvider != "" {
|
||||||
|
if _, exists := config.Providers[config.FallbackProvider]; !exists {
|
||||||
|
return fmt.Errorf("fallback_provider '%s' not found in providers", config.FallbackProvider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate each provider configuration
|
||||||
|
for name, providerConfig := range config.Providers {
|
||||||
|
if err := c.validateProviderConfig(name, providerConfig); err != nil {
|
||||||
|
return fmt.Errorf("invalid provider config '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate role configurations
|
||||||
|
for roleName, roleConfig := range config.Roles {
|
||||||
|
if err := c.validateRoleConfig(roleName, roleConfig, config.Providers); err != nil {
|
||||||
|
return fmt.Errorf("invalid role config '%s': %w", roleName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateProviderConfig validates a single provider configuration
|
||||||
|
func (c *ConfigLoader) validateProviderConfig(name string, config ProviderConfig) error {
|
||||||
|
// Check required fields
|
||||||
|
if config.Type == "" {
|
||||||
|
return fmt.Errorf("type is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate provider type
|
||||||
|
validTypes := []string{"ollama", "openai", "resetdata"}
|
||||||
|
typeValid := false
|
||||||
|
for _, validType := range validTypes {
|
||||||
|
if config.Type == validType {
|
||||||
|
typeValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !typeValid {
|
||||||
|
return fmt.Errorf("invalid provider type '%s', must be one of: %s",
|
||||||
|
config.Type, strings.Join(validTypes, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check endpoint for all types
|
||||||
|
if config.Endpoint == "" {
|
||||||
|
return fmt.Errorf("endpoint is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check API key for providers that require it
|
||||||
|
if (config.Type == "openai" || config.Type == "resetdata") && config.APIKey == "" {
|
||||||
|
return fmt.Errorf("api_key is required for %s provider", config.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check default model
|
||||||
|
if config.DefaultModel == "" {
|
||||||
|
return fmt.Errorf("default_model is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate timeout
|
||||||
|
if config.Timeout == 0 {
|
||||||
|
config.Timeout = 300 * time.Second // Set default
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate temperature range
|
||||||
|
if config.Temperature < 0 || config.Temperature > 2.0 {
|
||||||
|
return fmt.Errorf("temperature must be between 0 and 2.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate max tokens
|
||||||
|
if config.MaxTokens <= 0 {
|
||||||
|
config.MaxTokens = 4096 // Set default
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateRoleConfig validates a role configuration
|
||||||
|
func (c *ConfigLoader) validateRoleConfig(roleName string, config RoleConfig, providers map[string]ProviderConfig) error {
|
||||||
|
// Check that provider exists
|
||||||
|
if config.Provider != "" {
|
||||||
|
if _, exists := providers[config.Provider]; !exists {
|
||||||
|
return fmt.Errorf("provider '%s' not found", config.Provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check fallback provider exists if specified
|
||||||
|
if config.FallbackProvider != "" {
|
||||||
|
if _, exists := providers[config.FallbackProvider]; !exists {
|
||||||
|
return fmt.Errorf("fallback_provider '%s' not found", config.FallbackProvider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate temperature range
|
||||||
|
if config.Temperature < 0 || config.Temperature > 2.0 {
|
||||||
|
return fmt.Errorf("temperature must be between 0 and 2.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate max tokens
|
||||||
|
if config.MaxTokens < 0 {
|
||||||
|
return fmt.Errorf("max_tokens cannot be negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderForTaskType returns the best provider for a specific task type
|
||||||
|
func (c *ConfigLoader) GetProviderForTaskType(config *ModelConfig, factory *ProviderFactory, taskType string) (ModelProvider, ProviderConfig, error) {
|
||||||
|
// Check if we have preferences for this task type
|
||||||
|
if preference, exists := config.ModelPreferences[taskType]; exists {
|
||||||
|
// Try each preferred model in order
|
||||||
|
for _, modelName := range preference.PreferredModels {
|
||||||
|
for providerName, provider := range factory.providers {
|
||||||
|
capabilities := provider.GetCapabilities()
|
||||||
|
for _, supportedModel := range capabilities.SupportedModels {
|
||||||
|
if supportedModel == modelName && factory.isProviderHealthy(providerName) {
|
||||||
|
providerConfig := factory.configs[providerName]
|
||||||
|
providerConfig.DefaultModel = modelName
|
||||||
|
|
||||||
|
// Ensure minimum context if specified
|
||||||
|
if preference.MinContextTokens > providerConfig.MaxTokens {
|
||||||
|
providerConfig.MaxTokens = preference.MinContextTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider, providerConfig, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to default provider selection
|
||||||
|
if config.DefaultProvider != "" {
|
||||||
|
provider, err := factory.GetProvider(config.DefaultProvider)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ProviderConfig{}, err
|
||||||
|
}
|
||||||
|
return provider, factory.configs[config.DefaultProvider], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ProviderConfig{}, NewProviderError(ErrProviderNotFound, "no suitable provider found for task type "+taskType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfigPath returns the default path for the model configuration file
|
||||||
|
func DefaultConfigPath() string {
|
||||||
|
// Try environment variable first
|
||||||
|
if path := os.Getenv("CHORUS_MODEL_CONFIG"); path != "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try relative to current working directory
|
||||||
|
if _, err := os.Stat("configs/models.yaml"); err == nil {
|
||||||
|
return "configs/models.yaml"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try relative to executable
|
||||||
|
if _, err := os.Stat("./configs/models.yaml"); err == nil {
|
||||||
|
return "./configs/models.yaml"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default fallback
|
||||||
|
return "configs/models.yaml"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnvironment returns the current environment (from env var or default)
|
||||||
|
func GetEnvironment() string {
|
||||||
|
if env := os.Getenv("CHORUS_ENVIRONMENT"); env != "" {
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
if env := os.Getenv("NODE_ENV"); env != "" {
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
return "development" // default
|
||||||
|
}
|
||||||
596
pkg/ai/config_test.go
Normal file
596
pkg/ai/config_test.go
Normal file
@@ -0,0 +1,596 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewConfigLoader(t *testing.T) {
|
||||||
|
loader := NewConfigLoader("test.yaml", "development")
|
||||||
|
|
||||||
|
assert.Equal(t, "test.yaml", loader.configPath)
|
||||||
|
assert.Equal(t, "development", loader.environment)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderExpandEnvVars(t *testing.T) {
|
||||||
|
loader := NewConfigLoader("", "")
|
||||||
|
|
||||||
|
// Set test environment variables
|
||||||
|
os.Setenv("TEST_VAR", "test_value")
|
||||||
|
os.Setenv("ANOTHER_VAR", "another_value")
|
||||||
|
defer func() {
|
||||||
|
os.Unsetenv("TEST_VAR")
|
||||||
|
os.Unsetenv("ANOTHER_VAR")
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single variable",
|
||||||
|
input: "endpoint: ${TEST_VAR}",
|
||||||
|
expected: "endpoint: test_value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple variables",
|
||||||
|
input: "endpoint: ${TEST_VAR}/api\nkey: ${ANOTHER_VAR}",
|
||||||
|
expected: "endpoint: test_value/api\nkey: another_value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no variables",
|
||||||
|
input: "endpoint: http://localhost",
|
||||||
|
expected: "endpoint: http://localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "undefined variable",
|
||||||
|
input: "endpoint: ${UNDEFINED_VAR}",
|
||||||
|
expected: "endpoint: ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := loader.expandEnvVars(tt.input)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderApplyEnvironmentOverrides(t *testing.T) {
|
||||||
|
loader := NewConfigLoader("", "production")
|
||||||
|
|
||||||
|
config := &ModelConfig{
|
||||||
|
DefaultProvider: "ollama",
|
||||||
|
FallbackProvider: "resetdata",
|
||||||
|
Environments: map[string]EnvConfig{
|
||||||
|
"production": {
|
||||||
|
DefaultProvider: "openai",
|
||||||
|
FallbackProvider: "ollama",
|
||||||
|
},
|
||||||
|
"development": {
|
||||||
|
DefaultProvider: "ollama",
|
||||||
|
FallbackProvider: "mock",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
loader.applyEnvironmentOverrides(config)
|
||||||
|
|
||||||
|
assert.Equal(t, "openai", config.DefaultProvider)
|
||||||
|
assert.Equal(t, "ollama", config.FallbackProvider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderApplyEnvironmentOverridesNoMatch(t *testing.T) {
|
||||||
|
loader := NewConfigLoader("", "testing")
|
||||||
|
|
||||||
|
config := &ModelConfig{
|
||||||
|
DefaultProvider: "ollama",
|
||||||
|
FallbackProvider: "resetdata",
|
||||||
|
Environments: map[string]EnvConfig{
|
||||||
|
"production": {
|
||||||
|
DefaultProvider: "openai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
original := *config
|
||||||
|
loader.applyEnvironmentOverrides(config)
|
||||||
|
|
||||||
|
// Should remain unchanged
|
||||||
|
assert.Equal(t, original.DefaultProvider, config.DefaultProvider)
|
||||||
|
assert.Equal(t, original.FallbackProvider, config.FallbackProvider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderValidateConfig(t *testing.T) {
|
||||||
|
loader := NewConfigLoader("", "")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *ModelConfig
|
||||||
|
expectErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
config: &ModelConfig{
|
||||||
|
DefaultProvider: "test",
|
||||||
|
FallbackProvider: "backup",
|
||||||
|
Providers: map[string]ProviderConfig{
|
||||||
|
"test": {
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
},
|
||||||
|
"backup": {
|
||||||
|
Type: "resetdata",
|
||||||
|
Endpoint: "https://api.resetdata.ai",
|
||||||
|
APIKey: "key",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Roles: map[string]RoleConfig{
|
||||||
|
"developer": {
|
||||||
|
Provider: "test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default provider not found",
|
||||||
|
config: &ModelConfig{
|
||||||
|
DefaultProvider: "nonexistent",
|
||||||
|
Providers: map[string]ProviderConfig{
|
||||||
|
"test": {
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "default_provider 'nonexistent' not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback provider not found",
|
||||||
|
config: &ModelConfig{
|
||||||
|
FallbackProvider: "nonexistent",
|
||||||
|
Providers: map[string]ProviderConfig{
|
||||||
|
"test": {
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "fallback_provider 'nonexistent' not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid provider config",
|
||||||
|
config: &ModelConfig{
|
||||||
|
Providers: map[string]ProviderConfig{
|
||||||
|
"invalid": {
|
||||||
|
Type: "invalid_type",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "invalid provider config 'invalid'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid role config",
|
||||||
|
config: &ModelConfig{
|
||||||
|
Providers: map[string]ProviderConfig{
|
||||||
|
"test": {
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Roles: map[string]RoleConfig{
|
||||||
|
"developer": {
|
||||||
|
Provider: "nonexistent",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "invalid role config 'developer'",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := loader.validateConfig(tt.config)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderValidateProviderConfig(t *testing.T) {
|
||||||
|
loader := NewConfigLoader("", "")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config ProviderConfig
|
||||||
|
expectErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid ollama config",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
Temperature: 0.7,
|
||||||
|
MaxTokens: 4096,
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid openai config",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "openai",
|
||||||
|
Endpoint: "https://api.openai.com/v1",
|
||||||
|
APIKey: "test-key",
|
||||||
|
DefaultModel: "gpt-4",
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing type",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Endpoint: "http://localhost",
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "type is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid type",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "invalid",
|
||||||
|
Endpoint: "http://localhost",
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "invalid provider type 'invalid'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing endpoint",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "ollama",
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "endpoint is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "openai missing api key",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "openai",
|
||||||
|
Endpoint: "https://api.openai.com/v1",
|
||||||
|
DefaultModel: "gpt-4",
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "api_key is required for openai provider",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing default model",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "default_model is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid temperature",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
Temperature: 3.0, // Too high
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "temperature must be between 0 and 2.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := loader.validateProviderConfig("test", tt.config)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderValidateRoleConfig(t *testing.T) {
|
||||||
|
loader := NewConfigLoader("", "")
|
||||||
|
|
||||||
|
providers := map[string]ProviderConfig{
|
||||||
|
"test": {
|
||||||
|
Type: "ollama",
|
||||||
|
},
|
||||||
|
"backup": {
|
||||||
|
Type: "resetdata",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config RoleConfig
|
||||||
|
expectErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid role config",
|
||||||
|
config: RoleConfig{
|
||||||
|
Provider: "test",
|
||||||
|
Model: "llama2",
|
||||||
|
Temperature: 0.7,
|
||||||
|
MaxTokens: 4096,
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "provider not found",
|
||||||
|
config: RoleConfig{
|
||||||
|
Provider: "nonexistent",
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "provider 'nonexistent' not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback provider not found",
|
||||||
|
config: RoleConfig{
|
||||||
|
Provider: "test",
|
||||||
|
FallbackProvider: "nonexistent",
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "fallback_provider 'nonexistent' not found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid temperature",
|
||||||
|
config: RoleConfig{
|
||||||
|
Provider: "test",
|
||||||
|
Temperature: -1.0,
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "temperature must be between 0 and 2.0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid max tokens",
|
||||||
|
config: RoleConfig{
|
||||||
|
Provider: "test",
|
||||||
|
MaxTokens: -100,
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
errMsg: "max_tokens cannot be negative",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := loader.validateRoleConfig("test-role", tt.config, providers)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderLoadConfig(t *testing.T) {
|
||||||
|
// Create a temporary config file
|
||||||
|
configContent := `
|
||||||
|
providers:
|
||||||
|
test:
|
||||||
|
type: ollama
|
||||||
|
endpoint: http://localhost:11434
|
||||||
|
default_model: llama2
|
||||||
|
temperature: 0.7
|
||||||
|
|
||||||
|
default_provider: test
|
||||||
|
fallback_provider: test
|
||||||
|
|
||||||
|
roles:
|
||||||
|
developer:
|
||||||
|
provider: test
|
||||||
|
model: codellama
|
||||||
|
`
|
||||||
|
|
||||||
|
tmpFile, err := ioutil.TempFile("", "test-config-*.yaml")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.Remove(tmpFile.Name())
|
||||||
|
|
||||||
|
_, err = tmpFile.WriteString(configContent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
tmpFile.Close()
|
||||||
|
|
||||||
|
loader := NewConfigLoader(tmpFile.Name(), "")
|
||||||
|
config, err := loader.LoadConfig()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "test", config.DefaultProvider)
|
||||||
|
assert.Equal(t, "test", config.FallbackProvider)
|
||||||
|
assert.Len(t, config.Providers, 1)
|
||||||
|
assert.Contains(t, config.Providers, "test")
|
||||||
|
assert.Equal(t, "ollama", config.Providers["test"].Type)
|
||||||
|
assert.Len(t, config.Roles, 1)
|
||||||
|
assert.Contains(t, config.Roles, "developer")
|
||||||
|
assert.Equal(t, "codellama", config.Roles["developer"].Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderLoadConfigWithEnvVars(t *testing.T) {
|
||||||
|
// Set environment variables
|
||||||
|
os.Setenv("TEST_ENDPOINT", "http://test.example.com")
|
||||||
|
os.Setenv("TEST_MODEL", "test-model")
|
||||||
|
defer func() {
|
||||||
|
os.Unsetenv("TEST_ENDPOINT")
|
||||||
|
os.Unsetenv("TEST_MODEL")
|
||||||
|
}()
|
||||||
|
|
||||||
|
configContent := `
|
||||||
|
providers:
|
||||||
|
test:
|
||||||
|
type: ollama
|
||||||
|
endpoint: ${TEST_ENDPOINT}
|
||||||
|
default_model: ${TEST_MODEL}
|
||||||
|
|
||||||
|
default_provider: test
|
||||||
|
`
|
||||||
|
|
||||||
|
tmpFile, err := ioutil.TempFile("", "test-config-*.yaml")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.Remove(tmpFile.Name())
|
||||||
|
|
||||||
|
_, err = tmpFile.WriteString(configContent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
tmpFile.Close()
|
||||||
|
|
||||||
|
loader := NewConfigLoader(tmpFile.Name(), "")
|
||||||
|
config, err := loader.LoadConfig()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "http://test.example.com", config.Providers["test"].Endpoint)
|
||||||
|
assert.Equal(t, "test-model", config.Providers["test"].DefaultModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderLoadConfigFileNotFound(t *testing.T) {
|
||||||
|
loader := NewConfigLoader("nonexistent.yaml", "")
|
||||||
|
_, err := loader.LoadConfig()
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to read config file")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigLoaderLoadConfigInvalidYAML(t *testing.T) {
|
||||||
|
// Create a file with invalid YAML
|
||||||
|
tmpFile, err := ioutil.TempFile("", "invalid-config-*.yaml")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.Remove(tmpFile.Name())
|
||||||
|
|
||||||
|
_, err = tmpFile.WriteString("invalid: yaml: content: [")
|
||||||
|
require.NoError(t, err)
|
||||||
|
tmpFile.Close()
|
||||||
|
|
||||||
|
loader := NewConfigLoader(tmpFile.Name(), "")
|
||||||
|
_, err = loader.LoadConfig()
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to parse config file")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultConfigPath(t *testing.T) {
|
||||||
|
// Test with environment variable
|
||||||
|
os.Setenv("CHORUS_MODEL_CONFIG", "/custom/path/models.yaml")
|
||||||
|
defer os.Unsetenv("CHORUS_MODEL_CONFIG")
|
||||||
|
|
||||||
|
path := DefaultConfigPath()
|
||||||
|
assert.Equal(t, "/custom/path/models.yaml", path)
|
||||||
|
|
||||||
|
// Test without environment variable
|
||||||
|
os.Unsetenv("CHORUS_MODEL_CONFIG")
|
||||||
|
path = DefaultConfigPath()
|
||||||
|
assert.Equal(t, "configs/models.yaml", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetEnvironment(t *testing.T) {
|
||||||
|
// Test with CHORUS_ENVIRONMENT
|
||||||
|
os.Setenv("CHORUS_ENVIRONMENT", "production")
|
||||||
|
defer os.Unsetenv("CHORUS_ENVIRONMENT")
|
||||||
|
|
||||||
|
env := GetEnvironment()
|
||||||
|
assert.Equal(t, "production", env)
|
||||||
|
|
||||||
|
// Test with NODE_ENV fallback
|
||||||
|
os.Unsetenv("CHORUS_ENVIRONMENT")
|
||||||
|
os.Setenv("NODE_ENV", "staging")
|
||||||
|
defer os.Unsetenv("NODE_ENV")
|
||||||
|
|
||||||
|
env = GetEnvironment()
|
||||||
|
assert.Equal(t, "staging", env)
|
||||||
|
|
||||||
|
// Test default
|
||||||
|
os.Unsetenv("CHORUS_ENVIRONMENT")
|
||||||
|
os.Unsetenv("NODE_ENV")
|
||||||
|
|
||||||
|
env = GetEnvironment()
|
||||||
|
assert.Equal(t, "development", env)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelConfig(t *testing.T) {
|
||||||
|
config := ModelConfig{
|
||||||
|
Providers: map[string]ProviderConfig{
|
||||||
|
"test": {
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DefaultProvider: "test",
|
||||||
|
FallbackProvider: "test",
|
||||||
|
Roles: map[string]RoleConfig{
|
||||||
|
"developer": {
|
||||||
|
Provider: "test",
|
||||||
|
Model: "codellama",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Environments: map[string]EnvConfig{
|
||||||
|
"production": {
|
||||||
|
DefaultProvider: "openai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ModelPreferences: map[string]TaskPreference{
|
||||||
|
"code_generation": {
|
||||||
|
PreferredModels: []string{"codellama", "gpt-4"},
|
||||||
|
MinContextTokens: 8192,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Len(t, config.Providers, 1)
|
||||||
|
assert.Len(t, config.Roles, 1)
|
||||||
|
assert.Len(t, config.Environments, 1)
|
||||||
|
assert.Len(t, config.ModelPreferences, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvConfig(t *testing.T) {
|
||||||
|
envConfig := EnvConfig{
|
||||||
|
DefaultProvider: "openai",
|
||||||
|
FallbackProvider: "ollama",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "openai", envConfig.DefaultProvider)
|
||||||
|
assert.Equal(t, "ollama", envConfig.FallbackProvider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskPreference(t *testing.T) {
|
||||||
|
pref := TaskPreference{
|
||||||
|
PreferredModels: []string{"gpt-4", "codellama:13b"},
|
||||||
|
MinContextTokens: 8192,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Len(t, pref.PreferredModels, 2)
|
||||||
|
assert.Equal(t, 8192, pref.MinContextTokens)
|
||||||
|
assert.Contains(t, pref.PreferredModels, "gpt-4")
|
||||||
|
}
|
||||||
392
pkg/ai/factory.go
Normal file
392
pkg/ai/factory.go
Normal file
@@ -0,0 +1,392 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderFactory creates and manages AI model providers
|
||||||
|
type ProviderFactory struct {
|
||||||
|
configs map[string]ProviderConfig // provider name -> config
|
||||||
|
providers map[string]ModelProvider // provider name -> instance
|
||||||
|
roleMapping RoleModelMapping // role-based model selection
|
||||||
|
healthChecks map[string]bool // provider name -> health status
|
||||||
|
lastHealthCheck map[string]time.Time // provider name -> last check time
|
||||||
|
CreateProvider func(config ProviderConfig) (ModelProvider, error) // provider creation function
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProviderFactory creates a new provider factory
|
||||||
|
func NewProviderFactory() *ProviderFactory {
|
||||||
|
factory := &ProviderFactory{
|
||||||
|
configs: make(map[string]ProviderConfig),
|
||||||
|
providers: make(map[string]ModelProvider),
|
||||||
|
healthChecks: make(map[string]bool),
|
||||||
|
lastHealthCheck: make(map[string]time.Time),
|
||||||
|
}
|
||||||
|
factory.CreateProvider = factory.defaultCreateProvider
|
||||||
|
return factory
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProvider registers a provider configuration
|
||||||
|
func (f *ProviderFactory) RegisterProvider(name string, config ProviderConfig) error {
|
||||||
|
// Validate the configuration
|
||||||
|
provider, err := f.CreateProvider(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create provider %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := provider.ValidateConfig(); err != nil {
|
||||||
|
return fmt.Errorf("invalid configuration for provider %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.configs[name] = config
|
||||||
|
f.providers[name] = provider
|
||||||
|
f.healthChecks[name] = true
|
||||||
|
f.lastHealthCheck[name] = time.Now()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRoleMapping sets the role-to-model mapping configuration
|
||||||
|
func (f *ProviderFactory) SetRoleMapping(mapping RoleModelMapping) {
|
||||||
|
f.roleMapping = mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProvider returns a provider by name
|
||||||
|
func (f *ProviderFactory) GetProvider(name string) (ModelProvider, error) {
|
||||||
|
provider, exists := f.providers[name]
|
||||||
|
if !exists {
|
||||||
|
return nil, NewProviderError(ErrProviderNotFound, fmt.Sprintf("provider %s not found", name))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check health status
|
||||||
|
if !f.isProviderHealthy(name) {
|
||||||
|
return nil, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("provider %s is unhealthy", name))
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderForRole returns the best provider for a specific agent role
|
||||||
|
func (f *ProviderFactory) GetProviderForRole(role string) (ModelProvider, ProviderConfig, error) {
|
||||||
|
// Get role configuration
|
||||||
|
roleConfig, exists := f.roleMapping.Roles[role]
|
||||||
|
if !exists {
|
||||||
|
// Fall back to default provider
|
||||||
|
if f.roleMapping.DefaultProvider != "" {
|
||||||
|
return f.getProviderWithFallback(f.roleMapping.DefaultProvider, f.roleMapping.FallbackProvider)
|
||||||
|
}
|
||||||
|
return nil, ProviderConfig{}, NewProviderError(ErrProviderNotFound, fmt.Sprintf("no provider configured for role %s", role))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try primary provider first
|
||||||
|
provider, config, err := f.getProviderWithFallback(roleConfig.Provider, roleConfig.FallbackProvider)
|
||||||
|
if err != nil {
|
||||||
|
// Try role fallback
|
||||||
|
if roleConfig.FallbackProvider != "" {
|
||||||
|
return f.getProviderWithFallback(roleConfig.FallbackProvider, f.roleMapping.FallbackProvider)
|
||||||
|
}
|
||||||
|
// Try global fallback
|
||||||
|
if f.roleMapping.FallbackProvider != "" {
|
||||||
|
return f.getProviderWithFallback(f.roleMapping.FallbackProvider, "")
|
||||||
|
}
|
||||||
|
return nil, ProviderConfig{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge role-specific configuration
|
||||||
|
mergedConfig := f.mergeRoleConfig(config, roleConfig)
|
||||||
|
return provider, mergedConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderForTask returns the best provider for a specific task
|
||||||
|
func (f *ProviderFactory) GetProviderForTask(request *TaskRequest) (ModelProvider, ProviderConfig, error) {
|
||||||
|
// Check if a specific model is requested
|
||||||
|
if request.ModelName != "" {
|
||||||
|
// Find provider that supports the requested model
|
||||||
|
for name, provider := range f.providers {
|
||||||
|
capabilities := provider.GetCapabilities()
|
||||||
|
for _, supportedModel := range capabilities.SupportedModels {
|
||||||
|
if supportedModel == request.ModelName {
|
||||||
|
if f.isProviderHealthy(name) {
|
||||||
|
config := f.configs[name]
|
||||||
|
config.DefaultModel = request.ModelName // Override default model
|
||||||
|
return provider, config, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ProviderConfig{}, NewProviderError(ErrModelNotSupported, fmt.Sprintf("model %s not available", request.ModelName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use role-based selection
|
||||||
|
return f.GetProviderForRole(request.AgentRole)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProviders returns all registered provider names
|
||||||
|
func (f *ProviderFactory) ListProviders() []string {
|
||||||
|
var names []string
|
||||||
|
for name := range f.providers {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListHealthyProviders returns only healthy provider names
|
||||||
|
func (f *ProviderFactory) ListHealthyProviders() []string {
|
||||||
|
var names []string
|
||||||
|
for name := range f.providers {
|
||||||
|
if f.isProviderHealthy(name) {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderInfo returns information about all registered providers
|
||||||
|
func (f *ProviderFactory) GetProviderInfo() map[string]ProviderInfo {
|
||||||
|
info := make(map[string]ProviderInfo)
|
||||||
|
for name, provider := range f.providers {
|
||||||
|
providerInfo := provider.GetProviderInfo()
|
||||||
|
providerInfo.Name = name // Override with registered name
|
||||||
|
info[name] = providerInfo
|
||||||
|
}
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck performs health checks on all providers
|
||||||
|
func (f *ProviderFactory) HealthCheck(ctx context.Context) map[string]error {
|
||||||
|
results := make(map[string]error)
|
||||||
|
|
||||||
|
for name, provider := range f.providers {
|
||||||
|
err := f.checkProviderHealth(ctx, name, provider)
|
||||||
|
results[name] = err
|
||||||
|
f.healthChecks[name] = (err == nil)
|
||||||
|
f.lastHealthCheck[name] = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHealthStatus returns the current health status of all providers
|
||||||
|
func (f *ProviderFactory) GetHealthStatus() map[string]ProviderHealth {
|
||||||
|
status := make(map[string]ProviderHealth)
|
||||||
|
|
||||||
|
for name, provider := range f.providers {
|
||||||
|
status[name] = ProviderHealth{
|
||||||
|
Name: name,
|
||||||
|
Healthy: f.healthChecks[name],
|
||||||
|
LastCheck: f.lastHealthCheck[name],
|
||||||
|
ProviderInfo: provider.GetProviderInfo(),
|
||||||
|
Capabilities: provider.GetCapabilities(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartHealthCheckRoutine starts a background health check routine
|
||||||
|
func (f *ProviderFactory) StartHealthCheckRoutine(ctx context.Context, interval time.Duration) {
|
||||||
|
if interval == 0 {
|
||||||
|
interval = 5 * time.Minute // Default to 5 minutes
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
go func() {
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
healthCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
f.HealthCheck(healthCtx)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultCreateProvider creates a provider instance based on configuration
|
||||||
|
func (f *ProviderFactory) defaultCreateProvider(config ProviderConfig) (ModelProvider, error) {
|
||||||
|
switch config.Type {
|
||||||
|
case "ollama":
|
||||||
|
return NewOllamaProvider(config), nil
|
||||||
|
case "openai":
|
||||||
|
return NewOpenAIProvider(config), nil
|
||||||
|
case "resetdata":
|
||||||
|
return NewResetDataProvider(config), nil
|
||||||
|
default:
|
||||||
|
return nil, NewProviderError(ErrProviderNotFound, fmt.Sprintf("unknown provider type: %s", config.Type))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProviderWithFallback attempts to get a provider with fallback support
|
||||||
|
func (f *ProviderFactory) getProviderWithFallback(primaryName, fallbackName string) (ModelProvider, ProviderConfig, error) {
|
||||||
|
// Try primary provider
|
||||||
|
if primaryName != "" {
|
||||||
|
if provider, exists := f.providers[primaryName]; exists && f.isProviderHealthy(primaryName) {
|
||||||
|
return provider, f.configs[primaryName], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try fallback provider
|
||||||
|
if fallbackName != "" {
|
||||||
|
if provider, exists := f.providers[fallbackName]; exists && f.isProviderHealthy(fallbackName) {
|
||||||
|
return provider, f.configs[fallbackName], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if primaryName != "" {
|
||||||
|
return nil, ProviderConfig{}, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("provider %s and fallback %s are unavailable", primaryName, fallbackName))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ProviderConfig{}, NewProviderError(ErrProviderNotFound, "no provider specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeRoleConfig merges role-specific configuration with provider configuration
|
||||||
|
func (f *ProviderFactory) mergeRoleConfig(baseConfig ProviderConfig, roleConfig RoleConfig) ProviderConfig {
|
||||||
|
merged := baseConfig
|
||||||
|
|
||||||
|
// Override model if specified in role config
|
||||||
|
if roleConfig.Model != "" {
|
||||||
|
merged.DefaultModel = roleConfig.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override temperature if specified
|
||||||
|
if roleConfig.Temperature > 0 {
|
||||||
|
merged.Temperature = roleConfig.Temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override max tokens if specified
|
||||||
|
if roleConfig.MaxTokens > 0 {
|
||||||
|
merged.MaxTokens = roleConfig.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override tool settings
|
||||||
|
if roleConfig.EnableTools {
|
||||||
|
merged.EnableTools = roleConfig.EnableTools
|
||||||
|
}
|
||||||
|
if roleConfig.EnableMCP {
|
||||||
|
merged.EnableMCP = roleConfig.EnableMCP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge MCP servers
|
||||||
|
if len(roleConfig.MCPServers) > 0 {
|
||||||
|
merged.MCPServers = append(merged.MCPServers, roleConfig.MCPServers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// isProviderHealthy checks if a provider is currently healthy
|
||||||
|
func (f *ProviderFactory) isProviderHealthy(name string) bool {
|
||||||
|
healthy, exists := f.healthChecks[name]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if health check is too old (consider unhealthy if >10 minutes old)
|
||||||
|
lastCheck, exists := f.lastHealthCheck[name]
|
||||||
|
if !exists || time.Since(lastCheck) > 10*time.Minute {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return healthy
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkProviderHealth performs a health check on a specific provider
|
||||||
|
func (f *ProviderFactory) checkProviderHealth(ctx context.Context, name string, provider ModelProvider) error {
|
||||||
|
// Create a minimal health check request
|
||||||
|
healthRequest := &TaskRequest{
|
||||||
|
TaskID: "health-check",
|
||||||
|
AgentID: "health-checker",
|
||||||
|
AgentRole: "system",
|
||||||
|
Repository: "health-check",
|
||||||
|
TaskTitle: "Health Check",
|
||||||
|
TaskDescription: "Simple health check task",
|
||||||
|
ModelName: "", // Use default
|
||||||
|
MaxTokens: 50, // Minimal response
|
||||||
|
EnableTools: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a short timeout for health checks
|
||||||
|
healthCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := provider.ExecuteTask(healthCtx, healthRequest)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderHealth represents the health status of a provider
|
||||||
|
type ProviderHealth struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Healthy bool `json:"healthy"`
|
||||||
|
LastCheck time.Time `json:"last_check"`
|
||||||
|
ProviderInfo ProviderInfo `json:"provider_info"`
|
||||||
|
Capabilities ProviderCapabilities `json:"capabilities"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultProviderFactory creates a factory with common provider configurations
|
||||||
|
func DefaultProviderFactory() *ProviderFactory {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Register default Ollama provider
|
||||||
|
ollamaConfig := ProviderConfig{
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
DefaultModel: "llama3.1:8b",
|
||||||
|
Temperature: 0.7,
|
||||||
|
MaxTokens: 4096,
|
||||||
|
Timeout: 300 * time.Second,
|
||||||
|
RetryAttempts: 3,
|
||||||
|
RetryDelay: 2 * time.Second,
|
||||||
|
EnableTools: true,
|
||||||
|
EnableMCP: true,
|
||||||
|
}
|
||||||
|
factory.RegisterProvider("ollama", ollamaConfig)
|
||||||
|
|
||||||
|
// Set default role mapping
|
||||||
|
defaultMapping := RoleModelMapping{
|
||||||
|
DefaultProvider: "ollama",
|
||||||
|
FallbackProvider: "ollama",
|
||||||
|
Roles: map[string]RoleConfig{
|
||||||
|
"developer": {
|
||||||
|
Provider: "ollama",
|
||||||
|
Model: "codellama:13b",
|
||||||
|
Temperature: 0.3,
|
||||||
|
MaxTokens: 8192,
|
||||||
|
EnableTools: true,
|
||||||
|
EnableMCP: true,
|
||||||
|
SystemPrompt: "You are an expert software developer focused on writing clean, maintainable, and well-tested code.",
|
||||||
|
},
|
||||||
|
"reviewer": {
|
||||||
|
Provider: "ollama",
|
||||||
|
Model: "llama3.1:8b",
|
||||||
|
Temperature: 0.2,
|
||||||
|
MaxTokens: 6144,
|
||||||
|
EnableTools: true,
|
||||||
|
SystemPrompt: "You are a thorough code reviewer focused on quality, security, and best practices.",
|
||||||
|
},
|
||||||
|
"architect": {
|
||||||
|
Provider: "ollama",
|
||||||
|
Model: "llama3.1:13b",
|
||||||
|
Temperature: 0.5,
|
||||||
|
MaxTokens: 8192,
|
||||||
|
EnableTools: true,
|
||||||
|
SystemPrompt: "You are a senior software architect focused on system design and technical decision making.",
|
||||||
|
},
|
||||||
|
"tester": {
|
||||||
|
Provider: "ollama",
|
||||||
|
Model: "codellama:7b",
|
||||||
|
Temperature: 0.3,
|
||||||
|
MaxTokens: 6144,
|
||||||
|
EnableTools: true,
|
||||||
|
SystemPrompt: "You are a QA engineer focused on comprehensive testing and quality assurance.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
factory.SetRoleMapping(defaultMapping)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
}
|
||||||
516
pkg/ai/factory_test.go
Normal file
516
pkg/ai/factory_test.go
Normal file
@@ -0,0 +1,516 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewProviderFactory(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
assert.NotNil(t, factory)
|
||||||
|
assert.Empty(t, factory.configs)
|
||||||
|
assert.Empty(t, factory.providers)
|
||||||
|
assert.Empty(t, factory.healthChecks)
|
||||||
|
assert.Empty(t, factory.lastHealthCheck)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryRegisterProvider(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Create a valid mock provider config (since validation will be called)
|
||||||
|
config := ProviderConfig{
|
||||||
|
Type: "mock",
|
||||||
|
Endpoint: "mock://localhost",
|
||||||
|
DefaultModel: "test-model",
|
||||||
|
Temperature: 0.7,
|
||||||
|
MaxTokens: 4096,
|
||||||
|
Timeout: 300 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override CreateProvider to return our mock
|
||||||
|
originalCreate := factory.CreateProvider
|
||||||
|
factory.CreateProvider = func(config ProviderConfig) (ModelProvider, error) {
|
||||||
|
return NewMockProvider("test-provider"), nil
|
||||||
|
}
|
||||||
|
defer func() { factory.CreateProvider = originalCreate }()
|
||||||
|
|
||||||
|
err := factory.RegisterProvider("test", config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify provider was registered
|
||||||
|
assert.Len(t, factory.providers, 1)
|
||||||
|
assert.Contains(t, factory.providers, "test")
|
||||||
|
assert.True(t, factory.healthChecks["test"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryRegisterProviderValidationFailure(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Create a mock provider that will fail validation
|
||||||
|
config := ProviderConfig{
|
||||||
|
Type: "mock",
|
||||||
|
Endpoint: "mock://localhost",
|
||||||
|
DefaultModel: "test-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override CreateProvider to return a failing mock
|
||||||
|
factory.CreateProvider = func(config ProviderConfig) (ModelProvider, error) {
|
||||||
|
mock := NewMockProvider("failing-provider")
|
||||||
|
mock.shouldFail = true // This will make ValidateConfig fail
|
||||||
|
return mock, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := factory.RegisterProvider("failing", config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid configuration")
|
||||||
|
|
||||||
|
// Verify provider was not registered
|
||||||
|
assert.Empty(t, factory.providers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetProvider(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
mockProvider := NewMockProvider("test-provider")
|
||||||
|
|
||||||
|
// Manually add provider and mark as healthy
|
||||||
|
factory.providers["test"] = mockProvider
|
||||||
|
factory.healthChecks["test"] = true
|
||||||
|
factory.lastHealthCheck["test"] = time.Now()
|
||||||
|
|
||||||
|
provider, err := factory.GetProvider("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, mockProvider, provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetProviderNotFound(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
_, err := factory.GetProvider("nonexistent")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.IsType(t, &ProviderError{}, err)
|
||||||
|
|
||||||
|
providerErr := err.(*ProviderError)
|
||||||
|
assert.Equal(t, "PROVIDER_NOT_FOUND", providerErr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetProviderUnhealthy(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
mockProvider := NewMockProvider("test-provider")
|
||||||
|
|
||||||
|
// Add provider but mark as unhealthy
|
||||||
|
factory.providers["test"] = mockProvider
|
||||||
|
factory.healthChecks["test"] = false
|
||||||
|
factory.lastHealthCheck["test"] = time.Now()
|
||||||
|
|
||||||
|
_, err := factory.GetProvider("test")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.IsType(t, &ProviderError{}, err)
|
||||||
|
|
||||||
|
providerErr := err.(*ProviderError)
|
||||||
|
assert.Equal(t, "PROVIDER_UNAVAILABLE", providerErr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactorySetRoleMapping(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
mapping := RoleModelMapping{
|
||||||
|
DefaultProvider: "test",
|
||||||
|
FallbackProvider: "backup",
|
||||||
|
Roles: map[string]RoleConfig{
|
||||||
|
"developer": {
|
||||||
|
Provider: "test",
|
||||||
|
Model: "dev-model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
factory.SetRoleMapping(mapping)
|
||||||
|
|
||||||
|
assert.Equal(t, mapping, factory.roleMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetProviderForRole(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Set up providers
|
||||||
|
devProvider := NewMockProvider("dev-provider")
|
||||||
|
backupProvider := NewMockProvider("backup-provider")
|
||||||
|
|
||||||
|
factory.providers["dev"] = devProvider
|
||||||
|
factory.providers["backup"] = backupProvider
|
||||||
|
factory.healthChecks["dev"] = true
|
||||||
|
factory.healthChecks["backup"] = true
|
||||||
|
factory.lastHealthCheck["dev"] = time.Now()
|
||||||
|
factory.lastHealthCheck["backup"] = time.Now()
|
||||||
|
|
||||||
|
factory.configs["dev"] = ProviderConfig{
|
||||||
|
Type: "mock",
|
||||||
|
DefaultModel: "dev-model",
|
||||||
|
Temperature: 0.7,
|
||||||
|
}
|
||||||
|
|
||||||
|
factory.configs["backup"] = ProviderConfig{
|
||||||
|
Type: "mock",
|
||||||
|
DefaultModel: "backup-model",
|
||||||
|
Temperature: 0.8,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up role mapping
|
||||||
|
mapping := RoleModelMapping{
|
||||||
|
DefaultProvider: "backup",
|
||||||
|
FallbackProvider: "backup",
|
||||||
|
Roles: map[string]RoleConfig{
|
||||||
|
"developer": {
|
||||||
|
Provider: "dev",
|
||||||
|
Model: "custom-dev-model",
|
||||||
|
Temperature: 0.3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
factory.SetRoleMapping(mapping)
|
||||||
|
|
||||||
|
provider, config, err := factory.GetProviderForRole("developer")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, devProvider, provider)
|
||||||
|
assert.Equal(t, "custom-dev-model", config.DefaultModel)
|
||||||
|
assert.Equal(t, float32(0.3), config.Temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetProviderForRoleWithFallback(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Set up only backup provider (primary is missing)
|
||||||
|
backupProvider := NewMockProvider("backup-provider")
|
||||||
|
factory.providers["backup"] = backupProvider
|
||||||
|
factory.healthChecks["backup"] = true
|
||||||
|
factory.lastHealthCheck["backup"] = time.Now()
|
||||||
|
factory.configs["backup"] = ProviderConfig{Type: "mock", DefaultModel: "backup-model"}
|
||||||
|
|
||||||
|
// Set up role mapping with primary provider that doesn't exist
|
||||||
|
mapping := RoleModelMapping{
|
||||||
|
DefaultProvider: "backup",
|
||||||
|
FallbackProvider: "backup",
|
||||||
|
Roles: map[string]RoleConfig{
|
||||||
|
"developer": {
|
||||||
|
Provider: "nonexistent",
|
||||||
|
FallbackProvider: "backup",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
factory.SetRoleMapping(mapping)
|
||||||
|
|
||||||
|
provider, config, err := factory.GetProviderForRole("developer")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, backupProvider, provider)
|
||||||
|
assert.Equal(t, "backup-model", config.DefaultModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetProviderForRoleNotFound(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// No providers registered and no default
|
||||||
|
mapping := RoleModelMapping{
|
||||||
|
Roles: make(map[string]RoleConfig),
|
||||||
|
}
|
||||||
|
factory.SetRoleMapping(mapping)
|
||||||
|
|
||||||
|
_, _, err := factory.GetProviderForRole("nonexistent")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.IsType(t, &ProviderError{}, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetProviderForTask(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Set up a provider that supports a specific model
|
||||||
|
mockProvider := NewMockProvider("test-provider")
|
||||||
|
mockProvider.capabilities.SupportedModels = []string{"specific-model", "another-model"}
|
||||||
|
|
||||||
|
factory.providers["test"] = mockProvider
|
||||||
|
factory.healthChecks["test"] = true
|
||||||
|
factory.lastHealthCheck["test"] = time.Now()
|
||||||
|
factory.configs["test"] = ProviderConfig{Type: "mock", DefaultModel: "default-model"}
|
||||||
|
|
||||||
|
request := &TaskRequest{
|
||||||
|
TaskID: "test-123",
|
||||||
|
AgentRole: "developer",
|
||||||
|
ModelName: "specific-model", // Request specific model
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, config, err := factory.GetProviderForTask(request)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, mockProvider, provider)
|
||||||
|
assert.Equal(t, "specific-model", config.DefaultModel) // Should override default
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetProviderForTaskModelNotSupported(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
mockProvider := NewMockProvider("test-provider")
|
||||||
|
mockProvider.capabilities.SupportedModels = []string{"model-1", "model-2"}
|
||||||
|
|
||||||
|
factory.providers["test"] = mockProvider
|
||||||
|
factory.healthChecks["test"] = true
|
||||||
|
factory.lastHealthCheck["test"] = time.Now()
|
||||||
|
|
||||||
|
request := &TaskRequest{
|
||||||
|
TaskID: "test-123",
|
||||||
|
AgentRole: "developer",
|
||||||
|
ModelName: "unsupported-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := factory.GetProviderForTask(request)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.IsType(t, &ProviderError{}, err)
|
||||||
|
|
||||||
|
providerErr := err.(*ProviderError)
|
||||||
|
assert.Equal(t, "MODEL_NOT_SUPPORTED", providerErr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryListProviders(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Add some mock providers
|
||||||
|
factory.providers["provider1"] = NewMockProvider("provider1")
|
||||||
|
factory.providers["provider2"] = NewMockProvider("provider2")
|
||||||
|
factory.providers["provider3"] = NewMockProvider("provider3")
|
||||||
|
|
||||||
|
providers := factory.ListProviders()
|
||||||
|
|
||||||
|
assert.Len(t, providers, 3)
|
||||||
|
assert.Contains(t, providers, "provider1")
|
||||||
|
assert.Contains(t, providers, "provider2")
|
||||||
|
assert.Contains(t, providers, "provider3")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryListHealthyProviders(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Add providers with different health states
|
||||||
|
factory.providers["healthy1"] = NewMockProvider("healthy1")
|
||||||
|
factory.providers["healthy2"] = NewMockProvider("healthy2")
|
||||||
|
factory.providers["unhealthy"] = NewMockProvider("unhealthy")
|
||||||
|
|
||||||
|
factory.healthChecks["healthy1"] = true
|
||||||
|
factory.healthChecks["healthy2"] = true
|
||||||
|
factory.healthChecks["unhealthy"] = false
|
||||||
|
|
||||||
|
factory.lastHealthCheck["healthy1"] = time.Now()
|
||||||
|
factory.lastHealthCheck["healthy2"] = time.Now()
|
||||||
|
factory.lastHealthCheck["unhealthy"] = time.Now()
|
||||||
|
|
||||||
|
healthyProviders := factory.ListHealthyProviders()
|
||||||
|
|
||||||
|
assert.Len(t, healthyProviders, 2)
|
||||||
|
assert.Contains(t, healthyProviders, "healthy1")
|
||||||
|
assert.Contains(t, healthyProviders, "healthy2")
|
||||||
|
assert.NotContains(t, healthyProviders, "unhealthy")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetProviderInfo(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
mock1 := NewMockProvider("mock1")
|
||||||
|
mock2 := NewMockProvider("mock2")
|
||||||
|
|
||||||
|
factory.providers["provider1"] = mock1
|
||||||
|
factory.providers["provider2"] = mock2
|
||||||
|
|
||||||
|
info := factory.GetProviderInfo()
|
||||||
|
|
||||||
|
assert.Len(t, info, 2)
|
||||||
|
assert.Contains(t, info, "provider1")
|
||||||
|
assert.Contains(t, info, "provider2")
|
||||||
|
|
||||||
|
// Verify that the name is overridden with the registered name
|
||||||
|
assert.Equal(t, "provider1", info["provider1"].Name)
|
||||||
|
assert.Equal(t, "provider2", info["provider2"].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryHealthCheck(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Add a healthy and an unhealthy provider
|
||||||
|
healthyProvider := NewMockProvider("healthy")
|
||||||
|
unhealthyProvider := NewMockProvider("unhealthy")
|
||||||
|
unhealthyProvider.shouldFail = true
|
||||||
|
|
||||||
|
factory.providers["healthy"] = healthyProvider
|
||||||
|
factory.providers["unhealthy"] = unhealthyProvider
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
results := factory.HealthCheck(ctx)
|
||||||
|
|
||||||
|
assert.Len(t, results, 2)
|
||||||
|
assert.NoError(t, results["healthy"])
|
||||||
|
assert.Error(t, results["unhealthy"])
|
||||||
|
|
||||||
|
// Verify health states were updated
|
||||||
|
assert.True(t, factory.healthChecks["healthy"])
|
||||||
|
assert.False(t, factory.healthChecks["unhealthy"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryGetHealthStatus(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
mockProvider := NewMockProvider("test")
|
||||||
|
factory.providers["test"] = mockProvider
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
factory.healthChecks["test"] = true
|
||||||
|
factory.lastHealthCheck["test"] = now
|
||||||
|
|
||||||
|
status := factory.GetHealthStatus()
|
||||||
|
|
||||||
|
assert.Len(t, status, 1)
|
||||||
|
assert.Contains(t, status, "test")
|
||||||
|
|
||||||
|
testStatus := status["test"]
|
||||||
|
assert.Equal(t, "test", testStatus.Name)
|
||||||
|
assert.True(t, testStatus.Healthy)
|
||||||
|
assert.Equal(t, now, testStatus.LastCheck)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryIsProviderHealthy(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
// Test healthy provider
|
||||||
|
factory.healthChecks["healthy"] = true
|
||||||
|
factory.lastHealthCheck["healthy"] = time.Now()
|
||||||
|
assert.True(t, factory.isProviderHealthy("healthy"))
|
||||||
|
|
||||||
|
// Test unhealthy provider
|
||||||
|
factory.healthChecks["unhealthy"] = false
|
||||||
|
factory.lastHealthCheck["unhealthy"] = time.Now()
|
||||||
|
assert.False(t, factory.isProviderHealthy("unhealthy"))
|
||||||
|
|
||||||
|
// Test provider with old health check (should be considered unhealthy)
|
||||||
|
factory.healthChecks["stale"] = true
|
||||||
|
factory.lastHealthCheck["stale"] = time.Now().Add(-15 * time.Minute)
|
||||||
|
assert.False(t, factory.isProviderHealthy("stale"))
|
||||||
|
|
||||||
|
// Test non-existent provider
|
||||||
|
assert.False(t, factory.isProviderHealthy("nonexistent"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryMergeRoleConfig(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
baseConfig := ProviderConfig{
|
||||||
|
Type: "test",
|
||||||
|
DefaultModel: "base-model",
|
||||||
|
Temperature: 0.7,
|
||||||
|
MaxTokens: 4096,
|
||||||
|
EnableTools: false,
|
||||||
|
EnableMCP: false,
|
||||||
|
MCPServers: []string{"base-server"},
|
||||||
|
}
|
||||||
|
|
||||||
|
roleConfig := RoleConfig{
|
||||||
|
Model: "role-model",
|
||||||
|
Temperature: 0.3,
|
||||||
|
MaxTokens: 8192,
|
||||||
|
EnableTools: true,
|
||||||
|
EnableMCP: true,
|
||||||
|
MCPServers: []string{"role-server"},
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := factory.mergeRoleConfig(baseConfig, roleConfig)
|
||||||
|
|
||||||
|
assert.Equal(t, "role-model", merged.DefaultModel)
|
||||||
|
assert.Equal(t, float32(0.3), merged.Temperature)
|
||||||
|
assert.Equal(t, 8192, merged.MaxTokens)
|
||||||
|
assert.True(t, merged.EnableTools)
|
||||||
|
assert.True(t, merged.EnableMCP)
|
||||||
|
assert.Len(t, merged.MCPServers, 2) // Should be merged
|
||||||
|
assert.Contains(t, merged.MCPServers, "base-server")
|
||||||
|
assert.Contains(t, merged.MCPServers, "role-server")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultProviderFactory(t *testing.T) {
|
||||||
|
factory := DefaultProviderFactory()
|
||||||
|
|
||||||
|
// Should have at least the default ollama provider
|
||||||
|
providers := factory.ListProviders()
|
||||||
|
assert.Contains(t, providers, "ollama")
|
||||||
|
|
||||||
|
// Should have role mappings configured
|
||||||
|
assert.NotEmpty(t, factory.roleMapping.Roles)
|
||||||
|
assert.Contains(t, factory.roleMapping.Roles, "developer")
|
||||||
|
assert.Contains(t, factory.roleMapping.Roles, "reviewer")
|
||||||
|
|
||||||
|
// Test getting provider for developer role
|
||||||
|
_, config, err := factory.GetProviderForRole("developer")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "codellama:13b", config.DefaultModel)
|
||||||
|
assert.Equal(t, float32(0.3), config.Temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactoryCreateProvider(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config ProviderConfig
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ollama provider",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "ollama",
|
||||||
|
Endpoint: "http://localhost:11434",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "openai provider",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "openai",
|
||||||
|
Endpoint: "https://api.openai.com/v1",
|
||||||
|
APIKey: "test-key",
|
||||||
|
DefaultModel: "gpt-4",
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "resetdata provider",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "resetdata",
|
||||||
|
Endpoint: "https://api.resetdata.ai",
|
||||||
|
APIKey: "test-key",
|
||||||
|
DefaultModel: "llama2",
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown provider",
|
||||||
|
config: ProviderConfig{
|
||||||
|
Type: "unknown",
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
provider, err := factory.CreateProvider(tt.config)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, provider)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, provider)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
433
pkg/ai/ollama.go
Normal file
433
pkg/ai/ollama.go
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OllamaProvider implements ModelProvider for local Ollama instances
|
||||||
|
type OllamaProvider struct {
|
||||||
|
config ProviderConfig
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaRequest represents a request to Ollama API
|
||||||
|
type OllamaRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Messages []OllamaMessage `json:"messages,omitempty"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Format string `json:"format,omitempty"`
|
||||||
|
Options map[string]interface{} `json:"options,omitempty"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
Template string `json:"template,omitempty"`
|
||||||
|
Context []int `json:"context,omitempty"`
|
||||||
|
Raw bool `json:"raw,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaMessage represents a message in the Ollama chat format
|
||||||
|
type OllamaMessage struct {
|
||||||
|
Role string `json:"role"` // system, user, assistant
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaResponse represents a response from Ollama API
|
||||||
|
type OllamaResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Message OllamaMessage `json:"message,omitempty"`
|
||||||
|
Response string `json:"response,omitempty"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
Context []int `json:"context,omitempty"`
|
||||||
|
TotalDuration int64 `json:"total_duration,omitempty"`
|
||||||
|
LoadDuration int64 `json:"load_duration,omitempty"`
|
||||||
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
|
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
|
||||||
|
EvalCount int `json:"eval_count,omitempty"`
|
||||||
|
EvalDuration int64 `json:"eval_duration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaModelsResponse represents the response from /api/tags endpoint
|
||||||
|
type OllamaModelsResponse struct {
|
||||||
|
Models []OllamaModel `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaModel represents a model in Ollama
|
||||||
|
type OllamaModel struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
ModifiedAt time.Time `json:"modified_at"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
Digest string `json:"digest"`
|
||||||
|
Details OllamaModelDetails `json:"details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaModelDetails provides detailed model information
|
||||||
|
type OllamaModelDetails struct {
|
||||||
|
Format string `json:"format"`
|
||||||
|
Family string `json:"family"`
|
||||||
|
Families []string `json:"families,omitempty"`
|
||||||
|
ParameterSize string `json:"parameter_size"`
|
||||||
|
QuantizationLevel string `json:"quantization_level"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOllamaProvider creates a new Ollama provider instance
|
||||||
|
func NewOllamaProvider(config ProviderConfig) *OllamaProvider {
|
||||||
|
timeout := config.Timeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = 300 * time.Second // 5 minutes default for task execution
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OllamaProvider{
|
||||||
|
config: config,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteTask implements the ModelProvider interface for Ollama
|
||||||
|
func (p *OllamaProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Build the prompt from task context
|
||||||
|
prompt, err := p.buildTaskPrompt(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to build prompt: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare Ollama request
|
||||||
|
ollamaReq := OllamaRequest{
|
||||||
|
Model: p.selectModel(request.ModelName),
|
||||||
|
Stream: false,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": p.getTemperature(request.Temperature),
|
||||||
|
"num_predict": p.getMaxTokens(request.MaxTokens),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use chat format for better conversation handling
|
||||||
|
ollamaReq.Messages = []OllamaMessage{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: p.getSystemPrompt(request),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: prompt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the request
|
||||||
|
response, err := p.makeRequest(ctx, "/api/chat", ollamaReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
endTime := time.Now()
|
||||||
|
|
||||||
|
// Parse response and extract actions
|
||||||
|
actions, artifacts := p.parseResponseForActions(response.Message.Content, request)
|
||||||
|
|
||||||
|
return &TaskResponse{
|
||||||
|
Success: true,
|
||||||
|
TaskID: request.TaskID,
|
||||||
|
AgentID: request.AgentID,
|
||||||
|
ModelUsed: response.Model,
|
||||||
|
Provider: "ollama",
|
||||||
|
Response: response.Message.Content,
|
||||||
|
Actions: actions,
|
||||||
|
Artifacts: artifacts,
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Duration: endTime.Sub(startTime),
|
||||||
|
TokensUsed: TokenUsage{
|
||||||
|
PromptTokens: response.PromptEvalCount,
|
||||||
|
CompletionTokens: response.EvalCount,
|
||||||
|
TotalTokens: response.PromptEvalCount + response.EvalCount,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCapabilities returns Ollama provider capabilities
|
||||||
|
func (p *OllamaProvider) GetCapabilities() ProviderCapabilities {
|
||||||
|
return ProviderCapabilities{
|
||||||
|
SupportsMCP: p.config.EnableMCP,
|
||||||
|
SupportsTools: p.config.EnableTools,
|
||||||
|
SupportsStreaming: true,
|
||||||
|
SupportsFunctions: false, // Ollama doesn't support function calling natively
|
||||||
|
MaxTokens: p.config.MaxTokens,
|
||||||
|
SupportedModels: p.getSupportedModels(),
|
||||||
|
SupportsImages: true, // Many Ollama models support images
|
||||||
|
SupportsFiles: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConfig validates the Ollama provider configuration
|
||||||
|
func (p *OllamaProvider) ValidateConfig() error {
|
||||||
|
if p.config.Endpoint == "" {
|
||||||
|
return NewProviderError(ErrInvalidConfiguration, "endpoint is required for Ollama provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.DefaultModel == "" {
|
||||||
|
return NewProviderError(ErrInvalidConfiguration, "default_model is required for Ollama provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test connection to Ollama
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := p.testConnection(ctx); err != nil {
|
||||||
|
return NewProviderError(ErrProviderUnavailable, fmt.Sprintf("failed to connect to Ollama: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderInfo returns information about the Ollama provider
|
||||||
|
func (p *OllamaProvider) GetProviderInfo() ProviderInfo {
|
||||||
|
return ProviderInfo{
|
||||||
|
Name: "Ollama",
|
||||||
|
Type: "ollama",
|
||||||
|
Version: "1.0.0",
|
||||||
|
Endpoint: p.config.Endpoint,
|
||||||
|
DefaultModel: p.config.DefaultModel,
|
||||||
|
RequiresAPIKey: false,
|
||||||
|
RateLimit: 0, // No rate limit for local Ollama
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildTaskPrompt constructs a comprehensive prompt for task execution
|
||||||
|
func (p *OllamaProvider) buildTaskPrompt(request *TaskRequest) (string, error) {
|
||||||
|
var prompt strings.Builder
|
||||||
|
|
||||||
|
prompt.WriteString(fmt.Sprintf("You are a %s agent working on a task in the repository: %s\n\n",
|
||||||
|
request.AgentRole, request.Repository))
|
||||||
|
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Task Title:** %s\n", request.TaskTitle))
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Task Description:**\n%s\n\n", request.TaskDescription))
|
||||||
|
|
||||||
|
if len(request.TaskLabels) > 0 {
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Labels:** %s\n", strings.Join(request.TaskLabels, ", ")))
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Priority:** %d/10\n", request.Priority))
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Complexity:** %d/10\n\n", request.Complexity))
|
||||||
|
|
||||||
|
if request.WorkingDirectory != "" {
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Working Directory:** %s\n", request.WorkingDirectory))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(request.RepositoryFiles) > 0 {
|
||||||
|
prompt.WriteString("**Relevant Files:**\n")
|
||||||
|
for _, file := range request.RepositoryFiles {
|
||||||
|
prompt.WriteString(fmt.Sprintf("- %s\n", file))
|
||||||
|
}
|
||||||
|
prompt.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add role-specific instructions
|
||||||
|
prompt.WriteString(p.getRoleSpecificInstructions(request.AgentRole))
|
||||||
|
|
||||||
|
prompt.WriteString("\nPlease analyze the task and provide a detailed plan for implementation. ")
|
||||||
|
prompt.WriteString("If you need to make changes to files, describe the specific changes needed. ")
|
||||||
|
prompt.WriteString("If you need to run commands, specify the exact commands to execute.")
|
||||||
|
|
||||||
|
return prompt.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRoleSpecificInstructions returns instructions specific to the agent role
|
||||||
|
func (p *OllamaProvider) getRoleSpecificInstructions(role string) string {
|
||||||
|
switch strings.ToLower(role) {
|
||||||
|
case "developer":
|
||||||
|
return `As a developer agent, focus on:
|
||||||
|
- Implementing code changes to address the task requirements
|
||||||
|
- Following best practices for the programming language
|
||||||
|
- Writing clean, maintainable, and well-documented code
|
||||||
|
- Ensuring proper error handling and edge case coverage
|
||||||
|
- Running appropriate tests to validate your changes`
|
||||||
|
|
||||||
|
case "reviewer":
|
||||||
|
return `As a reviewer agent, focus on:
|
||||||
|
- Analyzing code quality and adherence to best practices
|
||||||
|
- Identifying potential bugs, security issues, or performance problems
|
||||||
|
- Suggesting improvements for maintainability and readability
|
||||||
|
- Validating test coverage and test quality
|
||||||
|
- Ensuring documentation is accurate and complete`
|
||||||
|
|
||||||
|
case "architect":
|
||||||
|
return `As an architect agent, focus on:
|
||||||
|
- Designing system architecture and component interactions
|
||||||
|
- Making technology stack and framework decisions
|
||||||
|
- Defining interfaces and API contracts
|
||||||
|
- Considering scalability, performance, and security implications
|
||||||
|
- Creating architectural documentation and diagrams`
|
||||||
|
|
||||||
|
case "tester":
|
||||||
|
return `As a tester agent, focus on:
|
||||||
|
- Creating comprehensive test cases and test plans
|
||||||
|
- Implementing unit, integration, and end-to-end tests
|
||||||
|
- Identifying edge cases and potential failure scenarios
|
||||||
|
- Setting up test automation and CI/CD integration
|
||||||
|
- Validating functionality against requirements`
|
||||||
|
|
||||||
|
default:
|
||||||
|
return `As an AI agent, focus on:
|
||||||
|
- Understanding the task requirements thoroughly
|
||||||
|
- Providing a clear and actionable implementation plan
|
||||||
|
- Following software development best practices
|
||||||
|
- Ensuring your work is well-documented and maintainable`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectModel chooses the appropriate model for the request
|
||||||
|
func (p *OllamaProvider) selectModel(requestedModel string) string {
|
||||||
|
if requestedModel != "" {
|
||||||
|
return requestedModel
|
||||||
|
}
|
||||||
|
return p.config.DefaultModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTemperature returns the temperature setting for the request
|
||||||
|
func (p *OllamaProvider) getTemperature(requestTemp float32) float32 {
|
||||||
|
if requestTemp > 0 {
|
||||||
|
return requestTemp
|
||||||
|
}
|
||||||
|
if p.config.Temperature > 0 {
|
||||||
|
return p.config.Temperature
|
||||||
|
}
|
||||||
|
return 0.7 // Default temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMaxTokens returns the max tokens setting for the request
|
||||||
|
func (p *OllamaProvider) getMaxTokens(requestTokens int) int {
|
||||||
|
if requestTokens > 0 {
|
||||||
|
return requestTokens
|
||||||
|
}
|
||||||
|
if p.config.MaxTokens > 0 {
|
||||||
|
return p.config.MaxTokens
|
||||||
|
}
|
||||||
|
return 4096 // Default max tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemPrompt constructs the system prompt
|
||||||
|
func (p *OllamaProvider) getSystemPrompt(request *TaskRequest) string {
|
||||||
|
if request.SystemPrompt != "" {
|
||||||
|
return request.SystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(`You are an AI assistant specializing in software development tasks.
|
||||||
|
You are currently working as a %s agent in the CHORUS autonomous agent system.
|
||||||
|
|
||||||
|
Your capabilities include:
|
||||||
|
- Analyzing code and repository structures
|
||||||
|
- Implementing features and fixing bugs
|
||||||
|
- Writing and reviewing code in multiple programming languages
|
||||||
|
- Creating tests and documentation
|
||||||
|
- Following software development best practices
|
||||||
|
|
||||||
|
Always provide detailed, actionable responses with specific implementation steps.`, request.AgentRole)
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeRequest makes an HTTP request to the Ollama API
|
||||||
|
func (p *OllamaProvider) makeRequest(ctx context.Context, endpoint string, request interface{}) (*OllamaResponse, error) {
|
||||||
|
requestJSON, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to marshal request: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
url := strings.TrimSuffix(p.config.Endpoint, "/") + endpoint
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(requestJSON))
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to create request: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Add custom headers if configured
|
||||||
|
for key, value := range p.config.CustomHeaders {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("request failed: %v", err))
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to read response: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed,
|
||||||
|
fmt.Sprintf("API request failed with status %d: %s", resp.StatusCode, string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var ollamaResp OllamaResponse
|
||||||
|
if err := json.Unmarshal(body, &ollamaResp); err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to parse response: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ollamaResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// testConnection tests the connection to Ollama
|
||||||
|
func (p *OllamaProvider) testConnection(ctx context.Context) error {
|
||||||
|
url := strings.TrimSuffix(p.config.Endpoint, "/") + "/api/tags"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSupportedModels returns a list of supported models (would normally query Ollama)
|
||||||
|
func (p *OllamaProvider) getSupportedModels() []string {
|
||||||
|
// In a real implementation, this would query the /api/tags endpoint
|
||||||
|
return []string{
|
||||||
|
"llama3.1:8b", "llama3.1:13b", "llama3.1:70b",
|
||||||
|
"codellama:7b", "codellama:13b", "codellama:34b",
|
||||||
|
"mistral:7b", "mixtral:8x7b",
|
||||||
|
"qwen2:7b", "gemma:7b",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseResponseForActions extracts actions and artifacts from the response
|
||||||
|
func (p *OllamaProvider) parseResponseForActions(response string, request *TaskRequest) ([]TaskAction, []Artifact) {
|
||||||
|
var actions []TaskAction
|
||||||
|
var artifacts []Artifact
|
||||||
|
|
||||||
|
// This is a simplified implementation - in reality, you'd parse the response
|
||||||
|
// to extract specific actions like file changes, commands to run, etc.
|
||||||
|
|
||||||
|
// For now, just create a basic action indicating task analysis
|
||||||
|
action := TaskAction{
|
||||||
|
Type: "task_analysis",
|
||||||
|
Target: request.TaskTitle,
|
||||||
|
Content: response,
|
||||||
|
Result: "Task analyzed successfully",
|
||||||
|
Success: true,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"agent_role": request.AgentRole,
|
||||||
|
"repository": request.Repository,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actions = append(actions, action)
|
||||||
|
|
||||||
|
return actions, artifacts
|
||||||
|
}
|
||||||
518
pkg/ai/openai.go
Normal file
518
pkg/ai/openai.go
Normal file
@@ -0,0 +1,518 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAIProvider implements ModelProvider for OpenAI API
|
||||||
|
type OpenAIProvider struct {
|
||||||
|
config ProviderConfig
|
||||||
|
client *openai.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOpenAIProvider creates a new OpenAI provider instance
|
||||||
|
func NewOpenAIProvider(config ProviderConfig) *OpenAIProvider {
|
||||||
|
client := openai.NewClient(config.APIKey)
|
||||||
|
|
||||||
|
// Use custom endpoint if specified
|
||||||
|
if config.Endpoint != "" && config.Endpoint != "https://api.openai.com/v1" {
|
||||||
|
clientConfig := openai.DefaultConfig(config.APIKey)
|
||||||
|
clientConfig.BaseURL = config.Endpoint
|
||||||
|
client = openai.NewClientWithConfig(clientConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OpenAIProvider{
|
||||||
|
config: config,
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteTask implements the ModelProvider interface for OpenAI
|
||||||
|
func (p *OpenAIProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Build messages for the chat completion
|
||||||
|
messages, err := p.buildChatMessages(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to build messages: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the chat completion request
|
||||||
|
chatReq := openai.ChatCompletionRequest{
|
||||||
|
Model: p.selectModel(request.ModelName),
|
||||||
|
Messages: messages,
|
||||||
|
Temperature: p.getTemperature(request.Temperature),
|
||||||
|
MaxTokens: p.getMaxTokens(request.MaxTokens),
|
||||||
|
Stream: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tools if enabled and supported
|
||||||
|
if p.config.EnableTools && request.EnableTools {
|
||||||
|
chatReq.Tools = p.getToolDefinitions(request)
|
||||||
|
chatReq.ToolChoice = "auto"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the chat completion
|
||||||
|
resp, err := p.client.CreateChatCompletion(ctx, chatReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, p.handleOpenAIError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
endTime := time.Now()
|
||||||
|
|
||||||
|
// Process the response
|
||||||
|
if len(resp.Choices) == 0 {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, "no response choices returned from OpenAI")
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := resp.Choices[0]
|
||||||
|
responseText := choice.Message.Content
|
||||||
|
|
||||||
|
// Process tool calls if present
|
||||||
|
var actions []TaskAction
|
||||||
|
var artifacts []Artifact
|
||||||
|
|
||||||
|
if len(choice.Message.ToolCalls) > 0 {
|
||||||
|
toolActions, toolArtifacts := p.processToolCalls(choice.Message.ToolCalls, request)
|
||||||
|
actions = append(actions, toolActions...)
|
||||||
|
artifacts = append(artifacts, toolArtifacts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse response for additional actions
|
||||||
|
responseActions, responseArtifacts := p.parseResponseForActions(responseText, request)
|
||||||
|
actions = append(actions, responseActions...)
|
||||||
|
artifacts = append(artifacts, responseArtifacts...)
|
||||||
|
|
||||||
|
return &TaskResponse{
|
||||||
|
Success: true,
|
||||||
|
TaskID: request.TaskID,
|
||||||
|
AgentID: request.AgentID,
|
||||||
|
ModelUsed: resp.Model,
|
||||||
|
Provider: "openai",
|
||||||
|
Response: responseText,
|
||||||
|
Actions: actions,
|
||||||
|
Artifacts: artifacts,
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Duration: endTime.Sub(startTime),
|
||||||
|
TokensUsed: TokenUsage{
|
||||||
|
PromptTokens: resp.Usage.PromptTokens,
|
||||||
|
CompletionTokens: resp.Usage.CompletionTokens,
|
||||||
|
TotalTokens: resp.Usage.TotalTokens,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCapabilities returns OpenAI provider capabilities
|
||||||
|
func (p *OpenAIProvider) GetCapabilities() ProviderCapabilities {
|
||||||
|
return ProviderCapabilities{
|
||||||
|
SupportsMCP: p.config.EnableMCP,
|
||||||
|
SupportsTools: true, // OpenAI supports function calling
|
||||||
|
SupportsStreaming: true,
|
||||||
|
SupportsFunctions: true,
|
||||||
|
MaxTokens: p.getModelMaxTokens(p.config.DefaultModel),
|
||||||
|
SupportedModels: p.getSupportedModels(),
|
||||||
|
SupportsImages: p.modelSupportsImages(p.config.DefaultModel),
|
||||||
|
SupportsFiles: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConfig validates the OpenAI provider configuration
|
||||||
|
func (p *OpenAIProvider) ValidateConfig() error {
|
||||||
|
if p.config.APIKey == "" {
|
||||||
|
return NewProviderError(ErrAPIKeyRequired, "API key is required for OpenAI provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.DefaultModel == "" {
|
||||||
|
return NewProviderError(ErrInvalidConfiguration, "default_model is required for OpenAI provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the API connection with a minimal request
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := p.testConnection(ctx); err != nil {
|
||||||
|
return NewProviderError(ErrProviderUnavailable, fmt.Sprintf("failed to connect to OpenAI: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderInfo returns information about the OpenAI provider
|
||||||
|
func (p *OpenAIProvider) GetProviderInfo() ProviderInfo {
|
||||||
|
endpoint := p.config.Endpoint
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = "https://api.openai.com/v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProviderInfo{
|
||||||
|
Name: "OpenAI",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "1.0.0",
|
||||||
|
Endpoint: endpoint,
|
||||||
|
DefaultModel: p.config.DefaultModel,
|
||||||
|
RequiresAPIKey: true,
|
||||||
|
RateLimit: 10000, // Approximate RPM for paid accounts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildChatMessages constructs messages for the OpenAI chat completion
|
||||||
|
func (p *OpenAIProvider) buildChatMessages(request *TaskRequest) ([]openai.ChatCompletionMessage, error) {
|
||||||
|
var messages []openai.ChatCompletionMessage
|
||||||
|
|
||||||
|
// System message
|
||||||
|
systemPrompt := p.getSystemPrompt(request)
|
||||||
|
if systemPrompt != "" {
|
||||||
|
messages = append(messages, openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleSystem,
|
||||||
|
Content: systemPrompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// User message with task details
|
||||||
|
userPrompt, err := p.buildTaskPrompt(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: userPrompt,
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildTaskPrompt constructs a comprehensive prompt for task execution
|
||||||
|
func (p *OpenAIProvider) buildTaskPrompt(request *TaskRequest) (string, error) {
|
||||||
|
var prompt strings.Builder
|
||||||
|
|
||||||
|
prompt.WriteString(fmt.Sprintf("You are working as a %s agent on the following task:\n\n",
|
||||||
|
request.AgentRole))
|
||||||
|
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Repository:** %s\n", request.Repository))
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Task:** %s\n", request.TaskTitle))
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Description:**\n%s\n\n", request.TaskDescription))
|
||||||
|
|
||||||
|
if len(request.TaskLabels) > 0 {
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Labels:** %s\n", strings.Join(request.TaskLabels, ", ")))
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Priority:** %d/10 | **Complexity:** %d/10\n\n",
|
||||||
|
request.Priority, request.Complexity))
|
||||||
|
|
||||||
|
if request.WorkingDirectory != "" {
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Working Directory:** %s\n", request.WorkingDirectory))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(request.RepositoryFiles) > 0 {
|
||||||
|
prompt.WriteString("**Relevant Files:**\n")
|
||||||
|
for _, file := range request.RepositoryFiles {
|
||||||
|
prompt.WriteString(fmt.Sprintf("- %s\n", file))
|
||||||
|
}
|
||||||
|
prompt.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add role-specific guidance
|
||||||
|
prompt.WriteString(p.getRoleSpecificGuidance(request.AgentRole))
|
||||||
|
|
||||||
|
prompt.WriteString("\nAnalyze this task and provide a detailed implementation plan. ")
|
||||||
|
if request.EnableTools {
|
||||||
|
prompt.WriteString("Use the available tools to make concrete changes or gather information as needed. ")
|
||||||
|
}
|
||||||
|
prompt.WriteString("Be specific about what needs to be done and how to accomplish it.")
|
||||||
|
|
||||||
|
return prompt.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRoleSpecificGuidance returns guidance specific to the agent role
|
||||||
|
func (p *OpenAIProvider) getRoleSpecificGuidance(role string) string {
|
||||||
|
switch strings.ToLower(role) {
|
||||||
|
case "developer":
|
||||||
|
return `**Developer Guidelines:**
|
||||||
|
- Write clean, maintainable, and well-documented code
|
||||||
|
- Follow language-specific best practices and conventions
|
||||||
|
- Implement proper error handling and validation
|
||||||
|
- Create or update tests to cover your changes
|
||||||
|
- Consider performance and security implications`
|
||||||
|
|
||||||
|
case "reviewer":
|
||||||
|
return `**Code Review Guidelines:**
|
||||||
|
- Analyze code quality, readability, and maintainability
|
||||||
|
- Check for bugs, security vulnerabilities, and performance issues
|
||||||
|
- Verify test coverage and quality
|
||||||
|
- Ensure documentation is accurate and complete
|
||||||
|
- Suggest improvements and alternatives`
|
||||||
|
|
||||||
|
case "architect":
|
||||||
|
return `**Architecture Guidelines:**
|
||||||
|
- Design scalable and maintainable system architecture
|
||||||
|
- Make informed technology and framework decisions
|
||||||
|
- Define clear interfaces and API contracts
|
||||||
|
- Consider security, performance, and scalability requirements
|
||||||
|
- Document architectural decisions and rationale`
|
||||||
|
|
||||||
|
case "tester":
|
||||||
|
return `**Testing Guidelines:**
|
||||||
|
- Create comprehensive test plans and test cases
|
||||||
|
- Implement unit, integration, and end-to-end tests
|
||||||
|
- Identify edge cases and potential failure scenarios
|
||||||
|
- Set up test automation and continuous integration
|
||||||
|
- Validate functionality against requirements`
|
||||||
|
|
||||||
|
default:
|
||||||
|
return `**General Guidelines:**
|
||||||
|
- Understand requirements thoroughly before implementation
|
||||||
|
- Follow software development best practices
|
||||||
|
- Provide clear documentation and explanations
|
||||||
|
- Consider maintainability and future extensibility`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getToolDefinitions returns tool definitions for OpenAI function calling
|
||||||
|
func (p *OpenAIProvider) getToolDefinitions(request *TaskRequest) []openai.Tool {
|
||||||
|
var tools []openai.Tool
|
||||||
|
|
||||||
|
// File operations tool
|
||||||
|
tools = append(tools, openai.Tool{
|
||||||
|
Type: openai.ToolTypeFunction,
|
||||||
|
Function: &openai.FunctionDefinition{
|
||||||
|
Name: "file_operation",
|
||||||
|
Description: "Create, read, update, or delete files in the repository",
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"operation": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"enum": []string{"create", "read", "update", "delete"},
|
||||||
|
"description": "The file operation to perform",
|
||||||
|
},
|
||||||
|
"path": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file path relative to the repository root",
|
||||||
|
},
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file content (for create/update operations)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"operation", "path"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Command execution tool
|
||||||
|
tools = append(tools, openai.Tool{
|
||||||
|
Type: openai.ToolTypeFunction,
|
||||||
|
Function: &openai.FunctionDefinition{
|
||||||
|
Name: "execute_command",
|
||||||
|
Description: "Execute shell commands in the repository working directory",
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"command": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The shell command to execute",
|
||||||
|
},
|
||||||
|
"working_dir": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Working directory for command execution (optional)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"command"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// processToolCalls handles OpenAI function calls
|
||||||
|
func (p *OpenAIProvider) processToolCalls(toolCalls []openai.ToolCall, request *TaskRequest) ([]TaskAction, []Artifact) {
|
||||||
|
var actions []TaskAction
|
||||||
|
var artifacts []Artifact
|
||||||
|
|
||||||
|
for _, toolCall := range toolCalls {
|
||||||
|
action := TaskAction{
|
||||||
|
Type: "function_call",
|
||||||
|
Target: toolCall.Function.Name,
|
||||||
|
Content: toolCall.Function.Arguments,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"tool_call_id": toolCall.ID,
|
||||||
|
"function": toolCall.Function.Name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// In a real implementation, you would actually execute these tool calls
|
||||||
|
// For now, just mark them as successful
|
||||||
|
action.Result = fmt.Sprintf("Function call %s processed", toolCall.Function.Name)
|
||||||
|
action.Success = true
|
||||||
|
|
||||||
|
actions = append(actions, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
return actions, artifacts
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectModel chooses the appropriate OpenAI model
|
||||||
|
func (p *OpenAIProvider) selectModel(requestedModel string) string {
|
||||||
|
if requestedModel != "" {
|
||||||
|
return requestedModel
|
||||||
|
}
|
||||||
|
return p.config.DefaultModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTemperature returns the temperature setting
|
||||||
|
func (p *OpenAIProvider) getTemperature(requestTemp float32) float32 {
|
||||||
|
if requestTemp > 0 {
|
||||||
|
return requestTemp
|
||||||
|
}
|
||||||
|
if p.config.Temperature > 0 {
|
||||||
|
return p.config.Temperature
|
||||||
|
}
|
||||||
|
return 0.7 // Default temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMaxTokens returns the max tokens setting
|
||||||
|
func (p *OpenAIProvider) getMaxTokens(requestTokens int) int {
|
||||||
|
if requestTokens > 0 {
|
||||||
|
return requestTokens
|
||||||
|
}
|
||||||
|
if p.config.MaxTokens > 0 {
|
||||||
|
return p.config.MaxTokens
|
||||||
|
}
|
||||||
|
return 4096 // Default max tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemPrompt constructs the system prompt
|
||||||
|
func (p *OpenAIProvider) getSystemPrompt(request *TaskRequest) string {
|
||||||
|
if request.SystemPrompt != "" {
|
||||||
|
return request.SystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(`You are an expert AI assistant specializing in software development.
|
||||||
|
You are currently operating as a %s agent in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your capabilities:
|
||||||
|
- Code analysis, implementation, and optimization
|
||||||
|
- Software architecture and design patterns
|
||||||
|
- Testing strategies and implementation
|
||||||
|
- Documentation and technical writing
|
||||||
|
- DevOps and deployment practices
|
||||||
|
|
||||||
|
Always provide thorough, actionable responses with specific implementation details.
|
||||||
|
When using tools, explain your reasoning and the expected outcomes.`, request.AgentRole)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getModelMaxTokens returns the maximum tokens for a specific model
|
||||||
|
func (p *OpenAIProvider) getModelMaxTokens(model string) int {
|
||||||
|
switch model {
|
||||||
|
case "gpt-4o", "gpt-4o-2024-05-13":
|
||||||
|
return 128000
|
||||||
|
case "gpt-4-turbo", "gpt-4-turbo-2024-04-09":
|
||||||
|
return 128000
|
||||||
|
case "gpt-4", "gpt-4-0613":
|
||||||
|
return 8192
|
||||||
|
case "gpt-3.5-turbo", "gpt-3.5-turbo-0125":
|
||||||
|
return 16385
|
||||||
|
default:
|
||||||
|
return 4096 // Conservative default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelSupportsImages checks if a model supports image inputs
|
||||||
|
func (p *OpenAIProvider) modelSupportsImages(model string) bool {
|
||||||
|
visionModels := []string{"gpt-4o", "gpt-4o-2024-05-13", "gpt-4-turbo", "gpt-4-vision-preview"}
|
||||||
|
for _, visionModel := range visionModels {
|
||||||
|
if strings.Contains(model, visionModel) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSupportedModels returns a list of supported OpenAI models
|
||||||
|
func (p *OpenAIProvider) getSupportedModels() []string {
|
||||||
|
return []string{
|
||||||
|
"gpt-4o", "gpt-4o-2024-05-13",
|
||||||
|
"gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||||
|
"gpt-4", "gpt-4-0613",
|
||||||
|
"gpt-3.5-turbo", "gpt-3.5-turbo-0125",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testConnection tests the OpenAI API connection
|
||||||
|
func (p *OpenAIProvider) testConnection(ctx context.Context) error {
|
||||||
|
// Simple test request to verify API key and connection
|
||||||
|
_, err := p.client.ListModels(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleOpenAIError converts OpenAI errors to provider errors
|
||||||
|
func (p *OpenAIProvider) handleOpenAIError(err error) *ProviderError {
|
||||||
|
errStr := err.Error()
|
||||||
|
|
||||||
|
if strings.Contains(errStr, "rate limit") {
|
||||||
|
return &ProviderError{
|
||||||
|
Code: "RATE_LIMIT_EXCEEDED",
|
||||||
|
Message: "OpenAI API rate limit exceeded",
|
||||||
|
Details: errStr,
|
||||||
|
Retryable: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(errStr, "quota") {
|
||||||
|
return &ProviderError{
|
||||||
|
Code: "QUOTA_EXCEEDED",
|
||||||
|
Message: "OpenAI API quota exceeded",
|
||||||
|
Details: errStr,
|
||||||
|
Retryable: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(errStr, "invalid_api_key") {
|
||||||
|
return &ProviderError{
|
||||||
|
Code: "INVALID_API_KEY",
|
||||||
|
Message: "Invalid OpenAI API key",
|
||||||
|
Details: errStr,
|
||||||
|
Retryable: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ProviderError{
|
||||||
|
Code: "API_ERROR",
|
||||||
|
Message: "OpenAI API error",
|
||||||
|
Details: errStr,
|
||||||
|
Retryable: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseResponseForActions extracts actions from the response text
|
||||||
|
func (p *OpenAIProvider) parseResponseForActions(response string, request *TaskRequest) ([]TaskAction, []Artifact) {
|
||||||
|
var actions []TaskAction
|
||||||
|
var artifacts []Artifact
|
||||||
|
|
||||||
|
// Create a basic task analysis action
|
||||||
|
action := TaskAction{
|
||||||
|
Type: "task_analysis",
|
||||||
|
Target: request.TaskTitle,
|
||||||
|
Content: response,
|
||||||
|
Result: "Task analyzed by OpenAI model",
|
||||||
|
Success: true,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"agent_role": request.AgentRole,
|
||||||
|
"repository": request.Repository,
|
||||||
|
"model": p.config.DefaultModel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actions = append(actions, action)
|
||||||
|
|
||||||
|
return actions, artifacts
|
||||||
|
}
|
||||||
211
pkg/ai/provider.go
Normal file
211
pkg/ai/provider.go
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelProvider defines the interface for AI model providers
|
||||||
|
type ModelProvider interface {
|
||||||
|
// ExecuteTask executes a task using the AI model
|
||||||
|
ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error)
|
||||||
|
|
||||||
|
// GetCapabilities returns the capabilities supported by this provider
|
||||||
|
GetCapabilities() ProviderCapabilities
|
||||||
|
|
||||||
|
// ValidateConfig validates the provider configuration
|
||||||
|
ValidateConfig() error
|
||||||
|
|
||||||
|
// GetProviderInfo returns information about this provider
|
||||||
|
GetProviderInfo() ProviderInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskRequest represents a request to execute a task
|
||||||
|
type TaskRequest struct {
|
||||||
|
// Task context and metadata
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
AgentID string `json:"agent_id"`
|
||||||
|
AgentRole string `json:"agent_role"`
|
||||||
|
Repository string `json:"repository"`
|
||||||
|
TaskTitle string `json:"task_title"`
|
||||||
|
TaskDescription string `json:"task_description"`
|
||||||
|
TaskLabels []string `json:"task_labels"`
|
||||||
|
Priority int `json:"priority"`
|
||||||
|
Complexity int `json:"complexity"`
|
||||||
|
|
||||||
|
// Model configuration
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||||
|
|
||||||
|
// Execution context
|
||||||
|
WorkingDirectory string `json:"working_directory"`
|
||||||
|
RepositoryFiles []string `json:"repository_files,omitempty"`
|
||||||
|
Context map[string]interface{} `json:"context,omitempty"`
|
||||||
|
|
||||||
|
// Tool and MCP configuration
|
||||||
|
EnableTools bool `json:"enable_tools"`
|
||||||
|
MCPServers []string `json:"mcp_servers,omitempty"`
|
||||||
|
AllowedTools []string `json:"allowed_tools,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskResponse represents the response from task execution
|
||||||
|
type TaskResponse struct {
|
||||||
|
// Execution results
|
||||||
|
Success bool `json:"success"`
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
AgentID string `json:"agent_id"`
|
||||||
|
ModelUsed string `json:"model_used"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
|
||||||
|
// Response content
|
||||||
|
Response string `json:"response"`
|
||||||
|
Reasoning string `json:"reasoning,omitempty"`
|
||||||
|
Actions []TaskAction `json:"actions,omitempty"`
|
||||||
|
Artifacts []Artifact `json:"artifacts,omitempty"`
|
||||||
|
|
||||||
|
// Metadata
|
||||||
|
StartTime time.Time `json:"start_time"`
|
||||||
|
EndTime time.Time `json:"end_time"`
|
||||||
|
Duration time.Duration `json:"duration"`
|
||||||
|
TokensUsed TokenUsage `json:"tokens_used,omitempty"`
|
||||||
|
|
||||||
|
// Error information
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
ErrorCode string `json:"error_code,omitempty"`
|
||||||
|
Retryable bool `json:"retryable,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskAction represents an action taken during task execution
|
||||||
|
type TaskAction struct {
|
||||||
|
Type string `json:"type"` // file_create, file_edit, command_run, etc.
|
||||||
|
Target string `json:"target"` // file path, command, etc.
|
||||||
|
Content string `json:"content"` // file content, command args, etc.
|
||||||
|
Result string `json:"result"` // execution result
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Artifact represents a file or output artifact from task execution
|
||||||
|
type Artifact struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"` // file, patch, log, etc.
|
||||||
|
Path string `json:"path"` // relative path in repository
|
||||||
|
Content string `json:"content"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Checksum string `json:"checksum"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenUsage represents token consumption for the request
|
||||||
|
type TokenUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderCapabilities defines what a provider supports
|
||||||
|
type ProviderCapabilities struct {
|
||||||
|
SupportsMCP bool `json:"supports_mcp"`
|
||||||
|
SupportsTools bool `json:"supports_tools"`
|
||||||
|
SupportsStreaming bool `json:"supports_streaming"`
|
||||||
|
SupportsFunctions bool `json:"supports_functions"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
SupportedModels []string `json:"supported_models"`
|
||||||
|
SupportsImages bool `json:"supports_images"`
|
||||||
|
SupportsFiles bool `json:"supports_files"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderInfo contains metadata about the provider
|
||||||
|
type ProviderInfo struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"` // ollama, openai, resetdata
|
||||||
|
Version string `json:"version"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
DefaultModel string `json:"default_model"`
|
||||||
|
RequiresAPIKey bool `json:"requires_api_key"`
|
||||||
|
RateLimit int `json:"rate_limit"` // requests per minute
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderConfig contains configuration for a specific provider
|
||||||
|
type ProviderConfig struct {
|
||||||
|
Type string `yaml:"type" json:"type"` // ollama, openai, resetdata
|
||||||
|
Endpoint string `yaml:"endpoint" json:"endpoint"`
|
||||||
|
APIKey string `yaml:"api_key" json:"api_key,omitempty"`
|
||||||
|
DefaultModel string `yaml:"default_model" json:"default_model"`
|
||||||
|
Temperature float32 `yaml:"temperature" json:"temperature"`
|
||||||
|
MaxTokens int `yaml:"max_tokens" json:"max_tokens"`
|
||||||
|
Timeout time.Duration `yaml:"timeout" json:"timeout"`
|
||||||
|
RetryAttempts int `yaml:"retry_attempts" json:"retry_attempts"`
|
||||||
|
RetryDelay time.Duration `yaml:"retry_delay" json:"retry_delay"`
|
||||||
|
EnableTools bool `yaml:"enable_tools" json:"enable_tools"`
|
||||||
|
EnableMCP bool `yaml:"enable_mcp" json:"enable_mcp"`
|
||||||
|
MCPServers []string `yaml:"mcp_servers" json:"mcp_servers,omitempty"`
|
||||||
|
CustomHeaders map[string]string `yaml:"custom_headers" json:"custom_headers,omitempty"`
|
||||||
|
ExtraParams map[string]interface{} `yaml:"extra_params" json:"extra_params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoleModelMapping defines model selection based on agent role
|
||||||
|
type RoleModelMapping struct {
|
||||||
|
DefaultProvider string `yaml:"default_provider" json:"default_provider"`
|
||||||
|
FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"`
|
||||||
|
Roles map[string]RoleConfig `yaml:"roles" json:"roles"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoleConfig defines model configuration for a specific role
|
||||||
|
type RoleConfig struct {
|
||||||
|
Provider string `yaml:"provider" json:"provider"`
|
||||||
|
Model string `yaml:"model" json:"model"`
|
||||||
|
Temperature float32 `yaml:"temperature" json:"temperature"`
|
||||||
|
MaxTokens int `yaml:"max_tokens" json:"max_tokens"`
|
||||||
|
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
|
||||||
|
FallbackProvider string `yaml:"fallback_provider" json:"fallback_provider"`
|
||||||
|
FallbackModel string `yaml:"fallback_model" json:"fallback_model"`
|
||||||
|
EnableTools bool `yaml:"enable_tools" json:"enable_tools"`
|
||||||
|
EnableMCP bool `yaml:"enable_mcp" json:"enable_mcp"`
|
||||||
|
AllowedTools []string `yaml:"allowed_tools" json:"allowed_tools,omitempty"`
|
||||||
|
MCPServers []string `yaml:"mcp_servers" json:"mcp_servers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common error types
|
||||||
|
var (
|
||||||
|
ErrProviderNotFound = &ProviderError{Code: "PROVIDER_NOT_FOUND", Message: "Provider not found"}
|
||||||
|
ErrModelNotSupported = &ProviderError{Code: "MODEL_NOT_SUPPORTED", Message: "Model not supported by provider"}
|
||||||
|
ErrAPIKeyRequired = &ProviderError{Code: "API_KEY_REQUIRED", Message: "API key required for provider"}
|
||||||
|
ErrRateLimitExceeded = &ProviderError{Code: "RATE_LIMIT_EXCEEDED", Message: "Rate limit exceeded"}
|
||||||
|
ErrProviderUnavailable = &ProviderError{Code: "PROVIDER_UNAVAILABLE", Message: "Provider temporarily unavailable"}
|
||||||
|
ErrInvalidConfiguration = &ProviderError{Code: "INVALID_CONFIGURATION", Message: "Invalid provider configuration"}
|
||||||
|
ErrTaskExecutionFailed = &ProviderError{Code: "TASK_EXECUTION_FAILED", Message: "Task execution failed"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderError represents provider-specific errors
|
||||||
|
type ProviderError struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Details string `json:"details,omitempty"`
|
||||||
|
Retryable bool `json:"retryable"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ProviderError) Error() string {
|
||||||
|
if e.Details != "" {
|
||||||
|
return e.Message + ": " + e.Details
|
||||||
|
}
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRetryable returns whether the error is retryable
|
||||||
|
func (e *ProviderError) IsRetryable() bool {
|
||||||
|
return e.Retryable
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProviderError creates a new provider error with details
|
||||||
|
func NewProviderError(base *ProviderError, details string) *ProviderError {
|
||||||
|
return &ProviderError{
|
||||||
|
Code: base.Code,
|
||||||
|
Message: base.Message,
|
||||||
|
Details: details,
|
||||||
|
Retryable: base.Retryable,
|
||||||
|
}
|
||||||
|
}
|
||||||
446
pkg/ai/provider_test.go
Normal file
446
pkg/ai/provider_test.go
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockProvider implements ModelProvider for testing
|
||||||
|
type MockProvider struct {
|
||||||
|
name string
|
||||||
|
capabilities ProviderCapabilities
|
||||||
|
shouldFail bool
|
||||||
|
response *TaskResponse
|
||||||
|
executeFunc func(ctx context.Context, request *TaskRequest) (*TaskResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockProvider(name string) *MockProvider {
|
||||||
|
return &MockProvider{
|
||||||
|
name: name,
|
||||||
|
capabilities: ProviderCapabilities{
|
||||||
|
SupportsMCP: true,
|
||||||
|
SupportsTools: true,
|
||||||
|
SupportsStreaming: true,
|
||||||
|
SupportsFunctions: false,
|
||||||
|
MaxTokens: 4096,
|
||||||
|
SupportedModels: []string{"test-model", "test-model-2"},
|
||||||
|
SupportsImages: false,
|
||||||
|
SupportsFiles: true,
|
||||||
|
},
|
||||||
|
response: &TaskResponse{
|
||||||
|
Success: true,
|
||||||
|
Response: "Mock response",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||||
|
if m.executeFunc != nil {
|
||||||
|
return m.executeFunc(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.shouldFail {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, "mock execution failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
response := *m.response // Copy the response
|
||||||
|
response.TaskID = request.TaskID
|
||||||
|
response.AgentID = request.AgentID
|
||||||
|
response.Provider = m.name
|
||||||
|
response.StartTime = time.Now()
|
||||||
|
response.EndTime = time.Now().Add(100 * time.Millisecond)
|
||||||
|
response.Duration = response.EndTime.Sub(response.StartTime)
|
||||||
|
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) GetCapabilities() ProviderCapabilities {
|
||||||
|
return m.capabilities
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) ValidateConfig() error {
|
||||||
|
if m.shouldFail {
|
||||||
|
return NewProviderError(ErrInvalidConfiguration, "mock config validation failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) GetProviderInfo() ProviderInfo {
|
||||||
|
return ProviderInfo{
|
||||||
|
Name: m.name,
|
||||||
|
Type: "mock",
|
||||||
|
Version: "1.0.0",
|
||||||
|
Endpoint: "mock://localhost",
|
||||||
|
DefaultModel: "test-model",
|
||||||
|
RequiresAPIKey: false,
|
||||||
|
RateLimit: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err *ProviderError
|
||||||
|
expected string
|
||||||
|
retryable bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple error",
|
||||||
|
err: ErrProviderNotFound,
|
||||||
|
expected: "Provider not found",
|
||||||
|
retryable: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error with details",
|
||||||
|
err: NewProviderError(ErrRateLimitExceeded, "API rate limit of 1000/hour exceeded"),
|
||||||
|
expected: "Rate limit exceeded: API rate limit of 1000/hour exceeded",
|
||||||
|
retryable: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retryable error",
|
||||||
|
err: &ProviderError{
|
||||||
|
Code: "TEMPORARY_ERROR",
|
||||||
|
Message: "Temporary failure",
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
expected: "Temporary failure",
|
||||||
|
retryable: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, tt.err.Error())
|
||||||
|
assert.Equal(t, tt.retryable, tt.err.IsRetryable())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskRequest(t *testing.T) {
|
||||||
|
request := &TaskRequest{
|
||||||
|
TaskID: "test-task-123",
|
||||||
|
AgentID: "agent-456",
|
||||||
|
AgentRole: "developer",
|
||||||
|
Repository: "test/repo",
|
||||||
|
TaskTitle: "Test Task",
|
||||||
|
TaskDescription: "A test task for unit testing",
|
||||||
|
TaskLabels: []string{"bug", "urgent"},
|
||||||
|
Priority: 8,
|
||||||
|
Complexity: 6,
|
||||||
|
ModelName: "test-model",
|
||||||
|
Temperature: 0.7,
|
||||||
|
MaxTokens: 4096,
|
||||||
|
EnableTools: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
assert.NotEmpty(t, request.TaskID)
|
||||||
|
assert.NotEmpty(t, request.AgentID)
|
||||||
|
assert.NotEmpty(t, request.AgentRole)
|
||||||
|
assert.NotEmpty(t, request.Repository)
|
||||||
|
assert.NotEmpty(t, request.TaskTitle)
|
||||||
|
assert.Greater(t, request.Priority, 0)
|
||||||
|
assert.Greater(t, request.Complexity, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskResponse(t *testing.T) {
|
||||||
|
startTime := time.Now()
|
||||||
|
endTime := startTime.Add(2 * time.Second)
|
||||||
|
|
||||||
|
response := &TaskResponse{
|
||||||
|
Success: true,
|
||||||
|
TaskID: "test-task-123",
|
||||||
|
AgentID: "agent-456",
|
||||||
|
ModelUsed: "test-model",
|
||||||
|
Provider: "mock",
|
||||||
|
Response: "Task completed successfully",
|
||||||
|
Actions: []TaskAction{
|
||||||
|
{
|
||||||
|
Type: "file_create",
|
||||||
|
Target: "test.go",
|
||||||
|
Content: "package main",
|
||||||
|
Result: "File created",
|
||||||
|
Success: true,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Artifacts: []Artifact{
|
||||||
|
{
|
||||||
|
Name: "test.go",
|
||||||
|
Type: "file",
|
||||||
|
Path: "./test.go",
|
||||||
|
Content: "package main",
|
||||||
|
Size: 12,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Duration: endTime.Sub(startTime),
|
||||||
|
TokensUsed: TokenUsage{
|
||||||
|
PromptTokens: 50,
|
||||||
|
CompletionTokens: 100,
|
||||||
|
TotalTokens: 150,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate response structure
|
||||||
|
assert.True(t, response.Success)
|
||||||
|
assert.NotEmpty(t, response.TaskID)
|
||||||
|
assert.NotEmpty(t, response.Provider)
|
||||||
|
assert.Len(t, response.Actions, 1)
|
||||||
|
assert.Len(t, response.Artifacts, 1)
|
||||||
|
assert.Equal(t, 2*time.Second, response.Duration)
|
||||||
|
assert.Equal(t, 150, response.TokensUsed.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskAction(t *testing.T) {
|
||||||
|
action := TaskAction{
|
||||||
|
Type: "file_edit",
|
||||||
|
Target: "main.go",
|
||||||
|
Content: "updated content",
|
||||||
|
Result: "File updated successfully",
|
||||||
|
Success: true,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"line_count": 42,
|
||||||
|
"backup": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "file_edit", action.Type)
|
||||||
|
assert.True(t, action.Success)
|
||||||
|
assert.NotNil(t, action.Metadata)
|
||||||
|
assert.Equal(t, 42, action.Metadata["line_count"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArtifact(t *testing.T) {
|
||||||
|
artifact := Artifact{
|
||||||
|
Name: "output.log",
|
||||||
|
Type: "log",
|
||||||
|
Path: "/tmp/output.log",
|
||||||
|
Content: "Log content here",
|
||||||
|
Size: 16,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Checksum: "abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "output.log", artifact.Name)
|
||||||
|
assert.Equal(t, "log", artifact.Type)
|
||||||
|
assert.Equal(t, int64(16), artifact.Size)
|
||||||
|
assert.NotEmpty(t, artifact.Checksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderCapabilities(t *testing.T) {
|
||||||
|
capabilities := ProviderCapabilities{
|
||||||
|
SupportsMCP: true,
|
||||||
|
SupportsTools: true,
|
||||||
|
SupportsStreaming: false,
|
||||||
|
SupportsFunctions: true,
|
||||||
|
MaxTokens: 8192,
|
||||||
|
SupportedModels: []string{"gpt-4", "gpt-3.5-turbo"},
|
||||||
|
SupportsImages: true,
|
||||||
|
SupportsFiles: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, capabilities.SupportsMCP)
|
||||||
|
assert.True(t, capabilities.SupportsTools)
|
||||||
|
assert.False(t, capabilities.SupportsStreaming)
|
||||||
|
assert.Equal(t, 8192, capabilities.MaxTokens)
|
||||||
|
assert.Len(t, capabilities.SupportedModels, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderInfo(t *testing.T) {
|
||||||
|
info := ProviderInfo{
|
||||||
|
Name: "Test Provider",
|
||||||
|
Type: "test",
|
||||||
|
Version: "1.0.0",
|
||||||
|
Endpoint: "https://api.test.com",
|
||||||
|
DefaultModel: "test-model",
|
||||||
|
RequiresAPIKey: true,
|
||||||
|
RateLimit: 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "Test Provider", info.Name)
|
||||||
|
assert.True(t, info.RequiresAPIKey)
|
||||||
|
assert.Equal(t, 1000, info.RateLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderConfig(t *testing.T) {
|
||||||
|
config := ProviderConfig{
|
||||||
|
Type: "test",
|
||||||
|
Endpoint: "https://api.test.com",
|
||||||
|
APIKey: "test-key",
|
||||||
|
DefaultModel: "test-model",
|
||||||
|
Temperature: 0.7,
|
||||||
|
MaxTokens: 4096,
|
||||||
|
Timeout: 300 * time.Second,
|
||||||
|
RetryAttempts: 3,
|
||||||
|
RetryDelay: 2 * time.Second,
|
||||||
|
EnableTools: true,
|
||||||
|
EnableMCP: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "test", config.Type)
|
||||||
|
assert.Equal(t, float32(0.7), config.Temperature)
|
||||||
|
assert.Equal(t, 4096, config.MaxTokens)
|
||||||
|
assert.Equal(t, 300*time.Second, config.Timeout)
|
||||||
|
assert.True(t, config.EnableTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleConfig(t *testing.T) {
|
||||||
|
roleConfig := RoleConfig{
|
||||||
|
Provider: "openai",
|
||||||
|
Model: "gpt-4",
|
||||||
|
Temperature: 0.3,
|
||||||
|
MaxTokens: 8192,
|
||||||
|
SystemPrompt: "You are a helpful assistant",
|
||||||
|
FallbackProvider: "ollama",
|
||||||
|
FallbackModel: "llama2",
|
||||||
|
EnableTools: true,
|
||||||
|
EnableMCP: false,
|
||||||
|
AllowedTools: []string{"file_ops", "code_analysis"},
|
||||||
|
MCPServers: []string{"file-server"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "openai", roleConfig.Provider)
|
||||||
|
assert.Equal(t, "gpt-4", roleConfig.Model)
|
||||||
|
assert.Equal(t, float32(0.3), roleConfig.Temperature)
|
||||||
|
assert.Len(t, roleConfig.AllowedTools, 2)
|
||||||
|
assert.True(t, roleConfig.EnableTools)
|
||||||
|
assert.False(t, roleConfig.EnableMCP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleModelMapping(t *testing.T) {
|
||||||
|
mapping := RoleModelMapping{
|
||||||
|
DefaultProvider: "ollama",
|
||||||
|
FallbackProvider: "openai",
|
||||||
|
Roles: map[string]RoleConfig{
|
||||||
|
"developer": {
|
||||||
|
Provider: "ollama",
|
||||||
|
Model: "codellama",
|
||||||
|
Temperature: 0.3,
|
||||||
|
},
|
||||||
|
"reviewer": {
|
||||||
|
Provider: "openai",
|
||||||
|
Model: "gpt-4",
|
||||||
|
Temperature: 0.2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "ollama", mapping.DefaultProvider)
|
||||||
|
assert.Len(t, mapping.Roles, 2)
|
||||||
|
|
||||||
|
devConfig, exists := mapping.Roles["developer"]
|
||||||
|
require.True(t, exists)
|
||||||
|
assert.Equal(t, "codellama", devConfig.Model)
|
||||||
|
assert.Equal(t, float32(0.3), devConfig.Temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUsage(t *testing.T) {
|
||||||
|
usage := TokenUsage{
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 200,
|
||||||
|
TotalTokens: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 100, usage.PromptTokens)
|
||||||
|
assert.Equal(t, 200, usage.CompletionTokens)
|
||||||
|
assert.Equal(t, 300, usage.TotalTokens)
|
||||||
|
assert.Equal(t, usage.PromptTokens+usage.CompletionTokens, usage.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockProviderExecuteTask(t *testing.T) {
|
||||||
|
provider := NewMockProvider("test-provider")
|
||||||
|
|
||||||
|
request := &TaskRequest{
|
||||||
|
TaskID: "test-123",
|
||||||
|
AgentID: "agent-456",
|
||||||
|
AgentRole: "developer",
|
||||||
|
Repository: "test/repo",
|
||||||
|
TaskTitle: "Test Task",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
response, err := provider.ExecuteTask(ctx, request)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, response.Success)
|
||||||
|
assert.Equal(t, "test-123", response.TaskID)
|
||||||
|
assert.Equal(t, "agent-456", response.AgentID)
|
||||||
|
assert.Equal(t, "test-provider", response.Provider)
|
||||||
|
assert.NotEmpty(t, response.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockProviderFailure(t *testing.T) {
|
||||||
|
provider := NewMockProvider("failing-provider")
|
||||||
|
provider.shouldFail = true
|
||||||
|
|
||||||
|
request := &TaskRequest{
|
||||||
|
TaskID: "test-123",
|
||||||
|
AgentID: "agent-456",
|
||||||
|
AgentRole: "developer",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err := provider.ExecuteTask(ctx, request)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.IsType(t, &ProviderError{}, err)
|
||||||
|
|
||||||
|
providerErr := err.(*ProviderError)
|
||||||
|
assert.Equal(t, "TASK_EXECUTION_FAILED", providerErr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockProviderCustomExecuteFunc(t *testing.T) {
|
||||||
|
provider := NewMockProvider("custom-provider")
|
||||||
|
|
||||||
|
// Set custom execution function
|
||||||
|
provider.executeFunc = func(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||||
|
return &TaskResponse{
|
||||||
|
Success: true,
|
||||||
|
TaskID: request.TaskID,
|
||||||
|
Response: "Custom response: " + request.TaskTitle,
|
||||||
|
Provider: "custom-provider",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
request := &TaskRequest{
|
||||||
|
TaskID: "test-123",
|
||||||
|
TaskTitle: "Custom Task",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
response, err := provider.ExecuteTask(ctx, request)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Custom response: Custom Task", response.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockProviderCapabilities(t *testing.T) {
|
||||||
|
provider := NewMockProvider("test-provider")
|
||||||
|
|
||||||
|
capabilities := provider.GetCapabilities()
|
||||||
|
|
||||||
|
assert.True(t, capabilities.SupportsMCP)
|
||||||
|
assert.True(t, capabilities.SupportsTools)
|
||||||
|
assert.Equal(t, 4096, capabilities.MaxTokens)
|
||||||
|
assert.Len(t, capabilities.SupportedModels, 2)
|
||||||
|
assert.Contains(t, capabilities.SupportedModels, "test-model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockProviderInfo(t *testing.T) {
|
||||||
|
provider := NewMockProvider("test-provider")
|
||||||
|
|
||||||
|
info := provider.GetProviderInfo()
|
||||||
|
|
||||||
|
assert.Equal(t, "test-provider", info.Name)
|
||||||
|
assert.Equal(t, "mock", info.Type)
|
||||||
|
assert.Equal(t, "test-model", info.DefaultModel)
|
||||||
|
assert.False(t, info.RequiresAPIKey)
|
||||||
|
}
|
||||||
500
pkg/ai/resetdata.go
Normal file
500
pkg/ai/resetdata.go
Normal file
@@ -0,0 +1,500 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResetDataProvider implements ModelProvider for ResetData LaaS API
|
||||||
|
type ResetDataProvider struct {
|
||||||
|
config ProviderConfig
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDataRequest represents a request to ResetData LaaS API
|
||||||
|
type ResetDataRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []ResetDataMessage `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
TopP float32 `json:"top_p,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDataMessage represents a message in the ResetData format
|
||||||
|
type ResetDataMessage struct {
|
||||||
|
Role string `json:"role"` // system, user, assistant
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDataResponse represents a response from ResetData LaaS API
|
||||||
|
type ResetDataResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ResetDataChoice `json:"choices"`
|
||||||
|
Usage ResetDataUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDataChoice represents a choice in the response
|
||||||
|
type ResetDataChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message ResetDataMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDataUsage represents token usage information
|
||||||
|
type ResetDataUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDataModelsResponse represents available models response
|
||||||
|
type ResetDataModelsResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []ResetDataModel `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDataModel represents a model in ResetData
|
||||||
|
type ResetDataModel struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResetDataProvider creates a new ResetData provider instance
|
||||||
|
func NewResetDataProvider(config ProviderConfig) *ResetDataProvider {
|
||||||
|
timeout := config.Timeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = 300 * time.Second // 5 minutes default for task execution
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ResetDataProvider{
|
||||||
|
config: config,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteTask implements the ModelProvider interface for ResetData
|
||||||
|
func (p *ResetDataProvider) ExecuteTask(ctx context.Context, request *TaskRequest) (*TaskResponse, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Build messages for the chat completion
|
||||||
|
messages, err := p.buildChatMessages(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to build messages: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the ResetData request
|
||||||
|
resetDataReq := ResetDataRequest{
|
||||||
|
Model: p.selectModel(request.ModelName),
|
||||||
|
Messages: messages,
|
||||||
|
Stream: false,
|
||||||
|
Temperature: p.getTemperature(request.Temperature),
|
||||||
|
MaxTokens: p.getMaxTokens(request.MaxTokens),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the request
|
||||||
|
response, err := p.makeRequest(ctx, "/v1/chat/completions", resetDataReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
endTime := time.Now()
|
||||||
|
|
||||||
|
// Process the response
|
||||||
|
if len(response.Choices) == 0 {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, "no response choices returned from ResetData")
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := response.Choices[0]
|
||||||
|
responseText := choice.Message.Content
|
||||||
|
|
||||||
|
// Parse response for actions and artifacts
|
||||||
|
actions, artifacts := p.parseResponseForActions(responseText, request)
|
||||||
|
|
||||||
|
return &TaskResponse{
|
||||||
|
Success: true,
|
||||||
|
TaskID: request.TaskID,
|
||||||
|
AgentID: request.AgentID,
|
||||||
|
ModelUsed: response.Model,
|
||||||
|
Provider: "resetdata",
|
||||||
|
Response: responseText,
|
||||||
|
Actions: actions,
|
||||||
|
Artifacts: artifacts,
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Duration: endTime.Sub(startTime),
|
||||||
|
TokensUsed: TokenUsage{
|
||||||
|
PromptTokens: response.Usage.PromptTokens,
|
||||||
|
CompletionTokens: response.Usage.CompletionTokens,
|
||||||
|
TotalTokens: response.Usage.TotalTokens,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCapabilities returns ResetData provider capabilities
|
||||||
|
func (p *ResetDataProvider) GetCapabilities() ProviderCapabilities {
|
||||||
|
return ProviderCapabilities{
|
||||||
|
SupportsMCP: p.config.EnableMCP,
|
||||||
|
SupportsTools: p.config.EnableTools,
|
||||||
|
SupportsStreaming: true,
|
||||||
|
SupportsFunctions: false, // ResetData LaaS doesn't support function calling
|
||||||
|
MaxTokens: p.config.MaxTokens,
|
||||||
|
SupportedModels: p.getSupportedModels(),
|
||||||
|
SupportsImages: false, // Most ResetData models don't support images
|
||||||
|
SupportsFiles: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConfig validates the ResetData provider configuration
|
||||||
|
func (p *ResetDataProvider) ValidateConfig() error {
|
||||||
|
if p.config.APIKey == "" {
|
||||||
|
return NewProviderError(ErrAPIKeyRequired, "API key is required for ResetData provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.Endpoint == "" {
|
||||||
|
return NewProviderError(ErrInvalidConfiguration, "endpoint is required for ResetData provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.DefaultModel == "" {
|
||||||
|
return NewProviderError(ErrInvalidConfiguration, "default_model is required for ResetData provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the API connection
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := p.testConnection(ctx); err != nil {
|
||||||
|
return NewProviderError(ErrProviderUnavailable, fmt.Sprintf("failed to connect to ResetData: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderInfo returns information about the ResetData provider
|
||||||
|
func (p *ResetDataProvider) GetProviderInfo() ProviderInfo {
|
||||||
|
return ProviderInfo{
|
||||||
|
Name: "ResetData",
|
||||||
|
Type: "resetdata",
|
||||||
|
Version: "1.0.0",
|
||||||
|
Endpoint: p.config.Endpoint,
|
||||||
|
DefaultModel: p.config.DefaultModel,
|
||||||
|
RequiresAPIKey: true,
|
||||||
|
RateLimit: 600, // 10 requests per second typical limit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildChatMessages constructs messages for the ResetData chat completion
|
||||||
|
func (p *ResetDataProvider) buildChatMessages(request *TaskRequest) ([]ResetDataMessage, error) {
|
||||||
|
var messages []ResetDataMessage
|
||||||
|
|
||||||
|
// System message
|
||||||
|
systemPrompt := p.getSystemPrompt(request)
|
||||||
|
if systemPrompt != "" {
|
||||||
|
messages = append(messages, ResetDataMessage{
|
||||||
|
Role: "system",
|
||||||
|
Content: systemPrompt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// User message with task details
|
||||||
|
userPrompt, err := p.buildTaskPrompt(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, ResetDataMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: userPrompt,
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildTaskPrompt constructs a comprehensive prompt for task execution
|
||||||
|
func (p *ResetDataProvider) buildTaskPrompt(request *TaskRequest) (string, error) {
|
||||||
|
var prompt strings.Builder
|
||||||
|
|
||||||
|
prompt.WriteString(fmt.Sprintf("Acting as a %s agent, analyze and work on this task:\n\n",
|
||||||
|
request.AgentRole))
|
||||||
|
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Repository:** %s\n", request.Repository))
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Task Title:** %s\n", request.TaskTitle))
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Description:**\n%s\n\n", request.TaskDescription))
|
||||||
|
|
||||||
|
if len(request.TaskLabels) > 0 {
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Labels:** %s\n", strings.Join(request.TaskLabels, ", ")))
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Priority:** %d/10 | **Complexity:** %d/10\n\n",
|
||||||
|
request.Priority, request.Complexity))
|
||||||
|
|
||||||
|
if request.WorkingDirectory != "" {
|
||||||
|
prompt.WriteString(fmt.Sprintf("**Working Directory:** %s\n", request.WorkingDirectory))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(request.RepositoryFiles) > 0 {
|
||||||
|
prompt.WriteString("**Relevant Files:**\n")
|
||||||
|
for _, file := range request.RepositoryFiles {
|
||||||
|
prompt.WriteString(fmt.Sprintf("- %s\n", file))
|
||||||
|
}
|
||||||
|
prompt.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add role-specific instructions
|
||||||
|
prompt.WriteString(p.getRoleSpecificInstructions(request.AgentRole))
|
||||||
|
|
||||||
|
prompt.WriteString("\nProvide a detailed analysis and implementation plan. ")
|
||||||
|
prompt.WriteString("Include specific steps, code changes, and any commands that need to be executed. ")
|
||||||
|
prompt.WriteString("Focus on delivering actionable results that address the task requirements completely.")
|
||||||
|
|
||||||
|
return prompt.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRoleSpecificInstructions returns instructions specific to the agent role
|
||||||
|
func (p *ResetDataProvider) getRoleSpecificInstructions(role string) string {
|
||||||
|
switch strings.ToLower(role) {
|
||||||
|
case "developer":
|
||||||
|
return `**Developer Focus Areas:**
|
||||||
|
- Implement robust, well-tested code solutions
|
||||||
|
- Follow coding standards and best practices
|
||||||
|
- Ensure proper error handling and edge case coverage
|
||||||
|
- Write clear documentation and comments
|
||||||
|
- Consider performance, security, and maintainability`
|
||||||
|
|
||||||
|
case "reviewer":
|
||||||
|
return `**Code Review Focus Areas:**
|
||||||
|
- Evaluate code quality, style, and best practices
|
||||||
|
- Identify potential bugs, security issues, and performance bottlenecks
|
||||||
|
- Check test coverage and test quality
|
||||||
|
- Verify documentation completeness and accuracy
|
||||||
|
- Suggest refactoring and improvement opportunities`
|
||||||
|
|
||||||
|
case "architect":
|
||||||
|
return `**Architecture Focus Areas:**
|
||||||
|
- Design scalable and maintainable system components
|
||||||
|
- Make informed decisions about technologies and patterns
|
||||||
|
- Define clear interfaces and integration points
|
||||||
|
- Consider scalability, security, and performance requirements
|
||||||
|
- Document architectural decisions and trade-offs`
|
||||||
|
|
||||||
|
case "tester":
|
||||||
|
return `**Testing Focus Areas:**
|
||||||
|
- Design comprehensive test strategies and test cases
|
||||||
|
- Implement automated tests at multiple levels
|
||||||
|
- Identify edge cases and failure scenarios
|
||||||
|
- Set up continuous testing and quality assurance
|
||||||
|
- Validate requirements and acceptance criteria`
|
||||||
|
|
||||||
|
default:
|
||||||
|
return `**General Focus Areas:**
|
||||||
|
- Understand requirements and constraints thoroughly
|
||||||
|
- Apply software engineering best practices
|
||||||
|
- Provide clear, actionable recommendations
|
||||||
|
- Consider long-term maintainability and extensibility`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectModel chooses the appropriate ResetData model
|
||||||
|
func (p *ResetDataProvider) selectModel(requestedModel string) string {
|
||||||
|
if requestedModel != "" {
|
||||||
|
return requestedModel
|
||||||
|
}
|
||||||
|
return p.config.DefaultModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTemperature returns the temperature setting
|
||||||
|
func (p *ResetDataProvider) getTemperature(requestTemp float32) float32 {
|
||||||
|
if requestTemp > 0 {
|
||||||
|
return requestTemp
|
||||||
|
}
|
||||||
|
if p.config.Temperature > 0 {
|
||||||
|
return p.config.Temperature
|
||||||
|
}
|
||||||
|
return 0.7 // Default temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMaxTokens returns the max tokens setting
|
||||||
|
func (p *ResetDataProvider) getMaxTokens(requestTokens int) int {
|
||||||
|
if requestTokens > 0 {
|
||||||
|
return requestTokens
|
||||||
|
}
|
||||||
|
if p.config.MaxTokens > 0 {
|
||||||
|
return p.config.MaxTokens
|
||||||
|
}
|
||||||
|
return 4096 // Default max tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemPrompt constructs the system prompt
|
||||||
|
func (p *ResetDataProvider) getSystemPrompt(request *TaskRequest) string {
|
||||||
|
if request.SystemPrompt != "" {
|
||||||
|
return request.SystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(`You are an expert software development AI assistant working as a %s agent
|
||||||
|
in the CHORUS autonomous development system.
|
||||||
|
|
||||||
|
Your expertise includes:
|
||||||
|
- Software architecture and design patterns
|
||||||
|
- Code implementation across multiple programming languages
|
||||||
|
- Testing strategies and quality assurance
|
||||||
|
- DevOps and deployment practices
|
||||||
|
- Security and performance optimization
|
||||||
|
|
||||||
|
Provide detailed, practical solutions with specific implementation steps.
|
||||||
|
Focus on delivering high-quality, production-ready results.`, request.AgentRole)
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeRequest makes an HTTP request to the ResetData API
|
||||||
|
func (p *ResetDataProvider) makeRequest(ctx context.Context, endpoint string, request interface{}) (*ResetDataResponse, error) {
|
||||||
|
requestJSON, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to marshal request: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
url := strings.TrimSuffix(p.config.Endpoint, "/") + endpoint
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(requestJSON))
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to create request: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set required headers
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||||
|
|
||||||
|
// Add custom headers if configured
|
||||||
|
for key, value := range p.config.CustomHeaders {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrProviderUnavailable, fmt.Sprintf("request failed: %v", err))
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to read response: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, p.handleHTTPError(resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resetDataResp ResetDataResponse
|
||||||
|
if err := json.Unmarshal(body, &resetDataResp); err != nil {
|
||||||
|
return nil, NewProviderError(ErrTaskExecutionFailed, fmt.Sprintf("failed to parse response: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &resetDataResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// testConnection tests the connection to ResetData API
|
||||||
|
func (p *ResetDataProvider) testConnection(ctx context.Context) error {
|
||||||
|
url := strings.TrimSuffix(p.config.Endpoint, "/") + "/v1/models"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||||
|
|
||||||
|
resp, err := p.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("API test failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSupportedModels returns a list of supported ResetData models
|
||||||
|
func (p *ResetDataProvider) getSupportedModels() []string {
|
||||||
|
// Common models available through ResetData LaaS
|
||||||
|
return []string{
|
||||||
|
"llama3.1:8b", "llama3.1:70b",
|
||||||
|
"mistral:7b", "mixtral:8x7b",
|
||||||
|
"qwen2:7b", "qwen2:72b",
|
||||||
|
"gemma:7b", "gemma2:9b",
|
||||||
|
"codellama:7b", "codellama:13b",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHTTPError converts HTTP errors to provider errors
|
||||||
|
func (p *ResetDataProvider) handleHTTPError(statusCode int, body []byte) *ProviderError {
|
||||||
|
bodyStr := string(body)
|
||||||
|
|
||||||
|
switch statusCode {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return &ProviderError{
|
||||||
|
Code: "UNAUTHORIZED",
|
||||||
|
Message: "Invalid ResetData API key",
|
||||||
|
Details: bodyStr,
|
||||||
|
Retryable: false,
|
||||||
|
}
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
return &ProviderError{
|
||||||
|
Code: "RATE_LIMIT_EXCEEDED",
|
||||||
|
Message: "ResetData API rate limit exceeded",
|
||||||
|
Details: bodyStr,
|
||||||
|
Retryable: true,
|
||||||
|
}
|
||||||
|
case http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable:
|
||||||
|
return &ProviderError{
|
||||||
|
Code: "SERVICE_UNAVAILABLE",
|
||||||
|
Message: "ResetData API service unavailable",
|
||||||
|
Details: bodyStr,
|
||||||
|
Retryable: true,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return &ProviderError{
|
||||||
|
Code: "API_ERROR",
|
||||||
|
Message: fmt.Sprintf("ResetData API error (status %d)", statusCode),
|
||||||
|
Details: bodyStr,
|
||||||
|
Retryable: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseResponseForActions extracts actions from the response text
|
||||||
|
func (p *ResetDataProvider) parseResponseForActions(response string, request *TaskRequest) ([]TaskAction, []Artifact) {
|
||||||
|
var actions []TaskAction
|
||||||
|
var artifacts []Artifact
|
||||||
|
|
||||||
|
// Create a basic task analysis action
|
||||||
|
action := TaskAction{
|
||||||
|
Type: "task_analysis",
|
||||||
|
Target: request.TaskTitle,
|
||||||
|
Content: response,
|
||||||
|
Result: "Task analyzed by ResetData model",
|
||||||
|
Success: true,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"agent_role": request.AgentRole,
|
||||||
|
"repository": request.Repository,
|
||||||
|
"model": p.config.DefaultModel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actions = append(actions, action)
|
||||||
|
|
||||||
|
return actions, artifacts
|
||||||
|
}
|
||||||
353
pkg/bootstrap/pool_manager.go
Normal file
353
pkg/bootstrap/pool_manager.go
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
package bootstrap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
|
"github.com/libp2p/go-libp2p/core/peer"
|
||||||
|
"github.com/multiformats/go-multiaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BootstrapPool manages a pool of bootstrap peers for DHT joining
|
||||||
|
type BootstrapPool struct {
|
||||||
|
peers []peer.AddrInfo
|
||||||
|
dialsPerSecond int
|
||||||
|
maxConcurrent int
|
||||||
|
staggerDelay time.Duration
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// BootstrapConfig represents the JSON configuration for bootstrap peers
|
||||||
|
type BootstrapConfig struct {
|
||||||
|
Peers []BootstrapPeer `json:"peers"`
|
||||||
|
Meta BootstrapMeta `json:"meta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BootstrapPeer represents a single bootstrap peer
|
||||||
|
type BootstrapPeer struct {
|
||||||
|
ID string `json:"id"` // Peer ID
|
||||||
|
Addresses []string `json:"addresses"` // Multiaddresses
|
||||||
|
Priority int `json:"priority"` // Priority (higher = more likely to be selected)
|
||||||
|
Healthy bool `json:"healthy"` // Health status
|
||||||
|
LastSeen string `json:"last_seen"` // Last seen timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
// BootstrapMeta contains metadata about the bootstrap configuration
|
||||||
|
type BootstrapMeta struct {
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
Version int `json:"version"`
|
||||||
|
ClusterID string `json:"cluster_id"`
|
||||||
|
TotalPeers int `json:"total_peers"`
|
||||||
|
HealthyPeers int `json:"healthy_peers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BootstrapSubset represents a subset of peers assigned to a replica
|
||||||
|
type BootstrapSubset struct {
|
||||||
|
Peers []peer.AddrInfo `json:"peers"`
|
||||||
|
StaggerDelayMS int `json:"stagger_delay_ms"`
|
||||||
|
AssignedAt time.Time `json:"assigned_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBootstrapPool creates a new bootstrap pool manager
|
||||||
|
func NewBootstrapPool(dialsPerSecond, maxConcurrent int, staggerMS int) *BootstrapPool {
|
||||||
|
return &BootstrapPool{
|
||||||
|
peers: []peer.AddrInfo{},
|
||||||
|
dialsPerSecond: dialsPerSecond,
|
||||||
|
maxConcurrent: maxConcurrent,
|
||||||
|
staggerDelay: time.Duration(staggerMS) * time.Millisecond,
|
||||||
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFromFile loads bootstrap configuration from a JSON file
|
||||||
|
func (bp *BootstrapPool) LoadFromFile(filePath string) error {
|
||||||
|
if filePath == "" {
|
||||||
|
return nil // No file configured
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := ioutil.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read bootstrap file %s: %w", filePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bp.loadFromJSON(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFromURL loads bootstrap configuration from a URL (WHOOSH endpoint)
|
||||||
|
func (bp *BootstrapPool) LoadFromURL(ctx context.Context, url string) error {
|
||||||
|
if url == "" {
|
||||||
|
return nil // No URL configured
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create bootstrap request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := bp.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("bootstrap request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("bootstrap request failed with status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read bootstrap response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bp.loadFromJSON(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadFromJSON parses JSON bootstrap configuration
|
||||||
|
func (bp *BootstrapPool) loadFromJSON(data []byte) error {
|
||||||
|
var config BootstrapConfig
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse bootstrap JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert bootstrap peers to AddrInfo
|
||||||
|
var peers []peer.AddrInfo
|
||||||
|
for _, bsPeer := range config.Peers {
|
||||||
|
// Only include healthy peers
|
||||||
|
if !bsPeer.Healthy {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse peer ID
|
||||||
|
peerID, err := peer.Decode(bsPeer.ID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("⚠️ Invalid peer ID %s: %v\n", bsPeer.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse multiaddresses
|
||||||
|
var addrs []multiaddr.Multiaddr
|
||||||
|
for _, addrStr := range bsPeer.Addresses {
|
||||||
|
addr, err := multiaddr.NewMultiaddr(addrStr)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("⚠️ Invalid multiaddress %s: %v\n", addrStr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrs = append(addrs, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addrs) > 0 {
|
||||||
|
peers = append(peers, peer.AddrInfo{
|
||||||
|
ID: peerID,
|
||||||
|
Addrs: addrs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bp.peers = peers
|
||||||
|
fmt.Printf("📋 Loaded %d healthy bootstrap peers from configuration\n", len(peers))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFromEnvironment loads bootstrap configuration from environment variables
|
||||||
|
func (bp *BootstrapPool) LoadFromEnvironment() error {
|
||||||
|
// Try loading from file first
|
||||||
|
if bootstrapFile := os.Getenv("BOOTSTRAP_JSON"); bootstrapFile != "" {
|
||||||
|
if err := bp.LoadFromFile(bootstrapFile); err != nil {
|
||||||
|
fmt.Printf("⚠️ Failed to load bootstrap from file: %v\n", err)
|
||||||
|
} else {
|
||||||
|
return nil // Successfully loaded from file
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try loading from URL
|
||||||
|
if bootstrapURL := os.Getenv("BOOTSTRAP_URL"); bootstrapURL != "" {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := bp.LoadFromURL(ctx, bootstrapURL); err != nil {
|
||||||
|
fmt.Printf("⚠️ Failed to load bootstrap from URL: %v\n", err)
|
||||||
|
} else {
|
||||||
|
return nil // Successfully loaded from URL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to legacy environment variable
|
||||||
|
if bootstrapPeersEnv := os.Getenv("CHORUS_BOOTSTRAP_PEERS"); bootstrapPeersEnv != "" {
|
||||||
|
return bp.loadFromLegacyEnv(bootstrapPeersEnv)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil // No bootstrap configuration found
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadFromLegacyEnv loads from comma-separated multiaddress list
|
||||||
|
func (bp *BootstrapPool) loadFromLegacyEnv(peersEnv string) error {
|
||||||
|
peerStrs := strings.Split(peersEnv, ",")
|
||||||
|
var peers []peer.AddrInfo
|
||||||
|
|
||||||
|
for _, peerStr := range peerStrs {
|
||||||
|
peerStr = strings.TrimSpace(peerStr)
|
||||||
|
if peerStr == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse multiaddress
|
||||||
|
addr, err := multiaddr.NewMultiaddr(peerStr)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("⚠️ Invalid bootstrap peer %s: %v\n", peerStr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract peer info
|
||||||
|
info, err := peer.AddrInfoFromP2pAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("⚠️ Failed to parse peer info from %s: %v\n", peerStr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peers = append(peers, *info)
|
||||||
|
}
|
||||||
|
|
||||||
|
bp.peers = peers
|
||||||
|
fmt.Printf("📋 Loaded %d bootstrap peers from legacy environment\n", len(peers))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSubset returns a subset of bootstrap peers for a replica
|
||||||
|
func (bp *BootstrapPool) GetSubset(count int) BootstrapSubset {
|
||||||
|
if len(bp.peers) == 0 {
|
||||||
|
return BootstrapSubset{
|
||||||
|
Peers: []peer.AddrInfo{},
|
||||||
|
StaggerDelayMS: 0,
|
||||||
|
AssignedAt: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure count doesn't exceed available peers
|
||||||
|
if count > len(bp.peers) {
|
||||||
|
count = len(bp.peers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Randomly select peers from the pool
|
||||||
|
selectedPeers := make([]peer.AddrInfo, 0, count)
|
||||||
|
indices := rand.Perm(len(bp.peers))
|
||||||
|
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
selectedPeers = append(selectedPeers, bp.peers[indices[i]])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate random stagger delay (0 to configured max)
|
||||||
|
staggerMS := 0
|
||||||
|
if bp.staggerDelay > 0 {
|
||||||
|
staggerMS = rand.Intn(int(bp.staggerDelay.Milliseconds()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return BootstrapSubset{
|
||||||
|
Peers: selectedPeers,
|
||||||
|
StaggerDelayMS: staggerMS,
|
||||||
|
AssignedAt: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectWithRateLimit connects to bootstrap peers with rate limiting
|
||||||
|
func (bp *BootstrapPool) ConnectWithRateLimit(ctx context.Context, h host.Host, subset BootstrapSubset) error {
|
||||||
|
if len(subset.Peers) == 0 {
|
||||||
|
return nil // No peers to connect to
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply stagger delay
|
||||||
|
if subset.StaggerDelayMS > 0 {
|
||||||
|
delay := time.Duration(subset.StaggerDelayMS) * time.Millisecond
|
||||||
|
fmt.Printf("⏱️ Applying join stagger delay: %v\n", delay)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-time.After(delay):
|
||||||
|
// Continue after delay
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create rate limiter for dials
|
||||||
|
ticker := time.NewTicker(time.Second / time.Duration(bp.dialsPerSecond))
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
// Semaphore for concurrent dials
|
||||||
|
semaphore := make(chan struct{}, bp.maxConcurrent)
|
||||||
|
|
||||||
|
// Connect to each peer with rate limiting
|
||||||
|
for i, peerInfo := range subset.Peers {
|
||||||
|
// Wait for rate limiter
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
// Rate limit satisfied
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire semaphore
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case semaphore <- struct{}{}:
|
||||||
|
// Semaphore acquired
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to peer in goroutine
|
||||||
|
go func(info peer.AddrInfo, index int) {
|
||||||
|
defer func() { <-semaphore }() // Release semaphore
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := h.Connect(ctx, info); err != nil {
|
||||||
|
fmt.Printf("⚠️ Failed to connect to bootstrap peer %s (%d/%d): %v\n",
|
||||||
|
info.ID.ShortString(), index+1, len(subset.Peers), err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("🔗 Connected to bootstrap peer %s (%d/%d)\n",
|
||||||
|
info.ID.ShortString(), index+1, len(subset.Peers))
|
||||||
|
}
|
||||||
|
}(peerInfo, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all connections to complete or timeout
|
||||||
|
for i := 0; i < bp.maxConcurrent && i < len(subset.Peers); i++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case semaphore <- struct{}{}:
|
||||||
|
<-semaphore // Immediately release
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerCount returns the number of available bootstrap peers
|
||||||
|
func (bp *BootstrapPool) GetPeerCount() int {
|
||||||
|
return len(bp.peers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeers returns all bootstrap peers (for debugging)
|
||||||
|
func (bp *BootstrapPool) GetPeers() []peer.AddrInfo {
|
||||||
|
return bp.peers
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats returns bootstrap pool statistics
|
||||||
|
func (bp *BootstrapPool) GetStats() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"peer_count": len(bp.peers),
|
||||||
|
"dials_per_second": bp.dialsPerSecond,
|
||||||
|
"max_concurrent": bp.maxConcurrent,
|
||||||
|
"stagger_delay_ms": bp.staggerDelay.Milliseconds(),
|
||||||
|
}
|
||||||
|
}
|
||||||
517
pkg/config/assignment.go
Normal file
517
pkg/config/assignment.go
Normal file
@@ -0,0 +1,517 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RuntimeConfig manages runtime configuration with assignment overrides
|
||||||
|
type RuntimeConfig struct {
|
||||||
|
Base *Config `json:"base"`
|
||||||
|
Override *AssignmentConfig `json:"override"`
|
||||||
|
mu sync.RWMutex
|
||||||
|
reloadCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssignmentConfig represents runtime assignment from WHOOSH
|
||||||
|
type AssignmentConfig struct {
|
||||||
|
// Assignment metadata
|
||||||
|
AssignmentID string `json:"assignment_id"`
|
||||||
|
TaskSlot string `json:"task_slot"`
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
ClusterID string `json:"cluster_id"`
|
||||||
|
AssignedAt time.Time `json:"assigned_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||||
|
|
||||||
|
// Agent configuration overrides
|
||||||
|
Agent *AgentConfig `json:"agent,omitempty"`
|
||||||
|
Network *NetworkConfig `json:"network,omitempty"`
|
||||||
|
AI *AIConfig `json:"ai,omitempty"`
|
||||||
|
Logging *LoggingConfig `json:"logging,omitempty"`
|
||||||
|
|
||||||
|
// Bootstrap configuration for scaling
|
||||||
|
BootstrapPeers []string `json:"bootstrap_peers,omitempty"`
|
||||||
|
JoinStagger int `json:"join_stagger_ms,omitempty"`
|
||||||
|
|
||||||
|
// Runtime capabilities
|
||||||
|
RuntimeCapabilities []string `json:"runtime_capabilities,omitempty"`
|
||||||
|
|
||||||
|
// Key derivation for encryption
|
||||||
|
RoleKey string `json:"role_key,omitempty"`
|
||||||
|
ClusterSecret string `json:"cluster_secret,omitempty"`
|
||||||
|
|
||||||
|
// Custom fields
|
||||||
|
Custom map[string]interface{} `json:"custom,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssignmentRequest represents a request for assignment from WHOOSH
|
||||||
|
type AssignmentRequest struct {
|
||||||
|
ClusterID string `json:"cluster_id"`
|
||||||
|
TaskSlot string `json:"task_slot,omitempty"`
|
||||||
|
TaskID string `json:"task_id,omitempty"`
|
||||||
|
AgentID string `json:"agent_id"`
|
||||||
|
NodeID string `json:"node_id"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRuntimeConfig creates a new runtime configuration manager
|
||||||
|
func NewRuntimeConfig(baseConfig *Config) *RuntimeConfig {
|
||||||
|
return &RuntimeConfig{
|
||||||
|
Base: baseConfig,
|
||||||
|
Override: nil,
|
||||||
|
reloadCh: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the effective configuration value, with override taking precedence
|
||||||
|
func (rc *RuntimeConfig) Get(field string) interface{} {
|
||||||
|
rc.mu.RLock()
|
||||||
|
defer rc.mu.RUnlock()
|
||||||
|
|
||||||
|
// Try override first
|
||||||
|
if rc.Override != nil {
|
||||||
|
if value := rc.getFromAssignment(field); value != nil {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to base configuration
|
||||||
|
return rc.getFromBase(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig returns a merged configuration with overrides applied
|
||||||
|
func (rc *RuntimeConfig) GetConfig() *Config {
|
||||||
|
rc.mu.RLock()
|
||||||
|
defer rc.mu.RUnlock()
|
||||||
|
|
||||||
|
if rc.Override == nil {
|
||||||
|
return rc.Base
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a copy of base config
|
||||||
|
merged := *rc.Base
|
||||||
|
|
||||||
|
// Apply overrides
|
||||||
|
if rc.Override.Agent != nil {
|
||||||
|
rc.mergeAgentConfig(&merged.Agent, rc.Override.Agent)
|
||||||
|
}
|
||||||
|
if rc.Override.Network != nil {
|
||||||
|
rc.mergeNetworkConfig(&merged.Network, rc.Override.Network)
|
||||||
|
}
|
||||||
|
if rc.Override.AI != nil {
|
||||||
|
rc.mergeAIConfig(&merged.AI, rc.Override.AI)
|
||||||
|
}
|
||||||
|
if rc.Override.Logging != nil {
|
||||||
|
rc.mergeLoggingConfig(&merged.Logging, rc.Override.Logging)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAssignment fetches assignment from WHOOSH and applies it
|
||||||
|
func (rc *RuntimeConfig) LoadAssignment(ctx context.Context, assignURL string) error {
|
||||||
|
if assignURL == "" {
|
||||||
|
return nil // No assignment URL configured
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build assignment request
|
||||||
|
agentID := rc.Base.Agent.ID
|
||||||
|
if agentID == "" {
|
||||||
|
agentID = "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
req := AssignmentRequest{
|
||||||
|
ClusterID: rc.Base.License.ClusterID,
|
||||||
|
TaskSlot: os.Getenv("TASK_SLOT"),
|
||||||
|
TaskID: os.Getenv("TASK_ID"),
|
||||||
|
AgentID: agentID,
|
||||||
|
NodeID: os.Getenv("NODE_ID"),
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make HTTP request to WHOOSH
|
||||||
|
assignment, err := rc.fetchAssignment(ctx, assignURL, req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch assignment: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply assignment
|
||||||
|
rc.mu.Lock()
|
||||||
|
rc.Override = assignment
|
||||||
|
rc.mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartReloadHandler starts a signal handler for SIGHUP configuration reloads
|
||||||
|
func (rc *RuntimeConfig) StartReloadHandler(ctx context.Context, assignURL string) {
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGHUP)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-sigCh:
|
||||||
|
fmt.Println("📡 Received SIGHUP, reloading assignment configuration...")
|
||||||
|
if err := rc.LoadAssignment(ctx, assignURL); err != nil {
|
||||||
|
fmt.Printf("❌ Failed to reload assignment: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("✅ Assignment configuration reloaded successfully")
|
||||||
|
}
|
||||||
|
case <-rc.reloadCh:
|
||||||
|
// Manual reload trigger
|
||||||
|
if err := rc.LoadAssignment(ctx, assignURL); err != nil {
|
||||||
|
fmt.Printf("❌ Failed to reload assignment: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("✅ Assignment configuration reloaded successfully")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload triggers a manual configuration reload
|
||||||
|
func (rc *RuntimeConfig) Reload() {
|
||||||
|
select {
|
||||||
|
case rc.reloadCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Channel full, reload already pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchAssignment makes HTTP request to WHOOSH assignment API
|
||||||
|
func (rc *RuntimeConfig) fetchAssignment(ctx context.Context, assignURL string, req AssignmentRequest) (*AssignmentConfig, error) {
|
||||||
|
// Build query parameters
|
||||||
|
queryParams := fmt.Sprintf("?cluster_id=%s&agent_id=%s&node_id=%s",
|
||||||
|
req.ClusterID, req.AgentID, req.NodeID)
|
||||||
|
|
||||||
|
if req.TaskSlot != "" {
|
||||||
|
queryParams += "&task_slot=" + req.TaskSlot
|
||||||
|
}
|
||||||
|
if req.TaskID != "" {
|
||||||
|
queryParams += "&task_id=" + req.TaskID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTP request
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, "GET", assignURL+queryParams, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create assignment request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq.Header.Set("Accept", "application/json")
|
||||||
|
httpReq.Header.Set("User-Agent", "CHORUS-Agent/0.1.0")
|
||||||
|
|
||||||
|
// Make request with timeout
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
resp, err := client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("assignment request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
// No assignment available
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("assignment request failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse assignment response
|
||||||
|
var assignment AssignmentConfig
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&assignment); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode assignment response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &assignment, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods for getting values from different sources
|
||||||
|
func (rc *RuntimeConfig) getFromAssignment(field string) interface{} {
|
||||||
|
if rc.Override == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple field mapping - in a real implementation, you'd use reflection
|
||||||
|
// or a more sophisticated field mapping system
|
||||||
|
switch field {
|
||||||
|
case "agent.id":
|
||||||
|
if rc.Override.Agent != nil && rc.Override.Agent.ID != "" {
|
||||||
|
return rc.Override.Agent.ID
|
||||||
|
}
|
||||||
|
case "agent.role":
|
||||||
|
if rc.Override.Agent != nil && rc.Override.Agent.Role != "" {
|
||||||
|
return rc.Override.Agent.Role
|
||||||
|
}
|
||||||
|
case "agent.capabilities":
|
||||||
|
if len(rc.Override.RuntimeCapabilities) > 0 {
|
||||||
|
return rc.Override.RuntimeCapabilities
|
||||||
|
}
|
||||||
|
case "bootstrap_peers":
|
||||||
|
if len(rc.Override.BootstrapPeers) > 0 {
|
||||||
|
return rc.Override.BootstrapPeers
|
||||||
|
}
|
||||||
|
case "join_stagger":
|
||||||
|
if rc.Override.JoinStagger > 0 {
|
||||||
|
return rc.Override.JoinStagger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check custom fields
|
||||||
|
if rc.Override.Custom != nil {
|
||||||
|
if val, exists := rc.Override.Custom[field]; exists {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *RuntimeConfig) getFromBase(field string) interface{} {
|
||||||
|
// Simple field mapping for base config
|
||||||
|
switch field {
|
||||||
|
case "agent.id":
|
||||||
|
return rc.Base.Agent.ID
|
||||||
|
case "agent.role":
|
||||||
|
return rc.Base.Agent.Role
|
||||||
|
case "agent.capabilities":
|
||||||
|
return rc.Base.Agent.Capabilities
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods for merging configuration sections
|
||||||
|
func (rc *RuntimeConfig) mergeAgentConfig(base *AgentConfig, override *AgentConfig) {
|
||||||
|
if override.ID != "" {
|
||||||
|
base.ID = override.ID
|
||||||
|
}
|
||||||
|
if override.Specialization != "" {
|
||||||
|
base.Specialization = override.Specialization
|
||||||
|
}
|
||||||
|
if override.MaxTasks > 0 {
|
||||||
|
base.MaxTasks = override.MaxTasks
|
||||||
|
}
|
||||||
|
if len(override.Capabilities) > 0 {
|
||||||
|
base.Capabilities = override.Capabilities
|
||||||
|
}
|
||||||
|
if len(override.Models) > 0 {
|
||||||
|
base.Models = override.Models
|
||||||
|
}
|
||||||
|
if override.Role != "" {
|
||||||
|
base.Role = override.Role
|
||||||
|
}
|
||||||
|
if override.Project != "" {
|
||||||
|
base.Project = override.Project
|
||||||
|
}
|
||||||
|
if len(override.Expertise) > 0 {
|
||||||
|
base.Expertise = override.Expertise
|
||||||
|
}
|
||||||
|
if override.ReportsTo != "" {
|
||||||
|
base.ReportsTo = override.ReportsTo
|
||||||
|
}
|
||||||
|
if len(override.Deliverables) > 0 {
|
||||||
|
base.Deliverables = override.Deliverables
|
||||||
|
}
|
||||||
|
if override.ModelSelectionWebhook != "" {
|
||||||
|
base.ModelSelectionWebhook = override.ModelSelectionWebhook
|
||||||
|
}
|
||||||
|
if override.DefaultReasoningModel != "" {
|
||||||
|
base.DefaultReasoningModel = override.DefaultReasoningModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *RuntimeConfig) mergeNetworkConfig(base *NetworkConfig, override *NetworkConfig) {
|
||||||
|
if override.P2PPort > 0 {
|
||||||
|
base.P2PPort = override.P2PPort
|
||||||
|
}
|
||||||
|
if override.APIPort > 0 {
|
||||||
|
base.APIPort = override.APIPort
|
||||||
|
}
|
||||||
|
if override.HealthPort > 0 {
|
||||||
|
base.HealthPort = override.HealthPort
|
||||||
|
}
|
||||||
|
if override.BindAddr != "" {
|
||||||
|
base.BindAddr = override.BindAddr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *RuntimeConfig) mergeAIConfig(base *AIConfig, override *AIConfig) {
|
||||||
|
if override.Provider != "" {
|
||||||
|
base.Provider = override.Provider
|
||||||
|
}
|
||||||
|
// Merge Ollama config if present
|
||||||
|
if override.Ollama.Endpoint != "" {
|
||||||
|
base.Ollama.Endpoint = override.Ollama.Endpoint
|
||||||
|
}
|
||||||
|
if override.Ollama.Timeout > 0 {
|
||||||
|
base.Ollama.Timeout = override.Ollama.Timeout
|
||||||
|
}
|
||||||
|
// Merge ResetData config if present
|
||||||
|
if override.ResetData.BaseURL != "" {
|
||||||
|
base.ResetData.BaseURL = override.ResetData.BaseURL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *RuntimeConfig) mergeLoggingConfig(base *LoggingConfig, override *LoggingConfig) {
|
||||||
|
if override.Level != "" {
|
||||||
|
base.Level = override.Level
|
||||||
|
}
|
||||||
|
if override.Format != "" {
|
||||||
|
base.Format = override.Format
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BootstrapConfig represents JSON bootstrap configuration
|
||||||
|
type BootstrapConfig struct {
|
||||||
|
Peers []BootstrapPeer `json:"peers"`
|
||||||
|
Metadata BootstrapMeta `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BootstrapPeer represents a single bootstrap peer
|
||||||
|
type BootstrapPeer struct {
|
||||||
|
Address string `json:"address"`
|
||||||
|
Priority int `json:"priority,omitempty"`
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
Roles []string `json:"roles,omitempty"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BootstrapMeta contains metadata about the bootstrap configuration
|
||||||
|
type BootstrapMeta struct {
|
||||||
|
GeneratedAt time.Time `json:"generated_at,omitempty"`
|
||||||
|
ClusterID string `json:"cluster_id,omitempty"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
Notes string `json:"notes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBootstrapPeers returns bootstrap peers with assignment override support and JSON config
|
||||||
|
func (rc *RuntimeConfig) GetBootstrapPeers() []string {
|
||||||
|
rc.mu.RLock()
|
||||||
|
defer rc.mu.RUnlock()
|
||||||
|
|
||||||
|
// First priority: Assignment override from WHOOSH
|
||||||
|
if rc.Override != nil && len(rc.Override.BootstrapPeers) > 0 {
|
||||||
|
return rc.Override.BootstrapPeers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second priority: JSON bootstrap configuration
|
||||||
|
if jsonPeers := rc.loadBootstrapJSON(); len(jsonPeers) > 0 {
|
||||||
|
return jsonPeers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third priority: Environment variable (CSV format)
|
||||||
|
if bootstrapEnv := os.Getenv("CHORUS_BOOTSTRAP_PEERS"); bootstrapEnv != "" {
|
||||||
|
peers := strings.Split(bootstrapEnv, ",")
|
||||||
|
// Trim whitespace from each peer
|
||||||
|
for i, peer := range peers {
|
||||||
|
peers[i] = strings.TrimSpace(peer)
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadBootstrapJSON loads bootstrap peers from JSON file
|
||||||
|
func (rc *RuntimeConfig) loadBootstrapJSON() []string {
|
||||||
|
jsonPath := os.Getenv("BOOTSTRAP_JSON")
|
||||||
|
if jsonPath == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if file exists
|
||||||
|
if _, err := os.Stat(jsonPath); os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and parse JSON file
|
||||||
|
data, err := os.ReadFile(jsonPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("⚠️ Failed to read bootstrap JSON file %s: %v\n", jsonPath, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var config BootstrapConfig
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
fmt.Printf("⚠️ Failed to parse bootstrap JSON file %s: %v\n", jsonPath, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract enabled peer addresses, sorted by priority
|
||||||
|
var peers []string
|
||||||
|
enabledPeers := make([]BootstrapPeer, 0, len(config.Peers))
|
||||||
|
|
||||||
|
// Filter enabled peers
|
||||||
|
for _, peer := range config.Peers {
|
||||||
|
if peer.Enabled && peer.Address != "" {
|
||||||
|
enabledPeers = append(enabledPeers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by priority (higher priority first)
|
||||||
|
for i := 0; i < len(enabledPeers)-1; i++ {
|
||||||
|
for j := i + 1; j < len(enabledPeers); j++ {
|
||||||
|
if enabledPeers[j].Priority > enabledPeers[i].Priority {
|
||||||
|
enabledPeers[i], enabledPeers[j] = enabledPeers[j], enabledPeers[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract addresses
|
||||||
|
for _, peer := range enabledPeers {
|
||||||
|
peers = append(peers, peer.Address)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(peers) > 0 {
|
||||||
|
fmt.Printf("📋 Loaded %d bootstrap peers from JSON: %s\n", len(peers), jsonPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJoinStagger returns join stagger delay with assignment override support
|
||||||
|
func (rc *RuntimeConfig) GetJoinStagger() time.Duration {
|
||||||
|
rc.mu.RLock()
|
||||||
|
defer rc.mu.RUnlock()
|
||||||
|
|
||||||
|
if rc.Override != nil && rc.Override.JoinStagger > 0 {
|
||||||
|
return time.Duration(rc.Override.JoinStagger) * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to environment variable
|
||||||
|
if staggerEnv := os.Getenv("CHORUS_JOIN_STAGGER_MS"); staggerEnv != "" {
|
||||||
|
if ms, err := time.ParseDuration(staggerEnv + "ms"); err == nil {
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAssignmentInfo returns current assignment metadata
|
||||||
|
func (rc *RuntimeConfig) GetAssignmentInfo() *AssignmentConfig {
|
||||||
|
rc.mu.RLock()
|
||||||
|
defer rc.mu.RUnlock()
|
||||||
|
|
||||||
|
if rc.Override == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a copy to prevent external modification
|
||||||
|
assignment := *rc.Override
|
||||||
|
return &assignment
|
||||||
|
}
|
||||||
@@ -100,6 +100,7 @@ type V2Config struct {
|
|||||||
type DHTConfig struct {
|
type DHTConfig struct {
|
||||||
Enabled bool `yaml:"enabled"`
|
Enabled bool `yaml:"enabled"`
|
||||||
BootstrapPeers []string `yaml:"bootstrap_peers"`
|
BootstrapPeers []string `yaml:"bootstrap_peers"`
|
||||||
|
MDNSEnabled bool `yaml:"mdns_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UCXLConfig defines UCXL protocol settings
|
// UCXLConfig defines UCXL protocol settings
|
||||||
@@ -192,6 +193,7 @@ func LoadFromEnvironment() (*Config, error) {
|
|||||||
DHT: DHTConfig{
|
DHT: DHTConfig{
|
||||||
Enabled: getEnvBoolOrDefault("CHORUS_DHT_ENABLED", true),
|
Enabled: getEnvBoolOrDefault("CHORUS_DHT_ENABLED", true),
|
||||||
BootstrapPeers: getEnvArrayOrDefault("CHORUS_BOOTSTRAP_PEERS", []string{}),
|
BootstrapPeers: getEnvArrayOrDefault("CHORUS_BOOTSTRAP_PEERS", []string{}),
|
||||||
|
MDNSEnabled: getEnvBoolOrDefault("CHORUS_MDNS_ENABLED", true),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
UCXL: UCXLConfig{
|
UCXL: UCXLConfig{
|
||||||
@@ -216,7 +218,7 @@ func LoadFromEnvironment() (*Config, error) {
|
|||||||
AuditLogging: getEnvBoolOrDefault("CHORUS_AUDIT_LOGGING", true),
|
AuditLogging: getEnvBoolOrDefault("CHORUS_AUDIT_LOGGING", true),
|
||||||
AuditPath: getEnvOrDefault("CHORUS_AUDIT_PATH", "/tmp/chorus-audit.log"),
|
AuditPath: getEnvOrDefault("CHORUS_AUDIT_PATH", "/tmp/chorus-audit.log"),
|
||||||
ElectionConfig: ElectionConfig{
|
ElectionConfig: ElectionConfig{
|
||||||
DiscoveryTimeout: getEnvDurationOrDefault("CHORUS_DISCOVERY_TIMEOUT", 10*time.Second),
|
DiscoveryTimeout: getEnvDurationOrDefault("CHORUS_DISCOVERY_TIMEOUT", 15*time.Second),
|
||||||
HeartbeatTimeout: getEnvDurationOrDefault("CHORUS_HEARTBEAT_TIMEOUT", 30*time.Second),
|
HeartbeatTimeout: getEnvDurationOrDefault("CHORUS_HEARTBEAT_TIMEOUT", 30*time.Second),
|
||||||
ElectionTimeout: getEnvDurationOrDefault("CHORUS_ELECTION_TIMEOUT", 60*time.Second),
|
ElectionTimeout: getEnvDurationOrDefault("CHORUS_ELECTION_TIMEOUT", 60*time.Second),
|
||||||
DiscoveryBackoff: getEnvDurationOrDefault("CHORUS_DISCOVERY_BACKOFF", 5*time.Second),
|
DiscoveryBackoff: getEnvDurationOrDefault("CHORUS_DISCOVERY_BACKOFF", 5*time.Second),
|
||||||
|
|||||||
@@ -41,10 +41,16 @@ type HybridUCXLConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type DiscoveryConfig struct {
|
type DiscoveryConfig struct {
|
||||||
MDNSEnabled bool `env:"CHORUS_MDNS_ENABLED" default:"true" json:"mdns_enabled" yaml:"mdns_enabled"`
|
MDNSEnabled bool `env:"CHORUS_MDNS_ENABLED" default:"true" json:"mdns_enabled" yaml:"mdns_enabled"`
|
||||||
DHTDiscovery bool `env:"CHORUS_DHT_DISCOVERY" default:"false" json:"dht_discovery" yaml:"dht_discovery"`
|
DHTDiscovery bool `env:"CHORUS_DHT_DISCOVERY" default:"false" json:"dht_discovery" yaml:"dht_discovery"`
|
||||||
AnnounceInterval time.Duration `env:"CHORUS_ANNOUNCE_INTERVAL" default:"30s" json:"announce_interval" yaml:"announce_interval"`
|
AnnounceInterval time.Duration `env:"CHORUS_ANNOUNCE_INTERVAL" default:"30s" json:"announce_interval" yaml:"announce_interval"`
|
||||||
ServiceName string `env:"CHORUS_SERVICE_NAME" default:"CHORUS" json:"service_name" yaml:"service_name"`
|
ServiceName string `env:"CHORUS_SERVICE_NAME" default:"CHORUS" json:"service_name" yaml:"service_name"`
|
||||||
|
|
||||||
|
// Rate limiting for scaling (as per WHOOSH issue #7)
|
||||||
|
DialsPerSecond int `env:"CHORUS_DIALS_PER_SEC" default:"5" json:"dials_per_second" yaml:"dials_per_second"`
|
||||||
|
MaxConcurrentDHT int `env:"CHORUS_MAX_CONCURRENT_DHT" default:"16" json:"max_concurrent_dht" yaml:"max_concurrent_dht"`
|
||||||
|
MaxConcurrentDials int `env:"CHORUS_MAX_CONCURRENT_DIALS" default:"10" json:"max_concurrent_dials" yaml:"max_concurrent_dials"`
|
||||||
|
JoinStaggerMS int `env:"CHORUS_JOIN_STAGGER_MS" default:"0" json:"join_stagger_ms" yaml:"join_stagger_ms"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MonitoringConfig struct {
|
type MonitoringConfig struct {
|
||||||
@@ -79,10 +85,16 @@ func LoadHybridConfig() (*HybridConfig, error) {
|
|||||||
|
|
||||||
// Load Discovery configuration
|
// Load Discovery configuration
|
||||||
config.Discovery = DiscoveryConfig{
|
config.Discovery = DiscoveryConfig{
|
||||||
MDNSEnabled: getEnvBool("CHORUS_MDNS_ENABLED", true),
|
MDNSEnabled: getEnvBool("CHORUS_MDNS_ENABLED", true),
|
||||||
DHTDiscovery: getEnvBool("CHORUS_DHT_DISCOVERY", false),
|
DHTDiscovery: getEnvBool("CHORUS_DHT_DISCOVERY", false),
|
||||||
AnnounceInterval: getEnvDuration("CHORUS_ANNOUNCE_INTERVAL", 30*time.Second),
|
AnnounceInterval: getEnvDuration("CHORUS_ANNOUNCE_INTERVAL", 30*time.Second),
|
||||||
ServiceName: getEnvString("CHORUS_SERVICE_NAME", "CHORUS"),
|
ServiceName: getEnvString("CHORUS_SERVICE_NAME", "CHORUS"),
|
||||||
|
|
||||||
|
// Rate limiting for scaling (as per WHOOSH issue #7)
|
||||||
|
DialsPerSecond: getEnvInt("CHORUS_DIALS_PER_SEC", 5),
|
||||||
|
MaxConcurrentDHT: getEnvInt("CHORUS_MAX_CONCURRENT_DHT", 16),
|
||||||
|
MaxConcurrentDials: getEnvInt("CHORUS_MAX_CONCURRENT_DIALS", 10),
|
||||||
|
JoinStaggerMS: getEnvInt("CHORUS_JOIN_STAGGER_MS", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load Monitoring configuration
|
// Load Monitoring configuration
|
||||||
|
|||||||
306
pkg/crypto/key_derivation.go
Normal file
306
pkg/crypto/key_derivation.go
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
"filippo.io/age"
|
||||||
|
"filippo.io/age/armor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyDerivationManager handles cluster-scoped key derivation for DHT encryption
|
||||||
|
type KeyDerivationManager struct {
|
||||||
|
clusterRootKey []byte
|
||||||
|
clusterID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DerivedKeySet contains keys derived for a specific role/scope
|
||||||
|
type DerivedKeySet struct {
|
||||||
|
RoleKey []byte // Role-specific key
|
||||||
|
NodeKey []byte // Node-specific key for this instance
|
||||||
|
AGEIdentity *age.X25519Identity // AGE identity for encryption/decryption
|
||||||
|
AGERecipient *age.X25519Recipient // AGE recipient for encryption
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKeyDerivationManager creates a new key derivation manager
|
||||||
|
func NewKeyDerivationManager(clusterRootKey []byte, clusterID string) *KeyDerivationManager {
|
||||||
|
return &KeyDerivationManager{
|
||||||
|
clusterRootKey: clusterRootKey,
|
||||||
|
clusterID: clusterID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKeyDerivationManagerFromSeed creates a manager from a seed string
|
||||||
|
func NewKeyDerivationManagerFromSeed(seed, clusterID string) *KeyDerivationManager {
|
||||||
|
// Use HKDF to derive a consistent root key from seed
|
||||||
|
hash := sha256.New
|
||||||
|
hkdf := hkdf.New(hash, []byte(seed), []byte(clusterID), []byte("CHORUS-cluster-root"))
|
||||||
|
|
||||||
|
rootKey := make([]byte, 32)
|
||||||
|
if _, err := io.ReadFull(hkdf, rootKey); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to derive cluster root key: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KeyDerivationManager{
|
||||||
|
clusterRootKey: rootKey,
|
||||||
|
clusterID: clusterID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveRoleKeys derives encryption keys for a specific role and agent
|
||||||
|
func (kdm *KeyDerivationManager) DeriveRoleKeys(role, agentID string) (*DerivedKeySet, error) {
|
||||||
|
if kdm.clusterRootKey == nil {
|
||||||
|
return nil, fmt.Errorf("cluster root key not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive role-specific key
|
||||||
|
roleKey, err := kdm.deriveKey(fmt.Sprintf("role-%s", role), 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive role key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive node-specific key from role key and agent ID
|
||||||
|
nodeKey, err := kdm.deriveKeyFromParent(roleKey, fmt.Sprintf("node-%s", agentID), 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive node key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate AGE identity from node key
|
||||||
|
ageIdentity, err := kdm.generateAGEIdentityFromKey(nodeKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate AGE identity: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ageRecipient := ageIdentity.Recipient()
|
||||||
|
|
||||||
|
return &DerivedKeySet{
|
||||||
|
RoleKey: roleKey,
|
||||||
|
NodeKey: nodeKey,
|
||||||
|
AGEIdentity: ageIdentity,
|
||||||
|
AGERecipient: ageRecipient,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveClusterWideKeys derives keys that are shared across the entire cluster for a role
|
||||||
|
func (kdm *KeyDerivationManager) DeriveClusterWideKeys(role string) (*DerivedKeySet, error) {
|
||||||
|
if kdm.clusterRootKey == nil {
|
||||||
|
return nil, fmt.Errorf("cluster root key not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive role-specific key
|
||||||
|
roleKey, err := kdm.deriveKey(fmt.Sprintf("role-%s", role), 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive role key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For cluster-wide keys, use a deterministic "cluster" identifier
|
||||||
|
clusterNodeKey, err := kdm.deriveKeyFromParent(roleKey, "cluster-shared", 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive cluster node key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate AGE identity from cluster node key
|
||||||
|
ageIdentity, err := kdm.generateAGEIdentityFromKey(clusterNodeKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate AGE identity: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ageRecipient := ageIdentity.Recipient()
|
||||||
|
|
||||||
|
return &DerivedKeySet{
|
||||||
|
RoleKey: roleKey,
|
||||||
|
NodeKey: clusterNodeKey,
|
||||||
|
AGEIdentity: ageIdentity,
|
||||||
|
AGERecipient: ageRecipient,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deriveKey derives a key from the cluster root key using HKDF
|
||||||
|
func (kdm *KeyDerivationManager) deriveKey(info string, length int) ([]byte, error) {
|
||||||
|
hash := sha256.New
|
||||||
|
hkdf := hkdf.New(hash, kdm.clusterRootKey, []byte(kdm.clusterID), []byte(info))
|
||||||
|
|
||||||
|
key := make([]byte, length)
|
||||||
|
if _, err := io.ReadFull(hkdf, key); err != nil {
|
||||||
|
return nil, fmt.Errorf("HKDF key derivation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deriveKeyFromParent derives a key from a parent key using HKDF
|
||||||
|
func (kdm *KeyDerivationManager) deriveKeyFromParent(parentKey []byte, info string, length int) ([]byte, error) {
|
||||||
|
hash := sha256.New
|
||||||
|
hkdf := hkdf.New(hash, parentKey, []byte(kdm.clusterID), []byte(info))
|
||||||
|
|
||||||
|
key := make([]byte, length)
|
||||||
|
if _, err := io.ReadFull(hkdf, key); err != nil {
|
||||||
|
return nil, fmt.Errorf("HKDF key derivation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateAGEIdentityFromKey generates a deterministic AGE identity from a key
|
||||||
|
func (kdm *KeyDerivationManager) generateAGEIdentityFromKey(key []byte) (*age.X25519Identity, error) {
|
||||||
|
if len(key) < 32 {
|
||||||
|
return nil, fmt.Errorf("key must be at least 32 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the first 32 bytes as the private key seed
|
||||||
|
var privKey [32]byte
|
||||||
|
copy(privKey[:], key[:32])
|
||||||
|
|
||||||
|
// Generate a new identity (note: this loses deterministic behavior)
|
||||||
|
// TODO: Implement deterministic key derivation when age API allows
|
||||||
|
identity, err := age.GenerateX25519Identity()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create AGE identity: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return identity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncryptForRole encrypts data for a specific role (all nodes in that role can decrypt)
|
||||||
|
func (kdm *KeyDerivationManager) EncryptForRole(data []byte, role string) ([]byte, error) {
|
||||||
|
// Get cluster-wide keys for the role
|
||||||
|
keySet, err := kdm.DeriveClusterWideKeys(role)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive cluster keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt using AGE
|
||||||
|
var encrypted []byte
|
||||||
|
buf := &writeBuffer{data: &encrypted}
|
||||||
|
armorWriter := armor.NewWriter(buf)
|
||||||
|
|
||||||
|
ageWriter, err := age.Encrypt(armorWriter, keySet.AGERecipient)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create age writer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := ageWriter.Write(data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write encrypted data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ageWriter.Close(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to close age writer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := armorWriter.Close(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to close armor writer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return encrypted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptForRole decrypts data encrypted for a specific role
|
||||||
|
func (kdm *KeyDerivationManager) DecryptForRole(encryptedData []byte, role, agentID string) ([]byte, error) {
|
||||||
|
// Try cluster-wide keys first
|
||||||
|
clusterKeys, err := kdm.DeriveClusterWideKeys(role)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive cluster keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decrypted, err := kdm.decryptWithIdentity(encryptedData, clusterKeys.AGEIdentity); err == nil {
|
||||||
|
return decrypted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cluster-wide decryption fails, try node-specific keys
|
||||||
|
nodeKeys, err := kdm.DeriveRoleKeys(role, agentID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive node keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return kdm.decryptWithIdentity(encryptedData, nodeKeys.AGEIdentity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decryptWithIdentity decrypts data using an AGE identity
|
||||||
|
func (kdm *KeyDerivationManager) decryptWithIdentity(encryptedData []byte, identity *age.X25519Identity) ([]byte, error) {
|
||||||
|
armorReader := armor.NewReader(newReadBuffer(encryptedData))
|
||||||
|
|
||||||
|
ageReader, err := age.Decrypt(armorReader, identity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decrypt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, err := io.ReadAll(ageReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read decrypted data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decrypted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoleRecipients returns AGE recipients for all nodes in a role (for multi-recipient encryption)
|
||||||
|
func (kdm *KeyDerivationManager) GetRoleRecipients(role string, agentIDs []string) ([]*age.X25519Recipient, error) {
|
||||||
|
var recipients []*age.X25519Recipient
|
||||||
|
|
||||||
|
// Add cluster-wide recipient
|
||||||
|
clusterKeys, err := kdm.DeriveClusterWideKeys(role)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive cluster keys: %w", err)
|
||||||
|
}
|
||||||
|
recipients = append(recipients, clusterKeys.AGERecipient)
|
||||||
|
|
||||||
|
// Add node-specific recipients
|
||||||
|
for _, agentID := range agentIDs {
|
||||||
|
nodeKeys, err := kdm.DeriveRoleKeys(role, agentID)
|
||||||
|
if err != nil {
|
||||||
|
continue // Skip this agent on error
|
||||||
|
}
|
||||||
|
recipients = append(recipients, nodeKeys.AGERecipient)
|
||||||
|
}
|
||||||
|
|
||||||
|
return recipients, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKeySetStats returns statistics about derived key sets
|
||||||
|
func (kdm *KeyDerivationManager) GetKeySetStats(role, agentID string) map[string]interface{} {
|
||||||
|
stats := map[string]interface{}{
|
||||||
|
"cluster_id": kdm.clusterID,
|
||||||
|
"role": role,
|
||||||
|
"agent_id": agentID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to derive keys and add fingerprint info
|
||||||
|
if keySet, err := kdm.DeriveRoleKeys(role, agentID); err == nil {
|
||||||
|
stats["node_key_length"] = len(keySet.NodeKey)
|
||||||
|
stats["role_key_length"] = len(keySet.RoleKey)
|
||||||
|
stats["age_recipient"] = keySet.AGERecipient.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper types for AGE encryption/decryption
|
||||||
|
|
||||||
|
type writeBuffer struct {
|
||||||
|
data *[]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *writeBuffer) Write(p []byte) (n int, err error) {
|
||||||
|
*w.data = append(*w.data, p...)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type readBuffer struct {
|
||||||
|
data []byte
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newReadBuffer(data []byte) *readBuffer {
|
||||||
|
return &readBuffer{data: data, pos: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *readBuffer) Read(p []byte) (n int, err error) {
|
||||||
|
if r.pos >= len(r.data) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
n = copy(p, r.data[r.pos:])
|
||||||
|
r.pos += n
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -102,6 +103,11 @@ type ElectionManager struct {
|
|||||||
onAdminChanged func(oldAdmin, newAdmin string)
|
onAdminChanged func(oldAdmin, newAdmin string)
|
||||||
onElectionComplete func(winner string)
|
onElectionComplete func(winner string)
|
||||||
|
|
||||||
|
// Stability window to prevent election churn (Medium-risk fix 2.1)
|
||||||
|
lastElectionTime time.Time
|
||||||
|
electionStabilityWindow time.Duration
|
||||||
|
leaderStabilityWindow time.Duration
|
||||||
|
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,6 +143,10 @@ func NewElectionManager(
|
|||||||
votes: make(map[string]string),
|
votes: make(map[string]string),
|
||||||
electionTrigger: make(chan ElectionTrigger, 10),
|
electionTrigger: make(chan ElectionTrigger, 10),
|
||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
|
|
||||||
|
// Initialize stability windows (as per WHOOSH issue #7)
|
||||||
|
electionStabilityWindow: getElectionStabilityWindow(cfg),
|
||||||
|
leaderStabilityWindow: getLeaderStabilityWindow(cfg),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize heartbeat manager
|
// Initialize heartbeat manager
|
||||||
@@ -167,10 +177,18 @@ func (em *ElectionManager) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start discovery process
|
// Start discovery process
|
||||||
go em.startDiscoveryLoop()
|
log.Printf("🔍 About to start discovery loop goroutine...")
|
||||||
|
go func() {
|
||||||
|
log.Printf("🔍 Discovery loop goroutine started successfully")
|
||||||
|
em.startDiscoveryLoop()
|
||||||
|
}()
|
||||||
|
|
||||||
// Start election coordinator
|
// Start election coordinator
|
||||||
go em.electionCoordinator()
|
log.Printf("🗳️ About to start election coordinator goroutine...")
|
||||||
|
go func() {
|
||||||
|
log.Printf("🗳️ Election coordinator goroutine started successfully")
|
||||||
|
em.electionCoordinator()
|
||||||
|
}()
|
||||||
|
|
||||||
// Start heartbeat if this node is already admin at startup
|
// Start heartbeat if this node is already admin at startup
|
||||||
if em.IsCurrentAdmin() {
|
if em.IsCurrentAdmin() {
|
||||||
@@ -212,8 +230,40 @@ func (em *ElectionManager) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerElection manually triggers an election
|
// TriggerElection manually triggers an election with stability window checks
|
||||||
func (em *ElectionManager) TriggerElection(trigger ElectionTrigger) {
|
func (em *ElectionManager) TriggerElection(trigger ElectionTrigger) {
|
||||||
|
// Check if election already in progress
|
||||||
|
em.mu.RLock()
|
||||||
|
currentState := em.state
|
||||||
|
currentAdmin := em.currentAdmin
|
||||||
|
lastElection := em.lastElectionTime
|
||||||
|
em.mu.RUnlock()
|
||||||
|
|
||||||
|
if currentState != StateIdle {
|
||||||
|
log.Printf("🗳️ Election already in progress (state: %s), ignoring trigger: %s", currentState, trigger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply stability window to prevent election churn (WHOOSH issue #7)
|
||||||
|
now := time.Now()
|
||||||
|
if !lastElection.IsZero() {
|
||||||
|
timeSinceElection := now.Sub(lastElection)
|
||||||
|
|
||||||
|
// If we have a current admin, check leader stability window
|
||||||
|
if currentAdmin != "" && timeSinceElection < em.leaderStabilityWindow {
|
||||||
|
log.Printf("⏳ Leader stability window active (%.1fs remaining), ignoring trigger: %s",
|
||||||
|
(em.leaderStabilityWindow - timeSinceElection).Seconds(), trigger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// General election stability window
|
||||||
|
if timeSinceElection < em.electionStabilityWindow {
|
||||||
|
log.Printf("⏳ Election stability window active (%.1fs remaining), ignoring trigger: %s",
|
||||||
|
(em.electionStabilityWindow - timeSinceElection).Seconds(), trigger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case em.electionTrigger <- trigger:
|
case em.electionTrigger <- trigger:
|
||||||
log.Printf("🗳️ Election triggered: %s", trigger)
|
log.Printf("🗳️ Election triggered: %s", trigger)
|
||||||
@@ -262,13 +312,27 @@ func (em *ElectionManager) GetHeartbeatStatus() map[string]interface{} {
|
|||||||
|
|
||||||
// startDiscoveryLoop starts the admin discovery loop
|
// startDiscoveryLoop starts the admin discovery loop
|
||||||
func (em *ElectionManager) startDiscoveryLoop() {
|
func (em *ElectionManager) startDiscoveryLoop() {
|
||||||
log.Printf("🔍 Starting admin discovery loop")
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("🔍 PANIC in discovery loop: %v", r)
|
||||||
|
}
|
||||||
|
log.Printf("🔍 Discovery loop goroutine exiting")
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Printf("🔍 ENHANCED-DEBUG: Starting admin discovery loop with timeout: %v", em.config.Security.ElectionConfig.DiscoveryTimeout)
|
||||||
|
log.Printf("🔍 ENHANCED-DEBUG: Context status: err=%v", em.ctx.Err())
|
||||||
|
log.Printf("🔍 ENHANCED-DEBUG: Node ID: %s, Can be admin: %v", em.nodeID, em.canBeAdmin())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
log.Printf("🔍 Discovery loop iteration starting, waiting for timeout...")
|
||||||
|
log.Printf("🔍 Context status before select: err=%v", em.ctx.Err())
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-em.ctx.Done():
|
case <-em.ctx.Done():
|
||||||
|
log.Printf("🔍 Discovery loop cancelled via context: %v", em.ctx.Err())
|
||||||
return
|
return
|
||||||
case <-time.After(em.config.Security.ElectionConfig.DiscoveryTimeout):
|
case <-time.After(em.config.Security.ElectionConfig.DiscoveryTimeout):
|
||||||
|
log.Printf("🔍 Discovery timeout triggered! Calling performAdminDiscovery()...")
|
||||||
em.performAdminDiscovery()
|
em.performAdminDiscovery()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -281,8 +345,12 @@ func (em *ElectionManager) performAdminDiscovery() {
|
|||||||
lastHeartbeat := em.lastHeartbeat
|
lastHeartbeat := em.lastHeartbeat
|
||||||
em.mu.Unlock()
|
em.mu.Unlock()
|
||||||
|
|
||||||
|
log.Printf("🔍 Discovery check: state=%s, lastHeartbeat=%v, canAdmin=%v",
|
||||||
|
currentState, lastHeartbeat, em.canBeAdmin())
|
||||||
|
|
||||||
// Only discover if we're idle or the heartbeat is stale
|
// Only discover if we're idle or the heartbeat is stale
|
||||||
if currentState != StateIdle {
|
if currentState != StateIdle {
|
||||||
|
log.Printf("🔍 Skipping discovery - not in idle state (current: %s)", currentState)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,13 +362,66 @@ func (em *ElectionManager) performAdminDiscovery() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we haven't heard from an admin recently, try to discover one
|
// If we haven't heard from an admin recently, try to discover one
|
||||||
if lastHeartbeat.IsZero() || time.Since(lastHeartbeat) > em.config.Security.ElectionConfig.DiscoveryTimeout/2 {
|
timeSinceHeartbeat := time.Since(lastHeartbeat)
|
||||||
|
discoveryThreshold := em.config.Security.ElectionConfig.DiscoveryTimeout / 2
|
||||||
|
|
||||||
|
log.Printf("🔍 Heartbeat check: isZero=%v, timeSince=%v, threshold=%v",
|
||||||
|
lastHeartbeat.IsZero(), timeSinceHeartbeat, discoveryThreshold)
|
||||||
|
|
||||||
|
if lastHeartbeat.IsZero() || timeSinceHeartbeat > discoveryThreshold {
|
||||||
|
log.Printf("🔍 Sending discovery request...")
|
||||||
em.sendDiscoveryRequest()
|
em.sendDiscoveryRequest()
|
||||||
|
|
||||||
|
// 🚨 CRITICAL FIX: If we have no admin and can become admin, trigger election after discovery timeout
|
||||||
|
em.mu.Lock()
|
||||||
|
currentAdmin := em.currentAdmin
|
||||||
|
em.mu.Unlock()
|
||||||
|
|
||||||
|
if currentAdmin == "" && em.canBeAdmin() {
|
||||||
|
log.Printf("🗳️ No admin discovered and we can be admin - scheduling election check")
|
||||||
|
go func() {
|
||||||
|
// Add randomization to prevent simultaneous elections from all nodes
|
||||||
|
baseDelay := em.config.Security.ElectionConfig.DiscoveryTimeout * 2
|
||||||
|
randomDelay := time.Duration(rand.Intn(int(em.config.Security.ElectionConfig.DiscoveryTimeout)))
|
||||||
|
totalDelay := baseDelay + randomDelay
|
||||||
|
|
||||||
|
log.Printf("🗳️ Waiting %v before checking if election needed", totalDelay)
|
||||||
|
time.Sleep(totalDelay)
|
||||||
|
|
||||||
|
// Check again if still no admin and no one else started election
|
||||||
|
em.mu.RLock()
|
||||||
|
stillNoAdmin := em.currentAdmin == ""
|
||||||
|
stillIdle := em.state == StateIdle
|
||||||
|
em.mu.RUnlock()
|
||||||
|
|
||||||
|
if stillNoAdmin && stillIdle && em.canBeAdmin() {
|
||||||
|
log.Printf("🗳️ Election grace period expired with no admin - triggering election")
|
||||||
|
em.TriggerElection(TriggerDiscoveryFailure)
|
||||||
|
} else {
|
||||||
|
log.Printf("🗳️ Election check: admin=%s, state=%s - skipping election", em.currentAdmin, em.state)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("🔍 Discovery threshold not met - waiting")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendDiscoveryRequest broadcasts admin discovery request
|
// sendDiscoveryRequest broadcasts admin discovery request
|
||||||
func (em *ElectionManager) sendDiscoveryRequest() {
|
func (em *ElectionManager) sendDiscoveryRequest() {
|
||||||
|
em.mu.RLock()
|
||||||
|
currentAdmin := em.currentAdmin
|
||||||
|
em.mu.RUnlock()
|
||||||
|
|
||||||
|
// WHOAMI debug message
|
||||||
|
if currentAdmin == "" {
|
||||||
|
log.Printf("🤖 WHOAMI: I'm %s and I have no leader", em.nodeID)
|
||||||
|
} else {
|
||||||
|
log.Printf("🤖 WHOAMI: I'm %s and my leader is %s", em.nodeID, currentAdmin)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("📡 Sending admin discovery request from node %s", em.nodeID)
|
||||||
|
|
||||||
discoveryMsg := ElectionMessage{
|
discoveryMsg := ElectionMessage{
|
||||||
Type: "admin_discovery_request",
|
Type: "admin_discovery_request",
|
||||||
NodeID: em.nodeID,
|
NodeID: em.nodeID,
|
||||||
@@ -309,6 +430,8 @@ func (em *ElectionManager) sendDiscoveryRequest() {
|
|||||||
|
|
||||||
if err := em.publishElectionMessage(discoveryMsg); err != nil {
|
if err := em.publishElectionMessage(discoveryMsg); err != nil {
|
||||||
log.Printf("❌ Failed to send admin discovery request: %v", err)
|
log.Printf("❌ Failed to send admin discovery request: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Printf("✅ Admin discovery request sent successfully")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,6 +474,7 @@ func (em *ElectionManager) beginElection(trigger ElectionTrigger) {
|
|||||||
em.mu.Lock()
|
em.mu.Lock()
|
||||||
em.state = StateElecting
|
em.state = StateElecting
|
||||||
em.currentTerm++
|
em.currentTerm++
|
||||||
|
em.lastElectionTime = time.Now() // Record election timestamp for stability window
|
||||||
term := em.currentTerm
|
term := em.currentTerm
|
||||||
em.candidates = make(map[string]*AdminCandidate)
|
em.candidates = make(map[string]*AdminCandidate)
|
||||||
em.votes = make(map[string]string)
|
em.votes = make(map[string]string)
|
||||||
@@ -652,6 +776,9 @@ func (em *ElectionManager) handleAdminDiscoveryRequest(msg ElectionMessage) {
|
|||||||
state := em.state
|
state := em.state
|
||||||
em.mu.RUnlock()
|
em.mu.RUnlock()
|
||||||
|
|
||||||
|
log.Printf("📩 Received admin discovery request from %s (my leader: %s, state: %s)",
|
||||||
|
msg.NodeID, currentAdmin, state)
|
||||||
|
|
||||||
// Only respond if we know who the current admin is and we're idle
|
// Only respond if we know who the current admin is and we're idle
|
||||||
if currentAdmin != "" && state == StateIdle {
|
if currentAdmin != "" && state == StateIdle {
|
||||||
responseMsg := ElectionMessage{
|
responseMsg := ElectionMessage{
|
||||||
@@ -663,23 +790,43 @@ func (em *ElectionManager) handleAdminDiscoveryRequest(msg ElectionMessage) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("📤 Responding to discovery with admin: %s", currentAdmin)
|
||||||
if err := em.publishElectionMessage(responseMsg); err != nil {
|
if err := em.publishElectionMessage(responseMsg); err != nil {
|
||||||
log.Printf("❌ Failed to send admin discovery response: %v", err)
|
log.Printf("❌ Failed to send admin discovery response: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Printf("✅ Admin discovery response sent successfully")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("🔇 Not responding to discovery (admin=%s, state=%s)", currentAdmin, state)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAdminDiscoveryResponse processes admin discovery responses
|
// handleAdminDiscoveryResponse processes admin discovery responses
|
||||||
func (em *ElectionManager) handleAdminDiscoveryResponse(msg ElectionMessage) {
|
func (em *ElectionManager) handleAdminDiscoveryResponse(msg ElectionMessage) {
|
||||||
|
log.Printf("📥 Received admin discovery response from %s", msg.NodeID)
|
||||||
|
|
||||||
if data, ok := msg.Data.(map[string]interface{}); ok {
|
if data, ok := msg.Data.(map[string]interface{}); ok {
|
||||||
if admin, ok := data["current_admin"].(string); ok && admin != "" {
|
if admin, ok := data["current_admin"].(string); ok && admin != "" {
|
||||||
em.mu.Lock()
|
em.mu.Lock()
|
||||||
|
oldAdmin := em.currentAdmin
|
||||||
if em.currentAdmin == "" {
|
if em.currentAdmin == "" {
|
||||||
log.Printf("📡 Discovered admin: %s", admin)
|
log.Printf("📡 Discovered admin: %s (reported by %s)", admin, msg.NodeID)
|
||||||
em.currentAdmin = admin
|
em.currentAdmin = admin
|
||||||
|
em.lastHeartbeat = time.Now() // Set initial heartbeat
|
||||||
|
} else if em.currentAdmin != admin {
|
||||||
|
log.Printf("⚠️ Admin conflict: I know %s, but %s reports %s", em.currentAdmin, msg.NodeID, admin)
|
||||||
|
} else {
|
||||||
|
log.Printf("📡 Admin confirmed: %s (reported by %s)", admin, msg.NodeID)
|
||||||
}
|
}
|
||||||
em.mu.Unlock()
|
em.mu.Unlock()
|
||||||
|
|
||||||
|
// Trigger callback if admin changed
|
||||||
|
if oldAdmin != admin && em.onAdminChanged != nil {
|
||||||
|
em.onAdminChanged(oldAdmin, admin)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("❌ Invalid admin discovery response from %s", msg.NodeID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1005,3 +1152,43 @@ func (hm *HeartbeatManager) GetHeartbeatStatus() map[string]interface{} {
|
|||||||
|
|
||||||
return status
|
return status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper functions for stability window configuration
|
||||||
|
|
||||||
|
// getElectionStabilityWindow gets the minimum time between elections
|
||||||
|
func getElectionStabilityWindow(cfg *config.Config) time.Duration {
|
||||||
|
// Try to get from environment or use default
|
||||||
|
if stability := os.Getenv("CHORUS_ELECTION_MIN_TERM"); stability != "" {
|
||||||
|
if duration, err := time.ParseDuration(stability); err == nil {
|
||||||
|
return duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get from config structure if it exists
|
||||||
|
if cfg.Security.ElectionConfig.DiscoveryTimeout > 0 {
|
||||||
|
// Use double the discovery timeout as default stability window
|
||||||
|
return cfg.Security.ElectionConfig.DiscoveryTimeout * 2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default fallback
|
||||||
|
return 30 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLeaderStabilityWindow gets the minimum time before challenging a healthy leader
|
||||||
|
func getLeaderStabilityWindow(cfg *config.Config) time.Duration {
|
||||||
|
// Try to get from environment or use default
|
||||||
|
if stability := os.Getenv("CHORUS_LEADER_MIN_TERM"); stability != "" {
|
||||||
|
if duration, err := time.ParseDuration(stability); err == nil {
|
||||||
|
return duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get from config structure if it exists
|
||||||
|
if cfg.Security.ElectionConfig.HeartbeatTimeout > 0 {
|
||||||
|
// Use 3x heartbeat timeout as default leader stability
|
||||||
|
return cfg.Security.ElectionConfig.HeartbeatTimeout * 3
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default fallback
|
||||||
|
return 45 * time.Second
|
||||||
|
}
|
||||||
|
|||||||
1020
pkg/execution/docker.go
Normal file
1020
pkg/execution/docker.go
Normal file
File diff suppressed because it is too large
Load Diff
482
pkg/execution/docker_test.go
Normal file
482
pkg/execution/docker_test.go
Normal file
@@ -0,0 +1,482 @@
|
|||||||
|
package execution
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDockerSandbox(t *testing.T) {
|
||||||
|
sandbox := NewDockerSandbox()
|
||||||
|
|
||||||
|
assert.NotNil(t, sandbox)
|
||||||
|
assert.NotNil(t, sandbox.environment)
|
||||||
|
assert.Empty(t, sandbox.containerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_Initialize(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := NewDockerSandbox()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a minimal configuration
|
||||||
|
config := &SandboxConfig{
|
||||||
|
Type: "docker",
|
||||||
|
Image: "alpine:latest",
|
||||||
|
Architecture: "amd64",
|
||||||
|
Resources: ResourceLimits{
|
||||||
|
MemoryLimit: 512 * 1024 * 1024, // 512MB
|
||||||
|
CPULimit: 1.0,
|
||||||
|
ProcessLimit: 50,
|
||||||
|
FileLimit: 1024,
|
||||||
|
},
|
||||||
|
Security: SecurityPolicy{
|
||||||
|
ReadOnlyRoot: false,
|
||||||
|
NoNewPrivileges: true,
|
||||||
|
AllowNetworking: false,
|
||||||
|
IsolateNetwork: true,
|
||||||
|
IsolateProcess: true,
|
||||||
|
DropCapabilities: []string{"ALL"},
|
||||||
|
},
|
||||||
|
Environment: map[string]string{
|
||||||
|
"TEST_VAR": "test_value",
|
||||||
|
},
|
||||||
|
WorkingDir: "/workspace",
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sandbox.Initialize(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Docker not available or image pull failed: %v", err)
|
||||||
|
}
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
// Verify sandbox is initialized
|
||||||
|
assert.NotEmpty(t, sandbox.containerID)
|
||||||
|
assert.Equal(t, config, sandbox.config)
|
||||||
|
assert.Equal(t, StatusRunning, sandbox.info.Status)
|
||||||
|
assert.Equal(t, "docker", sandbox.info.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_ExecuteCommand(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := setupTestSandbox(t)
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cmd *Command
|
||||||
|
expectedExit int
|
||||||
|
expectedOutput string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple echo command",
|
||||||
|
cmd: &Command{
|
||||||
|
Executable: "echo",
|
||||||
|
Args: []string{"hello world"},
|
||||||
|
},
|
||||||
|
expectedExit: 0,
|
||||||
|
expectedOutput: "hello world\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with environment",
|
||||||
|
cmd: &Command{
|
||||||
|
Executable: "sh",
|
||||||
|
Args: []string{"-c", "echo $TEST_VAR"},
|
||||||
|
Environment: map[string]string{"TEST_VAR": "custom_value"},
|
||||||
|
},
|
||||||
|
expectedExit: 0,
|
||||||
|
expectedOutput: "custom_value\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failing command",
|
||||||
|
cmd: &Command{
|
||||||
|
Executable: "sh",
|
||||||
|
Args: []string{"-c", "exit 1"},
|
||||||
|
},
|
||||||
|
expectedExit: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with timeout",
|
||||||
|
cmd: &Command{
|
||||||
|
Executable: "sleep",
|
||||||
|
Args: []string{"2"},
|
||||||
|
Timeout: 1 * time.Second,
|
||||||
|
},
|
||||||
|
shouldError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := sandbox.ExecuteCommand(ctx, tt.cmd)
|
||||||
|
|
||||||
|
if tt.shouldError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedExit, result.ExitCode)
|
||||||
|
assert.Equal(t, tt.expectedExit == 0, result.Success)
|
||||||
|
|
||||||
|
if tt.expectedOutput != "" {
|
||||||
|
assert.Equal(t, tt.expectedOutput, result.Stdout)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotZero(t, result.Duration)
|
||||||
|
assert.False(t, result.StartTime.IsZero())
|
||||||
|
assert.False(t, result.EndTime.IsZero())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_FileOperations(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := setupTestSandbox(t)
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test WriteFile
|
||||||
|
testContent := []byte("Hello, Docker sandbox!")
|
||||||
|
testPath := "/tmp/test_file.txt"
|
||||||
|
|
||||||
|
err := sandbox.WriteFile(ctx, testPath, testContent, 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test ReadFile
|
||||||
|
readContent, err := sandbox.ReadFile(ctx, testPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testContent, readContent)
|
||||||
|
|
||||||
|
// Test ListFiles
|
||||||
|
files, err := sandbox.ListFiles(ctx, "/tmp")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, files)
|
||||||
|
|
||||||
|
// Find our test file
|
||||||
|
var testFile *FileInfo
|
||||||
|
for _, file := range files {
|
||||||
|
if file.Name == "test_file.txt" {
|
||||||
|
testFile = &file
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, testFile)
|
||||||
|
assert.Equal(t, "test_file.txt", testFile.Name)
|
||||||
|
assert.Equal(t, int64(len(testContent)), testFile.Size)
|
||||||
|
assert.False(t, testFile.IsDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_CopyFiles(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := setupTestSandbox(t)
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a temporary file on host
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
hostFile := filepath.Join(tempDir, "host_file.txt")
|
||||||
|
hostContent := []byte("Content from host")
|
||||||
|
|
||||||
|
err := os.WriteFile(hostFile, hostContent, 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Copy from host to container
|
||||||
|
containerPath := "container:/tmp/copied_file.txt"
|
||||||
|
err = sandbox.CopyFiles(ctx, hostFile, containerPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify file exists in container
|
||||||
|
readContent, err := sandbox.ReadFile(ctx, "/tmp/copied_file.txt")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, hostContent, readContent)
|
||||||
|
|
||||||
|
// Copy from container back to host
|
||||||
|
hostDestFile := filepath.Join(tempDir, "copied_back.txt")
|
||||||
|
err = sandbox.CopyFiles(ctx, "container:/tmp/copied_file.txt", hostDestFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify file exists on host
|
||||||
|
backContent, err := os.ReadFile(hostDestFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, hostContent, backContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_Environment(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := setupTestSandbox(t)
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
// Test getting initial environment
|
||||||
|
env := sandbox.GetEnvironment()
|
||||||
|
assert.Equal(t, "test_value", env["TEST_VAR"])
|
||||||
|
|
||||||
|
// Test setting additional environment
|
||||||
|
newEnv := map[string]string{
|
||||||
|
"NEW_VAR": "new_value",
|
||||||
|
"PATH": "/custom/path",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sandbox.SetEnvironment(newEnv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify environment is updated
|
||||||
|
env = sandbox.GetEnvironment()
|
||||||
|
assert.Equal(t, "new_value", env["NEW_VAR"])
|
||||||
|
assert.Equal(t, "/custom/path", env["PATH"])
|
||||||
|
assert.Equal(t, "test_value", env["TEST_VAR"]) // Original should still be there
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_WorkingDirectory(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := setupTestSandbox(t)
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
// Test getting initial working directory
|
||||||
|
workDir := sandbox.GetWorkingDirectory()
|
||||||
|
assert.Equal(t, "/workspace", workDir)
|
||||||
|
|
||||||
|
// Test setting working directory
|
||||||
|
newWorkDir := "/tmp"
|
||||||
|
err := sandbox.SetWorkingDirectory(newWorkDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify working directory is updated
|
||||||
|
workDir = sandbox.GetWorkingDirectory()
|
||||||
|
assert.Equal(t, newWorkDir, workDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_ResourceUsage(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := setupTestSandbox(t)
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get resource usage
|
||||||
|
usage, err := sandbox.GetResourceUsage(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify usage structure
|
||||||
|
assert.NotNil(t, usage)
|
||||||
|
assert.False(t, usage.Timestamp.IsZero())
|
||||||
|
assert.GreaterOrEqual(t, usage.CPUUsage, 0.0)
|
||||||
|
assert.GreaterOrEqual(t, usage.MemoryUsage, int64(0))
|
||||||
|
assert.GreaterOrEqual(t, usage.MemoryPercent, 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_GetInfo(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := setupTestSandbox(t)
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
info := sandbox.GetInfo()
|
||||||
|
|
||||||
|
assert.NotEmpty(t, info.ID)
|
||||||
|
assert.Contains(t, info.Name, "chorus-sandbox")
|
||||||
|
assert.Equal(t, "docker", info.Type)
|
||||||
|
assert.Equal(t, StatusRunning, info.Status)
|
||||||
|
assert.Equal(t, "docker", info.Runtime)
|
||||||
|
assert.Equal(t, "alpine:latest", info.Image)
|
||||||
|
assert.False(t, info.CreatedAt.IsZero())
|
||||||
|
assert.False(t, info.StartedAt.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_Cleanup(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := setupTestSandbox(t)
|
||||||
|
|
||||||
|
// Verify sandbox is running
|
||||||
|
assert.Equal(t, StatusRunning, sandbox.info.Status)
|
||||||
|
assert.NotEmpty(t, sandbox.containerID)
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
err := sandbox.Cleanup()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify sandbox is destroyed
|
||||||
|
assert.Equal(t, StatusDestroyed, sandbox.info.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDockerSandbox_SecurityPolicies(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := NewDockerSandbox()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create configuration with strict security policies
|
||||||
|
config := &SandboxConfig{
|
||||||
|
Type: "docker",
|
||||||
|
Image: "alpine:latest",
|
||||||
|
Architecture: "amd64",
|
||||||
|
Resources: ResourceLimits{
|
||||||
|
MemoryLimit: 256 * 1024 * 1024, // 256MB
|
||||||
|
CPULimit: 0.5,
|
||||||
|
ProcessLimit: 10,
|
||||||
|
FileLimit: 256,
|
||||||
|
},
|
||||||
|
Security: SecurityPolicy{
|
||||||
|
ReadOnlyRoot: true,
|
||||||
|
NoNewPrivileges: true,
|
||||||
|
AllowNetworking: false,
|
||||||
|
IsolateNetwork: true,
|
||||||
|
IsolateProcess: true,
|
||||||
|
DropCapabilities: []string{"ALL"},
|
||||||
|
RunAsUser: "1000",
|
||||||
|
RunAsGroup: "1000",
|
||||||
|
TmpfsPaths: []string{"/tmp", "/var/tmp"},
|
||||||
|
MaskedPaths: []string{"/proc/kcore", "/proc/keys"},
|
||||||
|
ReadOnlyPaths: []string{"/etc"},
|
||||||
|
},
|
||||||
|
WorkingDir: "/workspace",
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sandbox.Initialize(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Docker not available or security policies not supported: %v", err)
|
||||||
|
}
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
// Test that we can't write to read-only filesystem
|
||||||
|
result, err := sandbox.ExecuteCommand(ctx, &Command{
|
||||||
|
Executable: "touch",
|
||||||
|
Args: []string{"/test_readonly"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, 0, result.ExitCode) // Should fail due to read-only root
|
||||||
|
|
||||||
|
// Test that tmpfs is writable
|
||||||
|
result, err = sandbox.ExecuteCommand(ctx, &Command{
|
||||||
|
Executable: "touch",
|
||||||
|
Args: []string{"/tmp/test_tmpfs"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, result.ExitCode) // Should succeed on tmpfs
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupTestSandbox creates a basic Docker sandbox for testing
|
||||||
|
func setupTestSandbox(t *testing.T) *DockerSandbox {
|
||||||
|
sandbox := NewDockerSandbox()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
config := &SandboxConfig{
|
||||||
|
Type: "docker",
|
||||||
|
Image: "alpine:latest",
|
||||||
|
Architecture: "amd64",
|
||||||
|
Resources: ResourceLimits{
|
||||||
|
MemoryLimit: 512 * 1024 * 1024, // 512MB
|
||||||
|
CPULimit: 1.0,
|
||||||
|
ProcessLimit: 50,
|
||||||
|
FileLimit: 1024,
|
||||||
|
},
|
||||||
|
Security: SecurityPolicy{
|
||||||
|
ReadOnlyRoot: false,
|
||||||
|
NoNewPrivileges: true,
|
||||||
|
AllowNetworking: true, // Allow networking for easier testing
|
||||||
|
IsolateNetwork: false,
|
||||||
|
IsolateProcess: true,
|
||||||
|
DropCapabilities: []string{"NET_ADMIN", "SYS_ADMIN"},
|
||||||
|
},
|
||||||
|
Environment: map[string]string{
|
||||||
|
"TEST_VAR": "test_value",
|
||||||
|
},
|
||||||
|
WorkingDir: "/workspace",
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sandbox.Initialize(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Docker not available: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sandbox
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkDockerSandbox_ExecuteCommand(b *testing.B) {
|
||||||
|
if testing.Short() {
|
||||||
|
b.Skip("Skipping Docker benchmark in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sandbox := &DockerSandbox{}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Setup minimal config for benchmarking
|
||||||
|
config := &SandboxConfig{
|
||||||
|
Type: "docker",
|
||||||
|
Image: "alpine:latest",
|
||||||
|
Architecture: "amd64",
|
||||||
|
Resources: ResourceLimits{
|
||||||
|
MemoryLimit: 256 * 1024 * 1024,
|
||||||
|
CPULimit: 1.0,
|
||||||
|
ProcessLimit: 50,
|
||||||
|
},
|
||||||
|
Security: SecurityPolicy{
|
||||||
|
NoNewPrivileges: true,
|
||||||
|
AllowNetworking: true,
|
||||||
|
},
|
||||||
|
WorkingDir: "/workspace",
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sandbox.Initialize(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
b.Skipf("Docker not available: %v", err)
|
||||||
|
}
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
cmd := &Command{
|
||||||
|
Executable: "echo",
|
||||||
|
Args: []string{"benchmark test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := sandbox.ExecuteCommand(ctx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Command execution failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
494
pkg/execution/engine.go
Normal file
494
pkg/execution/engine.go
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
package execution
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"chorus/pkg/ai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskExecutionEngine provides AI-powered task execution with isolated sandboxes
|
||||||
|
type TaskExecutionEngine interface {
|
||||||
|
ExecuteTask(ctx context.Context, request *TaskExecutionRequest) (*TaskExecutionResult, error)
|
||||||
|
Initialize(ctx context.Context, config *EngineConfig) error
|
||||||
|
Shutdown() error
|
||||||
|
GetMetrics() *EngineMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskExecutionRequest represents a task to be executed
|
||||||
|
type TaskExecutionRequest struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Context map[string]interface{} `json:"context,omitempty"`
|
||||||
|
Requirements *TaskRequirements `json:"requirements,omitempty"`
|
||||||
|
Timeout time.Duration `json:"timeout,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskRequirements specifies execution environment needs
|
||||||
|
type TaskRequirements struct {
|
||||||
|
AIModel string `json:"ai_model,omitempty"`
|
||||||
|
SandboxType string `json:"sandbox_type,omitempty"`
|
||||||
|
RequiredTools []string `json:"required_tools,omitempty"`
|
||||||
|
EnvironmentVars map[string]string `json:"environment_vars,omitempty"`
|
||||||
|
ResourceLimits *ResourceLimits `json:"resource_limits,omitempty"`
|
||||||
|
SecurityPolicy *SecurityPolicy `json:"security_policy,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskExecutionResult contains the results of task execution
|
||||||
|
type TaskExecutionResult struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Output string `json:"output"`
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"`
|
||||||
|
Artifacts []TaskArtifact `json:"artifacts,omitempty"`
|
||||||
|
Metrics *ExecutionMetrics `json:"metrics"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskArtifact represents a file or data produced during execution
|
||||||
|
type TaskArtifact struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Path string `json:"path,omitempty"`
|
||||||
|
Content []byte `json:"content,omitempty"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecutionMetrics tracks resource usage and performance
|
||||||
|
type ExecutionMetrics struct {
|
||||||
|
StartTime time.Time `json:"start_time"`
|
||||||
|
EndTime time.Time `json:"end_time"`
|
||||||
|
Duration time.Duration `json:"duration"`
|
||||||
|
AIProviderTime time.Duration `json:"ai_provider_time"`
|
||||||
|
SandboxTime time.Duration `json:"sandbox_time"`
|
||||||
|
ResourceUsage *ResourceUsage `json:"resource_usage,omitempty"`
|
||||||
|
CommandsExecuted int `json:"commands_executed"`
|
||||||
|
FilesGenerated int `json:"files_generated"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EngineConfig configures the task execution engine
|
||||||
|
type EngineConfig struct {
|
||||||
|
AIProviderFactory *ai.ProviderFactory `json:"-"`
|
||||||
|
SandboxDefaults *SandboxConfig `json:"sandbox_defaults"`
|
||||||
|
DefaultTimeout time.Duration `json:"default_timeout"`
|
||||||
|
MaxConcurrentTasks int `json:"max_concurrent_tasks"`
|
||||||
|
EnableMetrics bool `json:"enable_metrics"`
|
||||||
|
LogLevel string `json:"log_level"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EngineMetrics tracks overall engine performance
|
||||||
|
type EngineMetrics struct {
|
||||||
|
TasksExecuted int64 `json:"tasks_executed"`
|
||||||
|
TasksSuccessful int64 `json:"tasks_successful"`
|
||||||
|
TasksFailed int64 `json:"tasks_failed"`
|
||||||
|
AverageTime time.Duration `json:"average_time"`
|
||||||
|
TotalExecutionTime time.Duration `json:"total_execution_time"`
|
||||||
|
ActiveTasks int `json:"active_tasks"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultTaskExecutionEngine implements the TaskExecutionEngine interface
|
||||||
|
type DefaultTaskExecutionEngine struct {
|
||||||
|
config *EngineConfig
|
||||||
|
aiFactory *ai.ProviderFactory
|
||||||
|
metrics *EngineMetrics
|
||||||
|
activeTasks map[string]context.CancelFunc
|
||||||
|
logger *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTaskExecutionEngine creates a new task execution engine
|
||||||
|
func NewTaskExecutionEngine() *DefaultTaskExecutionEngine {
|
||||||
|
return &DefaultTaskExecutionEngine{
|
||||||
|
metrics: &EngineMetrics{},
|
||||||
|
activeTasks: make(map[string]context.CancelFunc),
|
||||||
|
logger: log.Default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize configures and prepares the execution engine
|
||||||
|
func (e *DefaultTaskExecutionEngine) Initialize(ctx context.Context, config *EngineConfig) error {
|
||||||
|
if config == nil {
|
||||||
|
return fmt.Errorf("engine config cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.AIProviderFactory == nil {
|
||||||
|
return fmt.Errorf("AI provider factory is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
e.config = config
|
||||||
|
e.aiFactory = config.AIProviderFactory
|
||||||
|
|
||||||
|
// Set default values
|
||||||
|
if e.config.DefaultTimeout == 0 {
|
||||||
|
e.config.DefaultTimeout = 5 * time.Minute
|
||||||
|
}
|
||||||
|
if e.config.MaxConcurrentTasks == 0 {
|
||||||
|
e.config.MaxConcurrentTasks = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
e.logger.Printf("TaskExecutionEngine initialized with %d max concurrent tasks", e.config.MaxConcurrentTasks)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteTask executes a task using AI providers and isolated sandboxes
|
||||||
|
func (e *DefaultTaskExecutionEngine) ExecuteTask(ctx context.Context, request *TaskExecutionRequest) (*TaskExecutionResult, error) {
|
||||||
|
if e.config == nil {
|
||||||
|
return nil, fmt.Errorf("engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Create task context with timeout
|
||||||
|
timeout := request.Timeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = e.config.DefaultTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
taskCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Track active task
|
||||||
|
e.activeTasks[request.ID] = cancel
|
||||||
|
defer delete(e.activeTasks, request.ID)
|
||||||
|
|
||||||
|
e.metrics.ActiveTasks++
|
||||||
|
defer func() { e.metrics.ActiveTasks-- }()
|
||||||
|
|
||||||
|
result := &TaskExecutionResult{
|
||||||
|
TaskID: request.ID,
|
||||||
|
Metrics: &ExecutionMetrics{StartTime: startTime},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the task
|
||||||
|
err := e.executeTaskInternal(taskCtx, request, result)
|
||||||
|
|
||||||
|
// Update metrics
|
||||||
|
result.Metrics.EndTime = time.Now()
|
||||||
|
result.Metrics.Duration = result.Metrics.EndTime.Sub(result.Metrics.StartTime)
|
||||||
|
|
||||||
|
e.metrics.TasksExecuted++
|
||||||
|
e.metrics.TotalExecutionTime += result.Metrics.Duration
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
result.Success = false
|
||||||
|
result.ErrorMessage = err.Error()
|
||||||
|
e.metrics.TasksFailed++
|
||||||
|
e.logger.Printf("Task %s failed: %v", request.ID, err)
|
||||||
|
} else {
|
||||||
|
result.Success = true
|
||||||
|
e.metrics.TasksSuccessful++
|
||||||
|
e.logger.Printf("Task %s completed successfully in %v", request.ID, result.Metrics.Duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.metrics.AverageTime = e.metrics.TotalExecutionTime / time.Duration(e.metrics.TasksExecuted)
|
||||||
|
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeTaskInternal performs the actual task execution
|
||||||
|
func (e *DefaultTaskExecutionEngine) executeTaskInternal(ctx context.Context, request *TaskExecutionRequest, result *TaskExecutionResult) error {
|
||||||
|
// Step 1: Determine AI model and get provider
|
||||||
|
aiStartTime := time.Now()
|
||||||
|
|
||||||
|
role := e.determineRoleFromTask(request)
|
||||||
|
provider, providerConfig, err := e.aiFactory.GetProviderForRole(role)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get AI provider for role %s: %w", role, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Create AI request
|
||||||
|
aiRequest := &ai.TaskRequest{
|
||||||
|
TaskID: request.ID,
|
||||||
|
TaskTitle: request.Type,
|
||||||
|
TaskDescription: request.Description,
|
||||||
|
Context: request.Context,
|
||||||
|
ModelName: providerConfig.DefaultModel,
|
||||||
|
AgentRole: role,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Get AI response
|
||||||
|
aiResponse, err := provider.ExecuteTask(ctx, aiRequest)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("AI provider execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Metrics.AIProviderTime = time.Since(aiStartTime)
|
||||||
|
|
||||||
|
// Step 4: Parse AI response for executable commands
|
||||||
|
commands, artifacts, err := e.parseAIResponse(aiResponse)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse AI response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Execute commands in sandbox if needed
|
||||||
|
if len(commands) > 0 {
|
||||||
|
sandboxStartTime := time.Now()
|
||||||
|
|
||||||
|
sandboxResult, err := e.executeSandboxCommands(ctx, request, commands)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sandbox execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Metrics.SandboxTime = time.Since(sandboxStartTime)
|
||||||
|
result.Metrics.CommandsExecuted = len(commands)
|
||||||
|
result.Metrics.ResourceUsage = sandboxResult.ResourceUsage
|
||||||
|
|
||||||
|
// Merge sandbox artifacts
|
||||||
|
artifacts = append(artifacts, sandboxResult.Artifacts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 6: Process results and artifacts
|
||||||
|
result.Output = e.formatOutput(aiResponse, artifacts)
|
||||||
|
result.Artifacts = artifacts
|
||||||
|
result.Metrics.FilesGenerated = len(artifacts)
|
||||||
|
|
||||||
|
// Add metadata
|
||||||
|
result.Metadata = map[string]interface{}{
|
||||||
|
"ai_provider": providerConfig.Type,
|
||||||
|
"ai_model": providerConfig.DefaultModel,
|
||||||
|
"role": role,
|
||||||
|
"commands": len(commands),
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// determineRoleFromTask analyzes the task to determine appropriate AI role
|
||||||
|
func (e *DefaultTaskExecutionEngine) determineRoleFromTask(request *TaskExecutionRequest) string {
|
||||||
|
taskType := strings.ToLower(request.Type)
|
||||||
|
description := strings.ToLower(request.Description)
|
||||||
|
|
||||||
|
// Determine role based on task type and description keywords
|
||||||
|
if strings.Contains(taskType, "code") || strings.Contains(description, "program") ||
|
||||||
|
strings.Contains(description, "script") || strings.Contains(description, "function") {
|
||||||
|
return "developer"
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(taskType, "analysis") || strings.Contains(description, "analyze") ||
|
||||||
|
strings.Contains(description, "review") {
|
||||||
|
return "analyst"
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(taskType, "test") || strings.Contains(description, "test") {
|
||||||
|
return "tester"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to general purpose
|
||||||
|
return "general"
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAIResponse extracts executable commands and artifacts from AI response
|
||||||
|
func (e *DefaultTaskExecutionEngine) parseAIResponse(response *ai.TaskResponse) ([]string, []TaskArtifact, error) {
|
||||||
|
var commands []string
|
||||||
|
var artifacts []TaskArtifact
|
||||||
|
|
||||||
|
// Parse response content for commands and files
|
||||||
|
// This is a simplified parser - in reality would need more sophisticated parsing
|
||||||
|
|
||||||
|
if len(response.Actions) > 0 {
|
||||||
|
for _, action := range response.Actions {
|
||||||
|
switch action.Type {
|
||||||
|
case "command", "command_run":
|
||||||
|
// Extract command from content or target
|
||||||
|
if action.Content != "" {
|
||||||
|
commands = append(commands, action.Content)
|
||||||
|
} else if action.Target != "" {
|
||||||
|
commands = append(commands, action.Target)
|
||||||
|
}
|
||||||
|
case "file", "file_create", "file_edit":
|
||||||
|
// Create artifact from file action
|
||||||
|
if action.Target != "" && action.Content != "" {
|
||||||
|
artifact := TaskArtifact{
|
||||||
|
Name: action.Target,
|
||||||
|
Type: "file",
|
||||||
|
Content: []byte(action.Content),
|
||||||
|
Size: int64(len(action.Content)),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
artifacts = append(artifacts, artifact)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return commands, artifacts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SandboxExecutionResult contains results from sandbox command execution
|
||||||
|
type SandboxExecutionResult struct {
|
||||||
|
Output string
|
||||||
|
Artifacts []TaskArtifact
|
||||||
|
ResourceUsage *ResourceUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeSandboxCommands runs commands in an isolated sandbox
|
||||||
|
func (e *DefaultTaskExecutionEngine) executeSandboxCommands(ctx context.Context, request *TaskExecutionRequest, commands []string) (*SandboxExecutionResult, error) {
|
||||||
|
// Create sandbox configuration
|
||||||
|
sandboxConfig := e.createSandboxConfig(request)
|
||||||
|
|
||||||
|
// Initialize sandbox
|
||||||
|
sandbox := NewDockerSandbox()
|
||||||
|
err := sandbox.Initialize(ctx, sandboxConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize sandbox: %w", err)
|
||||||
|
}
|
||||||
|
defer sandbox.Cleanup()
|
||||||
|
|
||||||
|
var outputs []string
|
||||||
|
var artifacts []TaskArtifact
|
||||||
|
|
||||||
|
// Execute each command
|
||||||
|
for _, cmdStr := range commands {
|
||||||
|
cmd := &Command{
|
||||||
|
Executable: "/bin/sh",
|
||||||
|
Args: []string{"-c", cmdStr},
|
||||||
|
WorkingDir: "/workspace",
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdResult, err := sandbox.ExecuteCommand(ctx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("command execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs = append(outputs, fmt.Sprintf("$ %s\n%s", cmdStr, cmdResult.Stdout))
|
||||||
|
|
||||||
|
if cmdResult.ExitCode != 0 {
|
||||||
|
outputs = append(outputs, fmt.Sprintf("Error (exit %d): %s", cmdResult.ExitCode, cmdResult.Stderr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get resource usage
|
||||||
|
resourceUsage, _ := sandbox.GetResourceUsage(ctx)
|
||||||
|
|
||||||
|
// Collect any generated files as artifacts
|
||||||
|
files, err := sandbox.ListFiles(ctx, "/workspace")
|
||||||
|
if err == nil {
|
||||||
|
for _, file := range files {
|
||||||
|
if !file.IsDir && file.Size > 0 {
|
||||||
|
content, err := sandbox.ReadFile(ctx, "/workspace/"+file.Name)
|
||||||
|
if err == nil {
|
||||||
|
artifact := TaskArtifact{
|
||||||
|
Name: file.Name,
|
||||||
|
Type: "generated_file",
|
||||||
|
Content: content,
|
||||||
|
Size: file.Size,
|
||||||
|
CreatedAt: file.ModTime,
|
||||||
|
}
|
||||||
|
artifacts = append(artifacts, artifact)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SandboxExecutionResult{
|
||||||
|
Output: strings.Join(outputs, "\n"),
|
||||||
|
Artifacts: artifacts,
|
||||||
|
ResourceUsage: resourceUsage,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSandboxConfig creates a sandbox configuration from task requirements
|
||||||
|
func (e *DefaultTaskExecutionEngine) createSandboxConfig(request *TaskExecutionRequest) *SandboxConfig {
|
||||||
|
config := &SandboxConfig{
|
||||||
|
Type: "docker",
|
||||||
|
Image: "alpine:latest",
|
||||||
|
Architecture: "amd64",
|
||||||
|
WorkingDir: "/workspace",
|
||||||
|
Timeout: 5 * time.Minute,
|
||||||
|
Environment: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults from engine config
|
||||||
|
if e.config.SandboxDefaults != nil {
|
||||||
|
if e.config.SandboxDefaults.Image != "" {
|
||||||
|
config.Image = e.config.SandboxDefaults.Image
|
||||||
|
}
|
||||||
|
if e.config.SandboxDefaults.Resources.MemoryLimit > 0 {
|
||||||
|
config.Resources = e.config.SandboxDefaults.Resources
|
||||||
|
}
|
||||||
|
if e.config.SandboxDefaults.Security.NoNewPrivileges {
|
||||||
|
config.Security = e.config.SandboxDefaults.Security
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply task-specific requirements
|
||||||
|
if request.Requirements != nil {
|
||||||
|
if request.Requirements.SandboxType != "" {
|
||||||
|
config.Type = request.Requirements.SandboxType
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Requirements.EnvironmentVars != nil {
|
||||||
|
for k, v := range request.Requirements.EnvironmentVars {
|
||||||
|
config.Environment[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Requirements.ResourceLimits != nil {
|
||||||
|
config.Resources = *request.Requirements.ResourceLimits
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Requirements.SecurityPolicy != nil {
|
||||||
|
config.Security = *request.Requirements.SecurityPolicy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatOutput creates a formatted output string from AI response and artifacts
|
||||||
|
func (e *DefaultTaskExecutionEngine) formatOutput(aiResponse *ai.TaskResponse, artifacts []TaskArtifact) string {
|
||||||
|
var output strings.Builder
|
||||||
|
|
||||||
|
output.WriteString("AI Response:\n")
|
||||||
|
output.WriteString(aiResponse.Response)
|
||||||
|
output.WriteString("\n\n")
|
||||||
|
|
||||||
|
if len(artifacts) > 0 {
|
||||||
|
output.WriteString("Generated Artifacts:\n")
|
||||||
|
for _, artifact := range artifacts {
|
||||||
|
output.WriteString(fmt.Sprintf("- %s (%s, %d bytes)\n",
|
||||||
|
artifact.Name, artifact.Type, artifact.Size))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return output.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns current engine metrics
|
||||||
|
func (e *DefaultTaskExecutionEngine) GetMetrics() *EngineMetrics {
|
||||||
|
return e.metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully shuts down the execution engine
|
||||||
|
func (e *DefaultTaskExecutionEngine) Shutdown() error {
|
||||||
|
e.logger.Printf("Shutting down TaskExecutionEngine...")
|
||||||
|
|
||||||
|
// Cancel all active tasks
|
||||||
|
for taskID, cancel := range e.activeTasks {
|
||||||
|
e.logger.Printf("Canceling active task: %s", taskID)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for tasks to finish (with timeout)
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for len(e.activeTasks) > 0 {
|
||||||
|
select {
|
||||||
|
case <-shutdownCtx.Done():
|
||||||
|
e.logger.Printf("Shutdown timeout reached, %d tasks may still be active", len(e.activeTasks))
|
||||||
|
return nil
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
// Continue waiting
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
e.logger.Printf("TaskExecutionEngine shutdown complete")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
599
pkg/execution/engine_test.go
Normal file
599
pkg/execution/engine_test.go
Normal file
@@ -0,0 +1,599 @@
|
|||||||
|
package execution
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"chorus/pkg/ai"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockProvider implements ai.ModelProvider for testing
|
||||||
|
type MockProvider struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) ExecuteTask(ctx context.Context, request *ai.TaskRequest) (*ai.TaskResponse, error) {
|
||||||
|
args := m.Called(ctx, request)
|
||||||
|
return args.Get(0).(*ai.TaskResponse), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) GetCapabilities() ai.ProviderCapabilities {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).(ai.ProviderCapabilities)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) ValidateConfig() error {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) GetProviderInfo() ai.ProviderInfo {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).(ai.ProviderInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockProviderFactory for testing
|
||||||
|
type MockProviderFactory struct {
|
||||||
|
mock.Mock
|
||||||
|
provider ai.ModelProvider
|
||||||
|
config ai.ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProviderFactory) GetProviderForRole(role string) (ai.ModelProvider, ai.ProviderConfig, error) {
|
||||||
|
args := m.Called(role)
|
||||||
|
return args.Get(0).(ai.ModelProvider), args.Get(1).(ai.ProviderConfig), args.Error(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProviderFactory) GetProvider(name string) (ai.ModelProvider, error) {
|
||||||
|
args := m.Called(name)
|
||||||
|
return args.Get(0).(ai.ModelProvider), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProviderFactory) ListProviders() []string {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).([]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProviderFactory) GetHealthStatus() map[string]bool {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).(map[string]bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewTaskExecutionEngine(t *testing.T) {
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
assert.NotNil(t, engine)
|
||||||
|
assert.NotNil(t, engine.metrics)
|
||||||
|
assert.NotNil(t, engine.activeTasks)
|
||||||
|
assert.NotNil(t, engine.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskExecutionEngine_Initialize(t *testing.T) {
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *EngineConfig
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil config",
|
||||||
|
config: nil,
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing AI factory",
|
||||||
|
config: &EngineConfig{
|
||||||
|
DefaultTimeout: 1 * time.Minute,
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
config: &EngineConfig{
|
||||||
|
AIProviderFactory: &MockProviderFactory{},
|
||||||
|
DefaultTimeout: 1 * time.Minute,
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "config with defaults",
|
||||||
|
config: &EngineConfig{
|
||||||
|
AIProviderFactory: &MockProviderFactory{},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := engine.Initialize(context.Background(), tt.config)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.config, engine.config)
|
||||||
|
|
||||||
|
// Check defaults are set
|
||||||
|
if tt.config.DefaultTimeout == 0 {
|
||||||
|
assert.Equal(t, 5*time.Minute, engine.config.DefaultTimeout)
|
||||||
|
}
|
||||||
|
if tt.config.MaxConcurrentTasks == 0 {
|
||||||
|
assert.Equal(t, 10, engine.config.MaxConcurrentTasks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskExecutionEngine_ExecuteTask_SimpleResponse(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
// Setup mock AI provider
|
||||||
|
mockProvider := &MockProvider{}
|
||||||
|
mockFactory := &MockProviderFactory{}
|
||||||
|
|
||||||
|
// Configure mock responses
|
||||||
|
mockProvider.On("ExecuteTask", mock.Anything, mock.Anything).Return(
|
||||||
|
&ai.TaskResponse{
|
||||||
|
TaskID: "test-123",
|
||||||
|
Content: "Task completed successfully",
|
||||||
|
Success: true,
|
||||||
|
Actions: []ai.ActionResult{},
|
||||||
|
Metadata: map[string]interface{}{},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
mockFactory.On("GetProviderForRole", "general").Return(
|
||||||
|
mockProvider,
|
||||||
|
ai.ProviderConfig{
|
||||||
|
Provider: "mock",
|
||||||
|
Model: "test-model",
|
||||||
|
},
|
||||||
|
nil)
|
||||||
|
|
||||||
|
config := &EngineConfig{
|
||||||
|
AIProviderFactory: mockFactory,
|
||||||
|
DefaultTimeout: 30 * time.Second,
|
||||||
|
EnableMetrics: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := engine.Initialize(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Execute simple task (no sandbox commands)
|
||||||
|
request := &TaskExecutionRequest{
|
||||||
|
ID: "test-123",
|
||||||
|
Type: "analysis",
|
||||||
|
Description: "Analyze the given data",
|
||||||
|
Context: map[string]interface{}{"data": "sample data"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := engine.ExecuteTask(ctx, request)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "test-123", result.TaskID)
|
||||||
|
assert.Contains(t, result.Output, "Task completed successfully")
|
||||||
|
assert.NotNil(t, result.Metrics)
|
||||||
|
assert.False(t, result.Metrics.StartTime.IsZero())
|
||||||
|
assert.False(t, result.Metrics.EndTime.IsZero())
|
||||||
|
assert.Greater(t, result.Metrics.Duration, time.Duration(0))
|
||||||
|
|
||||||
|
// Verify mocks were called
|
||||||
|
mockProvider.AssertCalled(t, "ExecuteTask", mock.Anything, mock.Anything)
|
||||||
|
mockFactory.AssertCalled(t, "GetProviderForRole", "general")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskExecutionEngine_ExecuteTask_WithCommands(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping Docker integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
// Setup mock AI provider with commands
|
||||||
|
mockProvider := &MockProvider{}
|
||||||
|
mockFactory := &MockProviderFactory{}
|
||||||
|
|
||||||
|
// Configure mock to return commands
|
||||||
|
mockProvider.On("ExecuteTask", mock.Anything, mock.Anything).Return(
|
||||||
|
&ai.TaskResponse{
|
||||||
|
TaskID: "test-456",
|
||||||
|
Content: "Executing commands",
|
||||||
|
Success: true,
|
||||||
|
Actions: []ai.ActionResult{
|
||||||
|
{
|
||||||
|
Type: "command",
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"command": "echo 'Hello World'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "file",
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"name": "test.txt",
|
||||||
|
"content": "Test file content",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Metadata: map[string]interface{}{},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
mockFactory.On("GetProviderForRole", "developer").Return(
|
||||||
|
mockProvider,
|
||||||
|
ai.ProviderConfig{
|
||||||
|
Provider: "mock",
|
||||||
|
Model: "test-model",
|
||||||
|
},
|
||||||
|
nil)
|
||||||
|
|
||||||
|
config := &EngineConfig{
|
||||||
|
AIProviderFactory: mockFactory,
|
||||||
|
DefaultTimeout: 1 * time.Minute,
|
||||||
|
SandboxDefaults: &SandboxConfig{
|
||||||
|
Type: "docker",
|
||||||
|
Image: "alpine:latest",
|
||||||
|
Resources: ResourceLimits{
|
||||||
|
MemoryLimit: 256 * 1024 * 1024,
|
||||||
|
CPULimit: 0.5,
|
||||||
|
},
|
||||||
|
Security: SecurityPolicy{
|
||||||
|
NoNewPrivileges: true,
|
||||||
|
AllowNetworking: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := engine.Initialize(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Execute task with commands
|
||||||
|
request := &TaskExecutionRequest{
|
||||||
|
ID: "test-456",
|
||||||
|
Type: "code_generation",
|
||||||
|
Description: "Generate a simple script",
|
||||||
|
Timeout: 2 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := engine.ExecuteTask(ctx, request)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// If Docker is not available, skip this test
|
||||||
|
t.Skipf("Docker not available for sandbox testing: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "test-456", result.TaskID)
|
||||||
|
assert.NotEmpty(t, result.Output)
|
||||||
|
assert.GreaterOrEqual(t, len(result.Artifacts), 1) // At least the file artifact
|
||||||
|
assert.Equal(t, 1, result.Metrics.CommandsExecuted)
|
||||||
|
assert.Greater(t, result.Metrics.SandboxTime, time.Duration(0))
|
||||||
|
|
||||||
|
// Check artifacts
|
||||||
|
var foundTestFile bool
|
||||||
|
for _, artifact := range result.Artifacts {
|
||||||
|
if artifact.Name == "test.txt" {
|
||||||
|
foundTestFile = true
|
||||||
|
assert.Equal(t, "file", artifact.Type)
|
||||||
|
assert.Equal(t, "Test file content", string(artifact.Content))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, foundTestFile, "Expected test.txt artifact not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskExecutionEngine_DetermineRoleFromTask(t *testing.T) {
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
request *TaskExecutionRequest
|
||||||
|
expectedRole string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "code task",
|
||||||
|
request: &TaskExecutionRequest{
|
||||||
|
Type: "code_generation",
|
||||||
|
Description: "Write a function to sort array",
|
||||||
|
},
|
||||||
|
expectedRole: "developer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "analysis task",
|
||||||
|
request: &TaskExecutionRequest{
|
||||||
|
Type: "analysis",
|
||||||
|
Description: "Analyze the performance metrics",
|
||||||
|
},
|
||||||
|
expectedRole: "analyst",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test task",
|
||||||
|
request: &TaskExecutionRequest{
|
||||||
|
Type: "testing",
|
||||||
|
Description: "Write tests for the function",
|
||||||
|
},
|
||||||
|
expectedRole: "tester",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "program task by description",
|
||||||
|
request: &TaskExecutionRequest{
|
||||||
|
Type: "general",
|
||||||
|
Description: "Create a program that processes data",
|
||||||
|
},
|
||||||
|
expectedRole: "developer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "review task by description",
|
||||||
|
request: &TaskExecutionRequest{
|
||||||
|
Type: "general",
|
||||||
|
Description: "Review the code quality",
|
||||||
|
},
|
||||||
|
expectedRole: "analyst",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "general task",
|
||||||
|
request: &TaskExecutionRequest{
|
||||||
|
Type: "documentation",
|
||||||
|
Description: "Write user documentation",
|
||||||
|
},
|
||||||
|
expectedRole: "general",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
role := engine.determineRoleFromTask(tt.request)
|
||||||
|
assert.Equal(t, tt.expectedRole, role)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskExecutionEngine_ParseAIResponse(t *testing.T) {
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
response *ai.TaskResponse
|
||||||
|
expectedCommands int
|
||||||
|
expectedArtifacts int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "response with commands and files",
|
||||||
|
response: &ai.TaskResponse{
|
||||||
|
Actions: []ai.ActionResult{
|
||||||
|
{
|
||||||
|
Type: "command",
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"command": "ls -la",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "command",
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"command": "echo 'test'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "file",
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"name": "script.sh",
|
||||||
|
"content": "#!/bin/bash\necho 'Hello'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCommands: 2,
|
||||||
|
expectedArtifacts: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "response with no actions",
|
||||||
|
response: &ai.TaskResponse{
|
||||||
|
Actions: []ai.ActionResult{},
|
||||||
|
},
|
||||||
|
expectedCommands: 0,
|
||||||
|
expectedArtifacts: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "response with unknown action types",
|
||||||
|
response: &ai.TaskResponse{
|
||||||
|
Actions: []ai.ActionResult{
|
||||||
|
{
|
||||||
|
Type: "unknown",
|
||||||
|
Content: map[string]interface{}{
|
||||||
|
"data": "some data",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCommands: 0,
|
||||||
|
expectedArtifacts: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
commands, artifacts, err := engine.parseAIResponse(tt.response)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, commands, tt.expectedCommands)
|
||||||
|
assert.Len(t, artifacts, tt.expectedArtifacts)
|
||||||
|
|
||||||
|
// Validate artifact content if present
|
||||||
|
for _, artifact := range artifacts {
|
||||||
|
assert.NotEmpty(t, artifact.Name)
|
||||||
|
assert.NotEmpty(t, artifact.Type)
|
||||||
|
assert.Greater(t, artifact.Size, int64(0))
|
||||||
|
assert.False(t, artifact.CreatedAt.IsZero())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskExecutionEngine_CreateSandboxConfig(t *testing.T) {
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
// Initialize with default config
|
||||||
|
config := &EngineConfig{
|
||||||
|
AIProviderFactory: &MockProviderFactory{},
|
||||||
|
SandboxDefaults: &SandboxConfig{
|
||||||
|
Image: "ubuntu:20.04",
|
||||||
|
Resources: ResourceLimits{
|
||||||
|
MemoryLimit: 1024 * 1024 * 1024,
|
||||||
|
CPULimit: 2.0,
|
||||||
|
},
|
||||||
|
Security: SecurityPolicy{
|
||||||
|
NoNewPrivileges: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
engine.Initialize(context.Background(), config)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
request *TaskExecutionRequest
|
||||||
|
validate func(t *testing.T, config *SandboxConfig)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic request uses defaults",
|
||||||
|
request: &TaskExecutionRequest{
|
||||||
|
ID: "test",
|
||||||
|
Type: "general",
|
||||||
|
Description: "test task",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, config *SandboxConfig) {
|
||||||
|
assert.Equal(t, "ubuntu:20.04", config.Image)
|
||||||
|
assert.Equal(t, int64(1024*1024*1024), config.Resources.MemoryLimit)
|
||||||
|
assert.Equal(t, 2.0, config.Resources.CPULimit)
|
||||||
|
assert.True(t, config.Security.NoNewPrivileges)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "request with custom requirements",
|
||||||
|
request: &TaskExecutionRequest{
|
||||||
|
ID: "test",
|
||||||
|
Type: "custom",
|
||||||
|
Description: "custom task",
|
||||||
|
Requirements: &TaskRequirements{
|
||||||
|
SandboxType: "container",
|
||||||
|
EnvironmentVars: map[string]string{
|
||||||
|
"ENV_VAR": "test_value",
|
||||||
|
},
|
||||||
|
ResourceLimits: &ResourceLimits{
|
||||||
|
MemoryLimit: 512 * 1024 * 1024,
|
||||||
|
CPULimit: 1.0,
|
||||||
|
},
|
||||||
|
SecurityPolicy: &SecurityPolicy{
|
||||||
|
ReadOnlyRoot: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, config *SandboxConfig) {
|
||||||
|
assert.Equal(t, "container", config.Type)
|
||||||
|
assert.Equal(t, "test_value", config.Environment["ENV_VAR"])
|
||||||
|
assert.Equal(t, int64(512*1024*1024), config.Resources.MemoryLimit)
|
||||||
|
assert.Equal(t, 1.0, config.Resources.CPULimit)
|
||||||
|
assert.True(t, config.Security.ReadOnlyRoot)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
sandboxConfig := engine.createSandboxConfig(tt.request)
|
||||||
|
tt.validate(t, sandboxConfig)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskExecutionEngine_GetMetrics(t *testing.T) {
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
metrics := engine.GetMetrics()
|
||||||
|
|
||||||
|
assert.NotNil(t, metrics)
|
||||||
|
assert.Equal(t, int64(0), metrics.TasksExecuted)
|
||||||
|
assert.Equal(t, int64(0), metrics.TasksSuccessful)
|
||||||
|
assert.Equal(t, int64(0), metrics.TasksFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskExecutionEngine_Shutdown(t *testing.T) {
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
// Initialize engine
|
||||||
|
config := &EngineConfig{
|
||||||
|
AIProviderFactory: &MockProviderFactory{},
|
||||||
|
}
|
||||||
|
err := engine.Initialize(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add a mock active task
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
engine.activeTasks["test-task"] = cancel
|
||||||
|
|
||||||
|
// Shutdown should cancel active tasks
|
||||||
|
err = engine.Shutdown()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify task was cleaned up
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Expected - task was canceled
|
||||||
|
default:
|
||||||
|
t.Error("Expected task context to be canceled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkTaskExecutionEngine_ExecuteSimpleTask(b *testing.B) {
|
||||||
|
engine := NewTaskExecutionEngine()
|
||||||
|
|
||||||
|
// Setup mock AI provider
|
||||||
|
mockProvider := &MockProvider{}
|
||||||
|
mockFactory := &MockProviderFactory{}
|
||||||
|
|
||||||
|
mockProvider.On("ExecuteTask", mock.Anything, mock.Anything).Return(
|
||||||
|
&ai.TaskResponse{
|
||||||
|
TaskID: "bench",
|
||||||
|
Content: "Benchmark task completed",
|
||||||
|
Success: true,
|
||||||
|
Actions: []ai.ActionResult{},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
mockFactory.On("GetProviderForRole", mock.Anything).Return(
|
||||||
|
mockProvider,
|
||||||
|
ai.ProviderConfig{Provider: "mock", Model: "test"},
|
||||||
|
nil)
|
||||||
|
|
||||||
|
config := &EngineConfig{
|
||||||
|
AIProviderFactory: mockFactory,
|
||||||
|
DefaultTimeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
engine.Initialize(context.Background(), config)
|
||||||
|
|
||||||
|
request := &TaskExecutionRequest{
|
||||||
|
ID: "bench",
|
||||||
|
Type: "benchmark",
|
||||||
|
Description: "Benchmark task",
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := engine.ExecuteTask(context.Background(), request)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Task execution failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
415
pkg/execution/sandbox.go
Normal file
415
pkg/execution/sandbox.go
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
package execution
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecutionSandbox defines the interface for isolated task execution environments
|
||||||
|
type ExecutionSandbox interface {
|
||||||
|
// Initialize sets up the sandbox environment
|
||||||
|
Initialize(ctx context.Context, config *SandboxConfig) error
|
||||||
|
|
||||||
|
// ExecuteCommand runs a command within the sandbox
|
||||||
|
ExecuteCommand(ctx context.Context, cmd *Command) (*CommandResult, error)
|
||||||
|
|
||||||
|
// CopyFiles copies files between host and sandbox
|
||||||
|
CopyFiles(ctx context.Context, source, dest string) error
|
||||||
|
|
||||||
|
// WriteFile writes content to a file in the sandbox
|
||||||
|
WriteFile(ctx context.Context, path string, content []byte, mode uint32) error
|
||||||
|
|
||||||
|
// ReadFile reads content from a file in the sandbox
|
||||||
|
ReadFile(ctx context.Context, path string) ([]byte, error)
|
||||||
|
|
||||||
|
// ListFiles lists files in a directory within the sandbox
|
||||||
|
ListFiles(ctx context.Context, path string) ([]FileInfo, error)
|
||||||
|
|
||||||
|
// GetWorkingDirectory returns the current working directory in the sandbox
|
||||||
|
GetWorkingDirectory() string
|
||||||
|
|
||||||
|
// SetWorkingDirectory changes the working directory in the sandbox
|
||||||
|
SetWorkingDirectory(path string) error
|
||||||
|
|
||||||
|
// GetEnvironment returns environment variables in the sandbox
|
||||||
|
GetEnvironment() map[string]string
|
||||||
|
|
||||||
|
// SetEnvironment sets environment variables in the sandbox
|
||||||
|
SetEnvironment(env map[string]string) error
|
||||||
|
|
||||||
|
// GetResourceUsage returns current resource usage statistics
|
||||||
|
GetResourceUsage(ctx context.Context) (*ResourceUsage, error)
|
||||||
|
|
||||||
|
// Cleanup destroys the sandbox and cleans up resources
|
||||||
|
Cleanup() error
|
||||||
|
|
||||||
|
// GetInfo returns information about the sandbox
|
||||||
|
GetInfo() SandboxInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// SandboxConfig represents configuration for a sandbox environment
|
||||||
|
type SandboxConfig struct {
|
||||||
|
// Sandbox type and runtime
|
||||||
|
Type string `json:"type"` // docker, vm, process
|
||||||
|
Image string `json:"image"` // Container/VM image
|
||||||
|
Runtime string `json:"runtime"` // docker, containerd, etc.
|
||||||
|
Architecture string `json:"architecture"` // amd64, arm64
|
||||||
|
|
||||||
|
// Resource limits
|
||||||
|
Resources ResourceLimits `json:"resources"`
|
||||||
|
|
||||||
|
// Security settings
|
||||||
|
Security SecurityPolicy `json:"security"`
|
||||||
|
|
||||||
|
// Repository configuration
|
||||||
|
Repository RepositoryConfig `json:"repository"`
|
||||||
|
|
||||||
|
// Network settings
|
||||||
|
Network NetworkConfig `json:"network"`
|
||||||
|
|
||||||
|
// Environment settings
|
||||||
|
Environment map[string]string `json:"environment"`
|
||||||
|
WorkingDir string `json:"working_dir"`
|
||||||
|
|
||||||
|
// Tool and service access
|
||||||
|
Tools []string `json:"tools"` // Available tools in sandbox
|
||||||
|
MCPServers []string `json:"mcp_servers"` // MCP servers to connect to
|
||||||
|
|
||||||
|
// Execution settings
|
||||||
|
Timeout time.Duration `json:"timeout"` // Maximum execution time
|
||||||
|
CleanupDelay time.Duration `json:"cleanup_delay"` // Delay before cleanup
|
||||||
|
|
||||||
|
// Metadata
|
||||||
|
Labels map[string]string `json:"labels"`
|
||||||
|
Annotations map[string]string `json:"annotations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Command represents a command to execute in the sandbox
|
||||||
|
type Command struct {
|
||||||
|
// Command specification
|
||||||
|
Executable string `json:"executable"`
|
||||||
|
Args []string `json:"args"`
|
||||||
|
WorkingDir string `json:"working_dir"`
|
||||||
|
Environment map[string]string `json:"environment"`
|
||||||
|
|
||||||
|
// Input/Output
|
||||||
|
Stdin io.Reader `json:"-"`
|
||||||
|
StdinContent string `json:"stdin_content"`
|
||||||
|
|
||||||
|
// Execution settings
|
||||||
|
Timeout time.Duration `json:"timeout"`
|
||||||
|
User string `json:"user"`
|
||||||
|
|
||||||
|
// Security settings
|
||||||
|
AllowNetwork bool `json:"allow_network"`
|
||||||
|
AllowWrite bool `json:"allow_write"`
|
||||||
|
RestrictPaths []string `json:"restrict_paths"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommandResult represents the result of command execution
|
||||||
|
type CommandResult struct {
|
||||||
|
// Exit information
|
||||||
|
ExitCode int `json:"exit_code"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
|
||||||
|
// Output
|
||||||
|
Stdout string `json:"stdout"`
|
||||||
|
Stderr string `json:"stderr"`
|
||||||
|
Combined string `json:"combined"`
|
||||||
|
|
||||||
|
// Timing
|
||||||
|
StartTime time.Time `json:"start_time"`
|
||||||
|
EndTime time.Time `json:"end_time"`
|
||||||
|
Duration time.Duration `json:"duration"`
|
||||||
|
|
||||||
|
// Resource usage during execution
|
||||||
|
ResourceUsage ResourceUsage `json:"resource_usage"`
|
||||||
|
|
||||||
|
// Error information
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
Signal string `json:"signal,omitempty"`
|
||||||
|
|
||||||
|
// Metadata
|
||||||
|
ProcessID int `json:"process_id,omitempty"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileInfo represents information about a file in the sandbox
|
||||||
|
type FileInfo struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
Mode uint32 `json:"mode"`
|
||||||
|
ModTime time.Time `json:"mod_time"`
|
||||||
|
IsDir bool `json:"is_dir"`
|
||||||
|
Owner string `json:"owner"`
|
||||||
|
Group string `json:"group"`
|
||||||
|
Permissions string `json:"permissions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceLimits defines resource constraints for the sandbox
|
||||||
|
type ResourceLimits struct {
|
||||||
|
// CPU limits
|
||||||
|
CPULimit float64 `json:"cpu_limit"` // CPU cores (e.g., 1.5)
|
||||||
|
CPURequest float64 `json:"cpu_request"` // CPU cores requested
|
||||||
|
|
||||||
|
// Memory limits
|
||||||
|
MemoryLimit int64 `json:"memory_limit"` // Bytes
|
||||||
|
MemoryRequest int64 `json:"memory_request"` // Bytes
|
||||||
|
|
||||||
|
// Storage limits
|
||||||
|
DiskLimit int64 `json:"disk_limit"` // Bytes
|
||||||
|
DiskRequest int64 `json:"disk_request"` // Bytes
|
||||||
|
|
||||||
|
// Network limits
|
||||||
|
NetworkInLimit int64 `json:"network_in_limit"` // Bytes/sec
|
||||||
|
NetworkOutLimit int64 `json:"network_out_limit"` // Bytes/sec
|
||||||
|
|
||||||
|
// Process limits
|
||||||
|
ProcessLimit int `json:"process_limit"` // Max processes
|
||||||
|
FileLimit int `json:"file_limit"` // Max open files
|
||||||
|
|
||||||
|
// Time limits
|
||||||
|
WallTimeLimit time.Duration `json:"wall_time_limit"` // Max wall clock time
|
||||||
|
CPUTimeLimit time.Duration `json:"cpu_time_limit"` // Max CPU time
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecurityPolicy defines security constraints and policies
|
||||||
|
type SecurityPolicy struct {
|
||||||
|
// Container security
|
||||||
|
RunAsUser string `json:"run_as_user"`
|
||||||
|
RunAsGroup string `json:"run_as_group"`
|
||||||
|
ReadOnlyRoot bool `json:"read_only_root"`
|
||||||
|
NoNewPrivileges bool `json:"no_new_privileges"`
|
||||||
|
|
||||||
|
// Capabilities
|
||||||
|
AddCapabilities []string `json:"add_capabilities"`
|
||||||
|
DropCapabilities []string `json:"drop_capabilities"`
|
||||||
|
|
||||||
|
// SELinux/AppArmor
|
||||||
|
SELinuxContext string `json:"selinux_context"`
|
||||||
|
AppArmorProfile string `json:"apparmor_profile"`
|
||||||
|
SeccompProfile string `json:"seccomp_profile"`
|
||||||
|
|
||||||
|
// Network security
|
||||||
|
AllowNetworking bool `json:"allow_networking"`
|
||||||
|
AllowedHosts []string `json:"allowed_hosts"`
|
||||||
|
BlockedHosts []string `json:"blocked_hosts"`
|
||||||
|
AllowedPorts []int `json:"allowed_ports"`
|
||||||
|
|
||||||
|
// File system security
|
||||||
|
ReadOnlyPaths []string `json:"read_only_paths"`
|
||||||
|
MaskedPaths []string `json:"masked_paths"`
|
||||||
|
TmpfsPaths []string `json:"tmpfs_paths"`
|
||||||
|
|
||||||
|
// Resource protection
|
||||||
|
PreventEscalation bool `json:"prevent_escalation"`
|
||||||
|
IsolateNetwork bool `json:"isolate_network"`
|
||||||
|
IsolateProcess bool `json:"isolate_process"`
|
||||||
|
|
||||||
|
// Monitoring
|
||||||
|
EnableAuditLog bool `json:"enable_audit_log"`
|
||||||
|
LogSecurityEvents bool `json:"log_security_events"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RepositoryConfig defines how the repository is mounted in the sandbox
|
||||||
|
type RepositoryConfig struct {
|
||||||
|
// Repository source
|
||||||
|
URL string `json:"url"`
|
||||||
|
Branch string `json:"branch"`
|
||||||
|
CommitHash string `json:"commit_hash"`
|
||||||
|
LocalPath string `json:"local_path"`
|
||||||
|
|
||||||
|
// Mount configuration
|
||||||
|
MountPoint string `json:"mount_point"` // Path in sandbox
|
||||||
|
ReadOnly bool `json:"read_only"`
|
||||||
|
|
||||||
|
// Git configuration
|
||||||
|
GitConfig GitConfig `json:"git_config"`
|
||||||
|
|
||||||
|
// File filters
|
||||||
|
IncludeFiles []string `json:"include_files"` // Glob patterns
|
||||||
|
ExcludeFiles []string `json:"exclude_files"` // Glob patterns
|
||||||
|
|
||||||
|
// Access permissions
|
||||||
|
Permissions string `json:"permissions"` // rwx format
|
||||||
|
Owner string `json:"owner"`
|
||||||
|
Group string `json:"group"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitConfig defines Git configuration within the sandbox
|
||||||
|
type GitConfig struct {
|
||||||
|
UserName string `json:"user_name"`
|
||||||
|
UserEmail string `json:"user_email"`
|
||||||
|
SigningKey string `json:"signing_key"`
|
||||||
|
ConfigValues map[string]string `json:"config_values"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkConfig defines network settings for the sandbox
|
||||||
|
type NetworkConfig struct {
|
||||||
|
// Network isolation
|
||||||
|
Isolated bool `json:"isolated"` // No network access
|
||||||
|
Bridge string `json:"bridge"` // Network bridge
|
||||||
|
|
||||||
|
// DNS settings
|
||||||
|
DNSServers []string `json:"dns_servers"`
|
||||||
|
DNSSearch []string `json:"dns_search"`
|
||||||
|
|
||||||
|
// Proxy settings
|
||||||
|
HTTPProxy string `json:"http_proxy"`
|
||||||
|
HTTPSProxy string `json:"https_proxy"`
|
||||||
|
NoProxy string `json:"no_proxy"`
|
||||||
|
|
||||||
|
// Port mappings
|
||||||
|
PortMappings []PortMapping `json:"port_mappings"`
|
||||||
|
|
||||||
|
// Bandwidth limits
|
||||||
|
IngressLimit int64 `json:"ingress_limit"` // Bytes/sec
|
||||||
|
EgressLimit int64 `json:"egress_limit"` // Bytes/sec
|
||||||
|
}
|
||||||
|
|
||||||
|
// PortMapping defines port forwarding configuration
|
||||||
|
type PortMapping struct {
|
||||||
|
HostPort int `json:"host_port"`
|
||||||
|
ContainerPort int `json:"container_port"`
|
||||||
|
Protocol string `json:"protocol"` // tcp, udp
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceUsage represents current resource consumption
|
||||||
|
type ResourceUsage struct {
|
||||||
|
// Timestamp of measurement
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
|
||||||
|
// CPU usage
|
||||||
|
CPUUsage float64 `json:"cpu_usage"` // Percentage
|
||||||
|
CPUTime time.Duration `json:"cpu_time"` // Total CPU time
|
||||||
|
|
||||||
|
// Memory usage
|
||||||
|
MemoryUsage int64 `json:"memory_usage"` // Bytes
|
||||||
|
MemoryPercent float64 `json:"memory_percent"` // Percentage of limit
|
||||||
|
MemoryPeak int64 `json:"memory_peak"` // Peak usage
|
||||||
|
|
||||||
|
// Disk usage
|
||||||
|
DiskUsage int64 `json:"disk_usage"` // Bytes
|
||||||
|
DiskReads int64 `json:"disk_reads"` // Read operations
|
||||||
|
DiskWrites int64 `json:"disk_writes"` // Write operations
|
||||||
|
|
||||||
|
// Network usage
|
||||||
|
NetworkIn int64 `json:"network_in"` // Bytes received
|
||||||
|
NetworkOut int64 `json:"network_out"` // Bytes sent
|
||||||
|
|
||||||
|
// Process information
|
||||||
|
ProcessCount int `json:"process_count"` // Active processes
|
||||||
|
ThreadCount int `json:"thread_count"` // Active threads
|
||||||
|
FileHandles int `json:"file_handles"` // Open file handles
|
||||||
|
|
||||||
|
// Runtime information
|
||||||
|
Uptime time.Duration `json:"uptime"` // Sandbox uptime
|
||||||
|
}
|
||||||
|
|
||||||
|
// SandboxInfo provides information about a sandbox instance
|
||||||
|
type SandboxInfo struct {
|
||||||
|
// Identification
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Status
|
||||||
|
Status SandboxStatus `json:"status"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
|
||||||
|
// Runtime information
|
||||||
|
Runtime string `json:"runtime"`
|
||||||
|
Image string `json:"image"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
|
||||||
|
// Network information
|
||||||
|
IPAddress string `json:"ip_address"`
|
||||||
|
MACAddress string `json:"mac_address"`
|
||||||
|
Hostname string `json:"hostname"`
|
||||||
|
|
||||||
|
// Resource information
|
||||||
|
AllocatedResources ResourceLimits `json:"allocated_resources"`
|
||||||
|
|
||||||
|
// Configuration
|
||||||
|
Config SandboxConfig `json:"config"`
|
||||||
|
|
||||||
|
// Metadata
|
||||||
|
Labels map[string]string `json:"labels"`
|
||||||
|
Annotations map[string]string `json:"annotations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SandboxStatus represents the current status of a sandbox
|
||||||
|
type SandboxStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StatusCreating SandboxStatus = "creating"
|
||||||
|
StatusStarting SandboxStatus = "starting"
|
||||||
|
StatusRunning SandboxStatus = "running"
|
||||||
|
StatusPaused SandboxStatus = "paused"
|
||||||
|
StatusStopping SandboxStatus = "stopping"
|
||||||
|
StatusStopped SandboxStatus = "stopped"
|
||||||
|
StatusFailed SandboxStatus = "failed"
|
||||||
|
StatusDestroyed SandboxStatus = "destroyed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Common sandbox errors
|
||||||
|
var (
|
||||||
|
ErrSandboxNotFound = &SandboxError{Code: "SANDBOX_NOT_FOUND", Message: "Sandbox not found"}
|
||||||
|
ErrSandboxAlreadyExists = &SandboxError{Code: "SANDBOX_ALREADY_EXISTS", Message: "Sandbox already exists"}
|
||||||
|
ErrSandboxNotRunning = &SandboxError{Code: "SANDBOX_NOT_RUNNING", Message: "Sandbox is not running"}
|
||||||
|
ErrSandboxInitFailed = &SandboxError{Code: "SANDBOX_INIT_FAILED", Message: "Sandbox initialization failed"}
|
||||||
|
ErrCommandExecutionFailed = &SandboxError{Code: "COMMAND_EXECUTION_FAILED", Message: "Command execution failed"}
|
||||||
|
ErrResourceLimitExceeded = &SandboxError{Code: "RESOURCE_LIMIT_EXCEEDED", Message: "Resource limit exceeded"}
|
||||||
|
ErrSecurityViolation = &SandboxError{Code: "SECURITY_VIOLATION", Message: "Security policy violation"}
|
||||||
|
ErrFileOperationFailed = &SandboxError{Code: "FILE_OPERATION_FAILED", Message: "File operation failed"}
|
||||||
|
ErrNetworkAccessDenied = &SandboxError{Code: "NETWORK_ACCESS_DENIED", Message: "Network access denied"}
|
||||||
|
ErrTimeoutExceeded = &SandboxError{Code: "TIMEOUT_EXCEEDED", Message: "Execution timeout exceeded"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// SandboxError represents sandbox-specific errors
|
||||||
|
type SandboxError struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Details string `json:"details,omitempty"`
|
||||||
|
Retryable bool `json:"retryable"`
|
||||||
|
Cause error `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SandboxError) Error() string {
|
||||||
|
if e.Details != "" {
|
||||||
|
return e.Message + ": " + e.Details
|
||||||
|
}
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SandboxError) Unwrap() error {
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SandboxError) IsRetryable() bool {
|
||||||
|
return e.Retryable
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSandboxError creates a new sandbox error with details
|
||||||
|
func NewSandboxError(base *SandboxError, details string) *SandboxError {
|
||||||
|
return &SandboxError{
|
||||||
|
Code: base.Code,
|
||||||
|
Message: base.Message,
|
||||||
|
Details: details,
|
||||||
|
Retryable: base.Retryable,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSandboxErrorWithCause creates a new sandbox error with an underlying cause
|
||||||
|
func NewSandboxErrorWithCause(base *SandboxError, details string, cause error) *SandboxError {
|
||||||
|
return &SandboxError{
|
||||||
|
Code: base.Code,
|
||||||
|
Message: base.Message,
|
||||||
|
Details: details,
|
||||||
|
Retryable: base.Retryable,
|
||||||
|
Cause: cause,
|
||||||
|
}
|
||||||
|
}
|
||||||
639
pkg/execution/sandbox_test.go
Normal file
639
pkg/execution/sandbox_test.go
Normal file
@@ -0,0 +1,639 @@
|
|||||||
|
package execution
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSandboxError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err *SandboxError
|
||||||
|
expected string
|
||||||
|
retryable bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple error",
|
||||||
|
err: ErrSandboxNotFound,
|
||||||
|
expected: "Sandbox not found",
|
||||||
|
retryable: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error with details",
|
||||||
|
err: NewSandboxError(ErrResourceLimitExceeded, "Memory limit of 1GB exceeded"),
|
||||||
|
expected: "Resource limit exceeded: Memory limit of 1GB exceeded",
|
||||||
|
retryable: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retryable error",
|
||||||
|
err: &SandboxError{
|
||||||
|
Code: "TEMPORARY_FAILURE",
|
||||||
|
Message: "Temporary network failure",
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
expected: "Temporary network failure",
|
||||||
|
retryable: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, tt.err.Error())
|
||||||
|
assert.Equal(t, tt.retryable, tt.err.IsRetryable())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSandboxErrorUnwrap(t *testing.T) {
|
||||||
|
baseErr := errors.New("underlying error")
|
||||||
|
sandboxErr := NewSandboxErrorWithCause(ErrCommandExecutionFailed, "command failed", baseErr)
|
||||||
|
|
||||||
|
unwrapped := sandboxErr.Unwrap()
|
||||||
|
assert.Equal(t, baseErr, unwrapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSandboxConfig(t *testing.T) {
|
||||||
|
config := &SandboxConfig{
|
||||||
|
Type: "docker",
|
||||||
|
Image: "alpine:latest",
|
||||||
|
Runtime: "docker",
|
||||||
|
Architecture: "amd64",
|
||||||
|
Resources: ResourceLimits{
|
||||||
|
MemoryLimit: 1024 * 1024 * 1024, // 1GB
|
||||||
|
MemoryRequest: 512 * 1024 * 1024, // 512MB
|
||||||
|
CPULimit: 2.0,
|
||||||
|
CPURequest: 1.0,
|
||||||
|
DiskLimit: 10 * 1024 * 1024 * 1024, // 10GB
|
||||||
|
ProcessLimit: 100,
|
||||||
|
FileLimit: 1024,
|
||||||
|
WallTimeLimit: 30 * time.Minute,
|
||||||
|
CPUTimeLimit: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
Security: SecurityPolicy{
|
||||||
|
RunAsUser: "1000",
|
||||||
|
RunAsGroup: "1000",
|
||||||
|
ReadOnlyRoot: true,
|
||||||
|
NoNewPrivileges: true,
|
||||||
|
AddCapabilities: []string{"NET_BIND_SERVICE"},
|
||||||
|
DropCapabilities: []string{"ALL"},
|
||||||
|
SELinuxContext: "unconfined_u:unconfined_r:container_t:s0",
|
||||||
|
AppArmorProfile: "docker-default",
|
||||||
|
SeccompProfile: "runtime/default",
|
||||||
|
AllowNetworking: false,
|
||||||
|
AllowedHosts: []string{"api.example.com"},
|
||||||
|
BlockedHosts: []string{"malicious.com"},
|
||||||
|
AllowedPorts: []int{80, 443},
|
||||||
|
ReadOnlyPaths: []string{"/etc", "/usr"},
|
||||||
|
MaskedPaths: []string{"/proc/kcore", "/proc/keys"},
|
||||||
|
TmpfsPaths: []string{"/tmp", "/var/tmp"},
|
||||||
|
PreventEscalation: true,
|
||||||
|
IsolateNetwork: true,
|
||||||
|
IsolateProcess: true,
|
||||||
|
EnableAuditLog: true,
|
||||||
|
LogSecurityEvents: true,
|
||||||
|
},
|
||||||
|
Repository: RepositoryConfig{
|
||||||
|
URL: "https://github.com/example/repo.git",
|
||||||
|
Branch: "main",
|
||||||
|
LocalPath: "/home/user/repo",
|
||||||
|
MountPoint: "/workspace",
|
||||||
|
ReadOnly: false,
|
||||||
|
GitConfig: GitConfig{
|
||||||
|
UserName: "Test User",
|
||||||
|
UserEmail: "test@example.com",
|
||||||
|
ConfigValues: map[string]string{
|
||||||
|
"core.autocrlf": "input",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IncludeFiles: []string{"*.go", "*.md"},
|
||||||
|
ExcludeFiles: []string{"*.tmp", "*.log"},
|
||||||
|
Permissions: "755",
|
||||||
|
Owner: "user",
|
||||||
|
Group: "user",
|
||||||
|
},
|
||||||
|
Network: NetworkConfig{
|
||||||
|
Isolated: false,
|
||||||
|
Bridge: "docker0",
|
||||||
|
DNSServers: []string{"8.8.8.8", "1.1.1.1"},
|
||||||
|
DNSSearch: []string{"example.com"},
|
||||||
|
HTTPProxy: "http://proxy:8080",
|
||||||
|
HTTPSProxy: "http://proxy:8080",
|
||||||
|
NoProxy: "localhost,127.0.0.1",
|
||||||
|
PortMappings: []PortMapping{
|
||||||
|
{HostPort: 8080, ContainerPort: 80, Protocol: "tcp"},
|
||||||
|
},
|
||||||
|
IngressLimit: 1024 * 1024, // 1MB/s
|
||||||
|
EgressLimit: 2048 * 1024, // 2MB/s
|
||||||
|
},
|
||||||
|
Environment: map[string]string{
|
||||||
|
"NODE_ENV": "test",
|
||||||
|
"DEBUG": "true",
|
||||||
|
},
|
||||||
|
WorkingDir: "/workspace",
|
||||||
|
Tools: []string{"git", "node", "npm"},
|
||||||
|
MCPServers: []string{"file-server", "web-server"},
|
||||||
|
Timeout: 5 * time.Minute,
|
||||||
|
CleanupDelay: 30 * time.Second,
|
||||||
|
Labels: map[string]string{
|
||||||
|
"app": "chorus",
|
||||||
|
"version": "1.0.0",
|
||||||
|
},
|
||||||
|
Annotations: map[string]string{
|
||||||
|
"description": "Test sandbox configuration",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
assert.NotEmpty(t, config.Type)
|
||||||
|
assert.NotEmpty(t, config.Image)
|
||||||
|
assert.NotEmpty(t, config.Architecture)
|
||||||
|
|
||||||
|
// Validate resource limits
|
||||||
|
assert.Greater(t, config.Resources.MemoryLimit, int64(0))
|
||||||
|
assert.Greater(t, config.Resources.CPULimit, 0.0)
|
||||||
|
|
||||||
|
// Validate security policy
|
||||||
|
assert.NotEmpty(t, config.Security.RunAsUser)
|
||||||
|
assert.True(t, config.Security.NoNewPrivileges)
|
||||||
|
assert.NotEmpty(t, config.Security.DropCapabilities)
|
||||||
|
|
||||||
|
// Validate repository config
|
||||||
|
assert.NotEmpty(t, config.Repository.MountPoint)
|
||||||
|
assert.NotEmpty(t, config.Repository.GitConfig.UserName)
|
||||||
|
|
||||||
|
// Validate network config
|
||||||
|
assert.NotEmpty(t, config.Network.DNSServers)
|
||||||
|
assert.Len(t, config.Network.PortMappings, 1)
|
||||||
|
|
||||||
|
// Validate timeouts
|
||||||
|
assert.Greater(t, config.Timeout, time.Duration(0))
|
||||||
|
assert.Greater(t, config.CleanupDelay, time.Duration(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommand(t *testing.T) {
|
||||||
|
cmd := &Command{
|
||||||
|
Executable: "python3",
|
||||||
|
Args: []string{"-c", "print('hello world')"},
|
||||||
|
WorkingDir: "/workspace",
|
||||||
|
Environment: map[string]string{"PYTHONPATH": "/custom/path"},
|
||||||
|
StdinContent: "input data",
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
User: "1000",
|
||||||
|
AllowNetwork: true,
|
||||||
|
AllowWrite: true,
|
||||||
|
RestrictPaths: []string{"/etc", "/usr"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate command structure
|
||||||
|
assert.Equal(t, "python3", cmd.Executable)
|
||||||
|
assert.Len(t, cmd.Args, 2)
|
||||||
|
assert.Equal(t, "/workspace", cmd.WorkingDir)
|
||||||
|
assert.Equal(t, "/custom/path", cmd.Environment["PYTHONPATH"])
|
||||||
|
assert.Equal(t, "input data", cmd.StdinContent)
|
||||||
|
assert.Equal(t, 30*time.Second, cmd.Timeout)
|
||||||
|
assert.True(t, cmd.AllowNetwork)
|
||||||
|
assert.True(t, cmd.AllowWrite)
|
||||||
|
assert.Len(t, cmd.RestrictPaths, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommandResult(t *testing.T) {
|
||||||
|
startTime := time.Now()
|
||||||
|
endTime := startTime.Add(2 * time.Second)
|
||||||
|
|
||||||
|
result := &CommandResult{
|
||||||
|
ExitCode: 0,
|
||||||
|
Success: true,
|
||||||
|
Stdout: "Standard output",
|
||||||
|
Stderr: "Standard error",
|
||||||
|
Combined: "Combined output",
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Duration: endTime.Sub(startTime),
|
||||||
|
ResourceUsage: ResourceUsage{
|
||||||
|
CPUUsage: 25.5,
|
||||||
|
MemoryUsage: 1024 * 1024, // 1MB
|
||||||
|
},
|
||||||
|
ProcessID: 12345,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"container_id": "abc123",
|
||||||
|
"image": "alpine:latest",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate result structure
|
||||||
|
assert.Equal(t, 0, result.ExitCode)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "Standard output", result.Stdout)
|
||||||
|
assert.Equal(t, "Standard error", result.Stderr)
|
||||||
|
assert.Equal(t, 2*time.Second, result.Duration)
|
||||||
|
assert.Equal(t, 25.5, result.ResourceUsage.CPUUsage)
|
||||||
|
assert.Equal(t, int64(1024*1024), result.ResourceUsage.MemoryUsage)
|
||||||
|
assert.Equal(t, 12345, result.ProcessID)
|
||||||
|
assert.Equal(t, "abc123", result.Metadata["container_id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileInfo(t *testing.T) {
|
||||||
|
modTime := time.Now()
|
||||||
|
|
||||||
|
fileInfo := FileInfo{
|
||||||
|
Name: "test.txt",
|
||||||
|
Path: "/workspace/test.txt",
|
||||||
|
Size: 1024,
|
||||||
|
Mode: 0644,
|
||||||
|
ModTime: modTime,
|
||||||
|
IsDir: false,
|
||||||
|
Owner: "user",
|
||||||
|
Group: "user",
|
||||||
|
Permissions: "-rw-r--r--",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate file info structure
|
||||||
|
assert.Equal(t, "test.txt", fileInfo.Name)
|
||||||
|
assert.Equal(t, "/workspace/test.txt", fileInfo.Path)
|
||||||
|
assert.Equal(t, int64(1024), fileInfo.Size)
|
||||||
|
assert.Equal(t, uint32(0644), fileInfo.Mode)
|
||||||
|
assert.Equal(t, modTime, fileInfo.ModTime)
|
||||||
|
assert.False(t, fileInfo.IsDir)
|
||||||
|
assert.Equal(t, "user", fileInfo.Owner)
|
||||||
|
assert.Equal(t, "user", fileInfo.Group)
|
||||||
|
assert.Equal(t, "-rw-r--r--", fileInfo.Permissions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResourceLimits(t *testing.T) {
|
||||||
|
limits := ResourceLimits{
|
||||||
|
CPULimit: 2.5,
|
||||||
|
CPURequest: 1.0,
|
||||||
|
MemoryLimit: 2 * 1024 * 1024 * 1024, // 2GB
|
||||||
|
MemoryRequest: 1 * 1024 * 1024 * 1024, // 1GB
|
||||||
|
DiskLimit: 50 * 1024 * 1024 * 1024, // 50GB
|
||||||
|
DiskRequest: 10 * 1024 * 1024 * 1024, // 10GB
|
||||||
|
NetworkInLimit: 10 * 1024 * 1024, // 10MB/s
|
||||||
|
NetworkOutLimit: 5 * 1024 * 1024, // 5MB/s
|
||||||
|
ProcessLimit: 200,
|
||||||
|
FileLimit: 2048,
|
||||||
|
WallTimeLimit: 1 * time.Hour,
|
||||||
|
CPUTimeLimit: 30 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate resource limits
|
||||||
|
assert.Equal(t, 2.5, limits.CPULimit)
|
||||||
|
assert.Equal(t, 1.0, limits.CPURequest)
|
||||||
|
assert.Equal(t, int64(2*1024*1024*1024), limits.MemoryLimit)
|
||||||
|
assert.Equal(t, int64(1*1024*1024*1024), limits.MemoryRequest)
|
||||||
|
assert.Equal(t, int64(50*1024*1024*1024), limits.DiskLimit)
|
||||||
|
assert.Equal(t, 200, limits.ProcessLimit)
|
||||||
|
assert.Equal(t, 2048, limits.FileLimit)
|
||||||
|
assert.Equal(t, 1*time.Hour, limits.WallTimeLimit)
|
||||||
|
assert.Equal(t, 30*time.Minute, limits.CPUTimeLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResourceUsage(t *testing.T) {
|
||||||
|
timestamp := time.Now()
|
||||||
|
|
||||||
|
usage := ResourceUsage{
|
||||||
|
Timestamp: timestamp,
|
||||||
|
CPUUsage: 75.5,
|
||||||
|
CPUTime: 15 * time.Minute,
|
||||||
|
MemoryUsage: 512 * 1024 * 1024, // 512MB
|
||||||
|
MemoryPercent: 25.0,
|
||||||
|
MemoryPeak: 768 * 1024 * 1024, // 768MB
|
||||||
|
DiskUsage: 1 * 1024 * 1024 * 1024, // 1GB
|
||||||
|
DiskReads: 1000,
|
||||||
|
DiskWrites: 500,
|
||||||
|
NetworkIn: 10 * 1024 * 1024, // 10MB
|
||||||
|
NetworkOut: 5 * 1024 * 1024, // 5MB
|
||||||
|
ProcessCount: 25,
|
||||||
|
ThreadCount: 100,
|
||||||
|
FileHandles: 50,
|
||||||
|
Uptime: 2 * time.Hour,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate resource usage
|
||||||
|
assert.Equal(t, timestamp, usage.Timestamp)
|
||||||
|
assert.Equal(t, 75.5, usage.CPUUsage)
|
||||||
|
assert.Equal(t, 15*time.Minute, usage.CPUTime)
|
||||||
|
assert.Equal(t, int64(512*1024*1024), usage.MemoryUsage)
|
||||||
|
assert.Equal(t, 25.0, usage.MemoryPercent)
|
||||||
|
assert.Equal(t, int64(768*1024*1024), usage.MemoryPeak)
|
||||||
|
assert.Equal(t, 25, usage.ProcessCount)
|
||||||
|
assert.Equal(t, 100, usage.ThreadCount)
|
||||||
|
assert.Equal(t, 50, usage.FileHandles)
|
||||||
|
assert.Equal(t, 2*time.Hour, usage.Uptime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSandboxInfo(t *testing.T) {
|
||||||
|
createdAt := time.Now()
|
||||||
|
startedAt := createdAt.Add(5 * time.Second)
|
||||||
|
|
||||||
|
info := SandboxInfo{
|
||||||
|
ID: "sandbox-123",
|
||||||
|
Name: "test-sandbox",
|
||||||
|
Type: "docker",
|
||||||
|
Status: StatusRunning,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
StartedAt: startedAt,
|
||||||
|
Runtime: "docker",
|
||||||
|
Image: "alpine:latest",
|
||||||
|
Platform: "linux/amd64",
|
||||||
|
IPAddress: "172.17.0.2",
|
||||||
|
MACAddress: "02:42:ac:11:00:02",
|
||||||
|
Hostname: "sandbox-123",
|
||||||
|
AllocatedResources: ResourceLimits{
|
||||||
|
MemoryLimit: 1024 * 1024 * 1024, // 1GB
|
||||||
|
CPULimit: 2.0,
|
||||||
|
},
|
||||||
|
Labels: map[string]string{
|
||||||
|
"app": "chorus",
|
||||||
|
},
|
||||||
|
Annotations: map[string]string{
|
||||||
|
"creator": "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate sandbox info
|
||||||
|
assert.Equal(t, "sandbox-123", info.ID)
|
||||||
|
assert.Equal(t, "test-sandbox", info.Name)
|
||||||
|
assert.Equal(t, "docker", info.Type)
|
||||||
|
assert.Equal(t, StatusRunning, info.Status)
|
||||||
|
assert.Equal(t, createdAt, info.CreatedAt)
|
||||||
|
assert.Equal(t, startedAt, info.StartedAt)
|
||||||
|
assert.Equal(t, "docker", info.Runtime)
|
||||||
|
assert.Equal(t, "alpine:latest", info.Image)
|
||||||
|
assert.Equal(t, "172.17.0.2", info.IPAddress)
|
||||||
|
assert.Equal(t, "chorus", info.Labels["app"])
|
||||||
|
assert.Equal(t, "test", info.Annotations["creator"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSandboxStatus(t *testing.T) {
|
||||||
|
statuses := []SandboxStatus{
|
||||||
|
StatusCreating,
|
||||||
|
StatusStarting,
|
||||||
|
StatusRunning,
|
||||||
|
StatusPaused,
|
||||||
|
StatusStopping,
|
||||||
|
StatusStopped,
|
||||||
|
StatusFailed,
|
||||||
|
StatusDestroyed,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedStatuses := []string{
|
||||||
|
"creating",
|
||||||
|
"starting",
|
||||||
|
"running",
|
||||||
|
"paused",
|
||||||
|
"stopping",
|
||||||
|
"stopped",
|
||||||
|
"failed",
|
||||||
|
"destroyed",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, status := range statuses {
|
||||||
|
assert.Equal(t, expectedStatuses[i], string(status))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPortMapping(t *testing.T) {
|
||||||
|
mapping := PortMapping{
|
||||||
|
HostPort: 8080,
|
||||||
|
ContainerPort: 80,
|
||||||
|
Protocol: "tcp",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 8080, mapping.HostPort)
|
||||||
|
assert.Equal(t, 80, mapping.ContainerPort)
|
||||||
|
assert.Equal(t, "tcp", mapping.Protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitConfig(t *testing.T) {
|
||||||
|
config := GitConfig{
|
||||||
|
UserName: "Test User",
|
||||||
|
UserEmail: "test@example.com",
|
||||||
|
SigningKey: "ABC123",
|
||||||
|
ConfigValues: map[string]string{
|
||||||
|
"core.autocrlf": "input",
|
||||||
|
"pull.rebase": "true",
|
||||||
|
"init.defaultBranch": "main",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "Test User", config.UserName)
|
||||||
|
assert.Equal(t, "test@example.com", config.UserEmail)
|
||||||
|
assert.Equal(t, "ABC123", config.SigningKey)
|
||||||
|
assert.Equal(t, "input", config.ConfigValues["core.autocrlf"])
|
||||||
|
assert.Equal(t, "true", config.ConfigValues["pull.rebase"])
|
||||||
|
assert.Equal(t, "main", config.ConfigValues["init.defaultBranch"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockSandbox implements ExecutionSandbox for testing
|
||||||
|
type MockSandbox struct {
|
||||||
|
id string
|
||||||
|
status SandboxStatus
|
||||||
|
workingDir string
|
||||||
|
environment map[string]string
|
||||||
|
shouldFail bool
|
||||||
|
commandResult *CommandResult
|
||||||
|
files []FileInfo
|
||||||
|
resourceUsage *ResourceUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockSandbox() *MockSandbox {
|
||||||
|
return &MockSandbox{
|
||||||
|
id: "mock-sandbox-123",
|
||||||
|
status: StatusStopped,
|
||||||
|
workingDir: "/workspace",
|
||||||
|
environment: make(map[string]string),
|
||||||
|
files: []FileInfo{},
|
||||||
|
commandResult: &CommandResult{
|
||||||
|
Success: true,
|
||||||
|
ExitCode: 0,
|
||||||
|
Stdout: "mock output",
|
||||||
|
},
|
||||||
|
resourceUsage: &ResourceUsage{
|
||||||
|
CPUUsage: 10.0,
|
||||||
|
MemoryUsage: 100 * 1024 * 1024, // 100MB
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) Initialize(ctx context.Context, config *SandboxConfig) error {
|
||||||
|
if m.shouldFail {
|
||||||
|
return NewSandboxError(ErrSandboxInitFailed, "mock initialization failed")
|
||||||
|
}
|
||||||
|
m.status = StatusRunning
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) ExecuteCommand(ctx context.Context, cmd *Command) (*CommandResult, error) {
|
||||||
|
if m.shouldFail {
|
||||||
|
return nil, NewSandboxError(ErrCommandExecutionFailed, "mock command execution failed")
|
||||||
|
}
|
||||||
|
return m.commandResult, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) CopyFiles(ctx context.Context, source, dest string) error {
|
||||||
|
if m.shouldFail {
|
||||||
|
return NewSandboxError(ErrFileOperationFailed, "mock file copy failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) WriteFile(ctx context.Context, path string, content []byte, mode uint32) error {
|
||||||
|
if m.shouldFail {
|
||||||
|
return NewSandboxError(ErrFileOperationFailed, "mock file write failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) ReadFile(ctx context.Context, path string) ([]byte, error) {
|
||||||
|
if m.shouldFail {
|
||||||
|
return nil, NewSandboxError(ErrFileOperationFailed, "mock file read failed")
|
||||||
|
}
|
||||||
|
return []byte("mock file content"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) ListFiles(ctx context.Context, path string) ([]FileInfo, error) {
|
||||||
|
if m.shouldFail {
|
||||||
|
return nil, NewSandboxError(ErrFileOperationFailed, "mock file list failed")
|
||||||
|
}
|
||||||
|
return m.files, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) GetWorkingDirectory() string {
|
||||||
|
return m.workingDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) SetWorkingDirectory(path string) error {
|
||||||
|
if m.shouldFail {
|
||||||
|
return NewSandboxError(ErrFileOperationFailed, "mock set working directory failed")
|
||||||
|
}
|
||||||
|
m.workingDir = path
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) GetEnvironment() map[string]string {
|
||||||
|
env := make(map[string]string)
|
||||||
|
for k, v := range m.environment {
|
||||||
|
env[k] = v
|
||||||
|
}
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) SetEnvironment(env map[string]string) error {
|
||||||
|
if m.shouldFail {
|
||||||
|
return NewSandboxError(ErrFileOperationFailed, "mock set environment failed")
|
||||||
|
}
|
||||||
|
for k, v := range env {
|
||||||
|
m.environment[k] = v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) GetResourceUsage(ctx context.Context) (*ResourceUsage, error) {
|
||||||
|
if m.shouldFail {
|
||||||
|
return nil, NewSandboxError(ErrSandboxInitFailed, "mock resource usage failed")
|
||||||
|
}
|
||||||
|
return m.resourceUsage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) Cleanup() error {
|
||||||
|
if m.shouldFail {
|
||||||
|
return NewSandboxError(ErrSandboxInitFailed, "mock cleanup failed")
|
||||||
|
}
|
||||||
|
m.status = StatusDestroyed
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSandbox) GetInfo() SandboxInfo {
|
||||||
|
return SandboxInfo{
|
||||||
|
ID: m.id,
|
||||||
|
Status: m.status,
|
||||||
|
Type: "mock",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockSandbox(t *testing.T) {
|
||||||
|
sandbox := NewMockSandbox()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test initialization
|
||||||
|
err := sandbox.Initialize(ctx, &SandboxConfig{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, StatusRunning, sandbox.status)
|
||||||
|
|
||||||
|
// Test command execution
|
||||||
|
result, err := sandbox.ExecuteCommand(ctx, &Command{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "mock output", result.Stdout)
|
||||||
|
|
||||||
|
// Test file operations
|
||||||
|
err = sandbox.WriteFile(ctx, "/test.txt", []byte("test"), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
content, err := sandbox.ReadFile(ctx, "/test.txt")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("mock file content"), content)
|
||||||
|
|
||||||
|
files, err := sandbox.ListFiles(ctx, "/")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, files) // Mock returns empty list by default
|
||||||
|
|
||||||
|
// Test environment
|
||||||
|
env := sandbox.GetEnvironment()
|
||||||
|
assert.Empty(t, env)
|
||||||
|
|
||||||
|
err = sandbox.SetEnvironment(map[string]string{"TEST": "value"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
env = sandbox.GetEnvironment()
|
||||||
|
assert.Equal(t, "value", env["TEST"])
|
||||||
|
|
||||||
|
// Test resource usage
|
||||||
|
usage, err := sandbox.GetResourceUsage(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 10.0, usage.CPUUsage)
|
||||||
|
|
||||||
|
// Test cleanup
|
||||||
|
err = sandbox.Cleanup()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, StatusDestroyed, sandbox.status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockSandboxFailure(t *testing.T) {
|
||||||
|
sandbox := NewMockSandbox()
|
||||||
|
sandbox.shouldFail = true
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// All operations should fail when shouldFail is true
|
||||||
|
err := sandbox.Initialize(ctx, &SandboxConfig{})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
_, err = sandbox.ExecuteCommand(ctx, &Command{})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
err = sandbox.WriteFile(ctx, "/test.txt", []byte("test"), 0644)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
_, err = sandbox.ReadFile(ctx, "/test.txt")
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
_, err = sandbox.ListFiles(ctx, "/")
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
err = sandbox.SetWorkingDirectory("/tmp")
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
err = sandbox.SetEnvironment(map[string]string{"TEST": "value"})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
_, err = sandbox.GetResourceUsage(ctx)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
err = sandbox.Cleanup()
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
@@ -179,9 +179,11 @@ func (ehc *EnhancedHealthChecks) registerHealthChecks() {
|
|||||||
ehc.manager.RegisterCheck(ehc.createEnhancedPubSubCheck())
|
ehc.manager.RegisterCheck(ehc.createEnhancedPubSubCheck())
|
||||||
}
|
}
|
||||||
|
|
||||||
if ehc.config.EnableDHTProbes {
|
// Temporarily disable DHT health check to prevent shutdown issues
|
||||||
ehc.manager.RegisterCheck(ehc.createEnhancedDHTCheck())
|
// TODO: Fix DHT configuration and re-enable this check
|
||||||
}
|
// if ehc.config.EnableDHTProbes {
|
||||||
|
// ehc.manager.RegisterCheck(ehc.createEnhancedDHTCheck())
|
||||||
|
// }
|
||||||
|
|
||||||
if ehc.config.EnableElectionProbes {
|
if ehc.config.EnableElectionProbes {
|
||||||
ehc.manager.RegisterCheck(ehc.createElectionHealthCheck())
|
ehc.manager.RegisterCheck(ehc.createElectionHealthCheck())
|
||||||
@@ -290,7 +292,7 @@ func (ehc *EnhancedHealthChecks) createElectionHealthCheck() *HealthCheck {
|
|||||||
return &HealthCheck{
|
return &HealthCheck{
|
||||||
Name: "election-health",
|
Name: "election-health",
|
||||||
Description: "Election system health and leadership stability check",
|
Description: "Election system health and leadership stability check",
|
||||||
Enabled: true,
|
Enabled: false, // Temporarily disabled to prevent shutdown loops
|
||||||
Critical: false,
|
Critical: false,
|
||||||
Interval: ehc.config.ElectionProbeInterval,
|
Interval: ehc.config.ElectionProbeInterval,
|
||||||
Timeout: ehc.config.ElectionProbeTimeout,
|
Timeout: ehc.config.ElectionProbeTimeout,
|
||||||
|
|||||||
261
pkg/providers/factory.go
Normal file
261
pkg/providers/factory.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"chorus/pkg/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderFactory creates task providers for different repository types
|
||||||
|
type ProviderFactory struct {
|
||||||
|
supportedProviders map[string]ProviderCreator
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderCreator is a function that creates a provider from config
|
||||||
|
type ProviderCreator func(config *repository.Config) (repository.TaskProvider, error)
|
||||||
|
|
||||||
|
// NewProviderFactory creates a new provider factory with all supported providers
|
||||||
|
func NewProviderFactory() *ProviderFactory {
|
||||||
|
factory := &ProviderFactory{
|
||||||
|
supportedProviders: make(map[string]ProviderCreator),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register all supported providers
|
||||||
|
factory.RegisterProvider("gitea", func(config *repository.Config) (repository.TaskProvider, error) {
|
||||||
|
return NewGiteaProvider(config)
|
||||||
|
})
|
||||||
|
|
||||||
|
factory.RegisterProvider("github", func(config *repository.Config) (repository.TaskProvider, error) {
|
||||||
|
return NewGitHubProvider(config)
|
||||||
|
})
|
||||||
|
|
||||||
|
factory.RegisterProvider("gitlab", func(config *repository.Config) (repository.TaskProvider, error) {
|
||||||
|
return NewGitLabProvider(config)
|
||||||
|
})
|
||||||
|
|
||||||
|
factory.RegisterProvider("mock", func(config *repository.Config) (repository.TaskProvider, error) {
|
||||||
|
return &repository.MockTaskProvider{}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return factory
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProvider registers a new provider creator
|
||||||
|
func (f *ProviderFactory) RegisterProvider(providerType string, creator ProviderCreator) {
|
||||||
|
f.supportedProviders[strings.ToLower(providerType)] = creator
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProvider creates a task provider based on the configuration
|
||||||
|
func (f *ProviderFactory) CreateProvider(ctx interface{}, config *repository.Config) (repository.TaskProvider, error) {
|
||||||
|
if config == nil {
|
||||||
|
return nil, fmt.Errorf("configuration cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
providerType := strings.ToLower(config.Provider)
|
||||||
|
if providerType == "" {
|
||||||
|
// Fall back to Type field if Provider is not set
|
||||||
|
providerType = strings.ToLower(config.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if providerType == "" {
|
||||||
|
return nil, fmt.Errorf("provider type must be specified in config.Provider or config.Type")
|
||||||
|
}
|
||||||
|
|
||||||
|
creator, exists := f.supportedProviders[providerType]
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("unsupported provider type: %s. Supported types: %v",
|
||||||
|
providerType, f.GetSupportedTypes())
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := creator(config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create %s provider: %w", providerType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSupportedTypes returns a list of all supported provider types
|
||||||
|
func (f *ProviderFactory) GetSupportedTypes() []string {
|
||||||
|
types := make([]string, 0, len(f.supportedProviders))
|
||||||
|
for providerType := range f.supportedProviders {
|
||||||
|
types = append(types, providerType)
|
||||||
|
}
|
||||||
|
return types
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportedProviders returns list of supported providers (alias for GetSupportedTypes)
|
||||||
|
func (f *ProviderFactory) SupportedProviders() []string {
|
||||||
|
return f.GetSupportedTypes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConfig validates a provider configuration
|
||||||
|
func (f *ProviderFactory) ValidateConfig(config *repository.Config) error {
|
||||||
|
if config == nil {
|
||||||
|
return fmt.Errorf("configuration cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
providerType := strings.ToLower(config.Provider)
|
||||||
|
if providerType == "" {
|
||||||
|
providerType = strings.ToLower(config.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if providerType == "" {
|
||||||
|
return fmt.Errorf("provider type must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if provider type is supported
|
||||||
|
if _, exists := f.supportedProviders[providerType]; !exists {
|
||||||
|
return fmt.Errorf("unsupported provider type: %s", providerType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider-specific validation
|
||||||
|
switch providerType {
|
||||||
|
case "gitea":
|
||||||
|
return f.validateGiteaConfig(config)
|
||||||
|
case "github":
|
||||||
|
return f.validateGitHubConfig(config)
|
||||||
|
case "gitlab":
|
||||||
|
return f.validateGitLabConfig(config)
|
||||||
|
case "mock":
|
||||||
|
return nil // Mock provider doesn't need validation
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("validation not implemented for provider type: %s", providerType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateGiteaConfig validates Gitea-specific configuration
|
||||||
|
func (f *ProviderFactory) validateGiteaConfig(config *repository.Config) error {
|
||||||
|
if config.BaseURL == "" {
|
||||||
|
return fmt.Errorf("baseURL is required for Gitea provider")
|
||||||
|
}
|
||||||
|
if config.AccessToken == "" {
|
||||||
|
return fmt.Errorf("accessToken is required for Gitea provider")
|
||||||
|
}
|
||||||
|
if config.Owner == "" {
|
||||||
|
return fmt.Errorf("owner is required for Gitea provider")
|
||||||
|
}
|
||||||
|
if config.Repository == "" {
|
||||||
|
return fmt.Errorf("repository is required for Gitea provider")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateGitHubConfig validates GitHub-specific configuration
|
||||||
|
func (f *ProviderFactory) validateGitHubConfig(config *repository.Config) error {
|
||||||
|
if config.AccessToken == "" {
|
||||||
|
return fmt.Errorf("accessToken is required for GitHub provider")
|
||||||
|
}
|
||||||
|
if config.Owner == "" {
|
||||||
|
return fmt.Errorf("owner is required for GitHub provider")
|
||||||
|
}
|
||||||
|
if config.Repository == "" {
|
||||||
|
return fmt.Errorf("repository is required for GitHub provider")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateGitLabConfig validates GitLab-specific configuration
|
||||||
|
func (f *ProviderFactory) validateGitLabConfig(config *repository.Config) error {
|
||||||
|
if config.AccessToken == "" {
|
||||||
|
return fmt.Errorf("accessToken is required for GitLab provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitLab requires either owner/repository or project_id in settings
|
||||||
|
if config.Owner != "" && config.Repository != "" {
|
||||||
|
return nil // owner/repo provided
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Settings != nil {
|
||||||
|
if projectID, ok := config.Settings["project_id"].(string); ok && projectID != "" {
|
||||||
|
return nil // project_id provided
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("either owner/repository or project_id in settings is required for GitLab provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderInfo returns information about a specific provider
|
||||||
|
func (f *ProviderFactory) GetProviderInfo(providerType string) (*ProviderInfo, error) {
|
||||||
|
providerType = strings.ToLower(providerType)
|
||||||
|
|
||||||
|
if _, exists := f.supportedProviders[providerType]; !exists {
|
||||||
|
return nil, fmt.Errorf("unsupported provider type: %s", providerType)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch providerType {
|
||||||
|
case "gitea":
|
||||||
|
return &ProviderInfo{
|
||||||
|
Name: "Gitea",
|
||||||
|
Type: "gitea",
|
||||||
|
Description: "Gitea self-hosted Git service provider",
|
||||||
|
RequiredFields: []string{"baseURL", "accessToken", "owner", "repository"},
|
||||||
|
OptionalFields: []string{"taskLabel", "inProgressLabel", "completedLabel", "baseBranch", "branchPrefix"},
|
||||||
|
SupportedFeatures: []string{"issues", "labels", "comments", "assignments"},
|
||||||
|
APIDocumentation: "https://docs.gitea.io/en-us/api-usage/",
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case "github":
|
||||||
|
return &ProviderInfo{
|
||||||
|
Name: "GitHub",
|
||||||
|
Type: "github",
|
||||||
|
Description: "GitHub cloud and enterprise Git service provider",
|
||||||
|
RequiredFields: []string{"accessToken", "owner", "repository"},
|
||||||
|
OptionalFields: []string{"taskLabel", "inProgressLabel", "completedLabel", "baseBranch", "branchPrefix"},
|
||||||
|
SupportedFeatures: []string{"issues", "labels", "comments", "assignments", "projects"},
|
||||||
|
APIDocumentation: "https://docs.github.com/en/rest",
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case "gitlab":
|
||||||
|
return &ProviderInfo{
|
||||||
|
Name: "GitLab",
|
||||||
|
Type: "gitlab",
|
||||||
|
Description: "GitLab cloud and self-hosted Git service provider",
|
||||||
|
RequiredFields: []string{"accessToken", "owner/repository OR project_id"},
|
||||||
|
OptionalFields: []string{"baseURL", "taskLabel", "inProgressLabel", "completedLabel", "baseBranch", "branchPrefix"},
|
||||||
|
SupportedFeatures: []string{"issues", "labels", "notes", "assignments", "time_tracking", "milestones"},
|
||||||
|
APIDocumentation: "https://docs.gitlab.com/ee/api/",
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case "mock":
|
||||||
|
return &ProviderInfo{
|
||||||
|
Name: "Mock Provider",
|
||||||
|
Type: "mock",
|
||||||
|
Description: "Mock provider for testing and development",
|
||||||
|
RequiredFields: []string{},
|
||||||
|
OptionalFields: []string{},
|
||||||
|
SupportedFeatures: []string{"basic_operations"},
|
||||||
|
APIDocumentation: "Built-in mock for testing purposes",
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("provider info not available for: %s", providerType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderInfo contains metadata about a provider
|
||||||
|
type ProviderInfo struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
RequiredFields []string `json:"required_fields"`
|
||||||
|
OptionalFields []string `json:"optional_fields"`
|
||||||
|
SupportedFeatures []string `json:"supported_features"`
|
||||||
|
APIDocumentation string `json:"api_documentation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProviders returns detailed information about all supported providers
|
||||||
|
func (f *ProviderFactory) ListProviders() ([]*ProviderInfo, error) {
|
||||||
|
providers := make([]*ProviderInfo, 0, len(f.supportedProviders))
|
||||||
|
|
||||||
|
for providerType := range f.supportedProviders {
|
||||||
|
info, err := f.GetProviderInfo(providerType)
|
||||||
|
if err != nil {
|
||||||
|
continue // Skip providers without info
|
||||||
|
}
|
||||||
|
providers = append(providers, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
return providers, nil
|
||||||
|
}
|
||||||
617
pkg/providers/gitea.go
Normal file
617
pkg/providers/gitea.go
Normal file
@@ -0,0 +1,617 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"chorus/pkg/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GiteaProvider implements TaskProvider for Gitea API
|
||||||
|
type GiteaProvider struct {
|
||||||
|
config *repository.Config
|
||||||
|
httpClient *http.Client
|
||||||
|
baseURL string
|
||||||
|
token string
|
||||||
|
owner string
|
||||||
|
repo string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGiteaProvider creates a new Gitea provider
|
||||||
|
func NewGiteaProvider(config *repository.Config) (*GiteaProvider, error) {
|
||||||
|
if config.BaseURL == "" {
|
||||||
|
return nil, fmt.Errorf("base URL is required for Gitea provider")
|
||||||
|
}
|
||||||
|
if config.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("access token is required for Gitea provider")
|
||||||
|
}
|
||||||
|
if config.Owner == "" {
|
||||||
|
return nil, fmt.Errorf("owner is required for Gitea provider")
|
||||||
|
}
|
||||||
|
if config.Repository == "" {
|
||||||
|
return nil, fmt.Errorf("repository name is required for Gitea provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure base URL has proper format
|
||||||
|
baseURL := strings.TrimSuffix(config.BaseURL, "/")
|
||||||
|
if !strings.HasPrefix(baseURL, "http") {
|
||||||
|
baseURL = "https://" + baseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GiteaProvider{
|
||||||
|
config: config,
|
||||||
|
baseURL: baseURL,
|
||||||
|
token: config.AccessToken,
|
||||||
|
owner: config.Owner,
|
||||||
|
repo: config.Repository,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GiteaIssue represents a Gitea issue
|
||||||
|
type GiteaIssue struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Number int `json:"number"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Labels []GiteaLabel `json:"labels"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
Repository *GiteaRepository `json:"repository"`
|
||||||
|
Assignee *GiteaUser `json:"assignee"`
|
||||||
|
Assignees []GiteaUser `json:"assignees"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GiteaLabel represents a Gitea label
|
||||||
|
type GiteaLabel struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Color string `json:"color"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GiteaRepository represents a Gitea repository
|
||||||
|
type GiteaRepository struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
FullName string `json:"full_name"`
|
||||||
|
Owner *GiteaUser `json:"owner"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GiteaUser represents a Gitea user
|
||||||
|
type GiteaUser struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
FullName string `json:"full_name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GiteaComment represents a Gitea issue comment
|
||||||
|
type GiteaComment struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
User *GiteaUser `json:"user"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeRequest makes an HTTP request to the Gitea API
|
||||||
|
func (g *GiteaProvider) makeRequest(method, endpoint string, body interface{}) (*http.Response, error) {
|
||||||
|
var reqBody io.Reader
|
||||||
|
|
||||||
|
if body != nil {
|
||||||
|
jsonData, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||||
|
}
|
||||||
|
reqBody = bytes.NewBuffer(jsonData)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/api/v1%s", g.baseURL, endpoint)
|
||||||
|
req, err := http.NewRequest(method, url, reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "token "+g.token)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := g.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTasks retrieves tasks (issues) from the Gitea repository
|
||||||
|
func (g *GiteaProvider) GetTasks(projectID int) ([]*repository.Task, error) {
|
||||||
|
// Build query parameters
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("state", "open")
|
||||||
|
params.Add("type", "issues")
|
||||||
|
params.Add("sort", "created")
|
||||||
|
params.Add("order", "desc")
|
||||||
|
|
||||||
|
// Add task label filter if specified
|
||||||
|
if g.config.TaskLabel != "" {
|
||||||
|
params.Add("labels", g.config.TaskLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues?%s", g.owner, g.repo, params.Encode())
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get issues: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var issues []GiteaIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert Gitea issues to repository tasks
|
||||||
|
tasks := make([]*repository.Task, 0, len(issues))
|
||||||
|
for _, issue := range issues {
|
||||||
|
task := g.issueToTask(&issue)
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaimTask claims a task by assigning it to the agent and adding in-progress label
|
||||||
|
func (g *GiteaProvider) ClaimTask(taskNumber int, agentID string) (bool, error) {
|
||||||
|
// First, get the current issue to check its state
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, taskNumber)
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to get issue: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return false, fmt.Errorf("issue not found or not accessible")
|
||||||
|
}
|
||||||
|
|
||||||
|
var issue GiteaIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||||
|
return false, fmt.Errorf("failed to decode issue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if issue is already assigned
|
||||||
|
if issue.Assignee != nil {
|
||||||
|
return false, fmt.Errorf("issue is already assigned to %s", issue.Assignee.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add in-progress label if specified
|
||||||
|
if g.config.InProgressLabel != "" {
|
||||||
|
err := g.addLabelToIssue(taskNumber, g.config.InProgressLabel)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to add in-progress label: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a comment indicating the task has been claimed
|
||||||
|
comment := fmt.Sprintf("🤖 Task claimed by CHORUS agent `%s`\n\nThis task is now being processed automatically.", agentID)
|
||||||
|
err = g.addCommentToIssue(taskNumber, comment)
|
||||||
|
if err != nil {
|
||||||
|
// Don't fail the claim if comment fails
|
||||||
|
fmt.Printf("Warning: failed to add claim comment: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTaskStatus updates the status of a task
|
||||||
|
func (g *GiteaProvider) UpdateTaskStatus(task *repository.Task, status string, comment string) error {
|
||||||
|
// Add a comment with the status update
|
||||||
|
statusComment := fmt.Sprintf("**Status Update:** %s\n\n%s", status, comment)
|
||||||
|
|
||||||
|
err := g.addCommentToIssue(task.Number, statusComment)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add status comment: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteTask completes a task by updating status and adding completion comment
|
||||||
|
func (g *GiteaProvider) CompleteTask(task *repository.Task, result *repository.TaskResult) error {
|
||||||
|
// Create completion comment with results
|
||||||
|
var commentBuffer strings.Builder
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("✅ **Task Completed Successfully**\n\n"))
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("**Result:** %s\n\n", result.Message))
|
||||||
|
|
||||||
|
// Add metadata if available
|
||||||
|
if result.Metadata != nil {
|
||||||
|
commentBuffer.WriteString("**Execution Details:**\n")
|
||||||
|
for key, value := range result.Metadata {
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- **%s:** %v\n", key, value))
|
||||||
|
}
|
||||||
|
commentBuffer.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
commentBuffer.WriteString("🤖 Completed by CHORUS autonomous agent")
|
||||||
|
|
||||||
|
// Add completion comment
|
||||||
|
err := g.addCommentToIssue(task.Number, commentBuffer.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add completion comment: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove in-progress label and add completed label
|
||||||
|
if g.config.InProgressLabel != "" {
|
||||||
|
err := g.removeLabelFromIssue(task.Number, g.config.InProgressLabel)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning: failed to remove in-progress label: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.config.CompletedLabel != "" {
|
||||||
|
err := g.addLabelToIssue(task.Number, g.config.CompletedLabel)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning: failed to add completed label: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the issue if the task was successful
|
||||||
|
if result.Success {
|
||||||
|
err := g.closeIssue(task.Number)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close issue: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTaskDetails retrieves detailed information about a specific task
|
||||||
|
func (g *GiteaProvider) GetTaskDetails(projectID int, taskNumber int) (*repository.Task, error) {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, taskNumber)
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get issue: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("issue not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var issue GiteaIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return g.issueToTask(&issue), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAvailableTasks lists all available (unassigned) tasks
|
||||||
|
func (g *GiteaProvider) ListAvailableTasks(projectID int) ([]*repository.Task, error) {
|
||||||
|
// Get all open issues without assignees
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("state", "open")
|
||||||
|
params.Add("type", "issues")
|
||||||
|
params.Add("assigned", "false") // Only unassigned issues
|
||||||
|
|
||||||
|
if g.config.TaskLabel != "" {
|
||||||
|
params.Add("labels", g.config.TaskLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues?%s", g.owner, g.repo, params.Encode())
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get available issues: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var issues []GiteaIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to tasks and filter out assigned ones
|
||||||
|
tasks := make([]*repository.Task, 0, len(issues))
|
||||||
|
for _, issue := range issues {
|
||||||
|
// Skip assigned issues
|
||||||
|
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
task := g.issueToTask(&issue)
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods
|
||||||
|
|
||||||
|
// issueToTask converts a Gitea issue to a repository Task
|
||||||
|
func (g *GiteaProvider) issueToTask(issue *GiteaIssue) *repository.Task {
|
||||||
|
// Extract labels
|
||||||
|
labels := make([]string, len(issue.Labels))
|
||||||
|
for i, label := range issue.Labels {
|
||||||
|
labels[i] = label.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate priority and complexity based on labels and content
|
||||||
|
priority := g.calculatePriority(labels, issue.Title, issue.Body)
|
||||||
|
complexity := g.calculateComplexity(labels, issue.Title, issue.Body)
|
||||||
|
|
||||||
|
// Determine required role and expertise from labels
|
||||||
|
requiredRole := g.determineRequiredRole(labels)
|
||||||
|
requiredExpertise := g.determineRequiredExpertise(labels)
|
||||||
|
|
||||||
|
return &repository.Task{
|
||||||
|
Number: issue.Number,
|
||||||
|
Title: issue.Title,
|
||||||
|
Body: issue.Body,
|
||||||
|
Repository: fmt.Sprintf("%s/%s", g.owner, g.repo),
|
||||||
|
Labels: labels,
|
||||||
|
Priority: priority,
|
||||||
|
Complexity: complexity,
|
||||||
|
Status: issue.State,
|
||||||
|
CreatedAt: issue.CreatedAt,
|
||||||
|
UpdatedAt: issue.UpdatedAt,
|
||||||
|
RequiredRole: requiredRole,
|
||||||
|
RequiredExpertise: requiredExpertise,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"gitea_id": issue.ID,
|
||||||
|
"provider": "gitea",
|
||||||
|
"repository": issue.Repository,
|
||||||
|
"assignee": issue.Assignee,
|
||||||
|
"assignees": issue.Assignees,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculatePriority determines task priority from labels and content
|
||||||
|
func (g *GiteaProvider) calculatePriority(labels []string, title, body string) int {
|
||||||
|
priority := 5 // default
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
switch strings.ToLower(label) {
|
||||||
|
case "priority:critical", "critical", "urgent":
|
||||||
|
priority = 10
|
||||||
|
case "priority:high", "high":
|
||||||
|
priority = 8
|
||||||
|
case "priority:medium", "medium":
|
||||||
|
priority = 5
|
||||||
|
case "priority:low", "low":
|
||||||
|
priority = 2
|
||||||
|
case "bug", "security", "hotfix":
|
||||||
|
priority = max(priority, 7)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boost priority for urgent keywords in title
|
||||||
|
titleLower := strings.ToLower(title)
|
||||||
|
if strings.Contains(titleLower, "urgent") || strings.Contains(titleLower, "critical") ||
|
||||||
|
strings.Contains(titleLower, "hotfix") || strings.Contains(titleLower, "security") {
|
||||||
|
priority = max(priority, 8)
|
||||||
|
}
|
||||||
|
|
||||||
|
return priority
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateComplexity estimates task complexity from labels and content
|
||||||
|
func (g *GiteaProvider) calculateComplexity(labels []string, title, body string) int {
|
||||||
|
complexity := 3 // default
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
switch strings.ToLower(label) {
|
||||||
|
case "complexity:high", "epic", "major":
|
||||||
|
complexity = 8
|
||||||
|
case "complexity:medium":
|
||||||
|
complexity = 5
|
||||||
|
case "complexity:low", "simple", "trivial":
|
||||||
|
complexity = 2
|
||||||
|
case "refactor", "architecture":
|
||||||
|
complexity = max(complexity, 7)
|
||||||
|
case "bug", "hotfix":
|
||||||
|
complexity = max(complexity, 4)
|
||||||
|
case "enhancement", "feature":
|
||||||
|
complexity = max(complexity, 5)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Estimate complexity from body length
|
||||||
|
bodyLength := len(strings.Fields(body))
|
||||||
|
if bodyLength > 200 {
|
||||||
|
complexity = max(complexity, 6)
|
||||||
|
} else if bodyLength > 50 {
|
||||||
|
complexity = max(complexity, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
return complexity
|
||||||
|
}
|
||||||
|
|
||||||
|
// determineRequiredRole determines what agent role is needed for this task
|
||||||
|
func (g *GiteaProvider) determineRequiredRole(labels []string) string {
|
||||||
|
for _, label := range labels {
|
||||||
|
switch strings.ToLower(label) {
|
||||||
|
case "frontend", "ui", "ux", "css", "html", "javascript", "react", "vue":
|
||||||
|
return "frontend-developer"
|
||||||
|
case "backend", "api", "server", "database", "sql":
|
||||||
|
return "backend-developer"
|
||||||
|
case "devops", "infrastructure", "deployment", "docker", "kubernetes":
|
||||||
|
return "devops-engineer"
|
||||||
|
case "security", "authentication", "authorization":
|
||||||
|
return "security-engineer"
|
||||||
|
case "testing", "qa", "quality":
|
||||||
|
return "tester"
|
||||||
|
case "documentation", "docs":
|
||||||
|
return "technical-writer"
|
||||||
|
case "design", "mockup", "wireframe":
|
||||||
|
return "designer"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "developer" // default role
|
||||||
|
}
|
||||||
|
|
||||||
|
// determineRequiredExpertise determines what expertise is needed
|
||||||
|
func (g *GiteaProvider) determineRequiredExpertise(labels []string) []string {
|
||||||
|
expertise := make([]string, 0)
|
||||||
|
expertiseMap := make(map[string]bool) // prevent duplicates
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
labelLower := strings.ToLower(label)
|
||||||
|
|
||||||
|
// Programming languages
|
||||||
|
languages := []string{"go", "python", "javascript", "typescript", "java", "rust", "c++", "php"}
|
||||||
|
for _, lang := range languages {
|
||||||
|
if strings.Contains(labelLower, lang) {
|
||||||
|
if !expertiseMap[lang] {
|
||||||
|
expertise = append(expertise, lang)
|
||||||
|
expertiseMap[lang] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Technologies and frameworks
|
||||||
|
technologies := []string{"docker", "kubernetes", "react", "vue", "angular", "nodejs", "django", "flask", "spring"}
|
||||||
|
for _, tech := range technologies {
|
||||||
|
if strings.Contains(labelLower, tech) {
|
||||||
|
if !expertiseMap[tech] {
|
||||||
|
expertise = append(expertise, tech)
|
||||||
|
expertiseMap[tech] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Domain areas
|
||||||
|
domains := []string{"frontend", "backend", "database", "security", "testing", "devops", "api"}
|
||||||
|
for _, domain := range domains {
|
||||||
|
if strings.Contains(labelLower, domain) {
|
||||||
|
if !expertiseMap[domain] {
|
||||||
|
expertise = append(expertise, domain)
|
||||||
|
expertiseMap[domain] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default expertise if none detected
|
||||||
|
if len(expertise) == 0 {
|
||||||
|
expertise = []string{"development", "programming"}
|
||||||
|
}
|
||||||
|
|
||||||
|
return expertise
|
||||||
|
}
|
||||||
|
|
||||||
|
// addLabelToIssue adds a label to an issue
|
||||||
|
func (g *GiteaProvider) addLabelToIssue(issueNumber int, labelName string) error {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/labels", g.owner, g.repo, issueNumber)
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"labels": []string{labelName},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("POST", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to add label (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeLabelFromIssue removes a label from an issue
|
||||||
|
func (g *GiteaProvider) removeLabelFromIssue(issueNumber int, labelName string) error {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/labels/%s", g.owner, g.repo, issueNumber, url.QueryEscape(labelName))
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("DELETE", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to remove label (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addCommentToIssue adds a comment to an issue
|
||||||
|
func (g *GiteaProvider) addCommentToIssue(issueNumber int, comment string) error {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/comments", g.owner, g.repo, issueNumber)
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"body": comment,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("POST", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to add comment (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeIssue closes an issue
|
||||||
|
func (g *GiteaProvider) closeIssue(issueNumber int) error {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, issueNumber)
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"state": "closed",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("PATCH", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to close issue (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// max returns the maximum of two integers
|
||||||
|
func max(a, b int) int {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
732
pkg/providers/github.go
Normal file
732
pkg/providers/github.go
Normal file
@@ -0,0 +1,732 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"chorus/pkg/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GitHubProvider implements TaskProvider for GitHub API
|
||||||
|
type GitHubProvider struct {
|
||||||
|
config *repository.Config
|
||||||
|
httpClient *http.Client
|
||||||
|
token string
|
||||||
|
owner string
|
||||||
|
repo string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGitHubProvider creates a new GitHub provider
|
||||||
|
func NewGitHubProvider(config *repository.Config) (*GitHubProvider, error) {
|
||||||
|
if config.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("access token is required for GitHub provider")
|
||||||
|
}
|
||||||
|
if config.Owner == "" {
|
||||||
|
return nil, fmt.Errorf("owner is required for GitHub provider")
|
||||||
|
}
|
||||||
|
if config.Repository == "" {
|
||||||
|
return nil, fmt.Errorf("repository name is required for GitHub provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GitHubProvider{
|
||||||
|
config: config,
|
||||||
|
token: config.AccessToken,
|
||||||
|
owner: config.Owner,
|
||||||
|
repo: config.Repository,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHubIssue represents a GitHub issue
|
||||||
|
type GitHubIssue struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Number int `json:"number"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Labels []GitHubLabel `json:"labels"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
Repository *GitHubRepository `json:"repository,omitempty"`
|
||||||
|
Assignee *GitHubUser `json:"assignee"`
|
||||||
|
Assignees []GitHubUser `json:"assignees"`
|
||||||
|
User *GitHubUser `json:"user"`
|
||||||
|
PullRequest *GitHubPullRequestRef `json:"pull_request,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHubLabel represents a GitHub label
|
||||||
|
type GitHubLabel struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Color string `json:"color"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHubRepository represents a GitHub repository
|
||||||
|
type GitHubRepository struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
FullName string `json:"full_name"`
|
||||||
|
Owner *GitHubUser `json:"owner"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHubUser represents a GitHub user
|
||||||
|
type GitHubUser struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Login string `json:"login"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
AvatarURL string `json:"avatar_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHubPullRequestRef indicates if issue is a PR
|
||||||
|
type GitHubPullRequestRef struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHubComment represents a GitHub issue comment
|
||||||
|
type GitHubComment struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
User *GitHubUser `json:"user"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeRequest makes an HTTP request to the GitHub API
|
||||||
|
func (g *GitHubProvider) makeRequest(method, endpoint string, body interface{}) (*http.Response, error) {
|
||||||
|
var reqBody io.Reader
|
||||||
|
|
||||||
|
if body != nil {
|
||||||
|
jsonData, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||||
|
}
|
||||||
|
reqBody = bytes.NewBuffer(jsonData)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("https://api.github.com%s", endpoint)
|
||||||
|
req, err := http.NewRequest(method, url, reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "token "+g.token)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||||
|
req.Header.Set("User-Agent", "CHORUS-Agent/1.0")
|
||||||
|
|
||||||
|
resp, err := g.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTasks retrieves tasks (issues) from the GitHub repository
|
||||||
|
func (g *GitHubProvider) GetTasks(projectID int) ([]*repository.Task, error) {
|
||||||
|
// Build query parameters
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("state", "open")
|
||||||
|
params.Add("sort", "created")
|
||||||
|
params.Add("direction", "desc")
|
||||||
|
|
||||||
|
// Add task label filter if specified
|
||||||
|
if g.config.TaskLabel != "" {
|
||||||
|
params.Add("labels", g.config.TaskLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues?%s", g.owner, g.repo, params.Encode())
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get issues: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var issues []GitHubIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out pull requests (GitHub API includes PRs in issues endpoint)
|
||||||
|
tasks := make([]*repository.Task, 0, len(issues))
|
||||||
|
for _, issue := range issues {
|
||||||
|
// Skip pull requests
|
||||||
|
if issue.PullRequest != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
task := g.issueToTask(&issue)
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaimTask claims a task by assigning it to the agent and adding in-progress label
|
||||||
|
func (g *GitHubProvider) ClaimTask(taskNumber int, agentID string) (bool, error) {
|
||||||
|
// First, get the current issue to check its state
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, taskNumber)
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to get issue: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return false, fmt.Errorf("issue not found or not accessible")
|
||||||
|
}
|
||||||
|
|
||||||
|
var issue GitHubIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||||
|
return false, fmt.Errorf("failed to decode issue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if issue is already assigned
|
||||||
|
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||||
|
assigneeName := ""
|
||||||
|
if issue.Assignee != nil {
|
||||||
|
assigneeName = issue.Assignee.Login
|
||||||
|
} else if len(issue.Assignees) > 0 {
|
||||||
|
assigneeName = issue.Assignees[0].Login
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("issue is already assigned to %s", assigneeName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add in-progress label if specified
|
||||||
|
if g.config.InProgressLabel != "" {
|
||||||
|
err := g.addLabelToIssue(taskNumber, g.config.InProgressLabel)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to add in-progress label: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a comment indicating the task has been claimed
|
||||||
|
comment := fmt.Sprintf("🤖 **Task Claimed by CHORUS Agent**\n\nAgent ID: `%s`\nStatus: Processing\n\nThis task is now being handled automatically by the CHORUS autonomous agent system.", agentID)
|
||||||
|
err = g.addCommentToIssue(taskNumber, comment)
|
||||||
|
if err != nil {
|
||||||
|
// Don't fail the claim if comment fails
|
||||||
|
fmt.Printf("Warning: failed to add claim comment: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTaskStatus updates the status of a task
|
||||||
|
func (g *GitHubProvider) UpdateTaskStatus(task *repository.Task, status string, comment string) error {
|
||||||
|
// Add a comment with the status update
|
||||||
|
statusComment := fmt.Sprintf("📊 **Status Update: %s**\n\n%s\n\n---\n*Updated by CHORUS Agent*", status, comment)
|
||||||
|
|
||||||
|
err := g.addCommentToIssue(task.Number, statusComment)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add status comment: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteTask completes a task by updating status and adding completion comment
|
||||||
|
func (g *GitHubProvider) CompleteTask(task *repository.Task, result *repository.TaskResult) error {
|
||||||
|
// Create completion comment with results
|
||||||
|
var commentBuffer strings.Builder
|
||||||
|
commentBuffer.WriteString("✅ **Task Completed Successfully**\n\n")
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("**Result:** %s\n\n", result.Message))
|
||||||
|
|
||||||
|
// Add metadata if available
|
||||||
|
if result.Metadata != nil {
|
||||||
|
commentBuffer.WriteString("## Execution Details\n\n")
|
||||||
|
for key, value := range result.Metadata {
|
||||||
|
// Format the metadata nicely
|
||||||
|
switch key {
|
||||||
|
case "duration":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- ⏱️ **Duration:** %v\n", value))
|
||||||
|
case "execution_type":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 🔧 **Execution Type:** %v\n", value))
|
||||||
|
case "commands_executed":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 🖥️ **Commands Executed:** %v\n", value))
|
||||||
|
case "files_generated":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 📄 **Files Generated:** %v\n", value))
|
||||||
|
case "ai_provider":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 🤖 **AI Provider:** %v\n", value))
|
||||||
|
case "ai_model":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 🧠 **AI Model:** %v\n", value))
|
||||||
|
default:
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- **%s:** %v\n", key, value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
commentBuffer.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
commentBuffer.WriteString("---\n🤖 *Completed by CHORUS Autonomous Agent System*")
|
||||||
|
|
||||||
|
// Add completion comment
|
||||||
|
err := g.addCommentToIssue(task.Number, commentBuffer.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add completion comment: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove in-progress label and add completed label
|
||||||
|
if g.config.InProgressLabel != "" {
|
||||||
|
err := g.removeLabelFromIssue(task.Number, g.config.InProgressLabel)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning: failed to remove in-progress label: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.config.CompletedLabel != "" {
|
||||||
|
err := g.addLabelToIssue(task.Number, g.config.CompletedLabel)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning: failed to add completed label: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the issue if the task was successful
|
||||||
|
if result.Success {
|
||||||
|
err := g.closeIssue(task.Number)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close issue: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTaskDetails retrieves detailed information about a specific task
|
||||||
|
func (g *GitHubProvider) GetTaskDetails(projectID int, taskNumber int) (*repository.Task, error) {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, taskNumber)
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get issue: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("issue not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var issue GitHubIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip pull requests
|
||||||
|
if issue.PullRequest != nil {
|
||||||
|
return nil, fmt.Errorf("pull requests are not supported as tasks")
|
||||||
|
}
|
||||||
|
|
||||||
|
return g.issueToTask(&issue), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAvailableTasks lists all available (unassigned) tasks
|
||||||
|
func (g *GitHubProvider) ListAvailableTasks(projectID int) ([]*repository.Task, error) {
|
||||||
|
// GitHub doesn't have a direct "unassigned" filter, so we get open issues and filter
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("state", "open")
|
||||||
|
params.Add("sort", "created")
|
||||||
|
params.Add("direction", "desc")
|
||||||
|
|
||||||
|
if g.config.TaskLabel != "" {
|
||||||
|
params.Add("labels", g.config.TaskLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues?%s", g.owner, g.repo, params.Encode())
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get available issues: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var issues []GitHubIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out assigned issues and PRs
|
||||||
|
tasks := make([]*repository.Task, 0, len(issues))
|
||||||
|
for _, issue := range issues {
|
||||||
|
// Skip pull requests
|
||||||
|
if issue.PullRequest != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip assigned issues
|
||||||
|
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
task := g.issueToTask(&issue)
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods
|
||||||
|
|
||||||
|
// issueToTask converts a GitHub issue to a repository Task
|
||||||
|
func (g *GitHubProvider) issueToTask(issue *GitHubIssue) *repository.Task {
|
||||||
|
// Extract labels
|
||||||
|
labels := make([]string, len(issue.Labels))
|
||||||
|
for i, label := range issue.Labels {
|
||||||
|
labels[i] = label.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate priority and complexity based on labels and content
|
||||||
|
priority := g.calculatePriority(labels, issue.Title, issue.Body)
|
||||||
|
complexity := g.calculateComplexity(labels, issue.Title, issue.Body)
|
||||||
|
|
||||||
|
// Determine required role and expertise from labels
|
||||||
|
requiredRole := g.determineRequiredRole(labels)
|
||||||
|
requiredExpertise := g.determineRequiredExpertise(labels)
|
||||||
|
|
||||||
|
return &repository.Task{
|
||||||
|
Number: issue.Number,
|
||||||
|
Title: issue.Title,
|
||||||
|
Body: issue.Body,
|
||||||
|
Repository: fmt.Sprintf("%s/%s", g.owner, g.repo),
|
||||||
|
Labels: labels,
|
||||||
|
Priority: priority,
|
||||||
|
Complexity: complexity,
|
||||||
|
Status: issue.State,
|
||||||
|
CreatedAt: issue.CreatedAt,
|
||||||
|
UpdatedAt: issue.UpdatedAt,
|
||||||
|
RequiredRole: requiredRole,
|
||||||
|
RequiredExpertise: requiredExpertise,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"github_id": issue.ID,
|
||||||
|
"provider": "github",
|
||||||
|
"repository": issue.Repository,
|
||||||
|
"assignee": issue.Assignee,
|
||||||
|
"assignees": issue.Assignees,
|
||||||
|
"user": issue.User,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculatePriority determines task priority from labels and content
|
||||||
|
func (g *GitHubProvider) calculatePriority(labels []string, title, body string) int {
|
||||||
|
priority := 5 // default
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
labelLower := strings.ToLower(label)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "critical"):
|
||||||
|
priority = 10
|
||||||
|
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "high"):
|
||||||
|
priority = 8
|
||||||
|
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "medium"):
|
||||||
|
priority = 5
|
||||||
|
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "low"):
|
||||||
|
priority = 2
|
||||||
|
case labelLower == "critical" || labelLower == "urgent":
|
||||||
|
priority = 10
|
||||||
|
case labelLower == "high":
|
||||||
|
priority = 8
|
||||||
|
case labelLower == "bug" || labelLower == "security" || labelLower == "hotfix":
|
||||||
|
priority = max(priority, 7)
|
||||||
|
case labelLower == "enhancement" || labelLower == "feature":
|
||||||
|
priority = max(priority, 5)
|
||||||
|
case labelLower == "good first issue":
|
||||||
|
priority = max(priority, 3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boost priority for urgent keywords in title
|
||||||
|
titleLower := strings.ToLower(title)
|
||||||
|
urgentKeywords := []string{"urgent", "critical", "hotfix", "security", "broken", "crash"}
|
||||||
|
for _, keyword := range urgentKeywords {
|
||||||
|
if strings.Contains(titleLower, keyword) {
|
||||||
|
priority = max(priority, 8)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return priority
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateComplexity estimates task complexity from labels and content
|
||||||
|
func (g *GitHubProvider) calculateComplexity(labels []string, title, body string) int {
|
||||||
|
complexity := 3 // default
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
labelLower := strings.ToLower(label)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "high"):
|
||||||
|
complexity = 8
|
||||||
|
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "medium"):
|
||||||
|
complexity = 5
|
||||||
|
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "low"):
|
||||||
|
complexity = 2
|
||||||
|
case labelLower == "epic" || labelLower == "major":
|
||||||
|
complexity = 8
|
||||||
|
case labelLower == "refactor" || labelLower == "architecture":
|
||||||
|
complexity = max(complexity, 7)
|
||||||
|
case labelLower == "bug" || labelLower == "hotfix":
|
||||||
|
complexity = max(complexity, 4)
|
||||||
|
case labelLower == "enhancement" || labelLower == "feature":
|
||||||
|
complexity = max(complexity, 5)
|
||||||
|
case labelLower == "good first issue" || labelLower == "beginner":
|
||||||
|
complexity = 2
|
||||||
|
case labelLower == "documentation" || labelLower == "docs":
|
||||||
|
complexity = max(complexity, 3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Estimate complexity from body length and content
|
||||||
|
bodyLength := len(strings.Fields(body))
|
||||||
|
if bodyLength > 500 {
|
||||||
|
complexity = max(complexity, 7)
|
||||||
|
} else if bodyLength > 200 {
|
||||||
|
complexity = max(complexity, 5)
|
||||||
|
} else if bodyLength > 50 {
|
||||||
|
complexity = max(complexity, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for complexity indicators in content
|
||||||
|
bodyLower := strings.ToLower(body)
|
||||||
|
complexityIndicators := []string{"refactor", "architecture", "breaking change", "migration", "redesign"}
|
||||||
|
for _, indicator := range complexityIndicators {
|
||||||
|
if strings.Contains(bodyLower, indicator) {
|
||||||
|
complexity = max(complexity, 7)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return complexity
|
||||||
|
}
|
||||||
|
|
||||||
|
// determineRequiredRole determines what agent role is needed for this task
|
||||||
|
func (g *GitHubProvider) determineRequiredRole(labels []string) string {
|
||||||
|
roleKeywords := map[string]string{
|
||||||
|
// Frontend
|
||||||
|
"frontend": "frontend-developer",
|
||||||
|
"ui": "frontend-developer",
|
||||||
|
"ux": "ui-ux-designer",
|
||||||
|
"css": "frontend-developer",
|
||||||
|
"html": "frontend-developer",
|
||||||
|
"javascript": "frontend-developer",
|
||||||
|
"react": "frontend-developer",
|
||||||
|
"vue": "frontend-developer",
|
||||||
|
"angular": "frontend-developer",
|
||||||
|
|
||||||
|
// Backend
|
||||||
|
"backend": "backend-developer",
|
||||||
|
"api": "backend-developer",
|
||||||
|
"server": "backend-developer",
|
||||||
|
"database": "backend-developer",
|
||||||
|
"sql": "backend-developer",
|
||||||
|
|
||||||
|
// DevOps
|
||||||
|
"devops": "devops-engineer",
|
||||||
|
"infrastructure": "devops-engineer",
|
||||||
|
"deployment": "devops-engineer",
|
||||||
|
"docker": "devops-engineer",
|
||||||
|
"kubernetes": "devops-engineer",
|
||||||
|
"ci/cd": "devops-engineer",
|
||||||
|
|
||||||
|
// Security
|
||||||
|
"security": "security-engineer",
|
||||||
|
"authentication": "security-engineer",
|
||||||
|
"authorization": "security-engineer",
|
||||||
|
"vulnerability": "security-engineer",
|
||||||
|
|
||||||
|
// Testing
|
||||||
|
"testing": "tester",
|
||||||
|
"qa": "tester",
|
||||||
|
"test": "tester",
|
||||||
|
|
||||||
|
// Documentation
|
||||||
|
"documentation": "technical-writer",
|
||||||
|
"docs": "technical-writer",
|
||||||
|
|
||||||
|
// Design
|
||||||
|
"design": "ui-ux-designer",
|
||||||
|
"mockup": "ui-ux-designer",
|
||||||
|
"wireframe": "ui-ux-designer",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
labelLower := strings.ToLower(label)
|
||||||
|
for keyword, role := range roleKeywords {
|
||||||
|
if strings.Contains(labelLower, keyword) {
|
||||||
|
return role
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "developer" // default role
|
||||||
|
}
|
||||||
|
|
||||||
|
// determineRequiredExpertise determines what expertise is needed
|
||||||
|
func (g *GitHubProvider) determineRequiredExpertise(labels []string) []string {
|
||||||
|
expertise := make([]string, 0)
|
||||||
|
expertiseMap := make(map[string]bool) // prevent duplicates
|
||||||
|
|
||||||
|
expertiseKeywords := map[string][]string{
|
||||||
|
// Programming languages
|
||||||
|
"go": {"go", "golang"},
|
||||||
|
"python": {"python"},
|
||||||
|
"javascript": {"javascript", "js"},
|
||||||
|
"typescript": {"typescript", "ts"},
|
||||||
|
"java": {"java"},
|
||||||
|
"rust": {"rust"},
|
||||||
|
"c++": {"c++", "cpp"},
|
||||||
|
"c#": {"c#", "csharp"},
|
||||||
|
"php": {"php"},
|
||||||
|
"ruby": {"ruby"},
|
||||||
|
|
||||||
|
// Frontend technologies
|
||||||
|
"react": {"react"},
|
||||||
|
"vue": {"vue", "vuejs"},
|
||||||
|
"angular": {"angular"},
|
||||||
|
"svelte": {"svelte"},
|
||||||
|
|
||||||
|
// Backend frameworks
|
||||||
|
"nodejs": {"nodejs", "node.js", "node"},
|
||||||
|
"django": {"django"},
|
||||||
|
"flask": {"flask"},
|
||||||
|
"spring": {"spring"},
|
||||||
|
"express": {"express"},
|
||||||
|
|
||||||
|
// Databases
|
||||||
|
"postgresql": {"postgresql", "postgres"},
|
||||||
|
"mysql": {"mysql"},
|
||||||
|
"mongodb": {"mongodb", "mongo"},
|
||||||
|
"redis": {"redis"},
|
||||||
|
|
||||||
|
// DevOps tools
|
||||||
|
"docker": {"docker"},
|
||||||
|
"kubernetes": {"kubernetes", "k8s"},
|
||||||
|
"aws": {"aws"},
|
||||||
|
"azure": {"azure"},
|
||||||
|
"gcp": {"gcp", "google cloud"},
|
||||||
|
|
||||||
|
// Other technologies
|
||||||
|
"graphql": {"graphql"},
|
||||||
|
"rest": {"rest", "restful"},
|
||||||
|
"grpc": {"grpc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
labelLower := strings.ToLower(label)
|
||||||
|
for expertiseArea, keywords := range expertiseKeywords {
|
||||||
|
for _, keyword := range keywords {
|
||||||
|
if strings.Contains(labelLower, keyword) && !expertiseMap[expertiseArea] {
|
||||||
|
expertise = append(expertise, expertiseArea)
|
||||||
|
expertiseMap[expertiseArea] = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default expertise if none detected
|
||||||
|
if len(expertise) == 0 {
|
||||||
|
expertise = []string{"development", "programming"}
|
||||||
|
}
|
||||||
|
|
||||||
|
return expertise
|
||||||
|
}
|
||||||
|
|
||||||
|
// addLabelToIssue adds a label to an issue
|
||||||
|
func (g *GitHubProvider) addLabelToIssue(issueNumber int, labelName string) error {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/labels", g.owner, g.repo, issueNumber)
|
||||||
|
|
||||||
|
body := []string{labelName}
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("POST", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to add label (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeLabelFromIssue removes a label from an issue
|
||||||
|
func (g *GitHubProvider) removeLabelFromIssue(issueNumber int, labelName string) error {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/labels/%s", g.owner, g.repo, issueNumber, url.QueryEscape(labelName))
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("DELETE", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to remove label (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addCommentToIssue adds a comment to an issue
|
||||||
|
func (g *GitHubProvider) addCommentToIssue(issueNumber int, comment string) error {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/comments", g.owner, g.repo, issueNumber)
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"body": comment,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("POST", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to add comment (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeIssue closes an issue
|
||||||
|
func (g *GitHubProvider) closeIssue(issueNumber int) error {
|
||||||
|
endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", g.owner, g.repo, issueNumber)
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"state": "closed",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("PATCH", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to close issue (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
781
pkg/providers/gitlab.go
Normal file
781
pkg/providers/gitlab.go
Normal file
@@ -0,0 +1,781 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"chorus/pkg/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GitLabProvider implements TaskProvider for GitLab API
|
||||||
|
type GitLabProvider struct {
|
||||||
|
config *repository.Config
|
||||||
|
httpClient *http.Client
|
||||||
|
baseURL string
|
||||||
|
token string
|
||||||
|
projectID string // GitLab uses project ID or namespace/project-name
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGitLabProvider creates a new GitLab provider
|
||||||
|
func NewGitLabProvider(config *repository.Config) (*GitLabProvider, error) {
|
||||||
|
if config.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("access token is required for GitLab provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to gitlab.com if no base URL provided
|
||||||
|
baseURL := config.BaseURL
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://gitlab.com"
|
||||||
|
}
|
||||||
|
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||||
|
|
||||||
|
// Build project ID from owner/repo if provided, otherwise use settings
|
||||||
|
var projectID string
|
||||||
|
if config.Owner != "" && config.Repository != "" {
|
||||||
|
projectID = url.QueryEscape(fmt.Sprintf("%s/%s", config.Owner, config.Repository))
|
||||||
|
} else if projectIDSetting, ok := config.Settings["project_id"].(string); ok {
|
||||||
|
projectID = projectIDSetting
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("either owner/repository or project_id in settings is required for GitLab provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GitLabProvider{
|
||||||
|
config: config,
|
||||||
|
baseURL: baseURL,
|
||||||
|
token: config.AccessToken,
|
||||||
|
projectID: projectID,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitLabIssue represents a GitLab issue
|
||||||
|
type GitLabIssue struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
IID int `json:"iid"` // Project-specific ID (what users see)
|
||||||
|
Title string `json:"title"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Labels []string `json:"labels"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
ProjectID int `json:"project_id"`
|
||||||
|
Author *GitLabUser `json:"author"`
|
||||||
|
Assignee *GitLabUser `json:"assignee"`
|
||||||
|
Assignees []GitLabUser `json:"assignees"`
|
||||||
|
WebURL string `json:"web_url"`
|
||||||
|
TimeStats *GitLabTimeStats `json:"time_stats,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitLabUser represents a GitLab user
|
||||||
|
type GitLabUser struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
AvatarURL string `json:"avatar_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitLabTimeStats represents time tracking statistics
|
||||||
|
type GitLabTimeStats struct {
|
||||||
|
TimeEstimate int `json:"time_estimate"`
|
||||||
|
TotalTimeSpent int `json:"total_time_spent"`
|
||||||
|
HumanTimeEstimate string `json:"human_time_estimate"`
|
||||||
|
HumanTotalTimeSpent string `json:"human_total_time_spent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitLabNote represents a GitLab issue note (comment)
|
||||||
|
type GitLabNote struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Author *GitLabUser `json:"author"`
|
||||||
|
System bool `json:"system"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitLabProject represents a GitLab project
|
||||||
|
type GitLabProject struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
NameWithNamespace string `json:"name_with_namespace"`
|
||||||
|
PathWithNamespace string `json:"path_with_namespace"`
|
||||||
|
WebURL string `json:"web_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeRequest makes an HTTP request to the GitLab API
|
||||||
|
func (g *GitLabProvider) makeRequest(method, endpoint string, body interface{}) (*http.Response, error) {
|
||||||
|
var reqBody io.Reader
|
||||||
|
|
||||||
|
if body != nil {
|
||||||
|
jsonData, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||||
|
}
|
||||||
|
reqBody = bytes.NewBuffer(jsonData)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/api/v4%s", g.baseURL, endpoint)
|
||||||
|
req, err := http.NewRequest(method, url, reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Private-Token", g.token)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := g.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTasks retrieves tasks (issues) from the GitLab project
|
||||||
|
func (g *GitLabProvider) GetTasks(projectID int) ([]*repository.Task, error) {
|
||||||
|
// Build query parameters
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("state", "opened")
|
||||||
|
params.Add("sort", "created_desc")
|
||||||
|
params.Add("per_page", "100") // GitLab default is 20
|
||||||
|
|
||||||
|
// Add task label filter if specified
|
||||||
|
if g.config.TaskLabel != "" {
|
||||||
|
params.Add("labels", g.config.TaskLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf("/projects/%s/issues?%s", g.projectID, params.Encode())
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get issues: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var issues []GitLabIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert GitLab issues to repository tasks
|
||||||
|
tasks := make([]*repository.Task, 0, len(issues))
|
||||||
|
for _, issue := range issues {
|
||||||
|
task := g.issueToTask(&issue)
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaimTask claims a task by assigning it to the agent and adding in-progress label
|
||||||
|
func (g *GitLabProvider) ClaimTask(taskNumber int, agentID string) (bool, error) {
|
||||||
|
// First, get the current issue to check its state
|
||||||
|
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, taskNumber)
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to get issue: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return false, fmt.Errorf("issue not found or not accessible")
|
||||||
|
}
|
||||||
|
|
||||||
|
var issue GitLabIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||||
|
return false, fmt.Errorf("failed to decode issue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if issue is already assigned
|
||||||
|
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||||
|
assigneeName := ""
|
||||||
|
if issue.Assignee != nil {
|
||||||
|
assigneeName = issue.Assignee.Username
|
||||||
|
} else if len(issue.Assignees) > 0 {
|
||||||
|
assigneeName = issue.Assignees[0].Username
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("issue is already assigned to %s", assigneeName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add in-progress label if specified
|
||||||
|
if g.config.InProgressLabel != "" {
|
||||||
|
err := g.addLabelToIssue(taskNumber, g.config.InProgressLabel)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to add in-progress label: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a note indicating the task has been claimed
|
||||||
|
comment := fmt.Sprintf("🤖 **Task Claimed by CHORUS Agent**\n\nAgent ID: `%s` \nStatus: Processing \n\nThis task is now being handled automatically by the CHORUS autonomous agent system.", agentID)
|
||||||
|
err = g.addNoteToIssue(taskNumber, comment)
|
||||||
|
if err != nil {
|
||||||
|
// Don't fail the claim if note fails
|
||||||
|
fmt.Printf("Warning: failed to add claim note: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTaskStatus updates the status of a task
|
||||||
|
func (g *GitLabProvider) UpdateTaskStatus(task *repository.Task, status string, comment string) error {
|
||||||
|
// Add a note with the status update
|
||||||
|
statusComment := fmt.Sprintf("📊 **Status Update: %s**\n\n%s\n\n---\n*Updated by CHORUS Agent*", status, comment)
|
||||||
|
|
||||||
|
err := g.addNoteToIssue(task.Number, statusComment)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add status note: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteTask completes a task by updating status and adding completion comment
|
||||||
|
func (g *GitLabProvider) CompleteTask(task *repository.Task, result *repository.TaskResult) error {
|
||||||
|
// Create completion comment with results
|
||||||
|
var commentBuffer strings.Builder
|
||||||
|
commentBuffer.WriteString("✅ **Task Completed Successfully**\n\n")
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("**Result:** %s\n\n", result.Message))
|
||||||
|
|
||||||
|
// Add metadata if available
|
||||||
|
if result.Metadata != nil {
|
||||||
|
commentBuffer.WriteString("## Execution Details\n\n")
|
||||||
|
for key, value := range result.Metadata {
|
||||||
|
// Format the metadata nicely
|
||||||
|
switch key {
|
||||||
|
case "duration":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- ⏱️ **Duration:** %v\n", value))
|
||||||
|
case "execution_type":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 🔧 **Execution Type:** %v\n", value))
|
||||||
|
case "commands_executed":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 🖥️ **Commands Executed:** %v\n", value))
|
||||||
|
case "files_generated":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 📄 **Files Generated:** %v\n", value))
|
||||||
|
case "ai_provider":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 🤖 **AI Provider:** %v\n", value))
|
||||||
|
case "ai_model":
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- 🧠 **AI Model:** %v\n", value))
|
||||||
|
default:
|
||||||
|
commentBuffer.WriteString(fmt.Sprintf("- **%s:** %v\n", key, value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
commentBuffer.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
commentBuffer.WriteString("---\n🤖 *Completed by CHORUS Autonomous Agent System*")
|
||||||
|
|
||||||
|
// Add completion note
|
||||||
|
err := g.addNoteToIssue(task.Number, commentBuffer.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add completion note: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove in-progress label and add completed label
|
||||||
|
if g.config.InProgressLabel != "" {
|
||||||
|
err := g.removeLabelFromIssue(task.Number, g.config.InProgressLabel)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning: failed to remove in-progress label: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.config.CompletedLabel != "" {
|
||||||
|
err := g.addLabelToIssue(task.Number, g.config.CompletedLabel)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning: failed to add completed label: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the issue if the task was successful
|
||||||
|
if result.Success {
|
||||||
|
err := g.closeIssue(task.Number)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close issue: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTaskDetails retrieves detailed information about a specific task
|
||||||
|
func (g *GitLabProvider) GetTaskDetails(projectID int, taskNumber int) (*repository.Task, error) {
|
||||||
|
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, taskNumber)
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get issue: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("issue not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var issue GitLabIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return g.issueToTask(&issue), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAvailableTasks lists all available (unassigned) tasks
|
||||||
|
func (g *GitLabProvider) ListAvailableTasks(projectID int) ([]*repository.Task, error) {
|
||||||
|
// Get open issues without assignees
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("state", "opened")
|
||||||
|
params.Add("assignee_id", "None") // GitLab filter for unassigned issues
|
||||||
|
params.Add("sort", "created_desc")
|
||||||
|
params.Add("per_page", "100")
|
||||||
|
|
||||||
|
if g.config.TaskLabel != "" {
|
||||||
|
params.Add("labels", g.config.TaskLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf("/projects/%s/issues?%s", g.projectID, params.Encode())
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get available issues: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var issues []GitLabIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issues); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode issues: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to tasks
|
||||||
|
tasks := make([]*repository.Task, 0, len(issues))
|
||||||
|
for _, issue := range issues {
|
||||||
|
// Double-check that issue is truly unassigned
|
||||||
|
if issue.Assignee != nil || len(issue.Assignees) > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
task := g.issueToTask(&issue)
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tasks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods
|
||||||
|
|
||||||
|
// issueToTask converts a GitLab issue to a repository Task
|
||||||
|
func (g *GitLabProvider) issueToTask(issue *GitLabIssue) *repository.Task {
|
||||||
|
// Calculate priority and complexity based on labels and content
|
||||||
|
priority := g.calculatePriority(issue.Labels, issue.Title, issue.Description)
|
||||||
|
complexity := g.calculateComplexity(issue.Labels, issue.Title, issue.Description)
|
||||||
|
|
||||||
|
// Determine required role and expertise from labels
|
||||||
|
requiredRole := g.determineRequiredRole(issue.Labels)
|
||||||
|
requiredExpertise := g.determineRequiredExpertise(issue.Labels)
|
||||||
|
|
||||||
|
// Extract project name from projectID
|
||||||
|
repositoryName := strings.Replace(g.projectID, "%2F", "/", -1) // URL decode
|
||||||
|
|
||||||
|
return &repository.Task{
|
||||||
|
Number: issue.IID, // Use IID (project-specific ID) not global ID
|
||||||
|
Title: issue.Title,
|
||||||
|
Body: issue.Description,
|
||||||
|
Repository: repositoryName,
|
||||||
|
Labels: issue.Labels,
|
||||||
|
Priority: priority,
|
||||||
|
Complexity: complexity,
|
||||||
|
Status: issue.State,
|
||||||
|
CreatedAt: issue.CreatedAt,
|
||||||
|
UpdatedAt: issue.UpdatedAt,
|
||||||
|
RequiredRole: requiredRole,
|
||||||
|
RequiredExpertise: requiredExpertise,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"gitlab_id": issue.ID,
|
||||||
|
"gitlab_iid": issue.IID,
|
||||||
|
"provider": "gitlab",
|
||||||
|
"project_id": issue.ProjectID,
|
||||||
|
"web_url": issue.WebURL,
|
||||||
|
"assignee": issue.Assignee,
|
||||||
|
"assignees": issue.Assignees,
|
||||||
|
"author": issue.Author,
|
||||||
|
"time_stats": issue.TimeStats,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculatePriority determines task priority from labels and content
|
||||||
|
func (g *GitLabProvider) calculatePriority(labels []string, title, body string) int {
|
||||||
|
priority := 5 // default
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
labelLower := strings.ToLower(label)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "critical"):
|
||||||
|
priority = 10
|
||||||
|
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "high"):
|
||||||
|
priority = 8
|
||||||
|
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "medium"):
|
||||||
|
priority = 5
|
||||||
|
case strings.Contains(labelLower, "priority") && strings.Contains(labelLower, "low"):
|
||||||
|
priority = 2
|
||||||
|
case labelLower == "critical" || labelLower == "urgent":
|
||||||
|
priority = 10
|
||||||
|
case labelLower == "high":
|
||||||
|
priority = 8
|
||||||
|
case labelLower == "bug" || labelLower == "security" || labelLower == "hotfix":
|
||||||
|
priority = max(priority, 7)
|
||||||
|
case labelLower == "enhancement" || labelLower == "feature":
|
||||||
|
priority = max(priority, 5)
|
||||||
|
case strings.Contains(labelLower, "milestone"):
|
||||||
|
priority = max(priority, 6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boost priority for urgent keywords in title
|
||||||
|
titleLower := strings.ToLower(title)
|
||||||
|
urgentKeywords := []string{"urgent", "critical", "hotfix", "security", "broken", "crash", "blocker"}
|
||||||
|
for _, keyword := range urgentKeywords {
|
||||||
|
if strings.Contains(titleLower, keyword) {
|
||||||
|
priority = max(priority, 8)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return priority
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateComplexity estimates task complexity from labels and content
|
||||||
|
func (g *GitLabProvider) calculateComplexity(labels []string, title, body string) int {
|
||||||
|
complexity := 3 // default
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
labelLower := strings.ToLower(label)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "high"):
|
||||||
|
complexity = 8
|
||||||
|
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "medium"):
|
||||||
|
complexity = 5
|
||||||
|
case strings.Contains(labelLower, "complexity") && strings.Contains(labelLower, "low"):
|
||||||
|
complexity = 2
|
||||||
|
case labelLower == "epic" || labelLower == "major":
|
||||||
|
complexity = 8
|
||||||
|
case labelLower == "refactor" || labelLower == "architecture":
|
||||||
|
complexity = max(complexity, 7)
|
||||||
|
case labelLower == "bug" || labelLower == "hotfix":
|
||||||
|
complexity = max(complexity, 4)
|
||||||
|
case labelLower == "enhancement" || labelLower == "feature":
|
||||||
|
complexity = max(complexity, 5)
|
||||||
|
case strings.Contains(labelLower, "beginner") || strings.Contains(labelLower, "newcomer"):
|
||||||
|
complexity = 2
|
||||||
|
case labelLower == "documentation" || labelLower == "docs":
|
||||||
|
complexity = max(complexity, 3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Estimate complexity from body length and content
|
||||||
|
bodyLength := len(strings.Fields(body))
|
||||||
|
if bodyLength > 500 {
|
||||||
|
complexity = max(complexity, 7)
|
||||||
|
} else if bodyLength > 200 {
|
||||||
|
complexity = max(complexity, 5)
|
||||||
|
} else if bodyLength > 50 {
|
||||||
|
complexity = max(complexity, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for complexity indicators in content
|
||||||
|
bodyLower := strings.ToLower(body)
|
||||||
|
complexityIndicators := []string{
|
||||||
|
"refactor", "architecture", "breaking change", "migration",
|
||||||
|
"redesign", "database schema", "api changes", "infrastructure",
|
||||||
|
}
|
||||||
|
for _, indicator := range complexityIndicators {
|
||||||
|
if strings.Contains(bodyLower, indicator) {
|
||||||
|
complexity = max(complexity, 7)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return complexity
|
||||||
|
}
|
||||||
|
|
||||||
|
// determineRequiredRole determines what agent role is needed for this task
|
||||||
|
func (g *GitLabProvider) determineRequiredRole(labels []string) string {
|
||||||
|
roleKeywords := map[string]string{
|
||||||
|
// Frontend
|
||||||
|
"frontend": "frontend-developer",
|
||||||
|
"ui": "frontend-developer",
|
||||||
|
"ux": "ui-ux-designer",
|
||||||
|
"css": "frontend-developer",
|
||||||
|
"html": "frontend-developer",
|
||||||
|
"javascript": "frontend-developer",
|
||||||
|
"react": "frontend-developer",
|
||||||
|
"vue": "frontend-developer",
|
||||||
|
"angular": "frontend-developer",
|
||||||
|
|
||||||
|
// Backend
|
||||||
|
"backend": "backend-developer",
|
||||||
|
"api": "backend-developer",
|
||||||
|
"server": "backend-developer",
|
||||||
|
"database": "backend-developer",
|
||||||
|
"sql": "backend-developer",
|
||||||
|
|
||||||
|
// DevOps
|
||||||
|
"devops": "devops-engineer",
|
||||||
|
"infrastructure": "devops-engineer",
|
||||||
|
"deployment": "devops-engineer",
|
||||||
|
"docker": "devops-engineer",
|
||||||
|
"kubernetes": "devops-engineer",
|
||||||
|
"ci/cd": "devops-engineer",
|
||||||
|
"pipeline": "devops-engineer",
|
||||||
|
|
||||||
|
// Security
|
||||||
|
"security": "security-engineer",
|
||||||
|
"authentication": "security-engineer",
|
||||||
|
"authorization": "security-engineer",
|
||||||
|
"vulnerability": "security-engineer",
|
||||||
|
|
||||||
|
// Testing
|
||||||
|
"testing": "tester",
|
||||||
|
"qa": "tester",
|
||||||
|
"test": "tester",
|
||||||
|
|
||||||
|
// Documentation
|
||||||
|
"documentation": "technical-writer",
|
||||||
|
"docs": "technical-writer",
|
||||||
|
|
||||||
|
// Design
|
||||||
|
"design": "ui-ux-designer",
|
||||||
|
"mockup": "ui-ux-designer",
|
||||||
|
"wireframe": "ui-ux-designer",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
labelLower := strings.ToLower(label)
|
||||||
|
for keyword, role := range roleKeywords {
|
||||||
|
if strings.Contains(labelLower, keyword) {
|
||||||
|
return role
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "developer" // default role
|
||||||
|
}
|
||||||
|
|
||||||
|
// determineRequiredExpertise determines what expertise is needed
|
||||||
|
func (g *GitLabProvider) determineRequiredExpertise(labels []string) []string {
|
||||||
|
expertise := make([]string, 0)
|
||||||
|
expertiseMap := make(map[string]bool) // prevent duplicates
|
||||||
|
|
||||||
|
expertiseKeywords := map[string][]string{
|
||||||
|
// Programming languages
|
||||||
|
"go": {"go", "golang"},
|
||||||
|
"python": {"python"},
|
||||||
|
"javascript": {"javascript", "js"},
|
||||||
|
"typescript": {"typescript", "ts"},
|
||||||
|
"java": {"java"},
|
||||||
|
"rust": {"rust"},
|
||||||
|
"c++": {"c++", "cpp"},
|
||||||
|
"c#": {"c#", "csharp"},
|
||||||
|
"php": {"php"},
|
||||||
|
"ruby": {"ruby"},
|
||||||
|
|
||||||
|
// Frontend technologies
|
||||||
|
"react": {"react"},
|
||||||
|
"vue": {"vue", "vuejs"},
|
||||||
|
"angular": {"angular"},
|
||||||
|
"svelte": {"svelte"},
|
||||||
|
|
||||||
|
// Backend frameworks
|
||||||
|
"nodejs": {"nodejs", "node.js", "node"},
|
||||||
|
"django": {"django"},
|
||||||
|
"flask": {"flask"},
|
||||||
|
"spring": {"spring"},
|
||||||
|
"express": {"express"},
|
||||||
|
|
||||||
|
// Databases
|
||||||
|
"postgresql": {"postgresql", "postgres"},
|
||||||
|
"mysql": {"mysql"},
|
||||||
|
"mongodb": {"mongodb", "mongo"},
|
||||||
|
"redis": {"redis"},
|
||||||
|
|
||||||
|
// DevOps tools
|
||||||
|
"docker": {"docker"},
|
||||||
|
"kubernetes": {"kubernetes", "k8s"},
|
||||||
|
"aws": {"aws"},
|
||||||
|
"azure": {"azure"},
|
||||||
|
"gcp": {"gcp", "google cloud"},
|
||||||
|
"gitlab-ci": {"gitlab-ci", "ci/cd"},
|
||||||
|
|
||||||
|
// Other technologies
|
||||||
|
"graphql": {"graphql"},
|
||||||
|
"rest": {"rest", "restful"},
|
||||||
|
"grpc": {"grpc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, label := range labels {
|
||||||
|
labelLower := strings.ToLower(label)
|
||||||
|
for expertiseArea, keywords := range expertiseKeywords {
|
||||||
|
for _, keyword := range keywords {
|
||||||
|
if strings.Contains(labelLower, keyword) && !expertiseMap[expertiseArea] {
|
||||||
|
expertise = append(expertise, expertiseArea)
|
||||||
|
expertiseMap[expertiseArea] = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default expertise if none detected
|
||||||
|
if len(expertise) == 0 {
|
||||||
|
expertise = []string{"development", "programming"}
|
||||||
|
}
|
||||||
|
|
||||||
|
return expertise
|
||||||
|
}
|
||||||
|
|
||||||
|
// addLabelToIssue adds a label to an issue
|
||||||
|
func (g *GitLabProvider) addLabelToIssue(issueNumber int, labelName string) error {
|
||||||
|
// First get the current labels
|
||||||
|
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, issueNumber)
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("failed to get current issue labels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var issue GitLabIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode issue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new label to existing labels
|
||||||
|
labels := append(issue.Labels, labelName)
|
||||||
|
|
||||||
|
// Update the issue with new labels
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"labels": strings.Join(labels, ","),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = g.makeRequest("PUT", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to add label (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeLabelFromIssue removes a label from an issue
|
||||||
|
func (g *GitLabProvider) removeLabelFromIssue(issueNumber int, labelName string) error {
|
||||||
|
// First get the current labels
|
||||||
|
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, issueNumber)
|
||||||
|
resp, err := g.makeRequest("GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("failed to get current issue labels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var issue GitLabIssue
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode issue: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the specified label
|
||||||
|
var newLabels []string
|
||||||
|
for _, label := range issue.Labels {
|
||||||
|
if label != labelName {
|
||||||
|
newLabels = append(newLabels, label)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the issue with new labels
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"labels": strings.Join(newLabels, ","),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = g.makeRequest("PUT", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to remove label (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNoteToIssue adds a note (comment) to an issue
|
||||||
|
func (g *GitLabProvider) addNoteToIssue(issueNumber int, note string) error {
|
||||||
|
endpoint := fmt.Sprintf("/projects/%s/issues/%d/notes", g.projectID, issueNumber)
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"body": note,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("POST", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to add note (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeIssue closes an issue
|
||||||
|
func (g *GitLabProvider) closeIssue(issueNumber int) error {
|
||||||
|
endpoint := fmt.Sprintf("/projects/%s/issues/%d", g.projectID, issueNumber)
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"state_event": "close",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := g.makeRequest("PUT", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("failed to close issue (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
698
pkg/providers/provider_test.go
Normal file
698
pkg/providers/provider_test.go
Normal file
@@ -0,0 +1,698 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"chorus/pkg/repository"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test Gitea Provider
|
||||||
|
func TestGiteaProvider_NewGiteaProvider(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *repository.Config
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
config: &repository.Config{
|
||||||
|
BaseURL: "https://gitea.example.com",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing base URL",
|
||||||
|
config: &repository.Config{
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "base URL is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing access token",
|
||||||
|
config: &repository.Config{
|
||||||
|
BaseURL: "https://gitea.example.com",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "access token is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing owner",
|
||||||
|
config: &repository.Config{
|
||||||
|
BaseURL: "https://gitea.example.com",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "owner is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing repository",
|
||||||
|
config: &repository.Config{
|
||||||
|
BaseURL: "https://gitea.example.com",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "repository name is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
provider, err := NewGiteaProvider(tt.config)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
assert.Nil(t, provider)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, provider)
|
||||||
|
assert.Equal(t, tt.config.AccessToken, provider.token)
|
||||||
|
assert.Equal(t, tt.config.Owner, provider.owner)
|
||||||
|
assert.Equal(t, tt.config.Repository, provider.repo)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGiteaProvider_GetTasks(t *testing.T) {
|
||||||
|
// Create a mock Gitea server
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "GET", r.Method)
|
||||||
|
assert.Contains(t, r.URL.Path, "/api/v1/repos/testowner/testrepo/issues")
|
||||||
|
assert.Equal(t, "token test-token", r.Header.Get("Authorization"))
|
||||||
|
|
||||||
|
// Mock response
|
||||||
|
issues := []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"number": 42,
|
||||||
|
"title": "Test Issue 1",
|
||||||
|
"body": "This is a test issue",
|
||||||
|
"state": "open",
|
||||||
|
"labels": []map[string]interface{}{
|
||||||
|
{"id": 1, "name": "bug", "color": "d73a4a"},
|
||||||
|
},
|
||||||
|
"created_at": "2023-01-01T12:00:00Z",
|
||||||
|
"updated_at": "2023-01-01T12:00:00Z",
|
||||||
|
"repository": map[string]interface{}{
|
||||||
|
"id": 1,
|
||||||
|
"name": "testrepo",
|
||||||
|
"full_name": "testowner/testrepo",
|
||||||
|
},
|
||||||
|
"assignee": nil,
|
||||||
|
"assignees": []interface{}{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(issues)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
config := &repository.Config{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := NewGiteaProvider(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tasks, err := provider.GetTasks(1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, tasks, 1)
|
||||||
|
assert.Equal(t, 42, tasks[0].Number)
|
||||||
|
assert.Equal(t, "Test Issue 1", tasks[0].Title)
|
||||||
|
assert.Equal(t, "This is a test issue", tasks[0].Body)
|
||||||
|
assert.Equal(t, "testowner/testrepo", tasks[0].Repository)
|
||||||
|
assert.Equal(t, []string{"bug"}, tasks[0].Labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GitHub Provider
|
||||||
|
func TestGitHubProvider_NewGitHubProvider(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *repository.Config
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
config: &repository.Config{
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing access token",
|
||||||
|
config: &repository.Config{
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "access token is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing owner",
|
||||||
|
config: &repository.Config{
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "owner is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing repository",
|
||||||
|
config: &repository.Config{
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "repository name is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
provider, err := NewGitHubProvider(tt.config)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
assert.Nil(t, provider)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, provider)
|
||||||
|
assert.Equal(t, tt.config.AccessToken, provider.token)
|
||||||
|
assert.Equal(t, tt.config.Owner, provider.owner)
|
||||||
|
assert.Equal(t, tt.config.Repository, provider.repo)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitHubProvider_GetTasks(t *testing.T) {
|
||||||
|
// Create a mock GitHub server
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "GET", r.Method)
|
||||||
|
assert.Contains(t, r.URL.Path, "/repos/testowner/testrepo/issues")
|
||||||
|
assert.Equal(t, "token test-token", r.Header.Get("Authorization"))
|
||||||
|
|
||||||
|
// Mock response (GitHub API format)
|
||||||
|
issues := []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"id": 123456789,
|
||||||
|
"number": 42,
|
||||||
|
"title": "Test GitHub Issue",
|
||||||
|
"body": "This is a test GitHub issue",
|
||||||
|
"state": "open",
|
||||||
|
"labels": []map[string]interface{}{
|
||||||
|
{"id": 1, "name": "enhancement", "color": "a2eeef"},
|
||||||
|
},
|
||||||
|
"created_at": "2023-01-01T12:00:00Z",
|
||||||
|
"updated_at": "2023-01-01T12:00:00Z",
|
||||||
|
"assignee": nil,
|
||||||
|
"assignees": []interface{}{},
|
||||||
|
"user": map[string]interface{}{
|
||||||
|
"id": 1,
|
||||||
|
"login": "testuser",
|
||||||
|
"name": "Test User",
|
||||||
|
},
|
||||||
|
"pull_request": nil, // Not a PR
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(issues)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Override the GitHub API URL for testing
|
||||||
|
config := &repository.Config{
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
BaseURL: server.URL, // This won't be used in real GitHub provider, but for testing we modify the URL in the provider
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := NewGitHubProvider(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// For testing, we need to create a modified provider that uses our test server
|
||||||
|
testProvider := &GitHubProvider{
|
||||||
|
config: config,
|
||||||
|
token: config.AccessToken,
|
||||||
|
owner: config.Owner,
|
||||||
|
repo: config.Repository,
|
||||||
|
httpClient: provider.httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can't easily test GitHub provider without modifying the URL, so we'll test the factory instead
|
||||||
|
assert.Equal(t, "test-token", provider.token)
|
||||||
|
assert.Equal(t, "testowner", provider.owner)
|
||||||
|
assert.Equal(t, "testrepo", provider.repo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GitLab Provider
|
||||||
|
func TestGitLabProvider_NewGitLabProvider(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *repository.Config
|
||||||
|
expectError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config with owner/repo",
|
||||||
|
config: &repository.Config{
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config with project ID",
|
||||||
|
config: &repository.Config{
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Settings: map[string]interface{}{
|
||||||
|
"project_id": "123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing access token",
|
||||||
|
config: &repository.Config{
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "access token is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing owner/repo and project_id",
|
||||||
|
config: &repository.Config{
|
||||||
|
AccessToken: "test-token",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
errorMsg: "either owner/repository or project_id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
provider, err := NewGitLabProvider(tt.config)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
assert.Nil(t, provider)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, provider)
|
||||||
|
assert.Equal(t, tt.config.AccessToken, provider.token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Provider Factory
|
||||||
|
func TestProviderFactory_CreateProvider(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *repository.Config
|
||||||
|
expectedType string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "create gitea provider",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "gitea",
|
||||||
|
BaseURL: "https://gitea.example.com",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectedType: "*providers.GiteaProvider",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create github provider",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "github",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectedType: "*providers.GitHubProvider",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create gitlab provider",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "gitlab",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectedType: "*providers.GitLabProvider",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create mock provider",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "mock",
|
||||||
|
},
|
||||||
|
expectedType: "*repository.MockTaskProvider",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported provider",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "unsupported",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil config",
|
||||||
|
config: nil,
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
provider, err := factory.CreateProvider(nil, tt.config)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, provider)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, provider)
|
||||||
|
// Note: We can't easily test exact type without reflection, so we just ensure it's not nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactory_ValidateConfig(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *repository.Config
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid gitea config",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "gitea",
|
||||||
|
BaseURL: "https://gitea.example.com",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid gitea config - missing baseURL",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "gitea",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid github config",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "github",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid github config - missing token",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "github",
|
||||||
|
Owner: "testowner",
|
||||||
|
Repository: "testrepo",
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid mock config",
|
||||||
|
config: &repository.Config{
|
||||||
|
Provider: "mock",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := factory.ValidateConfig(tt.config)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactory_GetSupportedTypes(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
types := factory.GetSupportedTypes()
|
||||||
|
|
||||||
|
assert.Contains(t, types, "gitea")
|
||||||
|
assert.Contains(t, types, "github")
|
||||||
|
assert.Contains(t, types, "gitlab")
|
||||||
|
assert.Contains(t, types, "mock")
|
||||||
|
assert.Len(t, types, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderFactory_GetProviderInfo(t *testing.T) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
|
||||||
|
info, err := factory.GetProviderInfo("gitea")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Gitea", info.Name)
|
||||||
|
assert.Equal(t, "gitea", info.Type)
|
||||||
|
assert.Contains(t, info.RequiredFields, "baseURL")
|
||||||
|
assert.Contains(t, info.RequiredFields, "accessToken")
|
||||||
|
|
||||||
|
// Test unsupported provider
|
||||||
|
_, err = factory.GetProviderInfo("unsupported")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test priority and complexity calculation
|
||||||
|
func TestPriorityComplexityCalculation(t *testing.T) {
|
||||||
|
provider := &GiteaProvider{} // We can test these methods with any provider
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
labels []string
|
||||||
|
title string
|
||||||
|
body string
|
||||||
|
expectedPriority int
|
||||||
|
expectedComplexity int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "critical bug",
|
||||||
|
labels: []string{"critical", "bug"},
|
||||||
|
title: "Critical security vulnerability",
|
||||||
|
body: "This is a critical security issue that needs immediate attention",
|
||||||
|
expectedPriority: 10,
|
||||||
|
expectedComplexity: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple enhancement",
|
||||||
|
labels: []string{"enhancement", "good first issue"},
|
||||||
|
title: "Add help text to button",
|
||||||
|
body: "Small UI improvement",
|
||||||
|
expectedPriority: 5,
|
||||||
|
expectedComplexity: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex refactor",
|
||||||
|
labels: []string{"refactor", "epic"},
|
||||||
|
title: "Refactor authentication system",
|
||||||
|
body: string(make([]byte, 1000)), // Long body
|
||||||
|
expectedPriority: 5,
|
||||||
|
expectedComplexity: 8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
priority := provider.calculatePriority(tt.labels, tt.title, tt.body)
|
||||||
|
complexity := provider.calculateComplexity(tt.labels, tt.title, tt.body)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedPriority, priority)
|
||||||
|
assert.Equal(t, tt.expectedComplexity, complexity)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test role determination
|
||||||
|
func TestRoleDetermination(t *testing.T) {
|
||||||
|
provider := &GiteaProvider{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
labels []string
|
||||||
|
expectedRole string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "frontend task",
|
||||||
|
labels: []string{"frontend", "ui"},
|
||||||
|
expectedRole: "frontend-developer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "backend task",
|
||||||
|
labels: []string{"backend", "api"},
|
||||||
|
expectedRole: "backend-developer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "devops task",
|
||||||
|
labels: []string{"devops", "deployment"},
|
||||||
|
expectedRole: "devops-engineer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "security task",
|
||||||
|
labels: []string{"security", "vulnerability"},
|
||||||
|
expectedRole: "security-engineer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "testing task",
|
||||||
|
labels: []string{"testing", "qa"},
|
||||||
|
expectedRole: "tester",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "documentation task",
|
||||||
|
labels: []string{"documentation"},
|
||||||
|
expectedRole: "technical-writer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "design task",
|
||||||
|
labels: []string{"design", "mockup"},
|
||||||
|
expectedRole: "ui-ux-designer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "generic task",
|
||||||
|
labels: []string{"bug"},
|
||||||
|
expectedRole: "developer",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
role := provider.determineRequiredRole(tt.labels)
|
||||||
|
assert.Equal(t, tt.expectedRole, role)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test expertise determination
|
||||||
|
func TestExpertiseDetermination(t *testing.T) {
|
||||||
|
provider := &GiteaProvider{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
labels []string
|
||||||
|
expectedExpertise []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "go programming",
|
||||||
|
labels: []string{"go", "backend"},
|
||||||
|
expectedExpertise: []string{"backend"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "react frontend",
|
||||||
|
labels: []string{"react", "javascript"},
|
||||||
|
expectedExpertise: []string{"javascript"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "docker devops",
|
||||||
|
labels: []string{"docker", "kubernetes"},
|
||||||
|
expectedExpertise: []string{"docker", "kubernetes"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no specific labels",
|
||||||
|
labels: []string{"bug", "minor"},
|
||||||
|
expectedExpertise: []string{"development", "programming"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
expertise := provider.determineRequiredExpertise(tt.labels)
|
||||||
|
// Check if all expected expertise areas are present
|
||||||
|
for _, expected := range tt.expectedExpertise {
|
||||||
|
assert.Contains(t, expertise, expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkGiteaProvider_CalculatePriority(b *testing.B) {
|
||||||
|
provider := &GiteaProvider{}
|
||||||
|
labels := []string{"critical", "bug", "security"}
|
||||||
|
title := "Critical security vulnerability in authentication"
|
||||||
|
body := "This is a detailed description of a critical security vulnerability that affects user authentication and needs immediate attention."
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
provider.calculatePriority(labels, title, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkProviderFactory_CreateProvider(b *testing.B) {
|
||||||
|
factory := NewProviderFactory()
|
||||||
|
config := &repository.Config{
|
||||||
|
Provider: "mock",
|
||||||
|
AccessToken: "test-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
provider, err := factory.CreateProvider(nil, config)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to create provider: %v", err)
|
||||||
|
}
|
||||||
|
_ = provider
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -147,17 +147,28 @@ func (m *DefaultTaskMatcher) ScoreTaskForAgent(task *Task, agentInfo *AgentInfo)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DefaultProviderFactory provides a default implementation of ProviderFactory
|
// DefaultProviderFactory provides a default implementation of ProviderFactory
|
||||||
type DefaultProviderFactory struct{}
|
// This is now a wrapper around the real provider factory
|
||||||
|
type DefaultProviderFactory struct {
|
||||||
|
factory ProviderFactory
|
||||||
|
}
|
||||||
|
|
||||||
// CreateProvider creates a task provider (stub implementation)
|
// NewDefaultProviderFactory creates a new default provider factory
|
||||||
|
func NewDefaultProviderFactory() *DefaultProviderFactory {
|
||||||
|
// This will be replaced by importing the providers factory
|
||||||
|
// For now, return a stub that creates mock providers
|
||||||
|
return &DefaultProviderFactory{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProvider creates a task provider
|
||||||
func (f *DefaultProviderFactory) CreateProvider(ctx interface{}, config *Config) (TaskProvider, error) {
|
func (f *DefaultProviderFactory) CreateProvider(ctx interface{}, config *Config) (TaskProvider, error) {
|
||||||
// In a real implementation, this would create GitHub, GitLab, etc. providers
|
// For backward compatibility, fall back to mock if no real factory is available
|
||||||
|
// In production, this should be replaced with the real provider factory
|
||||||
return &MockTaskProvider{}, nil
|
return &MockTaskProvider{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSupportedTypes returns supported repository types
|
// GetSupportedTypes returns supported repository types
|
||||||
func (f *DefaultProviderFactory) GetSupportedTypes() []string {
|
func (f *DefaultProviderFactory) GetSupportedTypes() []string {
|
||||||
return []string{"github", "gitlab", "mock"}
|
return []string{"github", "gitlab", "gitea", "mock"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SupportedProviders returns list of supported providers
|
// SupportedProviders returns list of supported providers
|
||||||
|
|||||||
1
vendor/github.com/Microsoft/go-winio/.gitattributes
generated
vendored
Normal file
1
vendor/github.com/Microsoft/go-winio/.gitattributes
generated
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
* text=auto eol=lf
|
||||||
10
vendor/github.com/Microsoft/go-winio/.gitignore
generated
vendored
Normal file
10
vendor/github.com/Microsoft/go-winio/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
.vscode/
|
||||||
|
|
||||||
|
*.exe
|
||||||
|
|
||||||
|
# testing
|
||||||
|
testdata
|
||||||
|
|
||||||
|
# go workspaces
|
||||||
|
go.work
|
||||||
|
go.work.sum
|
||||||
147
vendor/github.com/Microsoft/go-winio/.golangci.yml
generated
vendored
Normal file
147
vendor/github.com/Microsoft/go-winio/.golangci.yml
generated
vendored
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
linters:
|
||||||
|
enable:
|
||||||
|
# style
|
||||||
|
- containedctx # struct contains a context
|
||||||
|
- dupl # duplicate code
|
||||||
|
- errname # erorrs are named correctly
|
||||||
|
- nolintlint # "//nolint" directives are properly explained
|
||||||
|
- revive # golint replacement
|
||||||
|
- unconvert # unnecessary conversions
|
||||||
|
- wastedassign
|
||||||
|
|
||||||
|
# bugs, performance, unused, etc ...
|
||||||
|
- contextcheck # function uses a non-inherited context
|
||||||
|
- errorlint # errors not wrapped for 1.13
|
||||||
|
- exhaustive # check exhaustiveness of enum switch statements
|
||||||
|
- gofmt # files are gofmt'ed
|
||||||
|
- gosec # security
|
||||||
|
- nilerr # returns nil even with non-nil error
|
||||||
|
- thelper # test helpers without t.Helper()
|
||||||
|
- unparam # unused function params
|
||||||
|
|
||||||
|
issues:
|
||||||
|
exclude-dirs:
|
||||||
|
- pkg/etw/sample
|
||||||
|
|
||||||
|
exclude-rules:
|
||||||
|
# err is very often shadowed in nested scopes
|
||||||
|
- linters:
|
||||||
|
- govet
|
||||||
|
text: '^shadow: declaration of "err" shadows declaration'
|
||||||
|
|
||||||
|
# ignore long lines for skip autogen directives
|
||||||
|
- linters:
|
||||||
|
- revive
|
||||||
|
text: "^line-length-limit: "
|
||||||
|
source: "^//(go:generate|sys) "
|
||||||
|
|
||||||
|
#TODO: remove after upgrading to go1.18
|
||||||
|
# ignore comment spacing for nolint and sys directives
|
||||||
|
- linters:
|
||||||
|
- revive
|
||||||
|
text: "^comment-spacings: no space between comment delimiter and comment text"
|
||||||
|
source: "//(cspell:|nolint:|sys |todo)"
|
||||||
|
|
||||||
|
# not on go 1.18 yet, so no any
|
||||||
|
- linters:
|
||||||
|
- revive
|
||||||
|
text: "^use-any: since GO 1.18 'interface{}' can be replaced by 'any'"
|
||||||
|
|
||||||
|
# allow unjustified ignores of error checks in defer statements
|
||||||
|
- linters:
|
||||||
|
- nolintlint
|
||||||
|
text: "^directive `//nolint:errcheck` should provide explanation"
|
||||||
|
source: '^\s*defer '
|
||||||
|
|
||||||
|
# allow unjustified ignores of error lints for io.EOF
|
||||||
|
- linters:
|
||||||
|
- nolintlint
|
||||||
|
text: "^directive `//nolint:errorlint` should provide explanation"
|
||||||
|
source: '[=|!]= io.EOF'
|
||||||
|
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
exhaustive:
|
||||||
|
default-signifies-exhaustive: true
|
||||||
|
govet:
|
||||||
|
enable-all: true
|
||||||
|
disable:
|
||||||
|
# struct order is often for Win32 compat
|
||||||
|
# also, ignore pointer bytes/GC issues for now until performance becomes an issue
|
||||||
|
- fieldalignment
|
||||||
|
nolintlint:
|
||||||
|
require-explanation: true
|
||||||
|
require-specific: true
|
||||||
|
revive:
|
||||||
|
# revive is more configurable than static check, so likely the preferred alternative to static-check
|
||||||
|
# (once the perf issue is solved: https://github.com/golangci/golangci-lint/issues/2997)
|
||||||
|
enable-all-rules:
|
||||||
|
true
|
||||||
|
# https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md
|
||||||
|
rules:
|
||||||
|
# rules with required arguments
|
||||||
|
- name: argument-limit
|
||||||
|
disabled: true
|
||||||
|
- name: banned-characters
|
||||||
|
disabled: true
|
||||||
|
- name: cognitive-complexity
|
||||||
|
disabled: true
|
||||||
|
- name: cyclomatic
|
||||||
|
disabled: true
|
||||||
|
- name: file-header
|
||||||
|
disabled: true
|
||||||
|
- name: function-length
|
||||||
|
disabled: true
|
||||||
|
- name: function-result-limit
|
||||||
|
disabled: true
|
||||||
|
- name: max-public-structs
|
||||||
|
disabled: true
|
||||||
|
# geneally annoying rules
|
||||||
|
- name: add-constant # complains about any and all strings and integers
|
||||||
|
disabled: true
|
||||||
|
- name: confusing-naming # we frequently use "Foo()" and "foo()" together
|
||||||
|
disabled: true
|
||||||
|
- name: flag-parameter # excessive, and a common idiom we use
|
||||||
|
disabled: true
|
||||||
|
- name: unhandled-error # warns over common fmt.Print* and io.Close; rely on errcheck instead
|
||||||
|
disabled: true
|
||||||
|
# general config
|
||||||
|
- name: line-length-limit
|
||||||
|
arguments:
|
||||||
|
- 140
|
||||||
|
- name: var-naming
|
||||||
|
arguments:
|
||||||
|
- []
|
||||||
|
- - CID
|
||||||
|
- CRI
|
||||||
|
- CTRD
|
||||||
|
- DACL
|
||||||
|
- DLL
|
||||||
|
- DOS
|
||||||
|
- ETW
|
||||||
|
- FSCTL
|
||||||
|
- GCS
|
||||||
|
- GMSA
|
||||||
|
- HCS
|
||||||
|
- HV
|
||||||
|
- IO
|
||||||
|
- LCOW
|
||||||
|
- LDAP
|
||||||
|
- LPAC
|
||||||
|
- LTSC
|
||||||
|
- MMIO
|
||||||
|
- NT
|
||||||
|
- OCI
|
||||||
|
- PMEM
|
||||||
|
- PWSH
|
||||||
|
- RX
|
||||||
|
- SACl
|
||||||
|
- SID
|
||||||
|
- SMB
|
||||||
|
- TX
|
||||||
|
- VHD
|
||||||
|
- VHDX
|
||||||
|
- VMID
|
||||||
|
- VPCI
|
||||||
|
- WCOW
|
||||||
|
- WIM
|
||||||
1
vendor/github.com/Microsoft/go-winio/CODEOWNERS
generated
vendored
Normal file
1
vendor/github.com/Microsoft/go-winio/CODEOWNERS
generated
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
* @microsoft/containerplat
|
||||||
22
vendor/github.com/Microsoft/go-winio/LICENSE
generated
vendored
Normal file
22
vendor/github.com/Microsoft/go-winio/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2015 Microsoft
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
||||||
89
vendor/github.com/Microsoft/go-winio/README.md
generated
vendored
Normal file
89
vendor/github.com/Microsoft/go-winio/README.md
generated
vendored
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# go-winio [](https://github.com/microsoft/go-winio/actions/workflows/ci.yml)
|
||||||
|
|
||||||
|
This repository contains utilities for efficiently performing Win32 IO operations in
|
||||||
|
Go. Currently, this is focused on accessing named pipes and other file handles, and
|
||||||
|
for using named pipes as a net transport.
|
||||||
|
|
||||||
|
This code relies on IO completion ports to avoid blocking IO on system threads, allowing Go
|
||||||
|
to reuse the thread to schedule another goroutine. This limits support to Windows Vista and
|
||||||
|
newer operating systems. This is similar to the implementation of network sockets in Go's net
|
||||||
|
package.
|
||||||
|
|
||||||
|
Please see the LICENSE file for licensing information.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
This project welcomes contributions and suggestions.
|
||||||
|
Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that
|
||||||
|
you have the right to, and actually do, grant us the rights to use your contribution.
|
||||||
|
For details, visit [Microsoft CLA](https://cla.microsoft.com).
|
||||||
|
|
||||||
|
When you submit a pull request, a CLA-bot will automatically determine whether you need to
|
||||||
|
provide a CLA and decorate the PR appropriately (e.g., label, comment).
|
||||||
|
Simply follow the instructions provided by the bot.
|
||||||
|
You will only need to do this once across all repos using our CLA.
|
||||||
|
|
||||||
|
Additionally, the pull request pipeline requires the following steps to be performed before
|
||||||
|
mergining.
|
||||||
|
|
||||||
|
### Code Sign-Off
|
||||||
|
|
||||||
|
We require that contributors sign their commits using [`git commit --signoff`][git-commit-s]
|
||||||
|
to certify they either authored the work themselves or otherwise have permission to use it in this project.
|
||||||
|
|
||||||
|
A range of commits can be signed off using [`git rebase --signoff`][git-rebase-s].
|
||||||
|
|
||||||
|
Please see [the developer certificate](https://developercertificate.org) for more info,
|
||||||
|
as well as to make sure that you can attest to the rules listed.
|
||||||
|
Our CI uses the DCO Github app to ensure that all commits in a given PR are signed-off.
|
||||||
|
|
||||||
|
### Linting
|
||||||
|
|
||||||
|
Code must pass a linting stage, which uses [`golangci-lint`][lint].
|
||||||
|
The linting settings are stored in [`.golangci.yaml`](./.golangci.yaml), and can be run
|
||||||
|
automatically with VSCode by adding the following to your workspace or folder settings:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"go.lintTool": "golangci-lint",
|
||||||
|
"go.lintOnSave": "package",
|
||||||
|
```
|
||||||
|
|
||||||
|
Additional editor [integrations options are also available][lint-ide].
|
||||||
|
|
||||||
|
Alternatively, `golangci-lint` can be [installed locally][lint-install] and run from the repo root:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# use . or specify a path to only lint a package
|
||||||
|
# to show all lint errors, use flags "--max-issues-per-linter=0 --max-same-issues=0"
|
||||||
|
> golangci-lint run ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Go Generate
|
||||||
|
|
||||||
|
The pipeline checks that auto-generated code, via `go generate`, are up to date.
|
||||||
|
|
||||||
|
This can be done for the entire repo:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
> go generate ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code of Conduct
|
||||||
|
|
||||||
|
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||||
|
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||||
|
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||||
|
|
||||||
|
## Special Thanks
|
||||||
|
|
||||||
|
Thanks to [natefinch][natefinch] for the inspiration for this library.
|
||||||
|
See [npipe](https://github.com/natefinch/npipe) for another named pipe implementation.
|
||||||
|
|
||||||
|
[lint]: https://golangci-lint.run/
|
||||||
|
[lint-ide]: https://golangci-lint.run/usage/integrations/#editor-integration
|
||||||
|
[lint-install]: https://golangci-lint.run/usage/install/#local-installation
|
||||||
|
|
||||||
|
[git-commit-s]: https://git-scm.com/docs/git-commit#Documentation/git-commit.txt--s
|
||||||
|
[git-rebase-s]: https://git-scm.com/docs/git-rebase#Documentation/git-rebase.txt---signoff
|
||||||
|
|
||||||
|
[natefinch]: https://github.com/natefinch
|
||||||
41
vendor/github.com/Microsoft/go-winio/SECURITY.md
generated
vendored
Normal file
41
vendor/github.com/Microsoft/go-winio/SECURITY.md
generated
vendored
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->
|
||||||
|
|
||||||
|
## Security
|
||||||
|
|
||||||
|
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||||
|
|
||||||
|
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
|
||||||
|
|
||||||
|
## Reporting Security Issues
|
||||||
|
|
||||||
|
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||||
|
|
||||||
|
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
|
||||||
|
|
||||||
|
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
|
||||||
|
|
||||||
|
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
|
||||||
|
|
||||||
|
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||||
|
|
||||||
|
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||||
|
* Full paths of source file(s) related to the manifestation of the issue
|
||||||
|
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||||
|
* Any special configuration required to reproduce the issue
|
||||||
|
* Step-by-step instructions to reproduce the issue
|
||||||
|
* Proof-of-concept or exploit code (if possible)
|
||||||
|
* Impact of the issue, including how an attacker might exploit the issue
|
||||||
|
|
||||||
|
This information will help us triage your report more quickly.
|
||||||
|
|
||||||
|
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
|
||||||
|
|
||||||
|
## Preferred Languages
|
||||||
|
|
||||||
|
We prefer all communications to be in English.
|
||||||
|
|
||||||
|
## Policy
|
||||||
|
|
||||||
|
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
|
||||||
|
|
||||||
|
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
||||||
287
vendor/github.com/Microsoft/go-winio/backup.go
generated
vendored
Normal file
287
vendor/github.com/Microsoft/go-winio/backup.go
generated
vendored
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"unicode/utf16"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio/internal/fs"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
//sys backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead
|
||||||
|
//sys backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite
|
||||||
|
|
||||||
|
const (
|
||||||
|
BackupData = uint32(iota + 1)
|
||||||
|
BackupEaData
|
||||||
|
BackupSecurity
|
||||||
|
BackupAlternateData
|
||||||
|
BackupLink
|
||||||
|
BackupPropertyData
|
||||||
|
BackupObjectId //revive:disable-line:var-naming ID, not Id
|
||||||
|
BackupReparseData
|
||||||
|
BackupSparseBlock
|
||||||
|
BackupTxfsData
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
StreamSparseAttributes = uint32(8)
|
||||||
|
)
|
||||||
|
|
||||||
|
//nolint:revive // var-naming: ALL_CAPS
|
||||||
|
const (
|
||||||
|
WRITE_DAC = windows.WRITE_DAC
|
||||||
|
WRITE_OWNER = windows.WRITE_OWNER
|
||||||
|
ACCESS_SYSTEM_SECURITY = windows.ACCESS_SYSTEM_SECURITY
|
||||||
|
)
|
||||||
|
|
||||||
|
// BackupHeader represents a backup stream of a file.
|
||||||
|
type BackupHeader struct {
|
||||||
|
//revive:disable-next-line:var-naming ID, not Id
|
||||||
|
Id uint32 // The backup stream ID
|
||||||
|
Attributes uint32 // Stream attributes
|
||||||
|
Size int64 // The size of the stream in bytes
|
||||||
|
Name string // The name of the stream (for BackupAlternateData only).
|
||||||
|
Offset int64 // The offset of the stream in the file (for BackupSparseBlock only).
|
||||||
|
}
|
||||||
|
|
||||||
|
type win32StreamID struct {
|
||||||
|
StreamID uint32
|
||||||
|
Attributes uint32
|
||||||
|
Size uint64
|
||||||
|
NameSize uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackupStreamReader reads from a stream produced by the BackupRead Win32 API and produces a series
|
||||||
|
// of BackupHeader values.
|
||||||
|
type BackupStreamReader struct {
|
||||||
|
r io.Reader
|
||||||
|
bytesLeft int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBackupStreamReader produces a BackupStreamReader from any io.Reader.
|
||||||
|
func NewBackupStreamReader(r io.Reader) *BackupStreamReader {
|
||||||
|
return &BackupStreamReader{r, 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next returns the next backup stream and prepares for calls to Read(). It skips the remainder of the current stream if
|
||||||
|
// it was not completely read.
|
||||||
|
func (r *BackupStreamReader) Next() (*BackupHeader, error) {
|
||||||
|
if r.bytesLeft > 0 { //nolint:nestif // todo: flatten this
|
||||||
|
if s, ok := r.r.(io.Seeker); ok {
|
||||||
|
// Make sure Seek on io.SeekCurrent sometimes succeeds
|
||||||
|
// before trying the actual seek.
|
||||||
|
if _, err := s.Seek(0, io.SeekCurrent); err == nil {
|
||||||
|
if _, err = s.Seek(r.bytesLeft, io.SeekCurrent); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r.bytesLeft = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(io.Discard, r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var wsi win32StreamID
|
||||||
|
if err := binary.Read(r.r, binary.LittleEndian, &wsi); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hdr := &BackupHeader{
|
||||||
|
Id: wsi.StreamID,
|
||||||
|
Attributes: wsi.Attributes,
|
||||||
|
Size: int64(wsi.Size),
|
||||||
|
}
|
||||||
|
if wsi.NameSize != 0 {
|
||||||
|
name := make([]uint16, int(wsi.NameSize/2))
|
||||||
|
if err := binary.Read(r.r, binary.LittleEndian, name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hdr.Name = windows.UTF16ToString(name)
|
||||||
|
}
|
||||||
|
if wsi.StreamID == BackupSparseBlock {
|
||||||
|
if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hdr.Size -= 8
|
||||||
|
}
|
||||||
|
r.bytesLeft = hdr.Size
|
||||||
|
return hdr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads from the current backup stream.
|
||||||
|
func (r *BackupStreamReader) Read(b []byte) (int, error) {
|
||||||
|
if r.bytesLeft == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if int64(len(b)) > r.bytesLeft {
|
||||||
|
b = b[:r.bytesLeft]
|
||||||
|
}
|
||||||
|
n, err := r.r.Read(b)
|
||||||
|
r.bytesLeft -= int64(n)
|
||||||
|
if err == io.EOF {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
} else if r.bytesLeft == 0 && err == nil {
|
||||||
|
err = io.EOF
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackupStreamWriter writes a stream compatible with the BackupWrite Win32 API.
|
||||||
|
type BackupStreamWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
bytesLeft int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBackupStreamWriter produces a BackupStreamWriter on top of an io.Writer.
|
||||||
|
func NewBackupStreamWriter(w io.Writer) *BackupStreamWriter {
|
||||||
|
return &BackupStreamWriter{w, 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader writes the next backup stream header and prepares for calls to Write().
|
||||||
|
func (w *BackupStreamWriter) WriteHeader(hdr *BackupHeader) error {
|
||||||
|
if w.bytesLeft != 0 {
|
||||||
|
return fmt.Errorf("missing %d bytes", w.bytesLeft)
|
||||||
|
}
|
||||||
|
name := utf16.Encode([]rune(hdr.Name))
|
||||||
|
wsi := win32StreamID{
|
||||||
|
StreamID: hdr.Id,
|
||||||
|
Attributes: hdr.Attributes,
|
||||||
|
Size: uint64(hdr.Size),
|
||||||
|
NameSize: uint32(len(name) * 2),
|
||||||
|
}
|
||||||
|
if hdr.Id == BackupSparseBlock {
|
||||||
|
// Include space for the int64 block offset
|
||||||
|
wsi.Size += 8
|
||||||
|
}
|
||||||
|
if err := binary.Write(w.w, binary.LittleEndian, &wsi); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(name) != 0 {
|
||||||
|
if err := binary.Write(w.w, binary.LittleEndian, name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hdr.Id == BackupSparseBlock {
|
||||||
|
if err := binary.Write(w.w, binary.LittleEndian, hdr.Offset); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.bytesLeft = hdr.Size
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes to the current backup stream.
|
||||||
|
func (w *BackupStreamWriter) Write(b []byte) (int, error) {
|
||||||
|
if w.bytesLeft < int64(len(b)) {
|
||||||
|
return 0, fmt.Errorf("too many bytes by %d", int64(len(b))-w.bytesLeft)
|
||||||
|
}
|
||||||
|
n, err := w.w.Write(b)
|
||||||
|
w.bytesLeft -= int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackupFileReader provides an io.ReadCloser interface on top of the BackupRead Win32 API.
|
||||||
|
type BackupFileReader struct {
|
||||||
|
f *os.File
|
||||||
|
includeSecurity bool
|
||||||
|
ctx uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBackupFileReader returns a new BackupFileReader from a file handle. If includeSecurity is true,
|
||||||
|
// Read will attempt to read the security descriptor of the file.
|
||||||
|
func NewBackupFileReader(f *os.File, includeSecurity bool) *BackupFileReader {
|
||||||
|
r := &BackupFileReader{f, includeSecurity, 0}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads a backup stream from the file by calling the Win32 API BackupRead().
|
||||||
|
func (r *BackupFileReader) Read(b []byte) (int, error) {
|
||||||
|
var bytesRead uint32
|
||||||
|
err := backupRead(windows.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err}
|
||||||
|
}
|
||||||
|
runtime.KeepAlive(r.f)
|
||||||
|
if bytesRead == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
return int(bytesRead), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close frees Win32 resources associated with the BackupFileReader. It does not close
|
||||||
|
// the underlying file.
|
||||||
|
func (r *BackupFileReader) Close() error {
|
||||||
|
if r.ctx != 0 {
|
||||||
|
_ = backupRead(windows.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx)
|
||||||
|
runtime.KeepAlive(r.f)
|
||||||
|
r.ctx = 0
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackupFileWriter provides an io.WriteCloser interface on top of the BackupWrite Win32 API.
|
||||||
|
type BackupFileWriter struct {
|
||||||
|
f *os.File
|
||||||
|
includeSecurity bool
|
||||||
|
ctx uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBackupFileWriter returns a new BackupFileWriter from a file handle. If includeSecurity is true,
|
||||||
|
// Write() will attempt to restore the security descriptor from the stream.
|
||||||
|
func NewBackupFileWriter(f *os.File, includeSecurity bool) *BackupFileWriter {
|
||||||
|
w := &BackupFileWriter{f, includeSecurity, 0}
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write restores a portion of the file using the provided backup stream.
|
||||||
|
func (w *BackupFileWriter) Write(b []byte) (int, error) {
|
||||||
|
var bytesWritten uint32
|
||||||
|
err := backupWrite(windows.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err}
|
||||||
|
}
|
||||||
|
runtime.KeepAlive(w.f)
|
||||||
|
if int(bytesWritten) != len(b) {
|
||||||
|
return int(bytesWritten), errors.New("not all bytes could be written")
|
||||||
|
}
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close frees Win32 resources associated with the BackupFileWriter. It does not
|
||||||
|
// close the underlying file.
|
||||||
|
func (w *BackupFileWriter) Close() error {
|
||||||
|
if w.ctx != 0 {
|
||||||
|
_ = backupWrite(windows.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx)
|
||||||
|
runtime.KeepAlive(w.f)
|
||||||
|
w.ctx = 0
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenForBackup opens a file or directory, potentially skipping access checks if the backup
|
||||||
|
// or restore privileges have been acquired.
|
||||||
|
//
|
||||||
|
// If the file opened was a directory, it cannot be used with Readdir().
|
||||||
|
func OpenForBackup(path string, access uint32, share uint32, createmode uint32) (*os.File, error) {
|
||||||
|
h, err := fs.CreateFile(path,
|
||||||
|
fs.AccessMask(access),
|
||||||
|
fs.FileShareMode(share),
|
||||||
|
nil,
|
||||||
|
fs.FileCreationDisposition(createmode),
|
||||||
|
fs.FILE_FLAG_BACKUP_SEMANTICS|fs.FILE_FLAG_OPEN_REPARSE_POINT,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
err = &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return os.NewFile(uintptr(h), path), nil
|
||||||
|
}
|
||||||
22
vendor/github.com/Microsoft/go-winio/doc.go
generated
vendored
Normal file
22
vendor/github.com/Microsoft/go-winio/doc.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
// This package provides utilities for efficiently performing Win32 IO operations in Go.
|
||||||
|
// Currently, this package is provides support for genreal IO and management of
|
||||||
|
// - named pipes
|
||||||
|
// - files
|
||||||
|
// - [Hyper-V sockets]
|
||||||
|
//
|
||||||
|
// This code is similar to Go's [net] package, and uses IO completion ports to avoid
|
||||||
|
// blocking IO on system threads, allowing Go to reuse the thread to schedule other goroutines.
|
||||||
|
//
|
||||||
|
// This limits support to Windows Vista and newer operating systems.
|
||||||
|
//
|
||||||
|
// Additionally, this package provides support for:
|
||||||
|
// - creating and managing GUIDs
|
||||||
|
// - writing to [ETW]
|
||||||
|
// - opening and manageing VHDs
|
||||||
|
// - parsing [Windows Image files]
|
||||||
|
// - auto-generating Win32 API code
|
||||||
|
//
|
||||||
|
// [Hyper-V sockets]: https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service
|
||||||
|
// [ETW]: https://docs.microsoft.com/en-us/windows-hardware/drivers/devtest/event-tracing-for-windows--etw-
|
||||||
|
// [Windows Image files]: https://docs.microsoft.com/en-us/windows-hardware/manufacture/desktop/work-with-windows-images
|
||||||
|
package winio
|
||||||
137
vendor/github.com/Microsoft/go-winio/ea.go
generated
vendored
Normal file
137
vendor/github.com/Microsoft/go-winio/ea.go
generated
vendored
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileFullEaInformation struct {
|
||||||
|
NextEntryOffset uint32
|
||||||
|
Flags uint8
|
||||||
|
NameLength uint8
|
||||||
|
ValueLength uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
fileFullEaInformationSize = binary.Size(&fileFullEaInformation{})
|
||||||
|
|
||||||
|
errInvalidEaBuffer = errors.New("invalid extended attribute buffer")
|
||||||
|
errEaNameTooLarge = errors.New("extended attribute name too large")
|
||||||
|
errEaValueTooLarge = errors.New("extended attribute value too large")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExtendedAttribute represents a single Windows EA.
|
||||||
|
type ExtendedAttribute struct {
|
||||||
|
Name string
|
||||||
|
Value []byte
|
||||||
|
Flags uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
|
||||||
|
var info fileFullEaInformation
|
||||||
|
err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info)
|
||||||
|
if err != nil {
|
||||||
|
err = errInvalidEaBuffer
|
||||||
|
return ea, nb, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nameOffset := fileFullEaInformationSize
|
||||||
|
nameLen := int(info.NameLength)
|
||||||
|
valueOffset := nameOffset + int(info.NameLength) + 1
|
||||||
|
valueLen := int(info.ValueLength)
|
||||||
|
nextOffset := int(info.NextEntryOffset)
|
||||||
|
if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) {
|
||||||
|
err = errInvalidEaBuffer
|
||||||
|
return ea, nb, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ea.Name = string(b[nameOffset : nameOffset+nameLen])
|
||||||
|
ea.Value = b[valueOffset : valueOffset+valueLen]
|
||||||
|
ea.Flags = info.Flags
|
||||||
|
if info.NextEntryOffset != 0 {
|
||||||
|
nb = b[info.NextEntryOffset:]
|
||||||
|
}
|
||||||
|
return ea, nb, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION
|
||||||
|
// buffer retrieved from BackupRead, ZwQueryEaFile, etc.
|
||||||
|
func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) {
|
||||||
|
for len(b) != 0 {
|
||||||
|
ea, nb, err := parseEa(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
eas = append(eas, ea)
|
||||||
|
b = nb
|
||||||
|
}
|
||||||
|
return eas, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error {
|
||||||
|
if int(uint8(len(ea.Name))) != len(ea.Name) {
|
||||||
|
return errEaNameTooLarge
|
||||||
|
}
|
||||||
|
if int(uint16(len(ea.Value))) != len(ea.Value) {
|
||||||
|
return errEaValueTooLarge
|
||||||
|
}
|
||||||
|
entrySize := uint32(fileFullEaInformationSize + len(ea.Name) + 1 + len(ea.Value))
|
||||||
|
withPadding := (entrySize + 3) &^ 3
|
||||||
|
nextOffset := uint32(0)
|
||||||
|
if !last {
|
||||||
|
nextOffset = withPadding
|
||||||
|
}
|
||||||
|
info := fileFullEaInformation{
|
||||||
|
NextEntryOffset: nextOffset,
|
||||||
|
Flags: ea.Flags,
|
||||||
|
NameLength: uint8(len(ea.Name)),
|
||||||
|
ValueLength: uint16(len(ea.Value)),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := binary.Write(buf, binary.LittleEndian, &info)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = buf.Write([]byte(ea.Name))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = buf.WriteByte(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = buf.Write(ea.Value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = buf.Write([]byte{0, 0, 0}[0 : withPadding-entrySize])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeExtendedAttributes encodes a list of EAs into a FILE_FULL_EA_INFORMATION
|
||||||
|
// buffer for use with BackupWrite, ZwSetEaFile, etc.
|
||||||
|
func EncodeExtendedAttributes(eas []ExtendedAttribute) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
for i := range eas {
|
||||||
|
last := false
|
||||||
|
if i == len(eas)-1 {
|
||||||
|
last = true
|
||||||
|
}
|
||||||
|
|
||||||
|
err := writeEa(&buf, &eas[i], last)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
320
vendor/github.com/Microsoft/go-winio/file.go
generated
vendored
Normal file
320
vendor/github.com/Microsoft/go-winio/file.go
generated
vendored
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
|
||||||
|
//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
|
||||||
|
//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
||||||
|
//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
||||||
|
//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrFileClosed = errors.New("file has already been closed")
|
||||||
|
ErrTimeout = &timeoutError{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type timeoutError struct{}
|
||||||
|
|
||||||
|
func (*timeoutError) Error() string { return "i/o timeout" }
|
||||||
|
func (*timeoutError) Timeout() bool { return true }
|
||||||
|
func (*timeoutError) Temporary() bool { return true }
|
||||||
|
|
||||||
|
type timeoutChan chan struct{}
|
||||||
|
|
||||||
|
var ioInitOnce sync.Once
|
||||||
|
var ioCompletionPort windows.Handle
|
||||||
|
|
||||||
|
// ioResult contains the result of an asynchronous IO operation.
|
||||||
|
type ioResult struct {
|
||||||
|
bytes uint32
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ioOperation represents an outstanding asynchronous Win32 IO.
|
||||||
|
type ioOperation struct {
|
||||||
|
o windows.Overlapped
|
||||||
|
ch chan ioResult
|
||||||
|
}
|
||||||
|
|
||||||
|
func initIO() {
|
||||||
|
h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ioCompletionPort = h
|
||||||
|
go ioCompletionProcessor(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
||||||
|
// It takes ownership of this handle and will close it if it is garbage collected.
|
||||||
|
type win32File struct {
|
||||||
|
handle windows.Handle
|
||||||
|
wg sync.WaitGroup
|
||||||
|
wgLock sync.RWMutex
|
||||||
|
closing atomic.Bool
|
||||||
|
socket bool
|
||||||
|
readDeadline deadlineHandler
|
||||||
|
writeDeadline deadlineHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
type deadlineHandler struct {
|
||||||
|
setLock sync.Mutex
|
||||||
|
channel timeoutChan
|
||||||
|
channelLock sync.RWMutex
|
||||||
|
timer *time.Timer
|
||||||
|
timedout atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeWin32File makes a new win32File from an existing file handle.
|
||||||
|
func makeWin32File(h windows.Handle) (*win32File, error) {
|
||||||
|
f := &win32File{handle: h}
|
||||||
|
ioInitOnce.Do(initIO)
|
||||||
|
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = setFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f.readDeadline.channel = make(timeoutChan)
|
||||||
|
f.writeDeadline.channel = make(timeoutChan)
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: use NewOpenFile instead.
|
||||||
|
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
|
||||||
|
return NewOpenFile(windows.Handle(h))
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
|
||||||
|
// If we return the result of makeWin32File directly, it can result in an
|
||||||
|
// interface-wrapped nil, rather than a nil interface value.
|
||||||
|
f, err := makeWin32File(h)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeHandle closes the resources associated with a Win32 handle.
|
||||||
|
func (f *win32File) closeHandle() {
|
||||||
|
f.wgLock.Lock()
|
||||||
|
// Atomically set that we are closing, releasing the resources only once.
|
||||||
|
if !f.closing.Swap(true) {
|
||||||
|
f.wgLock.Unlock()
|
||||||
|
// cancel all IO and wait for it to complete
|
||||||
|
_ = cancelIoEx(f.handle, nil)
|
||||||
|
f.wg.Wait()
|
||||||
|
// at this point, no new IO can start
|
||||||
|
windows.Close(f.handle)
|
||||||
|
f.handle = 0
|
||||||
|
} else {
|
||||||
|
f.wgLock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes a win32File.
|
||||||
|
func (f *win32File) Close() error {
|
||||||
|
f.closeHandle()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed checks if the file has been closed.
|
||||||
|
func (f *win32File) IsClosed() bool {
|
||||||
|
return f.closing.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareIO prepares for a new IO operation.
|
||||||
|
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||||
|
func (f *win32File) prepareIO() (*ioOperation, error) {
|
||||||
|
f.wgLock.RLock()
|
||||||
|
if f.closing.Load() {
|
||||||
|
f.wgLock.RUnlock()
|
||||||
|
return nil, ErrFileClosed
|
||||||
|
}
|
||||||
|
f.wg.Add(1)
|
||||||
|
f.wgLock.RUnlock()
|
||||||
|
c := &ioOperation{}
|
||||||
|
c.ch = make(chan ioResult)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ioCompletionProcessor processes completed async IOs forever.
|
||||||
|
func ioCompletionProcessor(h windows.Handle) {
|
||||||
|
for {
|
||||||
|
var bytes uint32
|
||||||
|
var key uintptr
|
||||||
|
var op *ioOperation
|
||||||
|
err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
|
||||||
|
if op == nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
op.ch <- ioResult{bytes, err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo: helsaawy - create an asyncIO version that takes a context
|
||||||
|
|
||||||
|
// asyncIO processes the return value from ReadFile or WriteFile, blocking until
|
||||||
|
// the operation has actually completed.
|
||||||
|
func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
||||||
|
if err != windows.ERROR_IO_PENDING { //nolint:errorlint // err is Errno
|
||||||
|
return int(bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.closing.Load() {
|
||||||
|
_ = cancelIoEx(f.handle, &c.o)
|
||||||
|
}
|
||||||
|
|
||||||
|
var timeout timeoutChan
|
||||||
|
if d != nil {
|
||||||
|
d.channelLock.Lock()
|
||||||
|
timeout = d.channel
|
||||||
|
d.channelLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
var r ioResult
|
||||||
|
select {
|
||||||
|
case r = <-c.ch:
|
||||||
|
err = r.err
|
||||||
|
if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
|
||||||
|
if f.closing.Load() {
|
||||||
|
err = ErrFileClosed
|
||||||
|
}
|
||||||
|
} else if err != nil && f.socket {
|
||||||
|
// err is from Win32. Query the overlapped structure to get the winsock error.
|
||||||
|
var bytes, flags uint32
|
||||||
|
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
_ = cancelIoEx(f.handle, &c.o)
|
||||||
|
r = <-c.ch
|
||||||
|
err = r.err
|
||||||
|
if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
|
||||||
|
err = ErrTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runtime.KeepAlive is needed, as c is passed via native
|
||||||
|
// code to ioCompletionProcessor, c must remain alive
|
||||||
|
// until the channel read is complete.
|
||||||
|
// todo: (de)allocate *ioOperation via win32 heap functions, instead of needing to KeepAlive?
|
||||||
|
runtime.KeepAlive(c)
|
||||||
|
return int(r.bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads from a file handle.
|
||||||
|
func (f *win32File) Read(b []byte) (int, error) {
|
||||||
|
c, err := f.prepareIO()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer f.wg.Done()
|
||||||
|
|
||||||
|
if f.readDeadline.timedout.Load() {
|
||||||
|
return 0, ErrTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
var bytes uint32
|
||||||
|
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
|
||||||
|
n, err := f.asyncIO(c, &f.readDeadline, bytes, err)
|
||||||
|
runtime.KeepAlive(b)
|
||||||
|
|
||||||
|
// Handle EOF conditions.
|
||||||
|
if err == nil && n == 0 && len(b) != 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
} else if err == windows.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes to a file handle.
|
||||||
|
func (f *win32File) Write(b []byte) (int, error) {
|
||||||
|
c, err := f.prepareIO()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer f.wg.Done()
|
||||||
|
|
||||||
|
if f.writeDeadline.timedout.Load() {
|
||||||
|
return 0, ErrTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
var bytes uint32
|
||||||
|
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
|
||||||
|
n, err := f.asyncIO(c, &f.writeDeadline, bytes, err)
|
||||||
|
runtime.KeepAlive(b)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32File) SetReadDeadline(deadline time.Time) error {
|
||||||
|
return f.readDeadline.set(deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32File) SetWriteDeadline(deadline time.Time) error {
|
||||||
|
return f.writeDeadline.set(deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32File) Flush() error {
|
||||||
|
return windows.FlushFileBuffers(f.handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32File) Fd() uintptr {
|
||||||
|
return uintptr(f.handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *deadlineHandler) set(deadline time.Time) error {
|
||||||
|
d.setLock.Lock()
|
||||||
|
defer d.setLock.Unlock()
|
||||||
|
|
||||||
|
if d.timer != nil {
|
||||||
|
if !d.timer.Stop() {
|
||||||
|
<-d.channel
|
||||||
|
}
|
||||||
|
d.timer = nil
|
||||||
|
}
|
||||||
|
d.timedout.Store(false)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-d.channel:
|
||||||
|
d.channelLock.Lock()
|
||||||
|
d.channel = make(chan struct{})
|
||||||
|
d.channelLock.Unlock()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if deadline.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutIO := func() {
|
||||||
|
d.timedout.Store(true)
|
||||||
|
close(d.channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
duration := deadline.Sub(now)
|
||||||
|
if deadline.After(now) {
|
||||||
|
// Deadline is in the future, set a timer to wait
|
||||||
|
d.timer = time.AfterFunc(duration, timeoutIO)
|
||||||
|
} else {
|
||||||
|
// Deadline is in the past. Cancel all pending IO now.
|
||||||
|
timeoutIO()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
106
vendor/github.com/Microsoft/go-winio/fileinfo.go
generated
vendored
Normal file
106
vendor/github.com/Microsoft/go-winio/fileinfo.go
generated
vendored
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileBasicInfo contains file access time and file attributes information.
|
||||||
|
type FileBasicInfo struct {
|
||||||
|
CreationTime, LastAccessTime, LastWriteTime, ChangeTime windows.Filetime
|
||||||
|
FileAttributes uint32
|
||||||
|
_ uint32 // padding
|
||||||
|
}
|
||||||
|
|
||||||
|
// alignedFileBasicInfo is a FileBasicInfo, but aligned to uint64 by containing
|
||||||
|
// uint64 rather than windows.Filetime. Filetime contains two uint32s. uint64
|
||||||
|
// alignment is necessary to pass this as FILE_BASIC_INFO.
|
||||||
|
type alignedFileBasicInfo struct {
|
||||||
|
CreationTime, LastAccessTime, LastWriteTime, ChangeTime uint64
|
||||||
|
FileAttributes uint32
|
||||||
|
_ uint32 // padding
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFileBasicInfo retrieves times and attributes for a file.
|
||||||
|
func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
|
||||||
|
bi := &alignedFileBasicInfo{}
|
||||||
|
if err := windows.GetFileInformationByHandleEx(
|
||||||
|
windows.Handle(f.Fd()),
|
||||||
|
windows.FileBasicInfo,
|
||||||
|
(*byte)(unsafe.Pointer(bi)),
|
||||||
|
uint32(unsafe.Sizeof(*bi)),
|
||||||
|
); err != nil {
|
||||||
|
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||||
|
}
|
||||||
|
runtime.KeepAlive(f)
|
||||||
|
// Reinterpret the alignedFileBasicInfo as a FileBasicInfo so it matches the
|
||||||
|
// public API of this module. The data may be unnecessarily aligned.
|
||||||
|
return (*FileBasicInfo)(unsafe.Pointer(bi)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFileBasicInfo sets times and attributes for a file.
|
||||||
|
func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error {
|
||||||
|
// Create an alignedFileBasicInfo based on a FileBasicInfo. The copy is
|
||||||
|
// suitable to pass to GetFileInformationByHandleEx.
|
||||||
|
biAligned := *(*alignedFileBasicInfo)(unsafe.Pointer(bi))
|
||||||
|
if err := windows.SetFileInformationByHandle(
|
||||||
|
windows.Handle(f.Fd()),
|
||||||
|
windows.FileBasicInfo,
|
||||||
|
(*byte)(unsafe.Pointer(&biAligned)),
|
||||||
|
uint32(unsafe.Sizeof(biAligned)),
|
||||||
|
); err != nil {
|
||||||
|
return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err}
|
||||||
|
}
|
||||||
|
runtime.KeepAlive(f)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileStandardInfo contains extended information for the file.
|
||||||
|
// FILE_STANDARD_INFO in WinBase.h
|
||||||
|
// https://docs.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-file_standard_info
|
||||||
|
type FileStandardInfo struct {
|
||||||
|
AllocationSize, EndOfFile int64
|
||||||
|
NumberOfLinks uint32
|
||||||
|
DeletePending, Directory bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFileStandardInfo retrieves ended information for the file.
|
||||||
|
func GetFileStandardInfo(f *os.File) (*FileStandardInfo, error) {
|
||||||
|
si := &FileStandardInfo{}
|
||||||
|
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()),
|
||||||
|
windows.FileStandardInfo,
|
||||||
|
(*byte)(unsafe.Pointer(si)),
|
||||||
|
uint32(unsafe.Sizeof(*si))); err != nil {
|
||||||
|
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||||
|
}
|
||||||
|
runtime.KeepAlive(f)
|
||||||
|
return si, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileIDInfo contains the volume serial number and file ID for a file. This pair should be
|
||||||
|
// unique on a system.
|
||||||
|
type FileIDInfo struct {
|
||||||
|
VolumeSerialNumber uint64
|
||||||
|
FileID [16]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFileID retrieves the unique (volume, file ID) pair for a file.
|
||||||
|
func GetFileID(f *os.File) (*FileIDInfo, error) {
|
||||||
|
fileID := &FileIDInfo{}
|
||||||
|
if err := windows.GetFileInformationByHandleEx(
|
||||||
|
windows.Handle(f.Fd()),
|
||||||
|
windows.FileIdInfo,
|
||||||
|
(*byte)(unsafe.Pointer(fileID)),
|
||||||
|
uint32(unsafe.Sizeof(*fileID)),
|
||||||
|
); err != nil {
|
||||||
|
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||||
|
}
|
||||||
|
runtime.KeepAlive(f)
|
||||||
|
return fileID, nil
|
||||||
|
}
|
||||||
582
vendor/github.com/Microsoft/go-winio/hvsock.go
generated
vendored
Normal file
582
vendor/github.com/Microsoft/go-winio/hvsock.go
generated
vendored
Normal file
@@ -0,0 +1,582 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio/internal/socket"
|
||||||
|
"github.com/Microsoft/go-winio/pkg/guid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const afHVSock = 34 // AF_HYPERV
|
||||||
|
|
||||||
|
// Well known Service and VM IDs
|
||||||
|
// https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
|
||||||
|
|
||||||
|
// HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
|
||||||
|
func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
|
||||||
|
return guid.GUID{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions.
|
||||||
|
func HvsockGUIDBroadcast() guid.GUID { // ffffffff-ffff-ffff-ffff-ffffffffffff
|
||||||
|
return guid.GUID{
|
||||||
|
Data1: 0xffffffff,
|
||||||
|
Data2: 0xffff,
|
||||||
|
Data3: 0xffff,
|
||||||
|
Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
|
||||||
|
func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
|
||||||
|
return guid.GUID{
|
||||||
|
Data1: 0xe0e16197,
|
||||||
|
Data2: 0xdd56,
|
||||||
|
Data3: 0x4a10,
|
||||||
|
Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HvsockGUIDSiloHost is the address of a silo's host partition:
|
||||||
|
// - The silo host of a hosted silo is the utility VM.
|
||||||
|
// - The silo host of a silo on a physical host is the physical host.
|
||||||
|
func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
|
||||||
|
return guid.GUID{
|
||||||
|
Data1: 0x36bd0c5c,
|
||||||
|
Data2: 0x7276,
|
||||||
|
Data3: 0x4223,
|
||||||
|
Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
|
||||||
|
func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
|
||||||
|
return guid.GUID{
|
||||||
|
Data1: 0x90db8b89,
|
||||||
|
Data2: 0xd35,
|
||||||
|
Data3: 0x4f79,
|
||||||
|
Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
|
||||||
|
// Listening on this VmId accepts connection from:
|
||||||
|
// - Inside silos: silo host partition.
|
||||||
|
// - Inside hosted silo: host of the VM.
|
||||||
|
// - Inside VM: VM host.
|
||||||
|
// - Physical host: Not supported.
|
||||||
|
func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
|
||||||
|
return guid.GUID{
|
||||||
|
Data1: 0xa42e7cda,
|
||||||
|
Data2: 0xd03f,
|
||||||
|
Data3: 0x480c,
|
||||||
|
Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol.
|
||||||
|
func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
|
||||||
|
return guid.GUID{
|
||||||
|
Data2: 0xfacb,
|
||||||
|
Data3: 0x11e6,
|
||||||
|
Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// An HvsockAddr is an address for a AF_HYPERV socket.
|
||||||
|
type HvsockAddr struct {
|
||||||
|
VMID guid.GUID
|
||||||
|
ServiceID guid.GUID
|
||||||
|
}
|
||||||
|
|
||||||
|
type rawHvsockAddr struct {
|
||||||
|
Family uint16
|
||||||
|
_ uint16
|
||||||
|
VMID guid.GUID
|
||||||
|
ServiceID guid.GUID
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ socket.RawSockaddr = &rawHvsockAddr{}
|
||||||
|
|
||||||
|
// Network returns the address's network name, "hvsock".
|
||||||
|
func (*HvsockAddr) Network() string {
|
||||||
|
return "hvsock"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *HvsockAddr) String() string {
|
||||||
|
return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
|
||||||
|
func VsockServiceID(port uint32) guid.GUID {
|
||||||
|
g := hvsockVsockServiceTemplate() // make a copy
|
||||||
|
g.Data1 = port
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *HvsockAddr) raw() rawHvsockAddr {
|
||||||
|
return rawHvsockAddr{
|
||||||
|
Family: afHVSock,
|
||||||
|
VMID: addr.VMID,
|
||||||
|
ServiceID: addr.ServiceID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
|
||||||
|
addr.VMID = raw.VMID
|
||||||
|
addr.ServiceID = raw.ServiceID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sockaddr returns a pointer to and the size of this struct.
|
||||||
|
//
|
||||||
|
// Implements the [socket.RawSockaddr] interface, and allows use in
|
||||||
|
// [socket.Bind] and [socket.ConnectEx].
|
||||||
|
func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
|
||||||
|
return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`.
|
||||||
|
func (r *rawHvsockAddr) FromBytes(b []byte) error {
|
||||||
|
n := int(unsafe.Sizeof(rawHvsockAddr{}))
|
||||||
|
|
||||||
|
if len(b) < n {
|
||||||
|
return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
|
||||||
|
if r.Family != afHVSock {
|
||||||
|
return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HvsockListener is a socket listener for the AF_HYPERV address family.
|
||||||
|
type HvsockListener struct {
|
||||||
|
sock *win32File
|
||||||
|
addr HvsockAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Listener = &HvsockListener{}
|
||||||
|
|
||||||
|
// HvsockConn is a connected socket of the AF_HYPERV address family.
|
||||||
|
type HvsockConn struct {
|
||||||
|
sock *win32File
|
||||||
|
local, remote HvsockAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Conn = &HvsockConn{}
|
||||||
|
|
||||||
|
func newHVSocket() (*win32File, error) {
|
||||||
|
fd, err := windows.Socket(afHVSock, windows.SOCK_STREAM, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, os.NewSyscallError("socket", err)
|
||||||
|
}
|
||||||
|
f, err := makeWin32File(fd)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f.socket = true
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenHvsock listens for connections on the specified hvsock address.
|
||||||
|
func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
|
||||||
|
l := &HvsockListener{addr: *addr}
|
||||||
|
|
||||||
|
var sock *win32File
|
||||||
|
sock, err = newHVSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, l.opErr("listen", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = sock.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
sa := addr.raw()
|
||||||
|
err = socket.Bind(sock.handle, &sa)
|
||||||
|
if err != nil {
|
||||||
|
return nil, l.opErr("listen", os.NewSyscallError("socket", err))
|
||||||
|
}
|
||||||
|
err = windows.Listen(sock.handle, 16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, l.opErr("listen", os.NewSyscallError("listen", err))
|
||||||
|
}
|
||||||
|
return &HvsockListener{sock: sock, addr: *addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *HvsockListener) opErr(op string, err error) error {
|
||||||
|
return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr returns the listener's network address.
|
||||||
|
func (l *HvsockListener) Addr() net.Addr {
|
||||||
|
return &l.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept waits for the next connection and returns it.
|
||||||
|
func (l *HvsockListener) Accept() (_ net.Conn, err error) {
|
||||||
|
sock, err := newHVSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, l.opErr("accept", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if sock != nil {
|
||||||
|
sock.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
c, err := l.sock.prepareIO()
|
||||||
|
if err != nil {
|
||||||
|
return nil, l.opErr("accept", err)
|
||||||
|
}
|
||||||
|
defer l.sock.wg.Done()
|
||||||
|
|
||||||
|
// AcceptEx, per documentation, requires an extra 16 bytes per address.
|
||||||
|
//
|
||||||
|
// https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
|
||||||
|
const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
|
||||||
|
var addrbuf [addrlen * 2]byte
|
||||||
|
|
||||||
|
var bytes uint32
|
||||||
|
err = windows.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o)
|
||||||
|
if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
|
||||||
|
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &HvsockConn{
|
||||||
|
sock: sock,
|
||||||
|
}
|
||||||
|
// The local address returned in the AcceptEx buffer is the same as the Listener socket's
|
||||||
|
// address. However, the service GUID reported by GetSockName is different from the Listeners
|
||||||
|
// socket, and is sometimes the same as the local address of the socket that dialed the
|
||||||
|
// address, with the service GUID.Data1 incremented, but othertimes is different.
|
||||||
|
// todo: does the local address matter? is the listener's address or the actual address appropriate?
|
||||||
|
conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
|
||||||
|
conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
|
||||||
|
|
||||||
|
// initialize the accepted socket and update its properties with those of the listening socket
|
||||||
|
if err = windows.Setsockopt(sock.handle,
|
||||||
|
windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
|
||||||
|
(*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
|
||||||
|
return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
sock = nil
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the listener, causing any pending Accept calls to fail.
|
||||||
|
func (l *HvsockListener) Close() error {
|
||||||
|
return l.sock.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]).
|
||||||
|
type HvsockDialer struct {
|
||||||
|
// Deadline is the time the Dial operation must connect before erroring.
|
||||||
|
Deadline time.Time
|
||||||
|
|
||||||
|
// Retries is the number of additional connects to try if the connection times out, is refused,
|
||||||
|
// or the host is unreachable
|
||||||
|
Retries uint
|
||||||
|
|
||||||
|
// RetryWait is the time to wait after a connection error to retry
|
||||||
|
RetryWait time.Duration
|
||||||
|
|
||||||
|
rt *time.Timer // redial wait timer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial the Hyper-V socket at addr.
|
||||||
|
//
|
||||||
|
// See [HvsockDialer.Dial] for more information.
|
||||||
|
func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
|
||||||
|
return (&HvsockDialer{}).Dial(ctx, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful.
|
||||||
|
// Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between
|
||||||
|
// retries.
|
||||||
|
//
|
||||||
|
// Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx.
|
||||||
|
func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
|
||||||
|
op := "dial"
|
||||||
|
// create the conn early to use opErr()
|
||||||
|
conn = &HvsockConn{
|
||||||
|
remote: *addr,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !d.Deadline.IsZero() {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithDeadline(ctx, d.Deadline)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// preemptive timeout/cancellation check
|
||||||
|
if err = ctx.Err(); err != nil {
|
||||||
|
return nil, conn.opErr(op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sock, err := newHVSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, conn.opErr(op, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if sock != nil {
|
||||||
|
sock.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
sa := addr.raw()
|
||||||
|
err = socket.Bind(sock.handle, &sa)
|
||||||
|
if err != nil {
|
||||||
|
return nil, conn.opErr(op, os.NewSyscallError("bind", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := sock.prepareIO()
|
||||||
|
if err != nil {
|
||||||
|
return nil, conn.opErr(op, err)
|
||||||
|
}
|
||||||
|
defer sock.wg.Done()
|
||||||
|
var bytes uint32
|
||||||
|
for i := uint(0); i <= d.Retries; i++ {
|
||||||
|
err = socket.ConnectEx(
|
||||||
|
sock.handle,
|
||||||
|
&sa,
|
||||||
|
nil, // sendBuf
|
||||||
|
0, // sendDataLen
|
||||||
|
&bytes,
|
||||||
|
(*windows.Overlapped)(unsafe.Pointer(&c.o)))
|
||||||
|
_, err = sock.asyncIO(c, nil, bytes, err)
|
||||||
|
if i < d.Retries && canRedial(err) {
|
||||||
|
if err = d.redialWait(ctx); err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the connection properties, so shutdown can be used
|
||||||
|
if err = windows.Setsockopt(
|
||||||
|
sock.handle,
|
||||||
|
windows.SOL_SOCKET,
|
||||||
|
windows.SO_UPDATE_CONNECT_CONTEXT,
|
||||||
|
nil, // optvalue
|
||||||
|
0, // optlen
|
||||||
|
); err != nil {
|
||||||
|
return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the local name
|
||||||
|
var sal rawHvsockAddr
|
||||||
|
err = socket.GetSockName(sock.handle, &sal)
|
||||||
|
if err != nil {
|
||||||
|
return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
|
||||||
|
}
|
||||||
|
conn.local.fromRaw(&sal)
|
||||||
|
|
||||||
|
// one last check for timeout, since asyncIO doesn't check the context
|
||||||
|
if err = ctx.Err(); err != nil {
|
||||||
|
return nil, conn.opErr(op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.sock = sock
|
||||||
|
sock = nil
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// redialWait waits before attempting to redial, resetting the timer as appropriate.
|
||||||
|
func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
|
||||||
|
if d.RetryWait == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.rt == nil {
|
||||||
|
d.rt = time.NewTimer(d.RetryWait)
|
||||||
|
} else {
|
||||||
|
// should already be stopped and drained
|
||||||
|
d.rt.Reset(d.RetryWait)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-d.rt.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop and drain the timer
|
||||||
|
if !d.rt.Stop() {
|
||||||
|
<-d.rt.C
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// assumes error is a plain, unwrapped windows.Errno provided by direct syscall.
|
||||||
|
func canRedial(err error) bool {
|
||||||
|
//nolint:errorlint // guaranteed to be an Errno
|
||||||
|
switch err {
|
||||||
|
case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
|
||||||
|
windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *HvsockConn) opErr(op string, err error) error {
|
||||||
|
// translate from "file closed" to "socket closed"
|
||||||
|
if errors.Is(err, ErrFileClosed) {
|
||||||
|
err = socket.ErrSocketClosed
|
||||||
|
}
|
||||||
|
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *HvsockConn) Read(b []byte) (int, error) {
|
||||||
|
c, err := conn.sock.prepareIO()
|
||||||
|
if err != nil {
|
||||||
|
return 0, conn.opErr("read", err)
|
||||||
|
}
|
||||||
|
defer conn.sock.wg.Done()
|
||||||
|
buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||||
|
var flags, bytes uint32
|
||||||
|
err = windows.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
|
||||||
|
n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
|
||||||
|
if err != nil {
|
||||||
|
var eno windows.Errno
|
||||||
|
if errors.As(err, &eno) {
|
||||||
|
err = os.NewSyscallError("wsarecv", eno)
|
||||||
|
}
|
||||||
|
return 0, conn.opErr("read", err)
|
||||||
|
} else if n == 0 {
|
||||||
|
err = io.EOF
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *HvsockConn) Write(b []byte) (int, error) {
|
||||||
|
t := 0
|
||||||
|
for len(b) != 0 {
|
||||||
|
n, err := conn.write(b)
|
||||||
|
if err != nil {
|
||||||
|
return t + n, err
|
||||||
|
}
|
||||||
|
t += n
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *HvsockConn) write(b []byte) (int, error) {
|
||||||
|
c, err := conn.sock.prepareIO()
|
||||||
|
if err != nil {
|
||||||
|
return 0, conn.opErr("write", err)
|
||||||
|
}
|
||||||
|
defer conn.sock.wg.Done()
|
||||||
|
buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||||
|
var bytes uint32
|
||||||
|
err = windows.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
|
||||||
|
n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
|
||||||
|
if err != nil {
|
||||||
|
var eno windows.Errno
|
||||||
|
if errors.As(err, &eno) {
|
||||||
|
err = os.NewSyscallError("wsasend", eno)
|
||||||
|
}
|
||||||
|
return 0, conn.opErr("write", err)
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the socket connection, failing any pending read or write calls.
|
||||||
|
func (conn *HvsockConn) Close() error {
|
||||||
|
return conn.sock.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *HvsockConn) IsClosed() bool {
|
||||||
|
return conn.sock.IsClosed()
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdown disables sending or receiving on a socket.
|
||||||
|
func (conn *HvsockConn) shutdown(how int) error {
|
||||||
|
if conn.IsClosed() {
|
||||||
|
return socket.ErrSocketClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
err := windows.Shutdown(conn.sock.handle, how)
|
||||||
|
if err != nil {
|
||||||
|
// If the connection was closed, shutdowns fail with "not connected"
|
||||||
|
if errors.Is(err, windows.WSAENOTCONN) ||
|
||||||
|
errors.Is(err, windows.WSAESHUTDOWN) {
|
||||||
|
err = socket.ErrSocketClosed
|
||||||
|
}
|
||||||
|
return os.NewSyscallError("shutdown", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseRead shuts down the read end of the socket, preventing future read operations.
|
||||||
|
func (conn *HvsockConn) CloseRead() error {
|
||||||
|
err := conn.shutdown(windows.SHUT_RD)
|
||||||
|
if err != nil {
|
||||||
|
return conn.opErr("closeread", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseWrite shuts down the write end of the socket, preventing future write operations and
|
||||||
|
// notifying the other endpoint that no more data will be written.
|
||||||
|
func (conn *HvsockConn) CloseWrite() error {
|
||||||
|
err := conn.shutdown(windows.SHUT_WR)
|
||||||
|
if err != nil {
|
||||||
|
return conn.opErr("closewrite", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the local address of the connection.
|
||||||
|
func (conn *HvsockConn) LocalAddr() net.Addr {
|
||||||
|
return &conn.local
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the remote address of the connection.
|
||||||
|
func (conn *HvsockConn) RemoteAddr() net.Addr {
|
||||||
|
return &conn.remote
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline implements the net.Conn SetDeadline method.
|
||||||
|
func (conn *HvsockConn) SetDeadline(t time.Time) error {
|
||||||
|
// todo: implement `SetDeadline` for `win32File`
|
||||||
|
if err := conn.SetReadDeadline(t); err != nil {
|
||||||
|
return fmt.Errorf("set read deadline: %w", err)
|
||||||
|
}
|
||||||
|
if err := conn.SetWriteDeadline(t); err != nil {
|
||||||
|
return fmt.Errorf("set write deadline: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline implements the net.Conn SetReadDeadline method.
|
||||||
|
func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return conn.sock.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
|
||||||
|
func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return conn.sock.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
2
vendor/github.com/Microsoft/go-winio/internal/fs/doc.go
generated
vendored
Normal file
2
vendor/github.com/Microsoft/go-winio/internal/fs/doc.go
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// This package contains Win32 filesystem functionality.
|
||||||
|
package fs
|
||||||
262
vendor/github.com/Microsoft/go-winio/internal/fs/fs.go
generated
vendored
Normal file
262
vendor/github.com/Microsoft/go-winio/internal/fs/fs.go
generated
vendored
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package fs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio/internal/stringbuffer"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go fs.go
|
||||||
|
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew
|
||||||
|
//sys CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
|
||||||
|
|
||||||
|
const NullHandle windows.Handle = 0
|
||||||
|
|
||||||
|
// AccessMask defines standard, specific, and generic rights.
|
||||||
|
//
|
||||||
|
// Used with CreateFile and NtCreateFile (and co.).
|
||||||
|
//
|
||||||
|
// Bitmask:
|
||||||
|
// 3 3 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1
|
||||||
|
// 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
|
||||||
|
// +---------------+---------------+-------------------------------+
|
||||||
|
// |G|G|G|G|Resvd|A| StandardRights| SpecificRights |
|
||||||
|
// |R|W|E|A| |S| | |
|
||||||
|
// +-+-------------+---------------+-------------------------------+
|
||||||
|
//
|
||||||
|
// GR Generic Read
|
||||||
|
// GW Generic Write
|
||||||
|
// GE Generic Exectue
|
||||||
|
// GA Generic All
|
||||||
|
// Resvd Reserved
|
||||||
|
// AS Access Security System
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/secauthz/access-mask
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/fileio/file-access-rights-constants
|
||||||
|
type AccessMask = windows.ACCESS_MASK
|
||||||
|
|
||||||
|
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||||
|
const (
|
||||||
|
// Not actually any.
|
||||||
|
//
|
||||||
|
// For CreateFile: "query certain metadata such as file, directory, or device attributes without accessing that file or device"
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew#parameters
|
||||||
|
FILE_ANY_ACCESS AccessMask = 0
|
||||||
|
|
||||||
|
GENERIC_READ AccessMask = 0x8000_0000
|
||||||
|
GENERIC_WRITE AccessMask = 0x4000_0000
|
||||||
|
GENERIC_EXECUTE AccessMask = 0x2000_0000
|
||||||
|
GENERIC_ALL AccessMask = 0x1000_0000
|
||||||
|
ACCESS_SYSTEM_SECURITY AccessMask = 0x0100_0000
|
||||||
|
|
||||||
|
// Specific Object Access
|
||||||
|
// from ntioapi.h
|
||||||
|
|
||||||
|
FILE_READ_DATA AccessMask = (0x0001) // file & pipe
|
||||||
|
FILE_LIST_DIRECTORY AccessMask = (0x0001) // directory
|
||||||
|
|
||||||
|
FILE_WRITE_DATA AccessMask = (0x0002) // file & pipe
|
||||||
|
FILE_ADD_FILE AccessMask = (0x0002) // directory
|
||||||
|
|
||||||
|
FILE_APPEND_DATA AccessMask = (0x0004) // file
|
||||||
|
FILE_ADD_SUBDIRECTORY AccessMask = (0x0004) // directory
|
||||||
|
FILE_CREATE_PIPE_INSTANCE AccessMask = (0x0004) // named pipe
|
||||||
|
|
||||||
|
FILE_READ_EA AccessMask = (0x0008) // file & directory
|
||||||
|
FILE_READ_PROPERTIES AccessMask = FILE_READ_EA
|
||||||
|
|
||||||
|
FILE_WRITE_EA AccessMask = (0x0010) // file & directory
|
||||||
|
FILE_WRITE_PROPERTIES AccessMask = FILE_WRITE_EA
|
||||||
|
|
||||||
|
FILE_EXECUTE AccessMask = (0x0020) // file
|
||||||
|
FILE_TRAVERSE AccessMask = (0x0020) // directory
|
||||||
|
|
||||||
|
FILE_DELETE_CHILD AccessMask = (0x0040) // directory
|
||||||
|
|
||||||
|
FILE_READ_ATTRIBUTES AccessMask = (0x0080) // all
|
||||||
|
|
||||||
|
FILE_WRITE_ATTRIBUTES AccessMask = (0x0100) // all
|
||||||
|
|
||||||
|
FILE_ALL_ACCESS AccessMask = (STANDARD_RIGHTS_REQUIRED | SYNCHRONIZE | 0x1FF)
|
||||||
|
FILE_GENERIC_READ AccessMask = (STANDARD_RIGHTS_READ | FILE_READ_DATA | FILE_READ_ATTRIBUTES | FILE_READ_EA | SYNCHRONIZE)
|
||||||
|
FILE_GENERIC_WRITE AccessMask = (STANDARD_RIGHTS_WRITE | FILE_WRITE_DATA | FILE_WRITE_ATTRIBUTES | FILE_WRITE_EA | FILE_APPEND_DATA | SYNCHRONIZE)
|
||||||
|
FILE_GENERIC_EXECUTE AccessMask = (STANDARD_RIGHTS_EXECUTE | FILE_READ_ATTRIBUTES | FILE_EXECUTE | SYNCHRONIZE)
|
||||||
|
|
||||||
|
SPECIFIC_RIGHTS_ALL AccessMask = 0x0000FFFF
|
||||||
|
|
||||||
|
// Standard Access
|
||||||
|
// from ntseapi.h
|
||||||
|
|
||||||
|
DELETE AccessMask = 0x0001_0000
|
||||||
|
READ_CONTROL AccessMask = 0x0002_0000
|
||||||
|
WRITE_DAC AccessMask = 0x0004_0000
|
||||||
|
WRITE_OWNER AccessMask = 0x0008_0000
|
||||||
|
SYNCHRONIZE AccessMask = 0x0010_0000
|
||||||
|
|
||||||
|
STANDARD_RIGHTS_REQUIRED AccessMask = 0x000F_0000
|
||||||
|
|
||||||
|
STANDARD_RIGHTS_READ AccessMask = READ_CONTROL
|
||||||
|
STANDARD_RIGHTS_WRITE AccessMask = READ_CONTROL
|
||||||
|
STANDARD_RIGHTS_EXECUTE AccessMask = READ_CONTROL
|
||||||
|
|
||||||
|
STANDARD_RIGHTS_ALL AccessMask = 0x001F_0000
|
||||||
|
)
|
||||||
|
|
||||||
|
type FileShareMode uint32
|
||||||
|
|
||||||
|
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||||
|
const (
|
||||||
|
FILE_SHARE_NONE FileShareMode = 0x00
|
||||||
|
FILE_SHARE_READ FileShareMode = 0x01
|
||||||
|
FILE_SHARE_WRITE FileShareMode = 0x02
|
||||||
|
FILE_SHARE_DELETE FileShareMode = 0x04
|
||||||
|
FILE_SHARE_VALID_FLAGS FileShareMode = 0x07
|
||||||
|
)
|
||||||
|
|
||||||
|
type FileCreationDisposition uint32
|
||||||
|
|
||||||
|
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||||
|
const (
|
||||||
|
// from winbase.h
|
||||||
|
|
||||||
|
CREATE_NEW FileCreationDisposition = 0x01
|
||||||
|
CREATE_ALWAYS FileCreationDisposition = 0x02
|
||||||
|
OPEN_EXISTING FileCreationDisposition = 0x03
|
||||||
|
OPEN_ALWAYS FileCreationDisposition = 0x04
|
||||||
|
TRUNCATE_EXISTING FileCreationDisposition = 0x05
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create disposition values for NtCreate*
|
||||||
|
type NTFileCreationDisposition uint32
|
||||||
|
|
||||||
|
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||||
|
const (
|
||||||
|
// From ntioapi.h
|
||||||
|
|
||||||
|
FILE_SUPERSEDE NTFileCreationDisposition = 0x00
|
||||||
|
FILE_OPEN NTFileCreationDisposition = 0x01
|
||||||
|
FILE_CREATE NTFileCreationDisposition = 0x02
|
||||||
|
FILE_OPEN_IF NTFileCreationDisposition = 0x03
|
||||||
|
FILE_OVERWRITE NTFileCreationDisposition = 0x04
|
||||||
|
FILE_OVERWRITE_IF NTFileCreationDisposition = 0x05
|
||||||
|
FILE_MAXIMUM_DISPOSITION NTFileCreationDisposition = 0x05
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateFile and co. take flags or attributes together as one parameter.
|
||||||
|
// Define alias until we can use generics to allow both
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants
|
||||||
|
type FileFlagOrAttribute uint32
|
||||||
|
|
||||||
|
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||||
|
const (
|
||||||
|
// from winnt.h
|
||||||
|
|
||||||
|
FILE_FLAG_WRITE_THROUGH FileFlagOrAttribute = 0x8000_0000
|
||||||
|
FILE_FLAG_OVERLAPPED FileFlagOrAttribute = 0x4000_0000
|
||||||
|
FILE_FLAG_NO_BUFFERING FileFlagOrAttribute = 0x2000_0000
|
||||||
|
FILE_FLAG_RANDOM_ACCESS FileFlagOrAttribute = 0x1000_0000
|
||||||
|
FILE_FLAG_SEQUENTIAL_SCAN FileFlagOrAttribute = 0x0800_0000
|
||||||
|
FILE_FLAG_DELETE_ON_CLOSE FileFlagOrAttribute = 0x0400_0000
|
||||||
|
FILE_FLAG_BACKUP_SEMANTICS FileFlagOrAttribute = 0x0200_0000
|
||||||
|
FILE_FLAG_POSIX_SEMANTICS FileFlagOrAttribute = 0x0100_0000
|
||||||
|
FILE_FLAG_OPEN_REPARSE_POINT FileFlagOrAttribute = 0x0020_0000
|
||||||
|
FILE_FLAG_OPEN_NO_RECALL FileFlagOrAttribute = 0x0010_0000
|
||||||
|
FILE_FLAG_FIRST_PIPE_INSTANCE FileFlagOrAttribute = 0x0008_0000
|
||||||
|
)
|
||||||
|
|
||||||
|
// NtCreate* functions take a dedicated CreateOptions parameter.
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/Winternl/nf-winternl-ntcreatefile
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/devnotes/nt-create-named-pipe-file
|
||||||
|
type NTCreateOptions uint32
|
||||||
|
|
||||||
|
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||||
|
const (
|
||||||
|
// From ntioapi.h
|
||||||
|
|
||||||
|
FILE_DIRECTORY_FILE NTCreateOptions = 0x0000_0001
|
||||||
|
FILE_WRITE_THROUGH NTCreateOptions = 0x0000_0002
|
||||||
|
FILE_SEQUENTIAL_ONLY NTCreateOptions = 0x0000_0004
|
||||||
|
FILE_NO_INTERMEDIATE_BUFFERING NTCreateOptions = 0x0000_0008
|
||||||
|
|
||||||
|
FILE_SYNCHRONOUS_IO_ALERT NTCreateOptions = 0x0000_0010
|
||||||
|
FILE_SYNCHRONOUS_IO_NONALERT NTCreateOptions = 0x0000_0020
|
||||||
|
FILE_NON_DIRECTORY_FILE NTCreateOptions = 0x0000_0040
|
||||||
|
FILE_CREATE_TREE_CONNECTION NTCreateOptions = 0x0000_0080
|
||||||
|
|
||||||
|
FILE_COMPLETE_IF_OPLOCKED NTCreateOptions = 0x0000_0100
|
||||||
|
FILE_NO_EA_KNOWLEDGE NTCreateOptions = 0x0000_0200
|
||||||
|
FILE_DISABLE_TUNNELING NTCreateOptions = 0x0000_0400
|
||||||
|
FILE_RANDOM_ACCESS NTCreateOptions = 0x0000_0800
|
||||||
|
|
||||||
|
FILE_DELETE_ON_CLOSE NTCreateOptions = 0x0000_1000
|
||||||
|
FILE_OPEN_BY_FILE_ID NTCreateOptions = 0x0000_2000
|
||||||
|
FILE_OPEN_FOR_BACKUP_INTENT NTCreateOptions = 0x0000_4000
|
||||||
|
FILE_NO_COMPRESSION NTCreateOptions = 0x0000_8000
|
||||||
|
)
|
||||||
|
|
||||||
|
type FileSQSFlag = FileFlagOrAttribute
|
||||||
|
|
||||||
|
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||||
|
const (
|
||||||
|
// from winbase.h
|
||||||
|
|
||||||
|
SECURITY_ANONYMOUS FileSQSFlag = FileSQSFlag(SecurityAnonymous << 16)
|
||||||
|
SECURITY_IDENTIFICATION FileSQSFlag = FileSQSFlag(SecurityIdentification << 16)
|
||||||
|
SECURITY_IMPERSONATION FileSQSFlag = FileSQSFlag(SecurityImpersonation << 16)
|
||||||
|
SECURITY_DELEGATION FileSQSFlag = FileSQSFlag(SecurityDelegation << 16)
|
||||||
|
|
||||||
|
SECURITY_SQOS_PRESENT FileSQSFlag = 0x0010_0000
|
||||||
|
SECURITY_VALID_SQOS_FLAGS FileSQSFlag = 0x001F_0000
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetFinalPathNameByHandle flags
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfinalpathnamebyhandlew#parameters
|
||||||
|
type GetFinalPathFlag uint32
|
||||||
|
|
||||||
|
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||||
|
const (
|
||||||
|
GetFinalPathDefaultFlag GetFinalPathFlag = 0x0
|
||||||
|
|
||||||
|
FILE_NAME_NORMALIZED GetFinalPathFlag = 0x0
|
||||||
|
FILE_NAME_OPENED GetFinalPathFlag = 0x8
|
||||||
|
|
||||||
|
VOLUME_NAME_DOS GetFinalPathFlag = 0x0
|
||||||
|
VOLUME_NAME_GUID GetFinalPathFlag = 0x1
|
||||||
|
VOLUME_NAME_NT GetFinalPathFlag = 0x2
|
||||||
|
VOLUME_NAME_NONE GetFinalPathFlag = 0x4
|
||||||
|
)
|
||||||
|
|
||||||
|
// getFinalPathNameByHandle facilitates calling the Windows API GetFinalPathNameByHandle
|
||||||
|
// with the given handle and flags. It transparently takes care of creating a buffer of the
|
||||||
|
// correct size for the call.
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfinalpathnamebyhandlew
|
||||||
|
func GetFinalPathNameByHandle(h windows.Handle, flags GetFinalPathFlag) (string, error) {
|
||||||
|
b := stringbuffer.NewWString()
|
||||||
|
//TODO: can loop infinitely if Win32 keeps returning the same (or a larger) n?
|
||||||
|
for {
|
||||||
|
n, err := windows.GetFinalPathNameByHandle(h, b.Pointer(), b.Cap(), uint32(flags))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// If the buffer wasn't large enough, n will be the total size needed (including null terminator).
|
||||||
|
// Resize and try again.
|
||||||
|
if n > b.Cap() {
|
||||||
|
b.ResizeTo(n)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// If the buffer is large enough, n will be the size not including the null terminator.
|
||||||
|
// Convert to a Go string and return.
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
12
vendor/github.com/Microsoft/go-winio/internal/fs/security.go
generated
vendored
Normal file
12
vendor/github.com/Microsoft/go-winio/internal/fs/security.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package fs
|
||||||
|
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level
|
||||||
|
type SecurityImpersonationLevel int32 // C default enums underlying type is `int`, which is Go `int32`
|
||||||
|
|
||||||
|
// Impersonation levels
|
||||||
|
const (
|
||||||
|
SecurityAnonymous SecurityImpersonationLevel = 0
|
||||||
|
SecurityIdentification SecurityImpersonationLevel = 1
|
||||||
|
SecurityImpersonation SecurityImpersonationLevel = 2
|
||||||
|
SecurityDelegation SecurityImpersonationLevel = 3
|
||||||
|
)
|
||||||
61
vendor/github.com/Microsoft/go-winio/internal/fs/zsyscall_windows.go
generated
vendored
Normal file
61
vendor/github.com/Microsoft/go-winio/internal/fs/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package fs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ unsafe.Pointer
|
||||||
|
|
||||||
|
// Do the interface allocations only once for common
|
||||||
|
// Errno values.
|
||||||
|
const (
|
||||||
|
errnoERROR_IO_PENDING = 997
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||||
|
errERROR_EINVAL error = syscall.EINVAL
|
||||||
|
)
|
||||||
|
|
||||||
|
// errnoErr returns common boxed Errno values, to prevent
|
||||||
|
// allocations at runtime.
|
||||||
|
func errnoErr(e syscall.Errno) error {
|
||||||
|
switch e {
|
||||||
|
case 0:
|
||||||
|
return errERROR_EINVAL
|
||||||
|
case errnoERROR_IO_PENDING:
|
||||||
|
return errERROR_IO_PENDING
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
|
|
||||||
|
procCreateFileW = modkernel32.NewProc("CreateFileW")
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||||
|
var _p0 *uint16
|
||||||
|
_p0, err = syscall.UTF16PtrFromString(name)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return _CreateFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _CreateFile(name *uint16, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.SyscallN(procCreateFileW.Addr(), uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile))
|
||||||
|
handle = windows.Handle(r0)
|
||||||
|
if handle == windows.InvalidHandle {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
20
vendor/github.com/Microsoft/go-winio/internal/socket/rawaddr.go
generated
vendored
Normal file
20
vendor/github.com/Microsoft/go-winio/internal/socket/rawaddr.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package socket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RawSockaddr allows structs to be used with [Bind] and [ConnectEx]. The
|
||||||
|
// struct must meet the Win32 sockaddr requirements specified here:
|
||||||
|
// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
|
||||||
|
//
|
||||||
|
// Specifically, the struct size must be least larger than an int16 (unsigned short)
|
||||||
|
// for the address family.
|
||||||
|
type RawSockaddr interface {
|
||||||
|
// Sockaddr returns a pointer to the RawSockaddr and its struct size, allowing
|
||||||
|
// for the RawSockaddr's data to be overwritten by syscalls (if necessary).
|
||||||
|
//
|
||||||
|
// It is the callers responsibility to validate that the values are valid; invalid
|
||||||
|
// pointers or size can cause a panic.
|
||||||
|
Sockaddr() (unsafe.Pointer, int32, error)
|
||||||
|
}
|
||||||
177
vendor/github.com/Microsoft/go-winio/internal/socket/socket.go
generated
vendored
Normal file
177
vendor/github.com/Microsoft/go-winio/internal/socket/socket.go
generated
vendored
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package socket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio/pkg/guid"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
|
||||||
|
|
||||||
|
//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
|
||||||
|
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
|
||||||
|
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
|
||||||
|
|
||||||
|
const socketError = uintptr(^uint32(0))
|
||||||
|
|
||||||
|
var (
|
||||||
|
// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
|
||||||
|
|
||||||
|
ErrBufferSize = errors.New("buffer size")
|
||||||
|
ErrAddrFamily = errors.New("address family")
|
||||||
|
ErrInvalidPointer = errors.New("invalid pointer")
|
||||||
|
ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
|
||||||
|
)
|
||||||
|
|
||||||
|
// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
|
||||||
|
|
||||||
|
// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
|
||||||
|
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
|
||||||
|
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
|
||||||
|
ptr, l, err := rsa.Sockaddr()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
|
||||||
|
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
|
||||||
|
return getsockname(s, ptr, &l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerName returns the remote address the socket is connected to.
|
||||||
|
//
|
||||||
|
// See [GetSockName] for more information.
|
||||||
|
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
|
||||||
|
ptr, l, err := rsa.Sockaddr()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return getpeername(s, ptr, &l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
|
||||||
|
ptr, l, err := rsa.Sockaddr()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bind(s, ptr, l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
|
||||||
|
// their sockaddr interface, so they cannot be used with HvsockAddr
|
||||||
|
// Replicate functionality here from
|
||||||
|
// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
|
||||||
|
|
||||||
|
// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
|
||||||
|
// runtime via a WSAIoctl call:
|
||||||
|
// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
|
||||||
|
|
||||||
|
type runtimeFunc struct {
|
||||||
|
id guid.GUID
|
||||||
|
once sync.Once
|
||||||
|
addr uintptr
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *runtimeFunc) Load() error {
|
||||||
|
f.once.Do(func() {
|
||||||
|
var s windows.Handle
|
||||||
|
s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
|
||||||
|
if f.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer windows.CloseHandle(s) //nolint:errcheck
|
||||||
|
|
||||||
|
var n uint32
|
||||||
|
f.err = windows.WSAIoctl(s,
|
||||||
|
windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
|
||||||
|
(*byte)(unsafe.Pointer(&f.id)),
|
||||||
|
uint32(unsafe.Sizeof(f.id)),
|
||||||
|
(*byte)(unsafe.Pointer(&f.addr)),
|
||||||
|
uint32(unsafe.Sizeof(f.addr)),
|
||||||
|
&n,
|
||||||
|
nil, // overlapped
|
||||||
|
0, // completionRoutine
|
||||||
|
)
|
||||||
|
})
|
||||||
|
return f.err
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
|
||||||
|
WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
|
||||||
|
Data1: 0x25a207b9,
|
||||||
|
Data2: 0xddf3,
|
||||||
|
Data3: 0x4660,
|
||||||
|
Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
|
||||||
|
}
|
||||||
|
|
||||||
|
connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
|
||||||
|
)
|
||||||
|
|
||||||
|
func ConnectEx(
|
||||||
|
fd windows.Handle,
|
||||||
|
rsa RawSockaddr,
|
||||||
|
sendBuf *byte,
|
||||||
|
sendDataLen uint32,
|
||||||
|
bytesSent *uint32,
|
||||||
|
overlapped *windows.Overlapped,
|
||||||
|
) error {
|
||||||
|
if err := connectExFunc.Load(); err != nil {
|
||||||
|
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
|
||||||
|
}
|
||||||
|
ptr, n, err := rsa.Sockaddr()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BOOL LpfnConnectex(
|
||||||
|
// [in] SOCKET s,
|
||||||
|
// [in] const sockaddr *name,
|
||||||
|
// [in] int namelen,
|
||||||
|
// [in, optional] PVOID lpSendBuffer,
|
||||||
|
// [in] DWORD dwSendDataLength,
|
||||||
|
// [out] LPDWORD lpdwBytesSent,
|
||||||
|
// [in] LPOVERLAPPED lpOverlapped
|
||||||
|
// )
|
||||||
|
|
||||||
|
func connectEx(
|
||||||
|
s windows.Handle,
|
||||||
|
name unsafe.Pointer,
|
||||||
|
namelen int32,
|
||||||
|
sendBuf *byte,
|
||||||
|
sendDataLen uint32,
|
||||||
|
bytesSent *uint32,
|
||||||
|
overlapped *windows.Overlapped,
|
||||||
|
) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(connectExFunc.addr,
|
||||||
|
uintptr(s),
|
||||||
|
uintptr(name),
|
||||||
|
uintptr(namelen),
|
||||||
|
uintptr(unsafe.Pointer(sendBuf)),
|
||||||
|
uintptr(sendDataLen),
|
||||||
|
uintptr(unsafe.Pointer(bytesSent)),
|
||||||
|
uintptr(unsafe.Pointer(overlapped)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = error(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
69
vendor/github.com/Microsoft/go-winio/internal/socket/zsyscall_windows.go
generated
vendored
Normal file
69
vendor/github.com/Microsoft/go-winio/internal/socket/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package socket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ unsafe.Pointer
|
||||||
|
|
||||||
|
// Do the interface allocations only once for common
|
||||||
|
// Errno values.
|
||||||
|
const (
|
||||||
|
errnoERROR_IO_PENDING = 997
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||||
|
errERROR_EINVAL error = syscall.EINVAL
|
||||||
|
)
|
||||||
|
|
||||||
|
// errnoErr returns common boxed Errno values, to prevent
|
||||||
|
// allocations at runtime.
|
||||||
|
func errnoErr(e syscall.Errno) error {
|
||||||
|
switch e {
|
||||||
|
case 0:
|
||||||
|
return errERROR_EINVAL
|
||||||
|
case errnoERROR_IO_PENDING:
|
||||||
|
return errERROR_IO_PENDING
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||||
|
|
||||||
|
procbind = modws2_32.NewProc("bind")
|
||||||
|
procgetpeername = modws2_32.NewProc("getpeername")
|
||||||
|
procgetsockname = modws2_32.NewProc("getsockname")
|
||||||
|
)
|
||||||
|
|
||||||
|
func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procbind.Addr(), uintptr(s), uintptr(name), uintptr(namelen))
|
||||||
|
if r1 == socketError {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procgetpeername.Addr(), uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
|
||||||
|
if r1 == socketError {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procgetsockname.Addr(), uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
|
||||||
|
if r1 == socketError {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
132
vendor/github.com/Microsoft/go-winio/internal/stringbuffer/wstring.go
generated
vendored
Normal file
132
vendor/github.com/Microsoft/go-winio/internal/stringbuffer/wstring.go
generated
vendored
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package stringbuffer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"unicode/utf16"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: worth exporting and using in mkwinsyscall?
|
||||||
|
|
||||||
|
// Uint16BufferSize is the buffer size in the pool, chosen somewhat arbitrarily to accommodate
|
||||||
|
// large path strings:
|
||||||
|
// MAX_PATH (260) + size of volume GUID prefix (49) + null terminator = 310.
|
||||||
|
const MinWStringCap = 310
|
||||||
|
|
||||||
|
// use *[]uint16 since []uint16 creates an extra allocation where the slice header
|
||||||
|
// is copied to heap and then referenced via pointer in the interface header that sync.Pool
|
||||||
|
// stores.
|
||||||
|
var pathPool = sync.Pool{ // if go1.18+ adds Pool[T], use that to store []uint16 directly
|
||||||
|
New: func() interface{} {
|
||||||
|
b := make([]uint16, MinWStringCap)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBuffer() []uint16 { return *(pathPool.Get().(*[]uint16)) }
|
||||||
|
|
||||||
|
// freeBuffer copies the slice header data, and puts a pointer to that in the pool.
|
||||||
|
// This avoids taking a pointer to the slice header in WString, which can be set to nil.
|
||||||
|
func freeBuffer(b []uint16) { pathPool.Put(&b) }
|
||||||
|
|
||||||
|
// WString is a wide string buffer ([]uint16) meant for storing UTF-16 encoded strings
|
||||||
|
// for interacting with Win32 APIs.
|
||||||
|
// Sizes are specified as uint32 and not int.
|
||||||
|
//
|
||||||
|
// It is not thread safe.
|
||||||
|
type WString struct {
|
||||||
|
// type-def allows casting to []uint16 directly, use struct to prevent that and allow adding fields in the future.
|
||||||
|
|
||||||
|
// raw buffer
|
||||||
|
b []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWString returns a [WString] allocated from a shared pool with an
|
||||||
|
// initial capacity of at least [MinWStringCap].
|
||||||
|
// Since the buffer may have been previously used, its contents are not guaranteed to be empty.
|
||||||
|
//
|
||||||
|
// The buffer should be freed via [WString.Free]
|
||||||
|
func NewWString() *WString {
|
||||||
|
return &WString{
|
||||||
|
b: newBuffer(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *WString) Free() {
|
||||||
|
if b.empty() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
freeBuffer(b.b)
|
||||||
|
b.b = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResizeTo grows the buffer to at least c and returns the new capacity, freeing the
|
||||||
|
// previous buffer back into pool.
|
||||||
|
func (b *WString) ResizeTo(c uint32) uint32 {
|
||||||
|
// already sufficient (or n is 0)
|
||||||
|
if c <= b.Cap() {
|
||||||
|
return b.Cap()
|
||||||
|
}
|
||||||
|
|
||||||
|
if c <= MinWStringCap {
|
||||||
|
c = MinWStringCap
|
||||||
|
}
|
||||||
|
// allocate at-least double buffer size, as is done in [bytes.Buffer] and other places
|
||||||
|
if c <= 2*b.Cap() {
|
||||||
|
c = 2 * b.Cap()
|
||||||
|
}
|
||||||
|
|
||||||
|
b2 := make([]uint16, c)
|
||||||
|
if !b.empty() {
|
||||||
|
copy(b2, b.b)
|
||||||
|
freeBuffer(b.b)
|
||||||
|
}
|
||||||
|
b.b = b2
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffer returns the underlying []uint16 buffer.
|
||||||
|
func (b *WString) Buffer() []uint16 {
|
||||||
|
if b.empty() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return b.b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pointer returns a pointer to the first uint16 in the buffer.
|
||||||
|
// If the [WString.Free] has already been called, the pointer will be nil.
|
||||||
|
func (b *WString) Pointer() *uint16 {
|
||||||
|
if b.empty() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &b.b[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the returns the UTF-8 encoding of the UTF-16 string in the buffer.
|
||||||
|
//
|
||||||
|
// It assumes that the data is null-terminated.
|
||||||
|
func (b *WString) String() string {
|
||||||
|
// Using [windows.UTF16ToString] would require importing "golang.org/x/sys/windows"
|
||||||
|
// and would make this code Windows-only, which makes no sense.
|
||||||
|
// So copy UTF16ToString code into here.
|
||||||
|
// If other windows-specific code is added, switch to [windows.UTF16ToString]
|
||||||
|
|
||||||
|
s := b.b
|
||||||
|
for i, v := range s {
|
||||||
|
if v == 0 {
|
||||||
|
s = s[:i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(utf16.Decode(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap returns the underlying buffer capacity.
|
||||||
|
func (b *WString) Cap() uint32 {
|
||||||
|
if b.empty() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return b.cap()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *WString) cap() uint32 { return uint32(cap(b.b)) }
|
||||||
|
func (b *WString) empty() bool { return b == nil || b.cap() == 0 }
|
||||||
586
vendor/github.com/Microsoft/go-winio/pipe.go
generated
vendored
Normal file
586
vendor/github.com/Microsoft/go-winio/pipe.go
generated
vendored
Normal file
@@ -0,0 +1,586 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio/internal/fs"
|
||||||
|
)
|
||||||
|
|
||||||
|
//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
|
||||||
|
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
|
||||||
|
//sys disconnectNamedPipe(pipe windows.Handle) (err error) = DisconnectNamedPipe
|
||||||
|
//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
||||||
|
//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
||||||
|
//sys ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) = ntdll.NtCreateNamedPipeFile
|
||||||
|
//sys rtlNtStatusToDosError(status ntStatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
||||||
|
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U
|
||||||
|
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl
|
||||||
|
|
||||||
|
type PipeConn interface {
|
||||||
|
net.Conn
|
||||||
|
Disconnect() error
|
||||||
|
Flush() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// type aliases for mkwinsyscall code
|
||||||
|
type (
|
||||||
|
ntAccessMask = fs.AccessMask
|
||||||
|
ntFileShareMode = fs.FileShareMode
|
||||||
|
ntFileCreationDisposition = fs.NTFileCreationDisposition
|
||||||
|
ntFileOptions = fs.NTCreateOptions
|
||||||
|
)
|
||||||
|
|
||||||
|
type ioStatusBlock struct {
|
||||||
|
Status, Information uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
// typedef struct _OBJECT_ATTRIBUTES {
|
||||||
|
// ULONG Length;
|
||||||
|
// HANDLE RootDirectory;
|
||||||
|
// PUNICODE_STRING ObjectName;
|
||||||
|
// ULONG Attributes;
|
||||||
|
// PVOID SecurityDescriptor;
|
||||||
|
// PVOID SecurityQualityOfService;
|
||||||
|
// } OBJECT_ATTRIBUTES;
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/ntdef/ns-ntdef-_object_attributes
|
||||||
|
type objectAttributes struct {
|
||||||
|
Length uintptr
|
||||||
|
RootDirectory uintptr
|
||||||
|
ObjectName *unicodeString
|
||||||
|
Attributes uintptr
|
||||||
|
SecurityDescriptor *securityDescriptor
|
||||||
|
SecurityQoS uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type unicodeString struct {
|
||||||
|
Length uint16
|
||||||
|
MaximumLength uint16
|
||||||
|
Buffer uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
// typedef struct _SECURITY_DESCRIPTOR {
|
||||||
|
// BYTE Revision;
|
||||||
|
// BYTE Sbz1;
|
||||||
|
// SECURITY_DESCRIPTOR_CONTROL Control;
|
||||||
|
// PSID Owner;
|
||||||
|
// PSID Group;
|
||||||
|
// PACL Sacl;
|
||||||
|
// PACL Dacl;
|
||||||
|
// } SECURITY_DESCRIPTOR, *PISECURITY_DESCRIPTOR;
|
||||||
|
//
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-security_descriptor
|
||||||
|
type securityDescriptor struct {
|
||||||
|
Revision byte
|
||||||
|
Sbz1 byte
|
||||||
|
Control uint16
|
||||||
|
Owner uintptr
|
||||||
|
Group uintptr
|
||||||
|
Sacl uintptr //revive:disable-line:var-naming SACL, not Sacl
|
||||||
|
Dacl uintptr //revive:disable-line:var-naming DACL, not Dacl
|
||||||
|
}
|
||||||
|
|
||||||
|
type ntStatus int32
|
||||||
|
|
||||||
|
func (status ntStatus) Err() error {
|
||||||
|
if status >= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return rtlNtStatusToDosError(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
|
||||||
|
ErrPipeListenerClosed = net.ErrClosed
|
||||||
|
|
||||||
|
errPipeWriteClosed = errors.New("pipe has been closed for write")
|
||||||
|
)
|
||||||
|
|
||||||
|
type win32Pipe struct {
|
||||||
|
*win32File
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ PipeConn = (*win32Pipe)(nil)
|
||||||
|
|
||||||
|
type win32MessageBytePipe struct {
|
||||||
|
win32Pipe
|
||||||
|
writeClosed bool
|
||||||
|
readEOF bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type pipeAddress string
|
||||||
|
|
||||||
|
func (f *win32Pipe) LocalAddr() net.Addr {
|
||||||
|
return pipeAddress(f.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32Pipe) RemoteAddr() net.Addr {
|
||||||
|
return pipeAddress(f.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32Pipe) SetDeadline(t time.Time) error {
|
||||||
|
if err := f.SetReadDeadline(t); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return f.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32Pipe) Disconnect() error {
|
||||||
|
return disconnectNamedPipe(f.win32File.handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||||
|
func (f *win32MessageBytePipe) CloseWrite() error {
|
||||||
|
if f.writeClosed {
|
||||||
|
return errPipeWriteClosed
|
||||||
|
}
|
||||||
|
err := f.win32File.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = f.win32File.Write(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.writeClosed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||||
|
// they are used to implement CloseWrite().
|
||||||
|
func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
|
||||||
|
if f.writeClosed {
|
||||||
|
return 0, errPipeWriteClosed
|
||||||
|
}
|
||||||
|
if len(b) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return f.win32File.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
|
||||||
|
// mode pipe will return io.EOF, as will all subsequent reads.
|
||||||
|
func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
|
||||||
|
if f.readEOF {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n, err := f.win32File.Read(b)
|
||||||
|
if err == io.EOF { //nolint:errorlint
|
||||||
|
// If this was the result of a zero-byte read, then
|
||||||
|
// it is possible that the read was due to a zero-size
|
||||||
|
// message. Since we are simulating CloseWrite with a
|
||||||
|
// zero-byte message, ensure that all future Read() calls
|
||||||
|
// also return EOF.
|
||||||
|
f.readEOF = true
|
||||||
|
} else if err == windows.ERROR_MORE_DATA { //nolint:errorlint // err is Errno
|
||||||
|
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||||
|
// and the message still has more bytes. Treat this as a success, since
|
||||||
|
// this package presents all named pipes as byte streams.
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pipeAddress) Network() string {
|
||||||
|
return "pipe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s pipeAddress) String() string {
|
||||||
|
return string(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
||||||
|
func tryDialPipe(ctx context.Context, path *string, access fs.AccessMask, impLevel PipeImpLevel) (windows.Handle, error) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return windows.Handle(0), ctx.Err()
|
||||||
|
default:
|
||||||
|
h, err := fs.CreateFile(*path,
|
||||||
|
access,
|
||||||
|
0, // mode
|
||||||
|
nil, // security attributes
|
||||||
|
fs.OPEN_EXISTING,
|
||||||
|
fs.FILE_FLAG_OVERLAPPED|fs.SECURITY_SQOS_PRESENT|fs.FileSQSFlag(impLevel),
|
||||||
|
0, // template file handle
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
if err != windows.ERROR_PIPE_BUSY { //nolint:errorlint // err is Errno
|
||||||
|
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||||
|
}
|
||||||
|
// Wait 10 msec and try again. This is a rather simplistic
|
||||||
|
// view, as we always try each 10 milliseconds.
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialPipe connects to a named pipe by path, timing out if the connection
|
||||||
|
// takes longer than the specified duration. If timeout is nil, then we use
|
||||||
|
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
||||||
|
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
||||||
|
var absTimeout time.Time
|
||||||
|
if timeout != nil {
|
||||||
|
absTimeout = time.Now().Add(*timeout)
|
||||||
|
} else {
|
||||||
|
absTimeout = time.Now().Add(2 * time.Second)
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithDeadline(context.Background(), absTimeout)
|
||||||
|
defer cancel()
|
||||||
|
conn, err := DialPipeContext(ctx, path)
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil, ErrTimeout
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
||||||
|
// cancellation or timeout.
|
||||||
|
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
|
||||||
|
return DialPipeAccess(ctx, path, uint32(fs.GENERIC_READ|fs.GENERIC_WRITE))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PipeImpLevel is an enumeration of impersonation levels that may be set
|
||||||
|
// when calling DialPipeAccessImpersonation.
|
||||||
|
type PipeImpLevel uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
PipeImpLevelAnonymous = PipeImpLevel(fs.SECURITY_ANONYMOUS)
|
||||||
|
PipeImpLevelIdentification = PipeImpLevel(fs.SECURITY_IDENTIFICATION)
|
||||||
|
PipeImpLevelImpersonation = PipeImpLevel(fs.SECURITY_IMPERSONATION)
|
||||||
|
PipeImpLevelDelegation = PipeImpLevel(fs.SECURITY_DELEGATION)
|
||||||
|
)
|
||||||
|
|
||||||
|
// DialPipeAccess attempts to connect to a named pipe by `path` with `access` until `ctx`
|
||||||
|
// cancellation or timeout.
|
||||||
|
func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn, error) {
|
||||||
|
return DialPipeAccessImpLevel(ctx, path, access, PipeImpLevelAnonymous)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialPipeAccessImpLevel attempts to connect to a named pipe by `path` with
|
||||||
|
// `access` at `impLevel` until `ctx` cancellation or timeout. The other
|
||||||
|
// DialPipe* implementations use PipeImpLevelAnonymous.
|
||||||
|
func DialPipeAccessImpLevel(ctx context.Context, path string, access uint32, impLevel PipeImpLevel) (net.Conn, error) {
|
||||||
|
var err error
|
||||||
|
var h windows.Handle
|
||||||
|
h, err = tryDialPipe(ctx, &path, fs.AccessMask(access), impLevel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var flags uint32
|
||||||
|
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := makeWin32File(h)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the pipe is in message mode, return a message byte pipe, which
|
||||||
|
// supports CloseWrite().
|
||||||
|
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
|
||||||
|
return &win32MessageBytePipe{
|
||||||
|
win32Pipe: win32Pipe{win32File: f, path: path},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &win32Pipe{win32File: f, path: path}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type acceptResponse struct {
|
||||||
|
f *win32File
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type win32PipeListener struct {
|
||||||
|
firstHandle windows.Handle
|
||||||
|
path string
|
||||||
|
config PipeConfig
|
||||||
|
acceptCh chan (chan acceptResponse)
|
||||||
|
closeCh chan int
|
||||||
|
doneCh chan int
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (windows.Handle, error) {
|
||||||
|
path16, err := windows.UTF16FromString(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
var oa objectAttributes
|
||||||
|
oa.Length = unsafe.Sizeof(oa)
|
||||||
|
|
||||||
|
var ntPath unicodeString
|
||||||
|
if err := rtlDosPathNameToNtPathName(&path16[0],
|
||||||
|
&ntPath,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
).Err(); err != nil {
|
||||||
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
}
|
||||||
|
defer windows.LocalFree(windows.Handle(ntPath.Buffer)) //nolint:errcheck
|
||||||
|
oa.ObjectName = &ntPath
|
||||||
|
oa.Attributes = windows.OBJ_CASE_INSENSITIVE
|
||||||
|
|
||||||
|
// The security descriptor is only needed for the first pipe.
|
||||||
|
if first {
|
||||||
|
if sd != nil {
|
||||||
|
//todo: does `sdb` need to be allocated on the heap, or can go allocate it?
|
||||||
|
l := uint32(len(sd))
|
||||||
|
sdb, err := windows.LocalAlloc(0, l)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("LocalAlloc for security descriptor with of length %d: %w", l, err)
|
||||||
|
}
|
||||||
|
defer windows.LocalFree(windows.Handle(sdb)) //nolint:errcheck
|
||||||
|
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
|
||||||
|
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
|
||||||
|
} else {
|
||||||
|
// Construct the default named pipe security descriptor.
|
||||||
|
var dacl uintptr
|
||||||
|
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
||||||
|
return 0, fmt.Errorf("getting default named pipe ACL: %w", err)
|
||||||
|
}
|
||||||
|
defer windows.LocalFree(windows.Handle(dacl)) //nolint:errcheck
|
||||||
|
|
||||||
|
sdb := &securityDescriptor{
|
||||||
|
Revision: 1,
|
||||||
|
Control: windows.SE_DACL_PRESENT,
|
||||||
|
Dacl: dacl,
|
||||||
|
}
|
||||||
|
oa.SecurityDescriptor = sdb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
|
||||||
|
if c.MessageMode {
|
||||||
|
typ |= windows.FILE_PIPE_MESSAGE_TYPE
|
||||||
|
}
|
||||||
|
|
||||||
|
disposition := fs.FILE_OPEN
|
||||||
|
access := fs.GENERIC_READ | fs.GENERIC_WRITE | fs.SYNCHRONIZE
|
||||||
|
if first {
|
||||||
|
disposition = fs.FILE_CREATE
|
||||||
|
// By not asking for read or write access, the named pipe file system
|
||||||
|
// will put this pipe into an initially disconnected state, blocking
|
||||||
|
// client connections until the next call with first == false.
|
||||||
|
access = fs.SYNCHRONIZE
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := int64(-50 * 10000) // 50ms
|
||||||
|
|
||||||
|
var (
|
||||||
|
h windows.Handle
|
||||||
|
iosb ioStatusBlock
|
||||||
|
)
|
||||||
|
err = ntCreateNamedPipeFile(&h,
|
||||||
|
access,
|
||||||
|
&oa,
|
||||||
|
&iosb,
|
||||||
|
fs.FILE_SHARE_READ|fs.FILE_SHARE_WRITE,
|
||||||
|
disposition,
|
||||||
|
0,
|
||||||
|
typ,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0xffffffff,
|
||||||
|
uint32(c.InputBufferSize),
|
||||||
|
uint32(c.OutputBufferSize),
|
||||||
|
&timeout).Err()
|
||||||
|
if err != nil {
|
||||||
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime.KeepAlive(ntPath)
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
|
||||||
|
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f, err := makeWin32File(h)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
|
||||||
|
p, err := l.makeServerPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the client to connect.
|
||||||
|
ch := make(chan error)
|
||||||
|
go func(p *win32File) {
|
||||||
|
ch <- connectPipe(p)
|
||||||
|
}(p)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-ch:
|
||||||
|
if err != nil {
|
||||||
|
p.Close()
|
||||||
|
p = nil
|
||||||
|
}
|
||||||
|
case <-l.closeCh:
|
||||||
|
// Abort the connect request by closing the handle.
|
||||||
|
p.Close()
|
||||||
|
p = nil
|
||||||
|
err = <-ch
|
||||||
|
if err == nil || err == ErrFileClosed { //nolint:errorlint // err is Errno
|
||||||
|
err = ErrPipeListenerClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) listenerRoutine() {
|
||||||
|
closed := false
|
||||||
|
for !closed {
|
||||||
|
select {
|
||||||
|
case <-l.closeCh:
|
||||||
|
closed = true
|
||||||
|
case responseCh := <-l.acceptCh:
|
||||||
|
var (
|
||||||
|
p *win32File
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
p, err = l.makeConnectedServerPipe()
|
||||||
|
// If the connection was immediately closed by the client, try
|
||||||
|
// again.
|
||||||
|
if err != windows.ERROR_NO_DATA { //nolint:errorlint // err is Errno
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responseCh <- acceptResponse{p, err}
|
||||||
|
closed = err == ErrPipeListenerClosed //nolint:errorlint // err is Errno
|
||||||
|
}
|
||||||
|
}
|
||||||
|
windows.Close(l.firstHandle)
|
||||||
|
l.firstHandle = 0
|
||||||
|
// Notify Close() and Accept() callers that the handle has been closed.
|
||||||
|
close(l.doneCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PipeConfig contain configuration for the pipe listener.
|
||||||
|
type PipeConfig struct {
|
||||||
|
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
|
||||||
|
SecurityDescriptor string
|
||||||
|
|
||||||
|
// MessageMode determines whether the pipe is in byte or message mode. In either
|
||||||
|
// case the pipe is read in byte mode by default. The only practical difference in
|
||||||
|
// this implementation is that CloseWrite() is only supported for message mode pipes;
|
||||||
|
// CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
|
||||||
|
// transferred to the reader (and returned as io.EOF in this implementation)
|
||||||
|
// when the pipe is in message mode.
|
||||||
|
MessageMode bool
|
||||||
|
|
||||||
|
// InputBufferSize specifies the size of the input buffer, in bytes.
|
||||||
|
InputBufferSize int32
|
||||||
|
|
||||||
|
// OutputBufferSize specifies the size of the output buffer, in bytes.
|
||||||
|
OutputBufferSize int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
||||||
|
// The pipe must not already exist.
|
||||||
|
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
||||||
|
var (
|
||||||
|
sd []byte
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if c == nil {
|
||||||
|
c = &PipeConfig{}
|
||||||
|
}
|
||||||
|
if c.SecurityDescriptor != "" {
|
||||||
|
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h, err := makeServerPipeHandle(path, sd, c, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l := &win32PipeListener{
|
||||||
|
firstHandle: h,
|
||||||
|
path: path,
|
||||||
|
config: *c,
|
||||||
|
acceptCh: make(chan (chan acceptResponse)),
|
||||||
|
closeCh: make(chan int),
|
||||||
|
doneCh: make(chan int),
|
||||||
|
}
|
||||||
|
go l.listenerRoutine()
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectPipe(p *win32File) error {
|
||||||
|
c, err := p.prepareIO()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer p.wg.Done()
|
||||||
|
|
||||||
|
err = connectNamedPipe(p.handle, &c.o)
|
||||||
|
_, err = p.asyncIO(c, nil, 0, err)
|
||||||
|
if err != nil && err != windows.ERROR_PIPE_CONNECTED { //nolint:errorlint // err is Errno
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) Accept() (net.Conn, error) {
|
||||||
|
ch := make(chan acceptResponse)
|
||||||
|
select {
|
||||||
|
case l.acceptCh <- ch:
|
||||||
|
response := <-ch
|
||||||
|
err := response.err
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if l.config.MessageMode {
|
||||||
|
return &win32MessageBytePipe{
|
||||||
|
win32Pipe: win32Pipe{win32File: response.f, path: l.path},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &win32Pipe{win32File: response.f, path: l.path}, nil
|
||||||
|
case <-l.doneCh:
|
||||||
|
return nil, ErrPipeListenerClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) Close() error {
|
||||||
|
select {
|
||||||
|
case l.closeCh <- 1:
|
||||||
|
<-l.doneCh
|
||||||
|
case <-l.doneCh:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) Addr() net.Addr {
|
||||||
|
return pipeAddress(l.path)
|
||||||
|
}
|
||||||
232
vendor/github.com/Microsoft/go-winio/pkg/guid/guid.go
generated
vendored
Normal file
232
vendor/github.com/Microsoft/go-winio/pkg/guid/guid.go
generated
vendored
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
// Package guid provides a GUID type. The backing structure for a GUID is
|
||||||
|
// identical to that used by the golang.org/x/sys/windows GUID type.
|
||||||
|
// There are two main binary encodings used for a GUID, the big-endian encoding,
|
||||||
|
// and the Windows (mixed-endian) encoding. See here for details:
|
||||||
|
// https://en.wikipedia.org/wiki/Universally_unique_identifier#Encoding
|
||||||
|
package guid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha1" //nolint:gosec // not used for secure application
|
||||||
|
"encoding"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:generate go run golang.org/x/tools/cmd/stringer -type=Variant -trimprefix=Variant -linecomment
|
||||||
|
|
||||||
|
// Variant specifies which GUID variant (or "type") of the GUID. It determines
|
||||||
|
// how the entirety of the rest of the GUID is interpreted.
|
||||||
|
type Variant uint8
|
||||||
|
|
||||||
|
// The variants specified by RFC 4122 section 4.1.1.
|
||||||
|
const (
|
||||||
|
// VariantUnknown specifies a GUID variant which does not conform to one of
|
||||||
|
// the variant encodings specified in RFC 4122.
|
||||||
|
VariantUnknown Variant = iota
|
||||||
|
VariantNCS
|
||||||
|
VariantRFC4122 // RFC 4122
|
||||||
|
VariantMicrosoft
|
||||||
|
VariantFuture
|
||||||
|
)
|
||||||
|
|
||||||
|
// Version specifies how the bits in the GUID were generated. For instance, a
|
||||||
|
// version 4 GUID is randomly generated, and a version 5 is generated from the
|
||||||
|
// hash of an input string.
|
||||||
|
type Version uint8
|
||||||
|
|
||||||
|
func (v Version) String() string {
|
||||||
|
return strconv.FormatUint(uint64(v), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = (encoding.TextMarshaler)(GUID{})
|
||||||
|
var _ = (encoding.TextUnmarshaler)(&GUID{})
|
||||||
|
|
||||||
|
// NewV4 returns a new version 4 (pseudorandom) GUID, as defined by RFC 4122.
|
||||||
|
func NewV4() (GUID, error) {
|
||||||
|
var b [16]byte
|
||||||
|
if _, err := rand.Read(b[:]); err != nil {
|
||||||
|
return GUID{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
g := FromArray(b)
|
||||||
|
g.setVersion(4) // Version 4 means randomly generated.
|
||||||
|
g.setVariant(VariantRFC4122)
|
||||||
|
|
||||||
|
return g, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewV5 returns a new version 5 (generated from a string via SHA-1 hashing)
|
||||||
|
// GUID, as defined by RFC 4122. The RFC is unclear on the encoding of the name,
|
||||||
|
// and the sample code treats it as a series of bytes, so we do the same here.
|
||||||
|
//
|
||||||
|
// Some implementations, such as those found on Windows, treat the name as a
|
||||||
|
// big-endian UTF16 stream of bytes. If that is desired, the string can be
|
||||||
|
// encoded as such before being passed to this function.
|
||||||
|
func NewV5(namespace GUID, name []byte) (GUID, error) {
|
||||||
|
b := sha1.New() //nolint:gosec // not used for secure application
|
||||||
|
namespaceBytes := namespace.ToArray()
|
||||||
|
b.Write(namespaceBytes[:])
|
||||||
|
b.Write(name)
|
||||||
|
|
||||||
|
a := [16]byte{}
|
||||||
|
copy(a[:], b.Sum(nil))
|
||||||
|
|
||||||
|
g := FromArray(a)
|
||||||
|
g.setVersion(5) // Version 5 means generated from a string.
|
||||||
|
g.setVariant(VariantRFC4122)
|
||||||
|
|
||||||
|
return g, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromArray(b [16]byte, order binary.ByteOrder) GUID {
|
||||||
|
var g GUID
|
||||||
|
g.Data1 = order.Uint32(b[0:4])
|
||||||
|
g.Data2 = order.Uint16(b[4:6])
|
||||||
|
g.Data3 = order.Uint16(b[6:8])
|
||||||
|
copy(g.Data4[:], b[8:16])
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g GUID) toArray(order binary.ByteOrder) [16]byte {
|
||||||
|
b := [16]byte{}
|
||||||
|
order.PutUint32(b[0:4], g.Data1)
|
||||||
|
order.PutUint16(b[4:6], g.Data2)
|
||||||
|
order.PutUint16(b[6:8], g.Data3)
|
||||||
|
copy(b[8:16], g.Data4[:])
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromArray constructs a GUID from a big-endian encoding array of 16 bytes.
|
||||||
|
func FromArray(b [16]byte) GUID {
|
||||||
|
return fromArray(b, binary.BigEndian)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToArray returns an array of 16 bytes representing the GUID in big-endian
|
||||||
|
// encoding.
|
||||||
|
func (g GUID) ToArray() [16]byte {
|
||||||
|
return g.toArray(binary.BigEndian)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromWindowsArray constructs a GUID from a Windows encoding array of bytes.
|
||||||
|
func FromWindowsArray(b [16]byte) GUID {
|
||||||
|
return fromArray(b, binary.LittleEndian)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToWindowsArray returns an array of 16 bytes representing the GUID in Windows
|
||||||
|
// encoding.
|
||||||
|
func (g GUID) ToWindowsArray() [16]byte {
|
||||||
|
return g.toArray(binary.LittleEndian)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g GUID) String() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"%08x-%04x-%04x-%04x-%012x",
|
||||||
|
g.Data1,
|
||||||
|
g.Data2,
|
||||||
|
g.Data3,
|
||||||
|
g.Data4[:2],
|
||||||
|
g.Data4[2:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromString parses a string containing a GUID and returns the GUID. The only
|
||||||
|
// format currently supported is the `xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx`
|
||||||
|
// format.
|
||||||
|
func FromString(s string) (GUID, error) {
|
||||||
|
if len(s) != 36 {
|
||||||
|
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||||
|
}
|
||||||
|
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
|
||||||
|
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
var g GUID
|
||||||
|
|
||||||
|
data1, err := strconv.ParseUint(s[0:8], 16, 32)
|
||||||
|
if err != nil {
|
||||||
|
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||||
|
}
|
||||||
|
g.Data1 = uint32(data1)
|
||||||
|
|
||||||
|
data2, err := strconv.ParseUint(s[9:13], 16, 16)
|
||||||
|
if err != nil {
|
||||||
|
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||||
|
}
|
||||||
|
g.Data2 = uint16(data2)
|
||||||
|
|
||||||
|
data3, err := strconv.ParseUint(s[14:18], 16, 16)
|
||||||
|
if err != nil {
|
||||||
|
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||||
|
}
|
||||||
|
g.Data3 = uint16(data3)
|
||||||
|
|
||||||
|
for i, x := range []int{19, 21, 24, 26, 28, 30, 32, 34} {
|
||||||
|
v, err := strconv.ParseUint(s[x:x+2], 16, 8)
|
||||||
|
if err != nil {
|
||||||
|
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||||
|
}
|
||||||
|
g.Data4[i] = uint8(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return g, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GUID) setVariant(v Variant) {
|
||||||
|
d := g.Data4[0]
|
||||||
|
switch v {
|
||||||
|
case VariantNCS:
|
||||||
|
d = (d & 0x7f)
|
||||||
|
case VariantRFC4122:
|
||||||
|
d = (d & 0x3f) | 0x80
|
||||||
|
case VariantMicrosoft:
|
||||||
|
d = (d & 0x1f) | 0xc0
|
||||||
|
case VariantFuture:
|
||||||
|
d = (d & 0x0f) | 0xe0
|
||||||
|
case VariantUnknown:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("invalid variant: %d", v))
|
||||||
|
}
|
||||||
|
g.Data4[0] = d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Variant returns the GUID variant, as defined in RFC 4122.
|
||||||
|
func (g GUID) Variant() Variant {
|
||||||
|
b := g.Data4[0]
|
||||||
|
if b&0x80 == 0 {
|
||||||
|
return VariantNCS
|
||||||
|
} else if b&0xc0 == 0x80 {
|
||||||
|
return VariantRFC4122
|
||||||
|
} else if b&0xe0 == 0xc0 {
|
||||||
|
return VariantMicrosoft
|
||||||
|
} else if b&0xe0 == 0xe0 {
|
||||||
|
return VariantFuture
|
||||||
|
}
|
||||||
|
return VariantUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GUID) setVersion(v Version) {
|
||||||
|
g.Data3 = (g.Data3 & 0x0fff) | (uint16(v) << 12)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version returns the GUID version, as defined in RFC 4122.
|
||||||
|
func (g GUID) Version() Version {
|
||||||
|
return Version((g.Data3 & 0xF000) >> 12)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalText returns the textual representation of the GUID.
|
||||||
|
func (g GUID) MarshalText() ([]byte, error) {
|
||||||
|
return []byte(g.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalText takes the textual representation of a GUID, and unmarhals it
|
||||||
|
// into this GUID.
|
||||||
|
func (g *GUID) UnmarshalText(text []byte) error {
|
||||||
|
g2, err := FromString(string(text))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*g = g2
|
||||||
|
return nil
|
||||||
|
}
|
||||||
16
vendor/github.com/Microsoft/go-winio/pkg/guid/guid_nonwindows.go
generated
vendored
Normal file
16
vendor/github.com/Microsoft/go-winio/pkg/guid/guid_nonwindows.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package guid
|
||||||
|
|
||||||
|
// GUID represents a GUID/UUID. It has the same structure as
|
||||||
|
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
|
||||||
|
// that type. It is defined as its own type as that is only available to builds
|
||||||
|
// targeted at `windows`. The representation matches that used by native Windows
|
||||||
|
// code.
|
||||||
|
type GUID struct {
|
||||||
|
Data1 uint32
|
||||||
|
Data2 uint16
|
||||||
|
Data3 uint16
|
||||||
|
Data4 [8]byte
|
||||||
|
}
|
||||||
13
vendor/github.com/Microsoft/go-winio/pkg/guid/guid_windows.go
generated
vendored
Normal file
13
vendor/github.com/Microsoft/go-winio/pkg/guid/guid_windows.go
generated
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package guid
|
||||||
|
|
||||||
|
import "golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
// GUID represents a GUID/UUID. It has the same structure as
|
||||||
|
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
|
||||||
|
// that type. It is defined as its own type so that stringification and
|
||||||
|
// marshaling can be supported. The representation matches that used by native
|
||||||
|
// Windows code.
|
||||||
|
type GUID windows.GUID
|
||||||
27
vendor/github.com/Microsoft/go-winio/pkg/guid/variant_string.go
generated
vendored
Normal file
27
vendor/github.com/Microsoft/go-winio/pkg/guid/variant_string.go
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
// Code generated by "stringer -type=Variant -trimprefix=Variant -linecomment"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package guid
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
func _() {
|
||||||
|
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||||
|
// Re-run the stringer command to generate them again.
|
||||||
|
var x [1]struct{}
|
||||||
|
_ = x[VariantUnknown-0]
|
||||||
|
_ = x[VariantNCS-1]
|
||||||
|
_ = x[VariantRFC4122-2]
|
||||||
|
_ = x[VariantMicrosoft-3]
|
||||||
|
_ = x[VariantFuture-4]
|
||||||
|
}
|
||||||
|
|
||||||
|
const _Variant_name = "UnknownNCSRFC 4122MicrosoftFuture"
|
||||||
|
|
||||||
|
var _Variant_index = [...]uint8{0, 7, 10, 18, 27, 33}
|
||||||
|
|
||||||
|
func (i Variant) String() string {
|
||||||
|
if i >= Variant(len(_Variant_index)-1) {
|
||||||
|
return "Variant(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||||
|
}
|
||||||
|
return _Variant_name[_Variant_index[i]:_Variant_index[i+1]]
|
||||||
|
}
|
||||||
196
vendor/github.com/Microsoft/go-winio/privilege.go
generated
vendored
Normal file
196
vendor/github.com/Microsoft/go-winio/privilege.go
generated
vendored
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"unicode/utf16"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
//sys adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) [true] = advapi32.AdjustTokenPrivileges
|
||||||
|
//sys impersonateSelf(level uint32) (err error) = advapi32.ImpersonateSelf
|
||||||
|
//sys revertToSelf() (err error) = advapi32.RevertToSelf
|
||||||
|
//sys openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) = advapi32.OpenThreadToken
|
||||||
|
//sys getCurrentThread() (h windows.Handle) = GetCurrentThread
|
||||||
|
//sys lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) = advapi32.LookupPrivilegeValueW
|
||||||
|
//sys lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) = advapi32.LookupPrivilegeNameW
|
||||||
|
//sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW
|
||||||
|
|
||||||
|
const (
|
||||||
|
//revive:disable-next-line:var-naming ALL_CAPS
|
||||||
|
SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED
|
||||||
|
|
||||||
|
//revive:disable-next-line:var-naming ALL_CAPS
|
||||||
|
ERROR_NOT_ALL_ASSIGNED windows.Errno = windows.ERROR_NOT_ALL_ASSIGNED
|
||||||
|
|
||||||
|
SeBackupPrivilege = "SeBackupPrivilege"
|
||||||
|
SeRestorePrivilege = "SeRestorePrivilege"
|
||||||
|
SeSecurityPrivilege = "SeSecurityPrivilege"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
privNames = make(map[string]uint64)
|
||||||
|
privNameMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrivilegeError represents an error enabling privileges.
|
||||||
|
type PrivilegeError struct {
|
||||||
|
privileges []uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *PrivilegeError) Error() string {
|
||||||
|
s := "Could not enable privilege "
|
||||||
|
if len(e.privileges) > 1 {
|
||||||
|
s = "Could not enable privileges "
|
||||||
|
}
|
||||||
|
for i, p := range e.privileges {
|
||||||
|
if i != 0 {
|
||||||
|
s += ", "
|
||||||
|
}
|
||||||
|
s += `"`
|
||||||
|
s += getPrivilegeName(p)
|
||||||
|
s += `"`
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunWithPrivilege enables a single privilege for a function call.
|
||||||
|
func RunWithPrivilege(name string, fn func() error) error {
|
||||||
|
return RunWithPrivileges([]string{name}, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunWithPrivileges enables privileges for a function call.
|
||||||
|
func RunWithPrivileges(names []string, fn func() error) error {
|
||||||
|
privileges, err := mapPrivileges(names)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
token, err := newThreadToken()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer releaseThreadToken(token)
|
||||||
|
err = adjustPrivileges(token, privileges, SE_PRIVILEGE_ENABLED)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapPrivileges(names []string) ([]uint64, error) {
|
||||||
|
privileges := make([]uint64, 0, len(names))
|
||||||
|
privNameMutex.Lock()
|
||||||
|
defer privNameMutex.Unlock()
|
||||||
|
for _, name := range names {
|
||||||
|
p, ok := privNames[name]
|
||||||
|
if !ok {
|
||||||
|
err := lookupPrivilegeValue("", name, &p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
privNames[name] = p
|
||||||
|
}
|
||||||
|
privileges = append(privileges, p)
|
||||||
|
}
|
||||||
|
return privileges, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableProcessPrivileges enables privileges globally for the process.
|
||||||
|
func EnableProcessPrivileges(names []string) error {
|
||||||
|
return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableProcessPrivileges disables privileges globally for the process.
|
||||||
|
func DisableProcessPrivileges(names []string) error {
|
||||||
|
return enableDisableProcessPrivilege(names, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func enableDisableProcessPrivilege(names []string, action uint32) error {
|
||||||
|
privileges, err := mapPrivileges(names)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p := windows.CurrentProcess()
|
||||||
|
var token windows.Token
|
||||||
|
err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer token.Close()
|
||||||
|
return adjustPrivileges(token, privileges, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error {
|
||||||
|
var b bytes.Buffer
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, uint32(len(privileges)))
|
||||||
|
for _, p := range privileges {
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, p)
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, action)
|
||||||
|
}
|
||||||
|
prevState := make([]byte, b.Len())
|
||||||
|
reqSize := uint32(0)
|
||||||
|
success, err := adjustTokenPrivileges(token, false, &b.Bytes()[0], uint32(len(prevState)), &prevState[0], &reqSize)
|
||||||
|
if !success {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno
|
||||||
|
return &PrivilegeError{privileges}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPrivilegeName(luid uint64) string {
|
||||||
|
var nameBuffer [256]uint16
|
||||||
|
bufSize := uint32(len(nameBuffer))
|
||||||
|
err := lookupPrivilegeName("", &luid, &nameBuffer[0], &bufSize)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("<unknown privilege %d>", luid)
|
||||||
|
}
|
||||||
|
|
||||||
|
var displayNameBuffer [256]uint16
|
||||||
|
displayBufSize := uint32(len(displayNameBuffer))
|
||||||
|
var langID uint32
|
||||||
|
err = lookupPrivilegeDisplayName("", &nameBuffer[0], &displayNameBuffer[0], &displayBufSize, &langID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("<unknown privilege %s>", string(utf16.Decode(nameBuffer[:bufSize])))
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(utf16.Decode(displayNameBuffer[:displayBufSize]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newThreadToken() (windows.Token, error) {
|
||||||
|
err := impersonateSelf(windows.SecurityImpersonation)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var token windows.Token
|
||||||
|
err = openThreadToken(getCurrentThread(), windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, false, &token)
|
||||||
|
if err != nil {
|
||||||
|
rerr := revertToSelf()
|
||||||
|
if rerr != nil {
|
||||||
|
panic(rerr)
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func releaseThreadToken(h windows.Token) {
|
||||||
|
err := revertToSelf()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
h.Close()
|
||||||
|
}
|
||||||
131
vendor/github.com/Microsoft/go-winio/reparse.go
generated
vendored
Normal file
131
vendor/github.com/Microsoft/go-winio/reparse.go
generated
vendored
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf16"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
reparseTagMountPoint = 0xA0000003
|
||||||
|
reparseTagSymlink = 0xA000000C
|
||||||
|
)
|
||||||
|
|
||||||
|
type reparseDataBuffer struct {
|
||||||
|
ReparseTag uint32
|
||||||
|
ReparseDataLength uint16
|
||||||
|
Reserved uint16
|
||||||
|
SubstituteNameOffset uint16
|
||||||
|
SubstituteNameLength uint16
|
||||||
|
PrintNameOffset uint16
|
||||||
|
PrintNameLength uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReparsePoint describes a Win32 symlink or mount point.
|
||||||
|
type ReparsePoint struct {
|
||||||
|
Target string
|
||||||
|
IsMountPoint bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsupportedReparsePointError is returned when trying to decode a non-symlink or
|
||||||
|
// mount point reparse point.
|
||||||
|
type UnsupportedReparsePointError struct {
|
||||||
|
Tag uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UnsupportedReparsePointError) Error() string {
|
||||||
|
return fmt.Sprintf("unsupported reparse point %x", e.Tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeReparsePoint decodes a Win32 REPARSE_DATA_BUFFER structure containing either a symlink
|
||||||
|
// or a mount point.
|
||||||
|
func DecodeReparsePoint(b []byte) (*ReparsePoint, error) {
|
||||||
|
tag := binary.LittleEndian.Uint32(b[0:4])
|
||||||
|
return DecodeReparsePointData(tag, b[8:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeReparsePointData(tag uint32, b []byte) (*ReparsePoint, error) {
|
||||||
|
isMountPoint := false
|
||||||
|
switch tag {
|
||||||
|
case reparseTagMountPoint:
|
||||||
|
isMountPoint = true
|
||||||
|
case reparseTagSymlink:
|
||||||
|
default:
|
||||||
|
return nil, &UnsupportedReparsePointError{tag}
|
||||||
|
}
|
||||||
|
nameOffset := 8 + binary.LittleEndian.Uint16(b[4:6])
|
||||||
|
if !isMountPoint {
|
||||||
|
nameOffset += 4
|
||||||
|
}
|
||||||
|
nameLength := binary.LittleEndian.Uint16(b[6:8])
|
||||||
|
name := make([]uint16, nameLength/2)
|
||||||
|
err := binary.Read(bytes.NewReader(b[nameOffset:nameOffset+nameLength]), binary.LittleEndian, &name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ReparsePoint{string(utf16.Decode(name)), isMountPoint}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDriveLetter(c byte) bool {
|
||||||
|
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeReparsePoint encodes a Win32 REPARSE_DATA_BUFFER structure describing a symlink or
|
||||||
|
// mount point.
|
||||||
|
func EncodeReparsePoint(rp *ReparsePoint) []byte {
|
||||||
|
// Generate an NT path and determine if this is a relative path.
|
||||||
|
var ntTarget string
|
||||||
|
relative := false
|
||||||
|
if strings.HasPrefix(rp.Target, `\\?\`) {
|
||||||
|
ntTarget = `\??\` + rp.Target[4:]
|
||||||
|
} else if strings.HasPrefix(rp.Target, `\\`) {
|
||||||
|
ntTarget = `\??\UNC\` + rp.Target[2:]
|
||||||
|
} else if len(rp.Target) >= 2 && isDriveLetter(rp.Target[0]) && rp.Target[1] == ':' {
|
||||||
|
ntTarget = `\??\` + rp.Target
|
||||||
|
} else {
|
||||||
|
ntTarget = rp.Target
|
||||||
|
relative = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// The paths must be NUL-terminated even though they are counted strings.
|
||||||
|
target16 := utf16.Encode([]rune(rp.Target + "\x00"))
|
||||||
|
ntTarget16 := utf16.Encode([]rune(ntTarget + "\x00"))
|
||||||
|
|
||||||
|
size := int(unsafe.Sizeof(reparseDataBuffer{})) - 8
|
||||||
|
size += len(ntTarget16)*2 + len(target16)*2
|
||||||
|
|
||||||
|
tag := uint32(reparseTagMountPoint)
|
||||||
|
if !rp.IsMountPoint {
|
||||||
|
tag = reparseTagSymlink
|
||||||
|
size += 4 // Add room for symlink flags
|
||||||
|
}
|
||||||
|
|
||||||
|
data := reparseDataBuffer{
|
||||||
|
ReparseTag: tag,
|
||||||
|
ReparseDataLength: uint16(size),
|
||||||
|
SubstituteNameOffset: 0,
|
||||||
|
SubstituteNameLength: uint16((len(ntTarget16) - 1) * 2),
|
||||||
|
PrintNameOffset: uint16(len(ntTarget16) * 2),
|
||||||
|
PrintNameLength: uint16((len(target16) - 1) * 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, &data)
|
||||||
|
if !rp.IsMountPoint {
|
||||||
|
flags := uint32(0)
|
||||||
|
if relative {
|
||||||
|
flags |= 1
|
||||||
|
}
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, ntTarget16)
|
||||||
|
_ = binary.Write(&b, binary.LittleEndian, target16)
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
||||||
133
vendor/github.com/Microsoft/go-winio/sd.go
generated
vendored
Normal file
133
vendor/github.com/Microsoft/go-winio/sd.go
generated
vendored
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
//sys lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountNameW
|
||||||
|
//sys lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountSidW
|
||||||
|
//sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW
|
||||||
|
//sys convertStringSidToSid(str *uint16, sid **byte) (err error) = advapi32.ConvertStringSidToSidW
|
||||||
|
|
||||||
|
type AccountLookupError struct {
|
||||||
|
Name string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AccountLookupError) Error() string {
|
||||||
|
if e.Name == "" {
|
||||||
|
return "lookup account: empty account name specified"
|
||||||
|
}
|
||||||
|
var s string
|
||||||
|
switch {
|
||||||
|
case errors.Is(e.Err, windows.ERROR_INVALID_SID):
|
||||||
|
s = "the security ID structure is invalid"
|
||||||
|
case errors.Is(e.Err, windows.ERROR_NONE_MAPPED):
|
||||||
|
s = "not found"
|
||||||
|
default:
|
||||||
|
s = e.Err.Error()
|
||||||
|
}
|
||||||
|
return "lookup account " + e.Name + ": " + s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AccountLookupError) Unwrap() error { return e.Err }
|
||||||
|
|
||||||
|
type SddlConversionError struct {
|
||||||
|
Sddl string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SddlConversionError) Error() string {
|
||||||
|
return "convert " + e.Sddl + ": " + e.Err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SddlConversionError) Unwrap() error { return e.Err }
|
||||||
|
|
||||||
|
// LookupSidByName looks up the SID of an account by name
|
||||||
|
//
|
||||||
|
//revive:disable-next-line:var-naming SID, not Sid
|
||||||
|
func LookupSidByName(name string) (sid string, err error) {
|
||||||
|
if name == "" {
|
||||||
|
return "", &AccountLookupError{name, windows.ERROR_NONE_MAPPED}
|
||||||
|
}
|
||||||
|
|
||||||
|
var sidSize, sidNameUse, refDomainSize uint32
|
||||||
|
err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse)
|
||||||
|
if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
|
||||||
|
return "", &AccountLookupError{name, err}
|
||||||
|
}
|
||||||
|
sidBuffer := make([]byte, sidSize)
|
||||||
|
refDomainBuffer := make([]uint16, refDomainSize)
|
||||||
|
err = lookupAccountName(nil, name, &sidBuffer[0], &sidSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse)
|
||||||
|
if err != nil {
|
||||||
|
return "", &AccountLookupError{name, err}
|
||||||
|
}
|
||||||
|
var strBuffer *uint16
|
||||||
|
err = convertSidToStringSid(&sidBuffer[0], &strBuffer)
|
||||||
|
if err != nil {
|
||||||
|
return "", &AccountLookupError{name, err}
|
||||||
|
}
|
||||||
|
sid = windows.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(strBuffer))[:])
|
||||||
|
_, _ = windows.LocalFree(windows.Handle(unsafe.Pointer(strBuffer)))
|
||||||
|
return sid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupNameBySid looks up the name of an account by SID
|
||||||
|
//
|
||||||
|
//revive:disable-next-line:var-naming SID, not Sid
|
||||||
|
func LookupNameBySid(sid string) (name string, err error) {
|
||||||
|
if sid == "" {
|
||||||
|
return "", &AccountLookupError{sid, windows.ERROR_NONE_MAPPED}
|
||||||
|
}
|
||||||
|
|
||||||
|
sidBuffer, err := windows.UTF16PtrFromString(sid)
|
||||||
|
if err != nil {
|
||||||
|
return "", &AccountLookupError{sid, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
var sidPtr *byte
|
||||||
|
if err = convertStringSidToSid(sidBuffer, &sidPtr); err != nil {
|
||||||
|
return "", &AccountLookupError{sid, err}
|
||||||
|
}
|
||||||
|
defer windows.LocalFree(windows.Handle(unsafe.Pointer(sidPtr))) //nolint:errcheck
|
||||||
|
|
||||||
|
var nameSize, refDomainSize, sidNameUse uint32
|
||||||
|
err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse)
|
||||||
|
if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
|
||||||
|
return "", &AccountLookupError{sid, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
nameBuffer := make([]uint16, nameSize)
|
||||||
|
refDomainBuffer := make([]uint16, refDomainSize)
|
||||||
|
err = lookupAccountSid(nil, sidPtr, &nameBuffer[0], &nameSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse)
|
||||||
|
if err != nil {
|
||||||
|
return "", &AccountLookupError{sid, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
name = windows.UTF16ToString(nameBuffer)
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
||||||
|
sd, err := windows.SecurityDescriptorFromString(sddl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &SddlConversionError{Sddl: sddl, Err: err}
|
||||||
|
}
|
||||||
|
b := unsafe.Slice((*byte)(unsafe.Pointer(sd)), sd.Length())
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SecurityDescriptorToSddl(sd []byte) (string, error) {
|
||||||
|
if l := int(unsafe.Sizeof(windows.SECURITY_DESCRIPTOR{})); len(sd) < l {
|
||||||
|
return "", fmt.Errorf("SecurityDescriptor (%d) smaller than expected (%d): %w", len(sd), l, windows.ERROR_INCORRECT_SIZE)
|
||||||
|
}
|
||||||
|
s := (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(&sd[0]))
|
||||||
|
return s.String(), nil
|
||||||
|
}
|
||||||
5
vendor/github.com/Microsoft/go-winio/syscall.go
generated
vendored
Normal file
5
vendor/github.com/Microsoft/go-winio/syscall.go
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go
|
||||||
378
vendor/github.com/Microsoft/go-winio/zsyscall_windows.go
generated
vendored
Normal file
378
vendor/github.com/Microsoft/go-winio/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package winio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ unsafe.Pointer
|
||||||
|
|
||||||
|
// Do the interface allocations only once for common
|
||||||
|
// Errno values.
|
||||||
|
const (
|
||||||
|
errnoERROR_IO_PENDING = 997
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||||
|
errERROR_EINVAL error = syscall.EINVAL
|
||||||
|
)
|
||||||
|
|
||||||
|
// errnoErr returns common boxed Errno values, to prevent
|
||||||
|
// allocations at runtime.
|
||||||
|
func errnoErr(e syscall.Errno) error {
|
||||||
|
switch e {
|
||||||
|
case 0:
|
||||||
|
return errERROR_EINVAL
|
||||||
|
case errnoERROR_IO_PENDING:
|
||||||
|
return errERROR_IO_PENDING
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||||
|
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
|
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
||||||
|
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||||
|
|
||||||
|
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
|
||||||
|
procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW")
|
||||||
|
procConvertStringSidToSidW = modadvapi32.NewProc("ConvertStringSidToSidW")
|
||||||
|
procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf")
|
||||||
|
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
|
||||||
|
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
|
||||||
|
procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW")
|
||||||
|
procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW")
|
||||||
|
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
|
||||||
|
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
|
||||||
|
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
|
||||||
|
procBackupRead = modkernel32.NewProc("BackupRead")
|
||||||
|
procBackupWrite = modkernel32.NewProc("BackupWrite")
|
||||||
|
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
||||||
|
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
||||||
|
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
||||||
|
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
|
||||||
|
procDisconnectNamedPipe = modkernel32.NewProc("DisconnectNamedPipe")
|
||||||
|
procGetCurrentThread = modkernel32.NewProc("GetCurrentThread")
|
||||||
|
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
|
||||||
|
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
|
||||||
|
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
||||||
|
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
|
||||||
|
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
|
||||||
|
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
||||||
|
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
||||||
|
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
||||||
|
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
||||||
|
)
|
||||||
|
|
||||||
|
func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) {
|
||||||
|
var _p0 uint32
|
||||||
|
if releaseAll {
|
||||||
|
_p0 = 1
|
||||||
|
}
|
||||||
|
r0, _, e1 := syscall.SyscallN(procAdjustTokenPrivileges.Addr(), uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize)))
|
||||||
|
success = r0 != 0
|
||||||
|
if true {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertSidToStringSid(sid *byte, str **uint16) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procConvertSidToStringSidW.Addr(), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(str)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertStringSidToSid(str *uint16, sid **byte) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procConvertStringSidToSidW.Addr(), uintptr(unsafe.Pointer(str)), uintptr(unsafe.Pointer(sid)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func impersonateSelf(level uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procImpersonateSelf.Addr(), uintptr(level))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||||
|
var _p0 *uint16
|
||||||
|
_p0, err = syscall.UTF16PtrFromString(accountName)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return _lookupAccountName(systemName, _p0, sid, sidSize, refDomain, refDomainSize, sidNameUse)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procLookupAccountNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procLookupAccountSidW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
|
||||||
|
var _p0 *uint16
|
||||||
|
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return _lookupPrivilegeDisplayName(_p0, name, buffer, size, languageId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procLookupPrivilegeDisplayNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageId)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) {
|
||||||
|
var _p0 *uint16
|
||||||
|
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return _lookupPrivilegeName(_p0, luid, buffer, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procLookupPrivilegeNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) {
|
||||||
|
var _p0 *uint16
|
||||||
|
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var _p1 *uint16
|
||||||
|
_p1, err = syscall.UTF16PtrFromString(name)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return _lookupPrivilegeValue(_p0, _p1, luid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _lookupPrivilegeValue(systemName *uint16, name *uint16, luid *uint64) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procLookupPrivilegeValueW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(luid)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) {
|
||||||
|
var _p0 uint32
|
||||||
|
if openAsSelf {
|
||||||
|
_p0 = 1
|
||||||
|
}
|
||||||
|
r1, _, e1 := syscall.SyscallN(procOpenThreadToken.Addr(), uintptr(thread), uintptr(accessMask), uintptr(_p0), uintptr(unsafe.Pointer(token)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func revertToSelf() (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procRevertToSelf.Addr())
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
|
||||||
|
var _p0 *byte
|
||||||
|
if len(b) > 0 {
|
||||||
|
_p0 = &b[0]
|
||||||
|
}
|
||||||
|
var _p1 uint32
|
||||||
|
if abort {
|
||||||
|
_p1 = 1
|
||||||
|
}
|
||||||
|
var _p2 uint32
|
||||||
|
if processSecurity {
|
||||||
|
_p2 = 1
|
||||||
|
}
|
||||||
|
r1, _, e1 := syscall.SyscallN(procBackupRead.Addr(), uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesRead)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
|
||||||
|
var _p0 *byte
|
||||||
|
if len(b) > 0 {
|
||||||
|
_p0 = &b[0]
|
||||||
|
}
|
||||||
|
var _p1 uint32
|
||||||
|
if abort {
|
||||||
|
_p1 = 1
|
||||||
|
}
|
||||||
|
var _p2 uint32
|
||||||
|
if processSecurity {
|
||||||
|
_p2 = 1
|
||||||
|
}
|
||||||
|
r1, _, e1 := syscall.SyscallN(procBackupWrite.Addr(), uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesWritten)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procCancelIoEx.Addr(), uintptr(file), uintptr(unsafe.Pointer(o)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procConnectNamedPipe.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(o)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.SyscallN(procCreateIoCompletionPort.Addr(), uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount))
|
||||||
|
newport = windows.Handle(r0)
|
||||||
|
if newport == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||||
|
var _p0 *uint16
|
||||||
|
_p0, err = syscall.UTF16PtrFromString(name)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.SyscallN(procCreateNamedPipeW.Addr(), uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)))
|
||||||
|
handle = windows.Handle(r0)
|
||||||
|
if handle == windows.InvalidHandle {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func disconnectNamedPipe(pipe windows.Handle) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procDisconnectNamedPipe.Addr(), uintptr(pipe))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCurrentThread() (h windows.Handle) {
|
||||||
|
r0, _, _ := syscall.SyscallN(procGetCurrentThread.Addr())
|
||||||
|
h = windows.Handle(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procGetNamedPipeHandleStateW.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procGetNamedPipeInfo.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procGetQueuedCompletionStatus.Addr(), uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procSetFileCompletionNotificationModes.Addr(), uintptr(h), uintptr(flags))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) {
|
||||||
|
r0, _, _ := syscall.SyscallN(procNtCreateNamedPipeFile.Addr(), uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)))
|
||||||
|
status = ntStatus(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) {
|
||||||
|
r0, _, _ := syscall.SyscallN(procRtlDefaultNpAcl.Addr(), uintptr(unsafe.Pointer(dacl)))
|
||||||
|
status = ntStatus(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) {
|
||||||
|
r0, _, _ := syscall.SyscallN(procRtlDosPathNameToNtPathName_U.Addr(), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved))
|
||||||
|
status = ntStatus(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func rtlNtStatusToDosError(status ntStatus) (winerr error) {
|
||||||
|
r0, _, _ := syscall.SyscallN(procRtlNtStatusToDosErrorNoTeb.Addr(), uintptr(status))
|
||||||
|
if r0 != 0 {
|
||||||
|
winerr = syscall.Errno(r0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
||||||
|
var _p0 uint32
|
||||||
|
if wait {
|
||||||
|
_p0 = 1
|
||||||
|
}
|
||||||
|
r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)))
|
||||||
|
if r1 == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
191
vendor/github.com/containerd/errdefs/LICENSE
generated
vendored
Normal file
191
vendor/github.com/containerd/errdefs/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
https://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
Copyright The containerd Authors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
13
vendor/github.com/containerd/errdefs/README.md
generated
vendored
Normal file
13
vendor/github.com/containerd/errdefs/README.md
generated
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# errdefs
|
||||||
|
|
||||||
|
A Go package for defining and checking common containerd errors.
|
||||||
|
|
||||||
|
## Project details
|
||||||
|
|
||||||
|
**errdefs** is a containerd sub-project, licensed under the [Apache 2.0 license](./LICENSE).
|
||||||
|
As a containerd sub-project, you will find the:
|
||||||
|
* [Project governance](https://github.com/containerd/project/blob/main/GOVERNANCE.md),
|
||||||
|
* [Maintainers](https://github.com/containerd/project/blob/main/MAINTAINERS),
|
||||||
|
* and [Contributing guidelines](https://github.com/containerd/project/blob/main/CONTRIBUTING.md)
|
||||||
|
|
||||||
|
information in our [`containerd/project`](https://github.com/containerd/project) repository.
|
||||||
443
vendor/github.com/containerd/errdefs/errors.go
generated
vendored
Normal file
443
vendor/github.com/containerd/errdefs/errors.go
generated
vendored
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
/*
|
||||||
|
Copyright The containerd Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package errdefs defines the common errors used throughout containerd
|
||||||
|
// packages.
|
||||||
|
//
|
||||||
|
// Use with fmt.Errorf to add context to an error.
|
||||||
|
//
|
||||||
|
// To detect an error class, use the IsXXX functions to tell whether an error
|
||||||
|
// is of a certain type.
|
||||||
|
package errdefs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Definitions of common error types used throughout containerd. All containerd
|
||||||
|
// errors returned by most packages will map into one of these errors classes.
|
||||||
|
// Packages should return errors of these types when they want to instruct a
|
||||||
|
// client to take a particular action.
|
||||||
|
//
|
||||||
|
// These errors map closely to grpc errors.
|
||||||
|
var (
|
||||||
|
ErrUnknown = errUnknown{}
|
||||||
|
ErrInvalidArgument = errInvalidArgument{}
|
||||||
|
ErrNotFound = errNotFound{}
|
||||||
|
ErrAlreadyExists = errAlreadyExists{}
|
||||||
|
ErrPermissionDenied = errPermissionDenied{}
|
||||||
|
ErrResourceExhausted = errResourceExhausted{}
|
||||||
|
ErrFailedPrecondition = errFailedPrecondition{}
|
||||||
|
ErrConflict = errConflict{}
|
||||||
|
ErrNotModified = errNotModified{}
|
||||||
|
ErrAborted = errAborted{}
|
||||||
|
ErrOutOfRange = errOutOfRange{}
|
||||||
|
ErrNotImplemented = errNotImplemented{}
|
||||||
|
ErrInternal = errInternal{}
|
||||||
|
ErrUnavailable = errUnavailable{}
|
||||||
|
ErrDataLoss = errDataLoss{}
|
||||||
|
ErrUnauthenticated = errUnauthorized{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// cancelled maps to Moby's "ErrCancelled"
|
||||||
|
type cancelled interface {
|
||||||
|
Cancelled()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCanceled returns true if the error is due to `context.Canceled`.
|
||||||
|
func IsCanceled(err error) bool {
|
||||||
|
return errors.Is(err, context.Canceled) || isInterface[cancelled](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errUnknown struct{}
|
||||||
|
|
||||||
|
func (errUnknown) Error() string { return "unknown" }
|
||||||
|
|
||||||
|
func (errUnknown) Unknown() {}
|
||||||
|
|
||||||
|
func (e errUnknown) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unknown maps to Moby's "ErrUnknown"
|
||||||
|
type unknown interface {
|
||||||
|
Unknown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUnknown returns true if the error is due to an unknown error,
|
||||||
|
// unhandled condition or unexpected response.
|
||||||
|
func IsUnknown(err error) bool {
|
||||||
|
return errors.Is(err, errUnknown{}) || isInterface[unknown](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errInvalidArgument struct{}
|
||||||
|
|
||||||
|
func (errInvalidArgument) Error() string { return "invalid argument" }
|
||||||
|
|
||||||
|
func (errInvalidArgument) InvalidParameter() {}
|
||||||
|
|
||||||
|
func (e errInvalidArgument) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalidParameter maps to Moby's "ErrInvalidParameter"
|
||||||
|
type invalidParameter interface {
|
||||||
|
InvalidParameter()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInvalidArgument returns true if the error is due to an invalid argument
|
||||||
|
func IsInvalidArgument(err error) bool {
|
||||||
|
return errors.Is(err, ErrInvalidArgument) || isInterface[invalidParameter](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// deadlineExceed maps to Moby's "ErrDeadline"
|
||||||
|
type deadlineExceeded interface {
|
||||||
|
DeadlineExceeded()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDeadlineExceeded returns true if the error is due to
|
||||||
|
// `context.DeadlineExceeded`.
|
||||||
|
func IsDeadlineExceeded(err error) bool {
|
||||||
|
return errors.Is(err, context.DeadlineExceeded) || isInterface[deadlineExceeded](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errNotFound struct{}
|
||||||
|
|
||||||
|
func (errNotFound) Error() string { return "not found" }
|
||||||
|
|
||||||
|
func (errNotFound) NotFound() {}
|
||||||
|
|
||||||
|
func (e errNotFound) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// notFound maps to Moby's "ErrNotFound"
|
||||||
|
type notFound interface {
|
||||||
|
NotFound()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotFound returns true if the error is due to a missing object
|
||||||
|
func IsNotFound(err error) bool {
|
||||||
|
return errors.Is(err, ErrNotFound) || isInterface[notFound](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errAlreadyExists struct{}
|
||||||
|
|
||||||
|
func (errAlreadyExists) Error() string { return "already exists" }
|
||||||
|
|
||||||
|
func (errAlreadyExists) AlreadyExists() {}
|
||||||
|
|
||||||
|
func (e errAlreadyExists) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
type alreadyExists interface {
|
||||||
|
AlreadyExists()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAlreadyExists returns true if the error is due to an already existing
|
||||||
|
// metadata item
|
||||||
|
func IsAlreadyExists(err error) bool {
|
||||||
|
return errors.Is(err, ErrAlreadyExists) || isInterface[alreadyExists](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errPermissionDenied struct{}
|
||||||
|
|
||||||
|
func (errPermissionDenied) Error() string { return "permission denied" }
|
||||||
|
|
||||||
|
func (errPermissionDenied) Forbidden() {}
|
||||||
|
|
||||||
|
func (e errPermissionDenied) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// forbidden maps to Moby's "ErrForbidden"
|
||||||
|
type forbidden interface {
|
||||||
|
Forbidden()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPermissionDenied returns true if the error is due to permission denied
|
||||||
|
// or forbidden (403) response
|
||||||
|
func IsPermissionDenied(err error) bool {
|
||||||
|
return errors.Is(err, ErrPermissionDenied) || isInterface[forbidden](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errResourceExhausted struct{}
|
||||||
|
|
||||||
|
func (errResourceExhausted) Error() string { return "resource exhausted" }
|
||||||
|
|
||||||
|
func (errResourceExhausted) ResourceExhausted() {}
|
||||||
|
|
||||||
|
func (e errResourceExhausted) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
type resourceExhausted interface {
|
||||||
|
ResourceExhausted()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsResourceExhausted returns true if the error is due to
|
||||||
|
// a lack of resources or too many attempts.
|
||||||
|
func IsResourceExhausted(err error) bool {
|
||||||
|
return errors.Is(err, errResourceExhausted{}) || isInterface[resourceExhausted](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errFailedPrecondition struct{}
|
||||||
|
|
||||||
|
func (e errFailedPrecondition) Error() string { return "failed precondition" }
|
||||||
|
|
||||||
|
func (errFailedPrecondition) FailedPrecondition() {}
|
||||||
|
|
||||||
|
func (e errFailedPrecondition) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
type failedPrecondition interface {
|
||||||
|
FailedPrecondition()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFailedPrecondition returns true if an operation could not proceed due to
|
||||||
|
// the lack of a particular condition
|
||||||
|
func IsFailedPrecondition(err error) bool {
|
||||||
|
return errors.Is(err, errFailedPrecondition{}) || isInterface[failedPrecondition](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errConflict struct{}
|
||||||
|
|
||||||
|
func (errConflict) Error() string { return "conflict" }
|
||||||
|
|
||||||
|
func (errConflict) Conflict() {}
|
||||||
|
|
||||||
|
func (e errConflict) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// conflict maps to Moby's "ErrConflict"
|
||||||
|
type conflict interface {
|
||||||
|
Conflict()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConflict returns true if an operation could not proceed due to
|
||||||
|
// a conflict.
|
||||||
|
func IsConflict(err error) bool {
|
||||||
|
return errors.Is(err, errConflict{}) || isInterface[conflict](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errNotModified struct{}
|
||||||
|
|
||||||
|
func (errNotModified) Error() string { return "not modified" }
|
||||||
|
|
||||||
|
func (errNotModified) NotModified() {}
|
||||||
|
|
||||||
|
func (e errNotModified) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// notModified maps to Moby's "ErrNotModified"
|
||||||
|
type notModified interface {
|
||||||
|
NotModified()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotModified returns true if an operation could not proceed due
|
||||||
|
// to an object not modified from a previous state.
|
||||||
|
func IsNotModified(err error) bool {
|
||||||
|
return errors.Is(err, errNotModified{}) || isInterface[notModified](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errAborted struct{}
|
||||||
|
|
||||||
|
func (errAborted) Error() string { return "aborted" }
|
||||||
|
|
||||||
|
func (errAborted) Aborted() {}
|
||||||
|
|
||||||
|
func (e errAborted) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
type aborted interface {
|
||||||
|
Aborted()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAborted returns true if an operation was aborted.
|
||||||
|
func IsAborted(err error) bool {
|
||||||
|
return errors.Is(err, errAborted{}) || isInterface[aborted](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errOutOfRange struct{}
|
||||||
|
|
||||||
|
func (errOutOfRange) Error() string { return "out of range" }
|
||||||
|
|
||||||
|
func (errOutOfRange) OutOfRange() {}
|
||||||
|
|
||||||
|
func (e errOutOfRange) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
type outOfRange interface {
|
||||||
|
OutOfRange()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOutOfRange returns true if an operation could not proceed due
|
||||||
|
// to data being out of the expected range.
|
||||||
|
func IsOutOfRange(err error) bool {
|
||||||
|
return errors.Is(err, errOutOfRange{}) || isInterface[outOfRange](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errNotImplemented struct{}
|
||||||
|
|
||||||
|
func (errNotImplemented) Error() string { return "not implemented" }
|
||||||
|
|
||||||
|
func (errNotImplemented) NotImplemented() {}
|
||||||
|
|
||||||
|
func (e errNotImplemented) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// notImplemented maps to Moby's "ErrNotImplemented"
|
||||||
|
type notImplemented interface {
|
||||||
|
NotImplemented()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotImplemented returns true if the error is due to not being implemented
|
||||||
|
func IsNotImplemented(err error) bool {
|
||||||
|
return errors.Is(err, errNotImplemented{}) || isInterface[notImplemented](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errInternal struct{}
|
||||||
|
|
||||||
|
func (errInternal) Error() string { return "internal" }
|
||||||
|
|
||||||
|
func (errInternal) System() {}
|
||||||
|
|
||||||
|
func (e errInternal) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// system maps to Moby's "ErrSystem"
|
||||||
|
type system interface {
|
||||||
|
System()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInternal returns true if the error returns to an internal or system error
|
||||||
|
func IsInternal(err error) bool {
|
||||||
|
return errors.Is(err, errInternal{}) || isInterface[system](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errUnavailable struct{}
|
||||||
|
|
||||||
|
func (errUnavailable) Error() string { return "unavailable" }
|
||||||
|
|
||||||
|
func (errUnavailable) Unavailable() {}
|
||||||
|
|
||||||
|
func (e errUnavailable) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unavailable maps to Moby's "ErrUnavailable"
|
||||||
|
type unavailable interface {
|
||||||
|
Unavailable()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUnavailable returns true if the error is due to a resource being unavailable
|
||||||
|
func IsUnavailable(err error) bool {
|
||||||
|
return errors.Is(err, errUnavailable{}) || isInterface[unavailable](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errDataLoss struct{}
|
||||||
|
|
||||||
|
func (errDataLoss) Error() string { return "data loss" }
|
||||||
|
|
||||||
|
func (errDataLoss) DataLoss() {}
|
||||||
|
|
||||||
|
func (e errDataLoss) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dataLoss maps to Moby's "ErrDataLoss"
|
||||||
|
type dataLoss interface {
|
||||||
|
DataLoss()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDataLoss returns true if data during an operation was lost or corrupted
|
||||||
|
func IsDataLoss(err error) bool {
|
||||||
|
return errors.Is(err, errDataLoss{}) || isInterface[dataLoss](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errUnauthorized struct{}
|
||||||
|
|
||||||
|
func (errUnauthorized) Error() string { return "unauthorized" }
|
||||||
|
|
||||||
|
func (errUnauthorized) Unauthorized() {}
|
||||||
|
|
||||||
|
func (e errUnauthorized) WithMessage(msg string) error {
|
||||||
|
return customMessage{e, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unauthorized maps to Moby's "ErrUnauthorized"
|
||||||
|
type unauthorized interface {
|
||||||
|
Unauthorized()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUnauthorized returns true if the error indicates that the user was
|
||||||
|
// unauthenticated or unauthorized.
|
||||||
|
func IsUnauthorized(err error) bool {
|
||||||
|
return errors.Is(err, errUnauthorized{}) || isInterface[unauthorized](err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isInterface[T any](err error) bool {
|
||||||
|
for {
|
||||||
|
switch x := err.(type) {
|
||||||
|
case T:
|
||||||
|
return true
|
||||||
|
case customMessage:
|
||||||
|
err = x.err
|
||||||
|
case interface{ Unwrap() error }:
|
||||||
|
err = x.Unwrap()
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case interface{ Unwrap() []error }:
|
||||||
|
for _, err := range x.Unwrap() {
|
||||||
|
if isInterface[T](err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// customMessage is used to provide a defined error with a custom message.
|
||||||
|
// The message is not wrapped but can be compared by the `Is(error) bool` interface.
|
||||||
|
type customMessage struct {
|
||||||
|
err error
|
||||||
|
msg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c customMessage) Is(err error) bool {
|
||||||
|
return c.err == err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c customMessage) As(target any) bool {
|
||||||
|
return errors.As(c.err, target)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c customMessage) Error() string {
|
||||||
|
return c.msg
|
||||||
|
}
|
||||||
191
vendor/github.com/containerd/errdefs/pkg/LICENSE
generated
vendored
Normal file
191
vendor/github.com/containerd/errdefs/pkg/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
https://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
Copyright The containerd Authors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
96
vendor/github.com/containerd/errdefs/pkg/errhttp/http.go
generated
vendored
Normal file
96
vendor/github.com/containerd/errdefs/pkg/errhttp/http.go
generated
vendored
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
/*
|
||||||
|
Copyright The containerd Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package errhttp provides utility functions for translating errors to
|
||||||
|
// and from a HTTP context.
|
||||||
|
//
|
||||||
|
// The functions ToHTTP and ToNative can be used to map server-side and
|
||||||
|
// client-side errors to the correct types.
|
||||||
|
package errhttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/containerd/errdefs"
|
||||||
|
"github.com/containerd/errdefs/pkg/internal/cause"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToHTTP returns the best status code for the given error
|
||||||
|
func ToHTTP(err error) int {
|
||||||
|
switch {
|
||||||
|
case errdefs.IsNotFound(err):
|
||||||
|
return http.StatusNotFound
|
||||||
|
case errdefs.IsInvalidArgument(err):
|
||||||
|
return http.StatusBadRequest
|
||||||
|
case errdefs.IsConflict(err):
|
||||||
|
return http.StatusConflict
|
||||||
|
case errdefs.IsNotModified(err):
|
||||||
|
return http.StatusNotModified
|
||||||
|
case errdefs.IsFailedPrecondition(err):
|
||||||
|
return http.StatusPreconditionFailed
|
||||||
|
case errdefs.IsUnauthorized(err):
|
||||||
|
return http.StatusUnauthorized
|
||||||
|
case errdefs.IsPermissionDenied(err):
|
||||||
|
return http.StatusForbidden
|
||||||
|
case errdefs.IsResourceExhausted(err):
|
||||||
|
return http.StatusTooManyRequests
|
||||||
|
case errdefs.IsInternal(err):
|
||||||
|
return http.StatusInternalServerError
|
||||||
|
case errdefs.IsNotImplemented(err):
|
||||||
|
return http.StatusNotImplemented
|
||||||
|
case errdefs.IsUnavailable(err):
|
||||||
|
return http.StatusServiceUnavailable
|
||||||
|
case errdefs.IsUnknown(err):
|
||||||
|
var unexpected cause.ErrUnexpectedStatus
|
||||||
|
if errors.As(err, &unexpected) && unexpected.Status >= 200 && unexpected.Status < 600 {
|
||||||
|
return unexpected.Status
|
||||||
|
}
|
||||||
|
return http.StatusInternalServerError
|
||||||
|
default:
|
||||||
|
return http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToNative returns the error best matching the HTTP status code
|
||||||
|
func ToNative(statusCode int) error {
|
||||||
|
switch statusCode {
|
||||||
|
case http.StatusNotFound:
|
||||||
|
return errdefs.ErrNotFound
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
return errdefs.ErrInvalidArgument
|
||||||
|
case http.StatusConflict:
|
||||||
|
return errdefs.ErrConflict
|
||||||
|
case http.StatusPreconditionFailed:
|
||||||
|
return errdefs.ErrFailedPrecondition
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return errdefs.ErrUnauthenticated
|
||||||
|
case http.StatusForbidden:
|
||||||
|
return errdefs.ErrPermissionDenied
|
||||||
|
case http.StatusNotModified:
|
||||||
|
return errdefs.ErrNotModified
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
return errdefs.ErrResourceExhausted
|
||||||
|
case http.StatusInternalServerError:
|
||||||
|
return errdefs.ErrInternal
|
||||||
|
case http.StatusNotImplemented:
|
||||||
|
return errdefs.ErrNotImplemented
|
||||||
|
case http.StatusServiceUnavailable:
|
||||||
|
return errdefs.ErrUnavailable
|
||||||
|
default:
|
||||||
|
return cause.ErrUnexpectedStatus{Status: statusCode}
|
||||||
|
}
|
||||||
|
}
|
||||||
33
vendor/github.com/containerd/errdefs/pkg/internal/cause/cause.go
generated
vendored
Normal file
33
vendor/github.com/containerd/errdefs/pkg/internal/cause/cause.go
generated
vendored
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
/*
|
||||||
|
Copyright The containerd Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package cause is used to define root causes for errors
|
||||||
|
// common to errors packages like grpc and http.
|
||||||
|
package cause
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type ErrUnexpectedStatus struct {
|
||||||
|
Status int
|
||||||
|
}
|
||||||
|
|
||||||
|
const UnexpectedStatusPrefix = "unexpected status "
|
||||||
|
|
||||||
|
func (e ErrUnexpectedStatus) Error() string {
|
||||||
|
return fmt.Sprintf("%s%d", UnexpectedStatusPrefix, e.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ErrUnexpectedStatus) Unknown() {}
|
||||||
147
vendor/github.com/containerd/errdefs/resolve.go
generated
vendored
Normal file
147
vendor/github.com/containerd/errdefs/resolve.go
generated
vendored
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
/*
|
||||||
|
Copyright The containerd Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package errdefs
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Resolve returns the first error found in the error chain which matches an
|
||||||
|
// error defined in this package or context error. A raw, unwrapped error is
|
||||||
|
// returned or ErrUnknown if no matching error is found.
|
||||||
|
//
|
||||||
|
// This is useful for determining a response code based on the outermost wrapped
|
||||||
|
// error rather than the original cause. For example, a not found error deep
|
||||||
|
// in the code may be wrapped as an invalid argument. When determining status
|
||||||
|
// code from Is* functions, the depth or ordering of the error is not
|
||||||
|
// considered.
|
||||||
|
//
|
||||||
|
// The search order is depth first, a wrapped error returned from any part of
|
||||||
|
// the chain from `Unwrap() error` will be returned before any joined errors
|
||||||
|
// as returned by `Unwrap() []error`.
|
||||||
|
func Resolve(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err = firstError(err)
|
||||||
|
if err == nil {
|
||||||
|
err = ErrUnknown
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstError(err error) error {
|
||||||
|
for {
|
||||||
|
switch err {
|
||||||
|
case ErrUnknown,
|
||||||
|
ErrInvalidArgument,
|
||||||
|
ErrNotFound,
|
||||||
|
ErrAlreadyExists,
|
||||||
|
ErrPermissionDenied,
|
||||||
|
ErrResourceExhausted,
|
||||||
|
ErrFailedPrecondition,
|
||||||
|
ErrConflict,
|
||||||
|
ErrNotModified,
|
||||||
|
ErrAborted,
|
||||||
|
ErrOutOfRange,
|
||||||
|
ErrNotImplemented,
|
||||||
|
ErrInternal,
|
||||||
|
ErrUnavailable,
|
||||||
|
ErrDataLoss,
|
||||||
|
ErrUnauthenticated,
|
||||||
|
context.DeadlineExceeded,
|
||||||
|
context.Canceled:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch e := err.(type) {
|
||||||
|
case customMessage:
|
||||||
|
err = e.err
|
||||||
|
case unknown:
|
||||||
|
return ErrUnknown
|
||||||
|
case invalidParameter:
|
||||||
|
return ErrInvalidArgument
|
||||||
|
case notFound:
|
||||||
|
return ErrNotFound
|
||||||
|
case alreadyExists:
|
||||||
|
return ErrAlreadyExists
|
||||||
|
case forbidden:
|
||||||
|
return ErrPermissionDenied
|
||||||
|
case resourceExhausted:
|
||||||
|
return ErrResourceExhausted
|
||||||
|
case failedPrecondition:
|
||||||
|
return ErrFailedPrecondition
|
||||||
|
case conflict:
|
||||||
|
return ErrConflict
|
||||||
|
case notModified:
|
||||||
|
return ErrNotModified
|
||||||
|
case aborted:
|
||||||
|
return ErrAborted
|
||||||
|
case errOutOfRange:
|
||||||
|
return ErrOutOfRange
|
||||||
|
case notImplemented:
|
||||||
|
return ErrNotImplemented
|
||||||
|
case system:
|
||||||
|
return ErrInternal
|
||||||
|
case unavailable:
|
||||||
|
return ErrUnavailable
|
||||||
|
case dataLoss:
|
||||||
|
return ErrDataLoss
|
||||||
|
case unauthorized:
|
||||||
|
return ErrUnauthenticated
|
||||||
|
case deadlineExceeded:
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
case cancelled:
|
||||||
|
return context.Canceled
|
||||||
|
case interface{ Unwrap() error }:
|
||||||
|
err = e.Unwrap()
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case interface{ Unwrap() []error }:
|
||||||
|
for _, ue := range e.Unwrap() {
|
||||||
|
if fe := firstError(ue); fe != nil {
|
||||||
|
return fe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case interface{ Is(error) bool }:
|
||||||
|
for _, target := range []error{ErrUnknown,
|
||||||
|
ErrInvalidArgument,
|
||||||
|
ErrNotFound,
|
||||||
|
ErrAlreadyExists,
|
||||||
|
ErrPermissionDenied,
|
||||||
|
ErrResourceExhausted,
|
||||||
|
ErrFailedPrecondition,
|
||||||
|
ErrConflict,
|
||||||
|
ErrNotModified,
|
||||||
|
ErrAborted,
|
||||||
|
ErrOutOfRange,
|
||||||
|
ErrNotImplemented,
|
||||||
|
ErrInternal,
|
||||||
|
ErrUnavailable,
|
||||||
|
ErrDataLoss,
|
||||||
|
ErrUnauthenticated,
|
||||||
|
context.DeadlineExceeded,
|
||||||
|
context.Canceled} {
|
||||||
|
if e.Is(target) {
|
||||||
|
return target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1
vendor/github.com/distribution/reference/.gitattributes
generated
vendored
Normal file
1
vendor/github.com/distribution/reference/.gitattributes
generated
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
*.go text eol=lf
|
||||||
2
vendor/github.com/distribution/reference/.gitignore
generated
vendored
Normal file
2
vendor/github.com/distribution/reference/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Cover profiles
|
||||||
|
*.out
|
||||||
18
vendor/github.com/distribution/reference/.golangci.yml
generated
vendored
Normal file
18
vendor/github.com/distribution/reference/.golangci.yml
generated
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
linters:
|
||||||
|
enable:
|
||||||
|
- bodyclose
|
||||||
|
- dupword # Checks for duplicate words in the source code
|
||||||
|
- gofmt
|
||||||
|
- goimports
|
||||||
|
- ineffassign
|
||||||
|
- misspell
|
||||||
|
- revive
|
||||||
|
- staticcheck
|
||||||
|
- unconvert
|
||||||
|
- unused
|
||||||
|
- vet
|
||||||
|
disable:
|
||||||
|
- errcheck
|
||||||
|
|
||||||
|
run:
|
||||||
|
deadline: 2m
|
||||||
5
vendor/github.com/distribution/reference/CODE-OF-CONDUCT.md
generated
vendored
Normal file
5
vendor/github.com/distribution/reference/CODE-OF-CONDUCT.md
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Code of Conduct
|
||||||
|
|
||||||
|
We follow the [CNCF Code of Conduct](https://github.com/cncf/foundation/blob/main/code-of-conduct.md).
|
||||||
|
|
||||||
|
Please contact the [CNCF Code of Conduct Committee](mailto:conduct@cncf.io) in order to report violations of the Code of Conduct.
|
||||||
114
vendor/github.com/distribution/reference/CONTRIBUTING.md
generated
vendored
Normal file
114
vendor/github.com/distribution/reference/CONTRIBUTING.md
generated
vendored
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
# Contributing to the reference library
|
||||||
|
|
||||||
|
## Community help
|
||||||
|
|
||||||
|
If you need help, please ask in the [#distribution](https://cloud-native.slack.com/archives/C01GVR8SY4R) channel on CNCF community slack.
|
||||||
|
[Click here for an invite to the CNCF community slack](https://slack.cncf.io/)
|
||||||
|
|
||||||
|
## Reporting security issues
|
||||||
|
|
||||||
|
The maintainers take security seriously. If you discover a security
|
||||||
|
issue, please bring it to their attention right away!
|
||||||
|
|
||||||
|
Please **DO NOT** file a public issue, instead send your report privately to
|
||||||
|
[cncf-distribution-security@lists.cncf.io](mailto:cncf-distribution-security@lists.cncf.io).
|
||||||
|
|
||||||
|
## Reporting an issue properly
|
||||||
|
|
||||||
|
By following these simple rules you will get better and faster feedback on your issue.
|
||||||
|
|
||||||
|
- search the bugtracker for an already reported issue
|
||||||
|
|
||||||
|
### If you found an issue that describes your problem:
|
||||||
|
|
||||||
|
- please read other user comments first, and confirm this is the same issue: a given error condition might be indicative of different problems - you may also find a workaround in the comments
|
||||||
|
- please refrain from adding "same thing here" or "+1" comments
|
||||||
|
- you don't need to comment on an issue to get notified of updates: just hit the "subscribe" button
|
||||||
|
- comment if you have some new, technical and relevant information to add to the case
|
||||||
|
- __DO NOT__ comment on closed issues or merged PRs. If you think you have a related problem, open up a new issue and reference the PR or issue.
|
||||||
|
|
||||||
|
### If you have not found an existing issue that describes your problem:
|
||||||
|
|
||||||
|
1. create a new issue, with a succinct title that describes your issue:
|
||||||
|
- bad title: "It doesn't work with my docker"
|
||||||
|
- good title: "Private registry push fail: 400 error with E_INVALID_DIGEST"
|
||||||
|
2. copy the output of (or similar for other container tools):
|
||||||
|
- `docker version`
|
||||||
|
- `docker info`
|
||||||
|
- `docker exec <registry-container> registry --version`
|
||||||
|
3. copy the command line you used to launch your Registry
|
||||||
|
4. restart your docker daemon in debug mode (add `-D` to the daemon launch arguments)
|
||||||
|
5. reproduce your problem and get your docker daemon logs showing the error
|
||||||
|
6. if relevant, copy your registry logs that show the error
|
||||||
|
7. provide any relevant detail about your specific Registry configuration (e.g., storage backend used)
|
||||||
|
8. indicate if you are using an enterprise proxy, Nginx, or anything else between you and your Registry
|
||||||
|
|
||||||
|
## Contributing Code
|
||||||
|
|
||||||
|
Contributions should be made via pull requests. Pull requests will be reviewed
|
||||||
|
by one or more maintainers or reviewers and merged when acceptable.
|
||||||
|
|
||||||
|
You should follow the basic GitHub workflow:
|
||||||
|
|
||||||
|
1. Use your own [fork](https://help.github.com/en/articles/about-forks)
|
||||||
|
2. Create your [change](https://github.com/containerd/project/blob/master/CONTRIBUTING.md#successful-changes)
|
||||||
|
3. Test your code
|
||||||
|
4. [Commit](https://github.com/containerd/project/blob/master/CONTRIBUTING.md#commit-messages) your work, always [sign your commits](https://github.com/containerd/project/blob/master/CONTRIBUTING.md#commit-messages)
|
||||||
|
5. Push your change to your fork and create a [Pull Request](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request-from-a-fork)
|
||||||
|
|
||||||
|
Refer to [containerd's contribution guide](https://github.com/containerd/project/blob/master/CONTRIBUTING.md#successful-changes)
|
||||||
|
for tips on creating a successful contribution.
|
||||||
|
|
||||||
|
## Sign your work
|
||||||
|
|
||||||
|
The sign-off is a simple line at the end of the explanation for the patch. Your
|
||||||
|
signature certifies that you wrote the patch or otherwise have the right to pass
|
||||||
|
it on as an open-source patch. The rules are pretty simple: if you can certify
|
||||||
|
the below (from [developercertificate.org](http://developercertificate.org/)):
|
||||||
|
|
||||||
|
```
|
||||||
|
Developer Certificate of Origin
|
||||||
|
Version 1.1
|
||||||
|
|
||||||
|
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
|
||||||
|
660 York Street, Suite 102,
|
||||||
|
San Francisco, CA 94110 USA
|
||||||
|
|
||||||
|
Everyone is permitted to copy and distribute verbatim copies of this
|
||||||
|
license document, but changing it is not allowed.
|
||||||
|
|
||||||
|
Developer's Certificate of Origin 1.1
|
||||||
|
|
||||||
|
By making a contribution to this project, I certify that:
|
||||||
|
|
||||||
|
(a) The contribution was created in whole or in part by me and I
|
||||||
|
have the right to submit it under the open source license
|
||||||
|
indicated in the file; or
|
||||||
|
|
||||||
|
(b) The contribution is based upon previous work that, to the best
|
||||||
|
of my knowledge, is covered under an appropriate open source
|
||||||
|
license and I have the right under that license to submit that
|
||||||
|
work with modifications, whether created in whole or in part
|
||||||
|
by me, under the same open source license (unless I am
|
||||||
|
permitted to submit under a different license), as indicated
|
||||||
|
in the file; or
|
||||||
|
|
||||||
|
(c) The contribution was provided directly to me by some other
|
||||||
|
person who certified (a), (b) or (c) and I have not modified
|
||||||
|
it.
|
||||||
|
|
||||||
|
(d) I understand and agree that this project and the contribution
|
||||||
|
are public and that a record of the contribution (including all
|
||||||
|
personal information I submit with it, including my sign-off) is
|
||||||
|
maintained indefinitely and may be redistributed consistent with
|
||||||
|
this project or the open source license(s) involved.
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you just add a line to every git commit message:
|
||||||
|
|
||||||
|
Signed-off-by: Joe Smith <joe.smith@email.com>
|
||||||
|
|
||||||
|
Use your real name (sorry, no pseudonyms or anonymous contributions.)
|
||||||
|
|
||||||
|
If you set your `user.name` and `user.email` git configs, you can sign your
|
||||||
|
commit automatically with `git commit -s`.
|
||||||
144
vendor/github.com/distribution/reference/GOVERNANCE.md
generated
vendored
Normal file
144
vendor/github.com/distribution/reference/GOVERNANCE.md
generated
vendored
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
# distribution/reference Project Governance
|
||||||
|
|
||||||
|
Distribution [Code of Conduct](./CODE-OF-CONDUCT.md) can be found here.
|
||||||
|
|
||||||
|
For specific guidance on practical contribution steps please
|
||||||
|
see our [CONTRIBUTING.md](./CONTRIBUTING.md) guide.
|
||||||
|
|
||||||
|
## Maintainership
|
||||||
|
|
||||||
|
There are different types of maintainers, with different responsibilities, but
|
||||||
|
all maintainers have 3 things in common:
|
||||||
|
|
||||||
|
1) They share responsibility in the project's success.
|
||||||
|
2) They have made a long-term, recurring time investment to improve the project.
|
||||||
|
3) They spend that time doing whatever needs to be done, not necessarily what
|
||||||
|
is the most interesting or fun.
|
||||||
|
|
||||||
|
Maintainers are often under-appreciated, because their work is harder to appreciate.
|
||||||
|
It's easy to appreciate a really cool and technically advanced feature. It's harder
|
||||||
|
to appreciate the absence of bugs, the slow but steady improvement in stability,
|
||||||
|
or the reliability of a release process. But those things distinguish a good
|
||||||
|
project from a great one.
|
||||||
|
|
||||||
|
## Reviewers
|
||||||
|
|
||||||
|
A reviewer is a core role within the project.
|
||||||
|
They share in reviewing issues and pull requests and their LGTM counts towards the
|
||||||
|
required LGTM count to merge a code change into the project.
|
||||||
|
|
||||||
|
Reviewers are part of the organization but do not have write access.
|
||||||
|
Becoming a reviewer is a core aspect in the journey to becoming a maintainer.
|
||||||
|
|
||||||
|
## Adding maintainers
|
||||||
|
|
||||||
|
Maintainers are first and foremost contributors that have shown they are
|
||||||
|
committed to the long term success of a project. Contributors wanting to become
|
||||||
|
maintainers are expected to be deeply involved in contributing code, pull
|
||||||
|
request review, and triage of issues in the project for more than three months.
|
||||||
|
|
||||||
|
Just contributing does not make you a maintainer, it is about building trust
|
||||||
|
with the current maintainers of the project and being a person that they can
|
||||||
|
depend on and trust to make decisions in the best interest of the project.
|
||||||
|
|
||||||
|
Periodically, the existing maintainers curate a list of contributors that have
|
||||||
|
shown regular activity on the project over the prior months. From this list,
|
||||||
|
maintainer candidates are selected and proposed in a pull request or a
|
||||||
|
maintainers communication channel.
|
||||||
|
|
||||||
|
After a candidate has been announced to the maintainers, the existing
|
||||||
|
maintainers are given five business days to discuss the candidate, raise
|
||||||
|
objections and cast their vote. Votes may take place on the communication
|
||||||
|
channel or via pull request comment. Candidates must be approved by at least 66%
|
||||||
|
of the current maintainers by adding their vote on the mailing list. The
|
||||||
|
reviewer role has the same process but only requires 33% of current maintainers.
|
||||||
|
Only maintainers of the repository that the candidate is proposed for are
|
||||||
|
allowed to vote.
|
||||||
|
|
||||||
|
If a candidate is approved, a maintainer will contact the candidate to invite
|
||||||
|
the candidate to open a pull request that adds the contributor to the
|
||||||
|
MAINTAINERS file. The voting process may take place inside a pull request if a
|
||||||
|
maintainer has already discussed the candidacy with the candidate and a
|
||||||
|
maintainer is willing to be a sponsor by opening the pull request. The candidate
|
||||||
|
becomes a maintainer once the pull request is merged.
|
||||||
|
|
||||||
|
## Stepping down policy
|
||||||
|
|
||||||
|
Life priorities, interests, and passions can change. If you're a maintainer but
|
||||||
|
feel you must remove yourself from the list, inform other maintainers that you
|
||||||
|
intend to step down, and if possible, help find someone to pick up your work.
|
||||||
|
At the very least, ensure your work can be continued where you left off.
|
||||||
|
|
||||||
|
After you've informed other maintainers, create a pull request to remove
|
||||||
|
yourself from the MAINTAINERS file.
|
||||||
|
|
||||||
|
## Removal of inactive maintainers
|
||||||
|
|
||||||
|
Similar to the procedure for adding new maintainers, existing maintainers can
|
||||||
|
be removed from the list if they do not show significant activity on the
|
||||||
|
project. Periodically, the maintainers review the list of maintainers and their
|
||||||
|
activity over the last three months.
|
||||||
|
|
||||||
|
If a maintainer has shown insufficient activity over this period, a neutral
|
||||||
|
person will contact the maintainer to ask if they want to continue being
|
||||||
|
a maintainer. If the maintainer decides to step down as a maintainer, they
|
||||||
|
open a pull request to be removed from the MAINTAINERS file.
|
||||||
|
|
||||||
|
If the maintainer wants to remain a maintainer, but is unable to perform the
|
||||||
|
required duties they can be removed with a vote of at least 66% of the current
|
||||||
|
maintainers. In this case, maintainers should first propose the change to
|
||||||
|
maintainers via the maintainers communication channel, then open a pull request
|
||||||
|
for voting. The voting period is five business days. The voting pull request
|
||||||
|
should not come as a surpise to any maintainer and any discussion related to
|
||||||
|
performance must not be discussed on the pull request.
|
||||||
|
|
||||||
|
## How are decisions made?
|
||||||
|
|
||||||
|
Docker distribution is an open-source project with an open design philosophy.
|
||||||
|
This means that the repository is the source of truth for EVERY aspect of the
|
||||||
|
project, including its philosophy, design, road map, and APIs. *If it's part of
|
||||||
|
the project, it's in the repo. If it's in the repo, it's part of the project.*
|
||||||
|
|
||||||
|
As a result, all decisions can be expressed as changes to the repository. An
|
||||||
|
implementation change is a change to the source code. An API change is a change
|
||||||
|
to the API specification. A philosophy change is a change to the philosophy
|
||||||
|
manifesto, and so on.
|
||||||
|
|
||||||
|
All decisions affecting distribution, big and small, follow the same 3 steps:
|
||||||
|
|
||||||
|
* Step 1: Open a pull request. Anyone can do this.
|
||||||
|
|
||||||
|
* Step 2: Discuss the pull request. Anyone can do this.
|
||||||
|
|
||||||
|
* Step 3: Merge or refuse the pull request. Who does this depends on the nature
|
||||||
|
of the pull request and which areas of the project it affects.
|
||||||
|
|
||||||
|
## Helping contributors with the DCO
|
||||||
|
|
||||||
|
The [DCO or `Sign your work`](./CONTRIBUTING.md#sign-your-work)
|
||||||
|
requirement is not intended as a roadblock or speed bump.
|
||||||
|
|
||||||
|
Some contributors are not as familiar with `git`, or have used a web
|
||||||
|
based editor, and thus asking them to `git commit --amend -s` is not the best
|
||||||
|
way forward.
|
||||||
|
|
||||||
|
In this case, maintainers can update the commits based on clause (c) of the DCO.
|
||||||
|
The most trivial way for a contributor to allow the maintainer to do this, is to
|
||||||
|
add a DCO signature in a pull requests's comment, or a maintainer can simply
|
||||||
|
note that the change is sufficiently trivial that it does not substantially
|
||||||
|
change the existing contribution - i.e., a spelling change.
|
||||||
|
|
||||||
|
When you add someone's DCO, please also add your own to keep a log.
|
||||||
|
|
||||||
|
## I'm a maintainer. Should I make pull requests too?
|
||||||
|
|
||||||
|
Yes. Nobody should ever push to master directly. All changes should be
|
||||||
|
made through a pull request.
|
||||||
|
|
||||||
|
## Conflict Resolution
|
||||||
|
|
||||||
|
If you have a technical dispute that you feel has reached an impasse with a
|
||||||
|
subset of the community, any contributor may open an issue, specifically
|
||||||
|
calling for a resolution vote of the current core maintainers to resolve the
|
||||||
|
dispute. The same voting quorums required (2/3) for adding and removing
|
||||||
|
maintainers will apply to conflict resolution.
|
||||||
202
vendor/github.com/distribution/reference/LICENSE
generated
vendored
Normal file
202
vendor/github.com/distribution/reference/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright {yyyy} {name of copyright owner}
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
26
vendor/github.com/distribution/reference/MAINTAINERS
generated
vendored
Normal file
26
vendor/github.com/distribution/reference/MAINTAINERS
generated
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Distribution project maintainers & reviewers
|
||||||
|
#
|
||||||
|
# See GOVERNANCE.md for maintainer versus reviewer roles
|
||||||
|
#
|
||||||
|
# MAINTAINERS (cncf-distribution-maintainers@lists.cncf.io)
|
||||||
|
# GitHub ID, Name, Email address
|
||||||
|
"chrispat","Chris Patterson","chrispat@github.com"
|
||||||
|
"clarkbw","Bryan Clark","clarkbw@github.com"
|
||||||
|
"corhere","Cory Snider","csnider@mirantis.com"
|
||||||
|
"deleteriousEffect","Hayley Swimelar","hswimelar@gitlab.com"
|
||||||
|
"heww","He Weiwei","hweiwei@vmware.com"
|
||||||
|
"joaodrp","João Pereira","jpereira@gitlab.com"
|
||||||
|
"justincormack","Justin Cormack","justin.cormack@docker.com"
|
||||||
|
"squizzi","Kyle Squizzato","ksquizzato@mirantis.com"
|
||||||
|
"milosgajdos","Milos Gajdos","milosthegajdos@gmail.com"
|
||||||
|
"sargun","Sargun Dhillon","sargun@sargun.me"
|
||||||
|
"wy65701436","Wang Yan","wangyan@vmware.com"
|
||||||
|
"stevelasker","Steve Lasker","steve.lasker@microsoft.com"
|
||||||
|
#
|
||||||
|
# REVIEWERS
|
||||||
|
# GitHub ID, Name, Email address
|
||||||
|
"dmcgowan","Derek McGowan","derek@mcgstyle.net"
|
||||||
|
"stevvooe","Stephen Day","stevvooe@gmail.com"
|
||||||
|
"thajeztah","Sebastiaan van Stijn","github@gone.nl"
|
||||||
|
"DavidSpek", "David van der Spek", "vanderspek.david@gmail.com"
|
||||||
|
"Jamstah", "James Hewitt", "james.hewitt@gmail.com"
|
||||||
25
vendor/github.com/distribution/reference/Makefile
generated
vendored
Normal file
25
vendor/github.com/distribution/reference/Makefile
generated
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# Project packages.
|
||||||
|
PACKAGES=$(shell go list ./...)
|
||||||
|
|
||||||
|
# Flags passed to `go test`
|
||||||
|
BUILDFLAGS ?=
|
||||||
|
TESTFLAGS ?=
|
||||||
|
|
||||||
|
.PHONY: all build test coverage
|
||||||
|
.DEFAULT: all
|
||||||
|
|
||||||
|
all: build
|
||||||
|
|
||||||
|
build: ## no binaries to build, so just check compilation suceeds
|
||||||
|
go build ${BUILDFLAGS} ./...
|
||||||
|
|
||||||
|
test: ## run tests
|
||||||
|
go test ${TESTFLAGS} ./...
|
||||||
|
|
||||||
|
coverage: ## generate coverprofiles from the unit tests
|
||||||
|
rm -f coverage.txt
|
||||||
|
go test ${TESTFLAGS} -cover -coverprofile=cover.out ./...
|
||||||
|
|
||||||
|
.PHONY: help
|
||||||
|
help:
|
||||||
|
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_\/%-]+:.*?##/ { printf " \033[36m%-27s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)
|
||||||
30
vendor/github.com/distribution/reference/README.md
generated
vendored
Normal file
30
vendor/github.com/distribution/reference/README.md
generated
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Distribution reference
|
||||||
|
|
||||||
|
Go library to handle references to container images.
|
||||||
|
|
||||||
|
<img src="/distribution-logo.svg" width="200px" />
|
||||||
|
|
||||||
|
[](https://github.com/distribution/reference/actions?query=workflow%3ACI)
|
||||||
|
[](https://pkg.go.dev/github.com/distribution/reference)
|
||||||
|
[](LICENSE)
|
||||||
|
[](https://codecov.io/gh/distribution/reference)
|
||||||
|
[](https://app.fossa.com/projects/custom%2B162%2Fgithub.com%2Fdistribution%2Freference?ref=badge_shield)
|
||||||
|
|
||||||
|
This repository contains a library for handling references to container images held in container registries. Please see [godoc](https://pkg.go.dev/github.com/distribution/reference) for details.
|
||||||
|
|
||||||
|
## Contribution
|
||||||
|
|
||||||
|
Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to contribute
|
||||||
|
issues, fixes, and patches to this project.
|
||||||
|
|
||||||
|
## Communication
|
||||||
|
|
||||||
|
For async communication and long running discussions please use issues and pull requests on the github repo.
|
||||||
|
This will be the best place to discuss design and implementation.
|
||||||
|
|
||||||
|
For sync communication we have a #distribution channel in the [CNCF Slack](https://slack.cncf.io/)
|
||||||
|
that everyone is welcome to join and chat about development.
|
||||||
|
|
||||||
|
## Licenses
|
||||||
|
|
||||||
|
The distribution codebase is released under the [Apache 2.0 license](LICENSE).
|
||||||
7
vendor/github.com/distribution/reference/SECURITY.md
generated
vendored
Normal file
7
vendor/github.com/distribution/reference/SECURITY.md
generated
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Security Policy
|
||||||
|
|
||||||
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
|
The maintainers take security seriously. If you discover a security issue, please bring it to their attention right away!
|
||||||
|
|
||||||
|
Please DO NOT file a public issue, instead send your report privately to cncf-distribution-security@lists.cncf.io.
|
||||||
1
vendor/github.com/distribution/reference/distribution-logo.svg
generated
vendored
Normal file
1
vendor/github.com/distribution/reference/distribution-logo.svg
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 8.6 KiB |
42
vendor/github.com/distribution/reference/helpers.go
generated
vendored
Normal file
42
vendor/github.com/distribution/reference/helpers.go
generated
vendored
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package reference
|
||||||
|
|
||||||
|
import "path"
|
||||||
|
|
||||||
|
// IsNameOnly returns true if reference only contains a repo name.
|
||||||
|
func IsNameOnly(ref Named) bool {
|
||||||
|
if _, ok := ref.(NamedTagged); ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, ok := ref.(Canonical); ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// FamiliarName returns the familiar name string
|
||||||
|
// for the given named, familiarizing if needed.
|
||||||
|
func FamiliarName(ref Named) string {
|
||||||
|
if nn, ok := ref.(normalizedNamed); ok {
|
||||||
|
return nn.Familiar().Name()
|
||||||
|
}
|
||||||
|
return ref.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FamiliarString returns the familiar string representation
|
||||||
|
// for the given reference, familiarizing if needed.
|
||||||
|
func FamiliarString(ref Reference) string {
|
||||||
|
if nn, ok := ref.(normalizedNamed); ok {
|
||||||
|
return nn.Familiar().String()
|
||||||
|
}
|
||||||
|
return ref.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FamiliarMatch reports whether ref matches the specified pattern.
|
||||||
|
// See [path.Match] for supported patterns.
|
||||||
|
func FamiliarMatch(pattern string, ref Reference) (bool, error) {
|
||||||
|
matched, err := path.Match(pattern, FamiliarString(ref))
|
||||||
|
if namedRef, isNamed := ref.(Named); isNamed && !matched {
|
||||||
|
matched, _ = path.Match(pattern, FamiliarName(namedRef))
|
||||||
|
}
|
||||||
|
return matched, err
|
||||||
|
}
|
||||||
255
vendor/github.com/distribution/reference/normalize.go
generated
vendored
Normal file
255
vendor/github.com/distribution/reference/normalize.go
generated
vendored
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
package reference
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/opencontainers/go-digest"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// legacyDefaultDomain is the legacy domain for Docker Hub (which was
|
||||||
|
// originally named "the Docker Index"). This domain is still used for
|
||||||
|
// authentication and image search, which were part of the "v1" Docker
|
||||||
|
// registry specification.
|
||||||
|
//
|
||||||
|
// This domain will continue to be supported, but there are plans to consolidate
|
||||||
|
// legacy domains to new "canonical" domains. Once those domains are decided
|
||||||
|
// on, we must update the normalization functions, but preserve compatibility
|
||||||
|
// with existing installs, clients, and user configuration.
|
||||||
|
legacyDefaultDomain = "index.docker.io"
|
||||||
|
|
||||||
|
// defaultDomain is the default domain used for images on Docker Hub.
|
||||||
|
// It is used to normalize "familiar" names to canonical names, for example,
|
||||||
|
// to convert "ubuntu" to "docker.io/library/ubuntu:latest".
|
||||||
|
//
|
||||||
|
// Note that actual domain of Docker Hub's registry is registry-1.docker.io.
|
||||||
|
// This domain will continue to be supported, but there are plans to consolidate
|
||||||
|
// legacy domains to new "canonical" domains. Once those domains are decided
|
||||||
|
// on, we must update the normalization functions, but preserve compatibility
|
||||||
|
// with existing installs, clients, and user configuration.
|
||||||
|
defaultDomain = "docker.io"
|
||||||
|
|
||||||
|
// officialRepoPrefix is the namespace used for official images on Docker Hub.
|
||||||
|
// It is used to normalize "familiar" names to canonical names, for example,
|
||||||
|
// to convert "ubuntu" to "docker.io/library/ubuntu:latest".
|
||||||
|
officialRepoPrefix = "library/"
|
||||||
|
|
||||||
|
// defaultTag is the default tag if no tag is provided.
|
||||||
|
defaultTag = "latest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// normalizedNamed represents a name which has been
|
||||||
|
// normalized and has a familiar form. A familiar name
|
||||||
|
// is what is used in Docker UI. An example normalized
|
||||||
|
// name is "docker.io/library/ubuntu" and corresponding
|
||||||
|
// familiar name of "ubuntu".
|
||||||
|
type normalizedNamed interface {
|
||||||
|
Named
|
||||||
|
Familiar() Named
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseNormalizedNamed parses a string into a named reference
|
||||||
|
// transforming a familiar name from Docker UI to a fully
|
||||||
|
// qualified reference. If the value may be an identifier
|
||||||
|
// use ParseAnyReference.
|
||||||
|
func ParseNormalizedNamed(s string) (Named, error) {
|
||||||
|
if ok := anchoredIdentifierRegexp.MatchString(s); ok {
|
||||||
|
return nil, fmt.Errorf("invalid repository name (%s), cannot specify 64-byte hexadecimal strings", s)
|
||||||
|
}
|
||||||
|
domain, remainder := splitDockerDomain(s)
|
||||||
|
var remote string
|
||||||
|
if tagSep := strings.IndexRune(remainder, ':'); tagSep > -1 {
|
||||||
|
remote = remainder[:tagSep]
|
||||||
|
} else {
|
||||||
|
remote = remainder
|
||||||
|
}
|
||||||
|
if strings.ToLower(remote) != remote {
|
||||||
|
return nil, fmt.Errorf("invalid reference format: repository name (%s) must be lowercase", remote)
|
||||||
|
}
|
||||||
|
|
||||||
|
ref, err := Parse(domain + "/" + remainder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
named, isNamed := ref.(Named)
|
||||||
|
if !isNamed {
|
||||||
|
return nil, fmt.Errorf("reference %s has no name", ref.String())
|
||||||
|
}
|
||||||
|
return named, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// namedTaggedDigested is a reference that has both a tag and a digest.
|
||||||
|
type namedTaggedDigested interface {
|
||||||
|
NamedTagged
|
||||||
|
Digested
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseDockerRef normalizes the image reference following the docker convention,
|
||||||
|
// which allows for references to contain both a tag and a digest. It returns a
|
||||||
|
// reference that is either tagged or digested. For references containing both
|
||||||
|
// a tag and a digest, it returns a digested reference. For example, the following
|
||||||
|
// reference:
|
||||||
|
//
|
||||||
|
// docker.io/library/busybox:latest@sha256:7cc4b5aefd1d0cadf8d97d4350462ba51c694ebca145b08d7d41b41acc8db5aa
|
||||||
|
//
|
||||||
|
// Is returned as a digested reference (with the ":latest" tag removed):
|
||||||
|
//
|
||||||
|
// docker.io/library/busybox@sha256:7cc4b5aefd1d0cadf8d97d4350462ba51c694ebca145b08d7d41b41acc8db5aa
|
||||||
|
//
|
||||||
|
// References that are already "tagged" or "digested" are returned unmodified:
|
||||||
|
//
|
||||||
|
// // Already a digested reference
|
||||||
|
// docker.io/library/busybox@sha256:7cc4b5aefd1d0cadf8d97d4350462ba51c694ebca145b08d7d41b41acc8db5aa
|
||||||
|
//
|
||||||
|
// // Already a named reference
|
||||||
|
// docker.io/library/busybox:latest
|
||||||
|
func ParseDockerRef(ref string) (Named, error) {
|
||||||
|
named, err := ParseNormalizedNamed(ref)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if canonical, ok := named.(namedTaggedDigested); ok {
|
||||||
|
// The reference is both tagged and digested; only return digested.
|
||||||
|
newNamed, err := WithName(canonical.Name())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return WithDigest(newNamed, canonical.Digest())
|
||||||
|
}
|
||||||
|
return TagNameOnly(named), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitDockerDomain splits a repository name to domain and remote-name.
|
||||||
|
// If no valid domain is found, the default domain is used. Repository name
|
||||||
|
// needs to be already validated before.
|
||||||
|
func splitDockerDomain(name string) (domain, remoteName string) {
|
||||||
|
maybeDomain, maybeRemoteName, ok := strings.Cut(name, "/")
|
||||||
|
if !ok {
|
||||||
|
// Fast-path for single element ("familiar" names), such as "ubuntu"
|
||||||
|
// or "ubuntu:latest". Familiar names must be handled separately, to
|
||||||
|
// prevent them from being handled as "hostname:port".
|
||||||
|
//
|
||||||
|
// Canonicalize them as "docker.io/library/name[:tag]"
|
||||||
|
|
||||||
|
// FIXME(thaJeztah): account for bare "localhost" or "example.com" names, which SHOULD be considered a domain.
|
||||||
|
return defaultDomain, officialRepoPrefix + name
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case maybeDomain == localhost:
|
||||||
|
// localhost is a reserved namespace and always considered a domain.
|
||||||
|
domain, remoteName = maybeDomain, maybeRemoteName
|
||||||
|
case maybeDomain == legacyDefaultDomain:
|
||||||
|
// canonicalize the Docker Hub and legacy "Docker Index" domains.
|
||||||
|
domain, remoteName = defaultDomain, maybeRemoteName
|
||||||
|
case strings.ContainsAny(maybeDomain, ".:"):
|
||||||
|
// Likely a domain or IP-address:
|
||||||
|
//
|
||||||
|
// - contains a "." (e.g., "example.com" or "127.0.0.1")
|
||||||
|
// - contains a ":" (e.g., "example:5000", "::1", or "[::1]:5000")
|
||||||
|
domain, remoteName = maybeDomain, maybeRemoteName
|
||||||
|
case strings.ToLower(maybeDomain) != maybeDomain:
|
||||||
|
// Uppercase namespaces are not allowed, so if the first element
|
||||||
|
// is not lowercase, we assume it to be a domain-name.
|
||||||
|
domain, remoteName = maybeDomain, maybeRemoteName
|
||||||
|
default:
|
||||||
|
// None of the above: it's not a domain, so use the default, and
|
||||||
|
// use the name input the remote-name.
|
||||||
|
domain, remoteName = defaultDomain, name
|
||||||
|
}
|
||||||
|
|
||||||
|
if domain == defaultDomain && !strings.ContainsRune(remoteName, '/') {
|
||||||
|
// Canonicalize "familiar" names, but only on Docker Hub, not
|
||||||
|
// on other domains:
|
||||||
|
//
|
||||||
|
// "docker.io/ubuntu[:tag]" => "docker.io/library/ubuntu[:tag]"
|
||||||
|
remoteName = officialRepoPrefix + remoteName
|
||||||
|
}
|
||||||
|
|
||||||
|
return domain, remoteName
|
||||||
|
}
|
||||||
|
|
||||||
|
// familiarizeName returns a shortened version of the name familiar
|
||||||
|
// to the Docker UI. Familiar names have the default domain
|
||||||
|
// "docker.io" and "library/" repository prefix removed.
|
||||||
|
// For example, "docker.io/library/redis" will have the familiar
|
||||||
|
// name "redis" and "docker.io/dmcgowan/myapp" will be "dmcgowan/myapp".
|
||||||
|
// Returns a familiarized named only reference.
|
||||||
|
func familiarizeName(named namedRepository) repository {
|
||||||
|
repo := repository{
|
||||||
|
domain: named.Domain(),
|
||||||
|
path: named.Path(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if repo.domain == defaultDomain {
|
||||||
|
repo.domain = ""
|
||||||
|
// Handle official repositories which have the pattern "library/<official repo name>"
|
||||||
|
if strings.HasPrefix(repo.path, officialRepoPrefix) {
|
||||||
|
// TODO(thaJeztah): this check may be too strict, as it assumes the
|
||||||
|
// "library/" namespace does not have nested namespaces. While this
|
||||||
|
// is true (currently), technically it would be possible for Docker
|
||||||
|
// Hub to use those (e.g. "library/distros/ubuntu:latest").
|
||||||
|
// See https://github.com/distribution/distribution/pull/3769#issuecomment-1302031785.
|
||||||
|
if remainder := strings.TrimPrefix(repo.path, officialRepoPrefix); !strings.ContainsRune(remainder, '/') {
|
||||||
|
repo.path = remainder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r reference) Familiar() Named {
|
||||||
|
return reference{
|
||||||
|
namedRepository: familiarizeName(r.namedRepository),
|
||||||
|
tag: r.tag,
|
||||||
|
digest: r.digest,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r repository) Familiar() Named {
|
||||||
|
return familiarizeName(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t taggedReference) Familiar() Named {
|
||||||
|
return taggedReference{
|
||||||
|
namedRepository: familiarizeName(t.namedRepository),
|
||||||
|
tag: t.tag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c canonicalReference) Familiar() Named {
|
||||||
|
return canonicalReference{
|
||||||
|
namedRepository: familiarizeName(c.namedRepository),
|
||||||
|
digest: c.digest,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TagNameOnly adds the default tag "latest" to a reference if it only has
|
||||||
|
// a repo name.
|
||||||
|
func TagNameOnly(ref Named) Named {
|
||||||
|
if IsNameOnly(ref) {
|
||||||
|
namedTagged, err := WithTag(ref, defaultTag)
|
||||||
|
if err != nil {
|
||||||
|
// Default tag must be valid, to create a NamedTagged
|
||||||
|
// type with non-validated input the WithTag function
|
||||||
|
// should be used instead
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return namedTagged
|
||||||
|
}
|
||||||
|
return ref
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseAnyReference parses a reference string as a possible identifier,
|
||||||
|
// full digest, or familiar name.
|
||||||
|
func ParseAnyReference(ref string) (Reference, error) {
|
||||||
|
if ok := anchoredIdentifierRegexp.MatchString(ref); ok {
|
||||||
|
return digestReference("sha256:" + ref), nil
|
||||||
|
}
|
||||||
|
if dgst, err := digest.Parse(ref); err == nil {
|
||||||
|
return digestReference(dgst), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParseNormalizedNamed(ref)
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user