From 2fd9a969505b314e3cf8618fdae581cba984bf13 Mon Sep 17 00:00:00 2001 From: anthonyrawlins Date: Mon, 13 Oct 2025 17:04:00 +1100 Subject: [PATCH] Add Sequential Thinking compatibility server and JWKS support --- cmd/seqthink-wrapper/main.go | 2 +- deploy/seqthink/Dockerfile | 35 ++------ deploy/seqthink/entrypoint.sh | 4 +- deploy/seqthink/mcp_server.py | 160 ++++++++++++++++++++++++++++++++++ deploy/seqthink/mcp_stub.py | 70 --------------- pkg/seqthink/policy/jwt.go | 83 +++++++++++++----- 6 files changed, 232 insertions(+), 122 deletions(-) create mode 100644 deploy/seqthink/mcp_server.py delete mode 100644 deploy/seqthink/mcp_stub.py diff --git a/cmd/seqthink-wrapper/main.go b/cmd/seqthink-wrapper/main.go index 6d206f3..c616745 100644 --- a/cmd/seqthink-wrapper/main.go +++ b/cmd/seqthink-wrapper/main.go @@ -54,7 +54,7 @@ func main() { log.Info(). Str("port", cfg.Port). Str("mcp_url", cfg.MCPLocalURL). - Str("version", "0.1.0-beta1"). + Str("version", "0.1.0-beta2"). Msg("šŸš€ Starting Sequential Thinking Age Wrapper") // Create MCP client diff --git a/deploy/seqthink/Dockerfile b/deploy/seqthink/Dockerfile index f66c01f..9c76a35 100644 --- a/deploy/seqthink/Dockerfile +++ b/deploy/seqthink/Dockerfile @@ -1,28 +1,6 @@ # Sequential Thinking Age-Encrypted Wrapper -# Beat 1: Plaintext skeleton - encryption added in Beat 2 -# Stage 1: Build Go wrapper -FROM golang:1.23-alpine AS go-builder - -WORKDIR /build - -# Install build dependencies -RUN apk add --no-cache git make - -# Copy go mod files -COPY go.mod go.sum ./ -RUN go mod download - -# Copy source code -COPY . . - -# Build the wrapper binary -RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo \ - -ldflags '-w -s -extldflags "-static"' \ - -o seqthink-wrapper \ - ./cmd/seqthink-wrapper - -# Stage 2: Build Python MCP server +# Stage 1: Build Python MCP server FROM python:3.11-slim AS python-builder WORKDIR /mcp @@ -35,16 +13,17 @@ RUN pip install --no-cache-dir \ uvicorn[standard]==0.27.0 \ pydantic==2.5.3 -# Copy MCP server stub (to be replaced with real implementation) -COPY deploy/seqthink/mcp_stub.py /mcp/server.py +# Copy MCP compatibility server +COPY deploy/seqthink/mcp_server.py /mcp/server.py -# Stage 3: Runtime +# Stage 2: Runtime FROM debian:bookworm-slim # Install runtime dependencies RUN apt-get update && \ apt-get install -y --no-install-recommends \ ca-certificates \ + curl \ python3 \ python3-pip && \ apt-get clean && \ @@ -59,8 +38,8 @@ RUN pip3 install --no-cache-dir --break-system-packages \ # Create non-root user RUN useradd -r -u 1000 -m -s /bin/bash seqthink -# Copy binaries -COPY --from=go-builder /build/seqthink-wrapper /usr/local/bin/ +# Copy wrapper binary built on host (GOWORK=off GOOS=linux go build ...) +COPY deploy/seqthink/bin/seqthink-wrapper /usr/local/bin/seqthink-wrapper COPY --from=python-builder /mcp/server.py /opt/mcp/server.py # Copy entrypoint diff --git a/deploy/seqthink/entrypoint.sh b/deploy/seqthink/entrypoint.sh index 34a1840..1f65732 100644 --- a/deploy/seqthink/entrypoint.sh +++ b/deploy/seqthink/entrypoint.sh @@ -1,10 +1,10 @@ #!/bin/bash set -e -echo "šŸš€ Starting Sequential Thinking Age Wrapper (Beat 1)" +echo "šŸš€ Starting Sequential Thinking Age Wrapper" # Start MCP server on loopback -echo "šŸ“” Starting MCP server on 127.0.0.1:8000..." +echo "šŸ“” Starting Sequential Thinking MCP compatibility server on 127.0.0.1:8000..." python3 /opt/mcp/server.py & MCP_PID=$! diff --git a/deploy/seqthink/mcp_server.py b/deploy/seqthink/mcp_server.py new file mode 100644 index 0000000..812f4cc --- /dev/null +++ b/deploy/seqthink/mcp_server.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +"""Sequential Thinking MCP compatibility server (HTTP wrapper).""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any, Dict, List, Optional + +from fastapi import FastAPI, HTTPException +import uvicorn +from pydantic import BaseModel, Field, validator + +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger("seqthink") + + +class ToolRequest(BaseModel): + tool: str + payload: Dict[str, Any] + + @validator("tool") + def validate_tool(cls, value: str) -> str: + allowed = { + "sequentialthinking", + "mcp__sequential-thinking__sequentialthinking", + } + if value not in allowed: + raise ValueError(f"Unknown tool '{value}'") + return value + + +class ToolResponse(BaseModel): + result: Optional[Dict[str, Any]] = None + error: Optional[str] = None + + +class ThoughtData(BaseModel): + thought: str + thoughtNumber: int = Field(..., ge=1) + totalThoughts: int = Field(..., ge=1) + nextThoughtNeeded: bool + isRevision: Optional[bool] = False + revisesThought: Optional[int] = Field(default=None, ge=1) + branchFromThought: Optional[int] = Field(default=None, ge=1) + branchId: Optional[str] = None + needsMoreThoughts: Optional[bool] = None + + @validator("totalThoughts") + def normalize_total(cls, value: int, values: Dict[str, Any]) -> int: + thought_number = values.get("thoughtNumber") + if thought_number is not None and value < thought_number: + return thought_number + return value + + +class SequentialThinkingEngine: + """Replicates the upstream sequential thinking MCP behaviour.""" + + def __init__(self) -> None: + self._thought_history: List[ThoughtData] = [] + self._branches: Dict[str, List[ThoughtData]] = {} + env = os.environ.get("DISABLE_THOUGHT_LOGGING", "") + self._disable_logging = env.lower() == "true" + + def _record_branch(self, data: ThoughtData) -> None: + if data.branchFromThought and data.branchId: + self._branches.setdefault(data.branchId, []).append(data) + + def _log_thought(self, data: ThoughtData) -> None: + if self._disable_logging: + return + + header = [] + if data.isRevision: + header.append("šŸ”„ Revision") + if data.revisesThought: + header.append(f"(revising thought {data.revisesThought})") + elif data.branchFromThought: + header.append("🌿 Branch") + header.append(f"(from thought {data.branchFromThought})") + if data.branchId: + header.append(f"[ID: {data.branchId}]") + else: + header.append("šŸ’­ Thought") + + header.append(f"{data.thoughtNumber}/{data.totalThoughts}") + header_line = " ".join(part for part in header if part) + + border_width = max(len(header_line), len(data.thought)) + 4 + border = "─" * border_width + message = ( + f"\nā”Œ{border}┐\n" + f"│ {header_line.ljust(border_width - 2)} │\n" + f"ā”œ{border}┤\n" + f"│ {data.thought.ljust(border_width - 2)} │\n" + f"ā””{border}ā”˜" + ) + logger.error(message) + + def process(self, payload: Dict[str, Any]) -> Dict[str, Any]: + try: + thought = ThoughtData(**payload) + except Exception as exc: # pylint: disable=broad-except + logger.exception("Invalid thought payload") + return { + "content": [ + { + "type": "text", + "text": json.dumps({"error": str(exc)}, indent=2), + } + ], + "isError": True, + } + + self._thought_history.append(thought) + self._record_branch(thought) + self._log_thought(thought) + + response_payload = { + "thoughtNumber": thought.thoughtNumber, + "totalThoughts": thought.totalThoughts, + "nextThoughtNeeded": thought.nextThoughtNeeded, + "branches": list(self._branches.keys()), + "thoughtHistoryLength": len(self._thought_history), + } + + return { + "content": [ + { + "type": "text", + "text": json.dumps(response_payload, indent=2), + } + ] + } + + +engine = SequentialThinkingEngine() +app = FastAPI(title="Sequential Thinking MCP Compatibility Server") + + +@app.get("/health") +def health() -> Dict[str, str]: + return {"status": "ok"} + + +@app.post("/mcp/tool") +def call_tool(request: ToolRequest) -> ToolResponse: + try: + result = engine.process(request.payload) + if result.get("isError"): + return ToolResponse(error=result["content"][0]["text"]) + return ToolResponse(result=result) + except Exception as exc: # pylint: disable=broad-except + raise HTTPException(status_code=400, detail=str(exc)) from exc + + +if __name__ == "__main__": + uvicorn.run(app, host="127.0.0.1", port=8000, log_level="info") diff --git a/deploy/seqthink/mcp_stub.py b/deploy/seqthink/mcp_stub.py deleted file mode 100644 index 931e9a8..0000000 --- a/deploy/seqthink/mcp_stub.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -""" -Sequential Thinking MCP Server Stub (Beat 1) - -This is a minimal implementation for testing the wrapper infrastructure. -In later beats, this will be replaced with the full Sequential Thinking MCP server. -""" - -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from typing import Dict, Any, Optional -import uvicorn - -app = FastAPI(title="Sequential Thinking MCP Server Stub") - - -class ToolRequest(BaseModel): - tool: str - payload: Dict[str, Any] - - -class ToolResponse(BaseModel): - result: Optional[Any] = None - error: Optional[str] = None - - -@app.get("/health") -async def health(): - """Health check endpoint""" - return {"status": "ok"} - - -@app.post("/mcp/tool") -async def call_tool(request: ToolRequest) -> ToolResponse: - """ - Tool call endpoint - stub implementation - - In Beat 1, this just echoes back the request to verify the wrapper works. - Later beats will implement the actual Sequential Thinking logic. - """ - if request.tool != "mcp__sequential-thinking__sequentialthinking": - return ToolResponse( - error=f"Unknown tool: {request.tool}" - ) - - # Stub response for Sequential Thinking tool - payload = request.payload - thought_number = payload.get("thoughtNumber", 1) - total_thoughts = payload.get("totalThoughts", 5) - thought = payload.get("thought", "") - next_thought_needed = payload.get("nextThoughtNeeded", True) - - return ToolResponse( - result={ - "thoughtNumber": thought_number, - "totalThoughts": total_thoughts, - "thought": thought, - "nextThoughtNeeded": next_thought_needed, - "message": "Beat 1 stub - Sequential Thinking not yet implemented" - } - ) - - -if __name__ == "__main__": - uvicorn.run( - app, - host="127.0.0.1", - port=8000, - log_level="info" - ) diff --git a/pkg/seqthink/policy/jwt.go b/pkg/seqthink/policy/jwt.go index 025c3eb..ac3a391 100644 --- a/pkg/seqthink/policy/jwt.go +++ b/pkg/seqthink/policy/jwt.go @@ -2,6 +2,7 @@ package policy import ( "context" + "crypto/ed25519" "crypto/rsa" "encoding/base64" "encoding/json" @@ -38,6 +39,8 @@ type JWK struct { Use string `json:"use"` N string `json:"n"` E string `json:"e"` + X string `json:"x"` + Crv string `json:"crv"` } // Validator validates JWT tokens @@ -45,7 +48,7 @@ type Validator struct { jwksURL string requiredScope string httpClient *http.Client - keys map[string]*rsa.PublicKey + keys map[string]interface{} keysMutex sync.RWMutex lastFetch time.Time cacheDuration time.Duration @@ -59,7 +62,7 @@ func NewValidator(jwksURL, requiredScope string) *Validator { httpClient: &http.Client{ Timeout: 10 * time.Second, }, - keys: make(map[string]*rsa.PublicKey), + keys: make(map[string]interface{}), cacheDuration: 1 * time.Hour, // Cache JWKS for 1 hour } } @@ -68,11 +71,6 @@ func NewValidator(jwksURL, requiredScope string) *Validator { func (v *Validator) ValidateToken(tokenString string) (*Claims, error) { // Parse token token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { - // Verify signing algorithm - if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - // Get key ID from header kid, ok := token.Header["kid"].(string) if !ok { @@ -85,7 +83,22 @@ func (v *Validator) ValidateToken(tokenString string) (*Claims, error) { return nil, fmt.Errorf("get public key: %w", err) } - return publicKey, nil + switch token.Method.(type) { + case *jwt.SigningMethodRSA, *jwt.SigningMethodRSAPSS: + rsaKey, ok := publicKey.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("expected RSA public key for kid %s", kid) + } + return rsaKey, nil + case *jwt.SigningMethodEd25519: + edKey, ok := publicKey.(ed25519.PublicKey) + if !ok { + return nil, fmt.Errorf("expected Ed25519 public key for kid %s", kid) + } + return edKey, nil + default: + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } }) if err != nil { @@ -140,7 +153,7 @@ func (v *Validator) hasRequiredScope(claims *Claims) bool { } // getPublicKey retrieves a public key by kid, fetching JWKS if needed -func (v *Validator) getPublicKey(kid string) (*rsa.PublicKey, error) { +func (v *Validator) getPublicKey(kid string) (interface{}, error) { // Check if cache is expired v.keysMutex.RLock() cacheExpired := time.Since(v.lastFetch) > v.cacheDuration @@ -201,20 +214,30 @@ func (v *Validator) fetchJWKS() error { } // Parse and cache all keys - newKeys := make(map[string]*rsa.PublicKey) + newKeys := make(map[string]interface{}) for _, jwk := range jwks.Keys { - if jwk.Kty != "RSA" { - log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping non-RSA key") - continue + switch jwk.Kty { + case "RSA": + publicKey, err := jwk.toRSAPublicKey() + if err != nil { + log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse RSA JWK") + continue + } + newKeys[jwk.Kid] = publicKey + case "OKP": + if strings.EqualFold(jwk.Crv, "Ed25519") { + publicKey, err := jwk.toEd25519PublicKey() + if err != nil { + log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse Ed25519 JWK") + continue + } + newKeys[jwk.Kid] = publicKey + } else { + log.Warn().Str("kid", jwk.Kid).Str("crv", jwk.Crv).Msg("Skipping unsupported OKP curve") + } + default: + log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping unsupported key type") } - - publicKey, err := jwk.toRSAPublicKey() - if err != nil { - log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse JWK") - continue - } - - newKeys[jwk.Kid] = publicKey } if len(newKeys) == 0 { @@ -261,6 +284,24 @@ func (jwk *JWK) toRSAPublicKey() (*rsa.PublicKey, error) { return publicKey, nil } +// toEd25519PublicKey converts a JWK to an Ed25519 public key +func (jwk *JWK) toEd25519PublicKey() (ed25519.PublicKey, error) { + if jwk.X == "" { + return nil, fmt.Errorf("missing x coordinate for Ed25519 key") + } + + xBytes, err := base64URLDecode(jwk.X) + if err != nil { + return nil, fmt.Errorf("decode x: %w", err) + } + + if len(xBytes) != ed25519.PublicKeySize { + return nil, fmt.Errorf("invalid Ed25519 public key length: expected %d, got %d", ed25519.PublicKeySize, len(xBytes)) + } + + return ed25519.PublicKey(xBytes), nil +} + // parseScopes splits a space-separated scope string func parseScopes(scopeString string) []string { if scopeString == "" {