Add Sequential Thinking compatibility server and JWKS support

This commit is contained in:
anthonyrawlins
2025-10-13 17:04:00 +11:00
parent c99def17d7
commit 2fd9a96950
6 changed files with 232 additions and 122 deletions

View File

@@ -54,7 +54,7 @@ func main() {
log.Info().
Str("port", cfg.Port).
Str("mcp_url", cfg.MCPLocalURL).
Str("version", "0.1.0-beta1").
Str("version", "0.1.0-beta2").
Msg("🚀 Starting Sequential Thinking Age Wrapper")
// Create MCP client

View File

@@ -1,28 +1,6 @@
# Sequential Thinking Age-Encrypted Wrapper
# Beat 1: Plaintext skeleton - encryption added in Beat 2
# Stage 1: Build Go wrapper
FROM golang:1.23-alpine AS go-builder
WORKDIR /build
# Install build dependencies
RUN apk add --no-cache git make
# Copy go mod files
COPY go.mod go.sum ./
RUN go mod download
# Copy source code
COPY . .
# Build the wrapper binary
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo \
-ldflags '-w -s -extldflags "-static"' \
-o seqthink-wrapper \
./cmd/seqthink-wrapper
# Stage 2: Build Python MCP server
# Stage 1: Build Python MCP server
FROM python:3.11-slim AS python-builder
WORKDIR /mcp
@@ -35,16 +13,17 @@ RUN pip install --no-cache-dir \
uvicorn[standard]==0.27.0 \
pydantic==2.5.3
# Copy MCP server stub (to be replaced with real implementation)
COPY deploy/seqthink/mcp_stub.py /mcp/server.py
# Copy MCP compatibility server
COPY deploy/seqthink/mcp_server.py /mcp/server.py
# Stage 3: Runtime
# Stage 2: Runtime
FROM debian:bookworm-slim
# Install runtime dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ca-certificates \
curl \
python3 \
python3-pip && \
apt-get clean && \
@@ -59,8 +38,8 @@ RUN pip3 install --no-cache-dir --break-system-packages \
# Create non-root user
RUN useradd -r -u 1000 -m -s /bin/bash seqthink
# Copy binaries
COPY --from=go-builder /build/seqthink-wrapper /usr/local/bin/
# Copy wrapper binary built on host (GOWORK=off GOOS=linux go build ...)
COPY deploy/seqthink/bin/seqthink-wrapper /usr/local/bin/seqthink-wrapper
COPY --from=python-builder /mcp/server.py /opt/mcp/server.py
# Copy entrypoint

View File

@@ -1,10 +1,10 @@
#!/bin/bash
set -e
echo "🚀 Starting Sequential Thinking Age Wrapper (Beat 1)"
echo "🚀 Starting Sequential Thinking Age Wrapper"
# Start MCP server on loopback
echo "📡 Starting MCP server on 127.0.0.1:8000..."
echo "📡 Starting Sequential Thinking MCP compatibility server on 127.0.0.1:8000..."
python3 /opt/mcp/server.py &
MCP_PID=$!

View File

@@ -0,0 +1,160 @@
#!/usr/bin/env python3
"""Sequential Thinking MCP compatibility server (HTTP wrapper)."""
from __future__ import annotations
import json
import logging
import os
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException
import uvicorn
from pydantic import BaseModel, Field, validator
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger("seqthink")
class ToolRequest(BaseModel):
tool: str
payload: Dict[str, Any]
@validator("tool")
def validate_tool(cls, value: str) -> str:
allowed = {
"sequentialthinking",
"mcp__sequential-thinking__sequentialthinking",
}
if value not in allowed:
raise ValueError(f"Unknown tool '{value}'")
return value
class ToolResponse(BaseModel):
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
class ThoughtData(BaseModel):
thought: str
thoughtNumber: int = Field(..., ge=1)
totalThoughts: int = Field(..., ge=1)
nextThoughtNeeded: bool
isRevision: Optional[bool] = False
revisesThought: Optional[int] = Field(default=None, ge=1)
branchFromThought: Optional[int] = Field(default=None, ge=1)
branchId: Optional[str] = None
needsMoreThoughts: Optional[bool] = None
@validator("totalThoughts")
def normalize_total(cls, value: int, values: Dict[str, Any]) -> int:
thought_number = values.get("thoughtNumber")
if thought_number is not None and value < thought_number:
return thought_number
return value
class SequentialThinkingEngine:
"""Replicates the upstream sequential thinking MCP behaviour."""
def __init__(self) -> None:
self._thought_history: List[ThoughtData] = []
self._branches: Dict[str, List[ThoughtData]] = {}
env = os.environ.get("DISABLE_THOUGHT_LOGGING", "")
self._disable_logging = env.lower() == "true"
def _record_branch(self, data: ThoughtData) -> None:
if data.branchFromThought and data.branchId:
self._branches.setdefault(data.branchId, []).append(data)
def _log_thought(self, data: ThoughtData) -> None:
if self._disable_logging:
return
header = []
if data.isRevision:
header.append("🔄 Revision")
if data.revisesThought:
header.append(f"(revising thought {data.revisesThought})")
elif data.branchFromThought:
header.append("🌿 Branch")
header.append(f"(from thought {data.branchFromThought})")
if data.branchId:
header.append(f"[ID: {data.branchId}]")
else:
header.append("💭 Thought")
header.append(f"{data.thoughtNumber}/{data.totalThoughts}")
header_line = " ".join(part for part in header if part)
border_width = max(len(header_line), len(data.thought)) + 4
border = "" * border_width
message = (
f"\n{border}\n"
f"{header_line.ljust(border_width - 2)}\n"
f"{border}\n"
f"{data.thought.ljust(border_width - 2)}\n"
f"{border}"
)
logger.error(message)
def process(self, payload: Dict[str, Any]) -> Dict[str, Any]:
try:
thought = ThoughtData(**payload)
except Exception as exc: # pylint: disable=broad-except
logger.exception("Invalid thought payload")
return {
"content": [
{
"type": "text",
"text": json.dumps({"error": str(exc)}, indent=2),
}
],
"isError": True,
}
self._thought_history.append(thought)
self._record_branch(thought)
self._log_thought(thought)
response_payload = {
"thoughtNumber": thought.thoughtNumber,
"totalThoughts": thought.totalThoughts,
"nextThoughtNeeded": thought.nextThoughtNeeded,
"branches": list(self._branches.keys()),
"thoughtHistoryLength": len(self._thought_history),
}
return {
"content": [
{
"type": "text",
"text": json.dumps(response_payload, indent=2),
}
]
}
engine = SequentialThinkingEngine()
app = FastAPI(title="Sequential Thinking MCP Compatibility Server")
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
@app.post("/mcp/tool")
def call_tool(request: ToolRequest) -> ToolResponse:
try:
result = engine.process(request.payload)
if result.get("isError"):
return ToolResponse(error=result["content"][0]["text"])
return ToolResponse(result=result)
except Exception as exc: # pylint: disable=broad-except
raise HTTPException(status_code=400, detail=str(exc)) from exc
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8000, log_level="info")

