Add Sequential Thinking compatibility server and JWKS support
This commit is contained in:
@@ -54,7 +54,7 @@ func main() {
|
||||
log.Info().
|
||||
Str("port", cfg.Port).
|
||||
Str("mcp_url", cfg.MCPLocalURL).
|
||||
Str("version", "0.1.0-beta1").
|
||||
Str("version", "0.1.0-beta2").
|
||||
Msg("🚀 Starting Sequential Thinking Age Wrapper")
|
||||
|
||||
// Create MCP client
|
||||
|
||||
@@ -1,28 +1,6 @@
|
||||
# Sequential Thinking Age-Encrypted Wrapper
|
||||
# Beat 1: Plaintext skeleton - encryption added in Beat 2
|
||||
|
||||
# Stage 1: Build Go wrapper
|
||||
FROM golang:1.23-alpine AS go-builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk add --no-cache git make
|
||||
|
||||
# Copy go mod files
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Build the wrapper binary
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo \
|
||||
-ldflags '-w -s -extldflags "-static"' \
|
||||
-o seqthink-wrapper \
|
||||
./cmd/seqthink-wrapper
|
||||
|
||||
# Stage 2: Build Python MCP server
|
||||
# Stage 1: Build Python MCP server
|
||||
FROM python:3.11-slim AS python-builder
|
||||
|
||||
WORKDIR /mcp
|
||||
@@ -35,16 +13,17 @@ RUN pip install --no-cache-dir \
|
||||
uvicorn[standard]==0.27.0 \
|
||||
pydantic==2.5.3
|
||||
|
||||
# Copy MCP server stub (to be replaced with real implementation)
|
||||
COPY deploy/seqthink/mcp_stub.py /mcp/server.py
|
||||
# Copy MCP compatibility server
|
||||
COPY deploy/seqthink/mcp_server.py /mcp/server.py
|
||||
|
||||
# Stage 3: Runtime
|
||||
# Stage 2: Runtime
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
python3 \
|
||||
python3-pip && \
|
||||
apt-get clean && \
|
||||
@@ -59,8 +38,8 @@ RUN pip3 install --no-cache-dir --break-system-packages \
|
||||
# Create non-root user
|
||||
RUN useradd -r -u 1000 -m -s /bin/bash seqthink
|
||||
|
||||
# Copy binaries
|
||||
COPY --from=go-builder /build/seqthink-wrapper /usr/local/bin/
|
||||
# Copy wrapper binary built on host (GOWORK=off GOOS=linux go build ...)
|
||||
COPY deploy/seqthink/bin/seqthink-wrapper /usr/local/bin/seqthink-wrapper
|
||||
COPY --from=python-builder /mcp/server.py /opt/mcp/server.py
|
||||
|
||||
# Copy entrypoint
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "🚀 Starting Sequential Thinking Age Wrapper (Beat 1)"
|
||||
echo "🚀 Starting Sequential Thinking Age Wrapper"
|
||||
|
||||
# Start MCP server on loopback
|
||||
echo "📡 Starting MCP server on 127.0.0.1:8000..."
|
||||
echo "📡 Starting Sequential Thinking MCP compatibility server on 127.0.0.1:8000..."
|
||||
python3 /opt/mcp/server.py &
|
||||
MCP_PID=$!
|
||||
|
||||
|
||||
160
deploy/seqthink/mcp_server.py
Normal file
160
deploy/seqthink/mcp_server.py
Normal file
@@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Sequential Thinking MCP compatibility server (HTTP wrapper)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
import uvicorn
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
logger = logging.getLogger("seqthink")
|
||||
|
||||
|
||||
class ToolRequest(BaseModel):
|
||||
tool: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
@validator("tool")
|
||||
def validate_tool(cls, value: str) -> str:
|
||||
allowed = {
|
||||
"sequentialthinking",
|
||||
"mcp__sequential-thinking__sequentialthinking",
|
||||
}
|
||||
if value not in allowed:
|
||||
raise ValueError(f"Unknown tool '{value}'")
|
||||
return value
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ThoughtData(BaseModel):
|
||||
thought: str
|
||||
thoughtNumber: int = Field(..., ge=1)
|
||||
totalThoughts: int = Field(..., ge=1)
|
||||
nextThoughtNeeded: bool
|
||||
isRevision: Optional[bool] = False
|
||||
revisesThought: Optional[int] = Field(default=None, ge=1)
|
||||
branchFromThought: Optional[int] = Field(default=None, ge=1)
|
||||
branchId: Optional[str] = None
|
||||
needsMoreThoughts: Optional[bool] = None
|
||||
|
||||
@validator("totalThoughts")
|
||||
def normalize_total(cls, value: int, values: Dict[str, Any]) -> int:
|
||||
thought_number = values.get("thoughtNumber")
|
||||
if thought_number is not None and value < thought_number:
|
||||
return thought_number
|
||||
return value
|
||||
|
||||
|
||||
class SequentialThinkingEngine:
|
||||
"""Replicates the upstream sequential thinking MCP behaviour."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._thought_history: List[ThoughtData] = []
|
||||
self._branches: Dict[str, List[ThoughtData]] = {}
|
||||
env = os.environ.get("DISABLE_THOUGHT_LOGGING", "")
|
||||
self._disable_logging = env.lower() == "true"
|
||||
|
||||
def _record_branch(self, data: ThoughtData) -> None:
|
||||
if data.branchFromThought and data.branchId:
|
||||
self._branches.setdefault(data.branchId, []).append(data)
|
||||
|
||||
def _log_thought(self, data: ThoughtData) -> None:
|
||||
if self._disable_logging:
|
||||
return
|
||||
|
||||
header = []
|
||||
if data.isRevision:
|
||||
header.append("🔄 Revision")
|
||||
if data.revisesThought:
|
||||
header.append(f"(revising thought {data.revisesThought})")
|
||||
elif data.branchFromThought:
|
||||
header.append("🌿 Branch")
|
||||
header.append(f"(from thought {data.branchFromThought})")
|
||||
if data.branchId:
|
||||
header.append(f"[ID: {data.branchId}]")
|
||||
else:
|
||||
header.append("💭 Thought")
|
||||
|
||||
header.append(f"{data.thoughtNumber}/{data.totalThoughts}")
|
||||
header_line = " ".join(part for part in header if part)
|
||||
|
||||
border_width = max(len(header_line), len(data.thought)) + 4
|
||||
border = "─" * border_width
|
||||
message = (
|
||||
f"\n┌{border}┐\n"
|
||||
f"│ {header_line.ljust(border_width - 2)} │\n"
|
||||
f"├{border}┤\n"
|
||||
f"│ {data.thought.ljust(border_width - 2)} │\n"
|
||||
f"└{border}┘"
|
||||
)
|
||||
logger.error(message)
|
||||
|
||||
def process(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
thought = ThoughtData(**payload)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.exception("Invalid thought payload")
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({"error": str(exc)}, indent=2),
|
||||
}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
self._thought_history.append(thought)
|
||||
self._record_branch(thought)
|
||||
self._log_thought(thought)
|
||||
|
||||
response_payload = {
|
||||
"thoughtNumber": thought.thoughtNumber,
|
||||
"totalThoughts": thought.totalThoughts,
|
||||
"nextThoughtNeeded": thought.nextThoughtNeeded,
|
||||
"branches": list(self._branches.keys()),
|
||||
"thoughtHistoryLength": len(self._thought_history),
|
||||
}
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(response_payload, indent=2),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
engine = SequentialThinkingEngine()
|
||||
app = FastAPI(title="Sequential Thinking MCP Compatibility Server")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> Dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/mcp/tool")
|
||||
def call_tool(request: ToolRequest) -> ToolResponse:
|
||||
try:
|
||||
result = engine.process(request.payload)
|
||||
if result.get("isError"):
|
||||
return ToolResponse(error=result["content"][0]["text"])
|
||||
return ToolResponse(result=result)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000, log_level="info")
|
||||
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Sequential Thinking MCP Server Stub (Beat 1)
|
||||
|
||||
This is a minimal implementation for testing the wrapper infrastructure.
|
||||
In later beats, this will be replaced with the full Sequential Thinking MCP server.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Any, Optional
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI(title="Sequential Thinking MCP Server Stub")
|
||||
|
||||
|
||||
class ToolRequest(BaseModel):
|
||||
tool: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/mcp/tool")
|
||||
async def call_tool(request: ToolRequest) -> ToolResponse:
|
||||
"""
|
||||
Tool call endpoint - stub implementation
|
||||
|
||||
In Beat 1, this just echoes back the request to verify the wrapper works.
|
||||
Later beats will implement the actual Sequential Thinking logic.
|
||||
"""
|
||||
if request.tool != "mcp__sequential-thinking__sequentialthinking":
|
||||
return ToolResponse(
|
||||
error=f"Unknown tool: {request.tool}"
|
||||
)
|
||||
|
||||
# Stub response for Sequential Thinking tool
|
||||
payload = request.payload
|
||||
thought_number = payload.get("thoughtNumber", 1)
|
||||
total_thoughts = payload.get("totalThoughts", 5)
|
||||
thought = payload.get("thought", "")
|
||||
next_thought_needed = payload.get("nextThoughtNeeded", True)
|
||||
|
||||
return ToolResponse(
|
||||
result={
|
||||
"thoughtNumber": thought_number,
|
||||
"totalThoughts": total_thoughts,
|
||||
"thought": thought,
|
||||
"nextThoughtNeeded": next_thought_needed,
|
||||
"message": "Beat 1 stub - Sequential Thinking not yet implemented"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
log_level="info"
|
||||
)
|
||||
@@ -2,6 +2,7 @@ package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -38,6 +39,8 @@ type JWK struct {
|
||||
Use string `json:"use"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
X string `json:"x"`
|
||||
Crv string `json:"crv"`
|
||||
}
|
||||
|
||||
// Validator validates JWT tokens
|
||||
@@ -45,7 +48,7 @@ type Validator struct {
|
||||
jwksURL string
|
||||
requiredScope string
|
||||
httpClient *http.Client
|
||||
keys map[string]*rsa.PublicKey
|
||||
keys map[string]interface{}
|
||||
keysMutex sync.RWMutex
|
||||
lastFetch time.Time
|
||||
cacheDuration time.Duration
|
||||
@@ -59,7 +62,7 @@ func NewValidator(jwksURL, requiredScope string) *Validator {
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
keys: make(map[string]*rsa.PublicKey),
|
||||
keys: make(map[string]interface{}),
|
||||
cacheDuration: 1 * time.Hour, // Cache JWKS for 1 hour
|
||||
}
|
||||
}
|
||||
@@ -68,11 +71,6 @@ func NewValidator(jwksURL, requiredScope string) *Validator {
|
||||
func (v *Validator) ValidateToken(tokenString string) (*Claims, error) {
|
||||
// Parse token
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify signing algorithm
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// Get key ID from header
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
@@ -85,7 +83,22 @@ func (v *Validator) ValidateToken(tokenString string) (*Claims, error) {
|
||||
return nil, fmt.Errorf("get public key: %w", err)
|
||||
}
|
||||
|
||||
return publicKey, nil
|
||||
switch token.Method.(type) {
|
||||
case *jwt.SigningMethodRSA, *jwt.SigningMethodRSAPSS:
|
||||
rsaKey, ok := publicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected RSA public key for kid %s", kid)
|
||||
}
|
||||
return rsaKey, nil
|
||||
case *jwt.SigningMethodEd25519:
|
||||
edKey, ok := publicKey.(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected Ed25519 public key for kid %s", kid)
|
||||
}
|
||||
return edKey, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -140,7 +153,7 @@ func (v *Validator) hasRequiredScope(claims *Claims) bool {
|
||||
}
|
||||
|
||||
// getPublicKey retrieves a public key by kid, fetching JWKS if needed
|
||||
func (v *Validator) getPublicKey(kid string) (*rsa.PublicKey, error) {
|
||||
func (v *Validator) getPublicKey(kid string) (interface{}, error) {
|
||||
// Check if cache is expired
|
||||
v.keysMutex.RLock()
|
||||
cacheExpired := time.Since(v.lastFetch) > v.cacheDuration
|
||||
@@ -201,20 +214,30 @@ func (v *Validator) fetchJWKS() error {
|
||||
}
|
||||
|
||||
// Parse and cache all keys
|
||||
newKeys := make(map[string]*rsa.PublicKey)
|
||||
newKeys := make(map[string]interface{})
|
||||
for _, jwk := range jwks.Keys {
|
||||
if jwk.Kty != "RSA" {
|
||||
log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping non-RSA key")
|
||||
continue
|
||||
}
|
||||
|
||||
switch jwk.Kty {
|
||||
case "RSA":
|
||||
publicKey, err := jwk.toRSAPublicKey()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse JWK")
|
||||
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse RSA JWK")
|
||||
continue
|
||||
}
|
||||
|
||||
newKeys[jwk.Kid] = publicKey
|
||||
case "OKP":
|
||||
if strings.EqualFold(jwk.Crv, "Ed25519") {
|
||||
publicKey, err := jwk.toEd25519PublicKey()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("kid", jwk.Kid).Msg("Failed to parse Ed25519 JWK")
|
||||
continue
|
||||
}
|
||||
newKeys[jwk.Kid] = publicKey
|
||||
} else {
|
||||
log.Warn().Str("kid", jwk.Kid).Str("crv", jwk.Crv).Msg("Skipping unsupported OKP curve")
|
||||
}
|
||||
default:
|
||||
log.Warn().Str("kid", jwk.Kid).Str("kty", jwk.Kty).Msg("Skipping unsupported key type")
|
||||
}
|
||||
}
|
||||
|
||||
if len(newKeys) == 0 {
|
||||
@@ -261,6 +284,24 @@ func (jwk *JWK) toRSAPublicKey() (*rsa.PublicKey, error) {
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
// toEd25519PublicKey converts a JWK to an Ed25519 public key
|
||||
func (jwk *JWK) toEd25519PublicKey() (ed25519.PublicKey, error) {
|
||||
if jwk.X == "" {
|
||||
return nil, fmt.Errorf("missing x coordinate for Ed25519 key")
|
||||
}
|
||||
|
||||
xBytes, err := base64URLDecode(jwk.X)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode x: %w", err)
|
||||
}
|
||||
|
||||
if len(xBytes) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("invalid Ed25519 public key length: expected %d, got %d", ed25519.PublicKeySize, len(xBytes))
|
||||
}
|
||||
|
||||
return ed25519.PublicKey(xBytes), nil
|
||||
}
|
||||
|
||||
// parseScopes splits a space-separated scope string
|
||||
func parseScopes(scopeString string) []string {
|
||||
if scopeString == "" {
|
||||
|
||||
Reference in New Issue
Block a user