View File

@@ -1,70 +0,0 @@
#!/usr/bin/env python3
"""
Sequential Thinking MCP Server Stub (Beat 1)
This is a minimal implementation for testing the wrapper infrastructure.
In later beats, this will be replaced with the full Sequential Thinking MCP server.
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, Any, Optional
import uvicorn
app = FastAPI(title="Sequential Thinking MCP Server Stub")
class ToolRequest(BaseModel):
tool: str
payload: Dict[str, Any]
class ToolResponse(BaseModel):
result: Optional[Any] = None
error: Optional[str] = None
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@app.post("/mcp/tool")
async def call_tool(request: ToolRequest) -> ToolResponse:
"""
Tool call endpoint - stub implementation
In Beat 1, this just echoes back the request to verify the wrapper works.
Later beats will implement the actual Sequential Thinking logic.
"""
if request.tool != "mcp__sequential-thinking__sequentialthinking":
return ToolResponse(
error=f"Unknown tool: {request.tool}"
)
# Stub response for Sequential Thinking tool
payload = request.payload
thought_number = payload.get("thoughtNumber", 1)
total_thoughts = payload.get("totalThoughts", 5)
thought = payload.get("thought", "")
next_thought_needed = payload.get("nextThoughtNeeded", True)
return ToolResponse(
result={
"thoughtNumber": thought_number,
"totalThoughts": total_thoughts,
"thought": thought,
"nextThoughtNeeded": next_thought_needed,
"message": "Beat 1 stub - Sequential Thinking not yet implemented"
}
)
if __name__ == "__main__":
uvicorn.run(
app,
host="127.0.0.1",
port=8000,
log_level="info"
)

View File

@@ -2,6 +2,7 @@ package policy
import (
"context"
"crypto/ed25519"
"crypto/rsa"
"encoding/base64"
"encoding/json"
@@ -38,6 +39,8 @@ type JWK struct {
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
X string `json:"x"`
Crv string `json:"crv"`
}
// Validator validates JWT tokens
@@ -45,7 +48,7 @@ type Validator struct {
jwksURL string
requiredScope string
httpClient *http.Client
keys map[string]*rsa.PublicKey
keys map[string]interface{}
keysMutex sync.RWMutex
lastFetch time.Time
cacheDuration time.Duration
@@ -59,7 +62,7 @@ func NewValidator(jwksURL, requiredScope string) *Validator {
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
keys: make(map[string]*rsa.PublicKey),
keys: make(map[string]interface{}),
cacheDuration: 1 * time.Hour, // Cache JWKS for 1 hour
}
}
@@ -68,11 +71,6 @@ func NewValidator(jwksURL, requiredScope string) *Validator {
func (v *Validator) ValidateToken(tokenString string) (*Claims, error) {
// Parse token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
// Verify signing algorithm
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// Get key ID from header
kid, ok := token.Header["kid"].(string)
if !ok {
@@ -85,7 +83,22 @@ func (v *Validator) ValidateToken(tokenString string) (*Claims, error) {
return nil, fmt.Errorf("get public key: %w", err)
}
return publicKey, nil
switch token.Method.(type) {
case *jwt.SigningMethodRSA, *jwt.SigningMethodRSAPSS:
rsaKey, ok := publicKey.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("expected RSA public key for kid %s", kid)
}
return rsaKey, nil
case *jwt.SigningMethodEd25519:
edKey, ok := publicKey.(ed25519.PublicKey)
if !ok {
return nil, fmt.Errorf("expected Ed25519 public key for kid %s", kid)
}
return edKey, nil
default:
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
})
if err != nil {
@@ -140,7 +153,7 @@ func (v *Validator) hasRequiredScope(claims *Claims) bool {
}
// getPublicKey retrieves a public key by kid, fetching JWKS if needed
func (v *Validator) getPublicKey(kid string) (*rsa.PublicKey, error) {
func (v *Validator) getPublicKey(kid string) (interface{}, error) {
// Check if cache is expired
v.keysMutex.RLock()
cacheExpired := time.Since(v.lastFetch) > v.cacheDuration
@@ -201,20 +214,30 @@ func (v *Validator) fetchJWKS() error {
}
// Parse and cache all keys
newKeys := make(map[string]*rsa.PublicKey)
newKeys := make(map[string]interface{})
for _, jwk := range jwks.Keys {
if jwk.Kty != "RSA" {
log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping non-RSA key")
continue
}
switch jwk.Kty {
case "RSA":
publicKey, err := jwk.toRSAPublicKey()
if err != nil {
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse JWK")
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse RSA JWK")
continue
}
newKeys[jwk.Kid] = publicKey
case "OKP":
if strings.EqualFold(jwk.Crv, "Ed25519") {
publicKey, err := jwk.toEd25519PublicKey()
if err != nil {
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse Ed25519 JWK")
continue
}
newKeys[jwk.Kid] = publicKey
} else {
log.Warn().Str("kid", jwk.Kid).Str("crv", jwk.Crv).Msg("Skipping unsupported OKP curve")
}
default:
log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping unsupported key type")
}
}
if len(newKeys) == 0 {
@@ -261,6 +284,24 @@ func (jwk *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
return publicKey, nil
}
// toEd25519PublicKey converts a JWK to an Ed25519 public key
func (jwk *JWK) toEd25519PublicKey() (ed25519.PublicKey, error) {
if jwk.X == "" {
return nil, fmt.Errorf("missing x coordinate for Ed25519 key")
}
xBytes, err := base64URLDecode(jwk.X)
if err != nil {
return nil, fmt.Errorf("decode x: %w", err)
}
if len(xBytes) != ed25519.PublicKeySize {
return nil, fmt.Errorf("invalid Ed25519 public key length: expected %d, got %d", ed25519.PublicKeySize, len(xBytes))
}
return ed25519.PublicKey(xBytes), nil
}
// parseScopes splits a space-separated scope string
func parseScopes(scopeString string) []string {
if scopeString == "" {