Files
HCFS/hcfs-python/hcfs/api/middleware.py
2025-07-30 09:34:16 +10:00

365 lines
13 KiB
Python

"""
Custom middleware for HCFS API.
Provides authentication, logging, error handling, and security features.
"""
import time
import uuid
import json
from typing import Optional
from datetime import datetime, timedelta
from fastapi import Request, Response, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import jwt
import structlog
logger = structlog.get_logger()
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware for comprehensive request/response logging."""
def __init__(self, app, log_body: bool = False):
super().__init__(app)
self.log_body = log_body
async def dispatch(self, request: Request, call_next):
# Generate request ID
request_id = str(uuid.uuid4())
request.state.request_id = request_id
# Start timing
start_time = time.time()
# Log request
logger.info(
"Request started",
request_id=request_id,
method=request.method,
url=str(request.url),
client_ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
)
# Call the next middleware/endpoint
response = await call_next(request)
# Calculate duration
duration = time.time() - start_time
# Log response
logger.info(
"Request completed",
request_id=request_id,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2),
)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
return response
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
"""Middleware for consistent error handling and formatting."""
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
return response
except HTTPException as e:
# FastAPI HTTPExceptions are handled by FastAPI itself
raise e
except Exception as e:
# Log unexpected errors
request_id = getattr(request.state, 'request_id', 'unknown')
logger.error(
"Unhandled exception",
request_id=request_id,
error=str(e),
error_type=type(e).__name__,
method=request.method,
url=str(request.url),
exc_info=True
)
# Return consistent error response
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "Internal server error",
"error_details": [{"message": "An unexpected error occurred"}],
"timestamp": datetime.utcnow().isoformat(),
"request_id": request_id,
"api_version": "v1"
}
)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Middleware to add security headers."""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Add security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
return response
class JWTAuthenticationManager:
"""JWT-based authentication manager."""
def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 30):
self.secret_key = secret_key
self.algorithm = algorithm
self.token_expire_minutes = token_expire_minutes
def create_access_token(self, data: dict) -> str:
"""Create JWT access token."""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=self.token_expire_minutes)
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
def verify_token(self, token: str) -> Optional[dict]:
"""Verify and decode JWT token."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
except jwt.JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
class APIKeyManager:
"""API key-based authentication manager."""
def __init__(self):
# In production, store these in a database
self.api_keys = {
"dev-key-123": {
"name": "Development Key",
"scopes": ["read", "write"],
"rate_limit": 1000,
"created_at": datetime.utcnow(),
"last_used": None
}
}
def validate_api_key(self, api_key: str) -> Optional[dict]:
"""Validate API key and return key info."""
key_info = self.api_keys.get(api_key)
if key_info:
# Update last used timestamp
key_info["last_used"] = datetime.utcnow()
return key_info
return None
class AuthenticationMiddleware(BaseHTTPMiddleware):
"""Authentication middleware supporting multiple auth methods."""
def __init__(self, app, jwt_manager: JWTAuthenticationManager = None, api_key_manager: APIKeyManager = None):
super().__init__(app)
self.jwt_manager = jwt_manager
self.api_key_manager = api_key_manager or APIKeyManager()
# Paths that don't require authentication
self.public_paths = {
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/favicon.ico"
}
async def dispatch(self, request: Request, call_next):
# Skip authentication for public paths
if any(request.url.path.startswith(path) for path in self.public_paths):
return await call_next(request)
# Extract authentication credentials
auth_header = request.headers.get("Authorization")
api_key_header = request.headers.get("X-API-Key")
user_info = None
# Try JWT authentication first
if auth_header and auth_header.startswith("Bearer ") and self.jwt_manager:
token = auth_header[7:] # Remove "Bearer " prefix
try:
payload = self.jwt_manager.verify_token(token)
user_info = {
"user_id": payload.get("sub"),
"username": payload.get("username"),
"scopes": payload.get("scopes", []),
"auth_method": "jwt"
}
except HTTPException:
pass # Try other auth methods
# Try API key authentication
if not user_info and api_key_header:
key_info = self.api_key_manager.validate_api_key(api_key_header)
if key_info:
user_info = {
"user_id": f"api_key_{api_key_header[:8]}",
"username": key_info["name"],
"scopes": key_info["scopes"],
"auth_method": "api_key",
"rate_limit": key_info["rate_limit"]
}
# If no valid authentication found
if not user_info:
return JSONResponse(
status_code=401,
content={
"success": False,
"error": "Authentication required",
"error_details": [{"message": "Valid API key or JWT token required"}],
"timestamp": datetime.utcnow().isoformat(),
"api_version": "v1"
}
)
# Add user info to request state
request.state.user = user_info
return await call_next(request)
class RateLimitingMiddleware(BaseHTTPMiddleware):
"""Custom rate limiting middleware."""
def __init__(self, app, default_rate_limit: int = 100):
super().__init__(app)
self.default_rate_limit = default_rate_limit
self.request_counts = {} # In production, use Redis
async def dispatch(self, request: Request, call_next):
# Get user identifier
user_info = getattr(request.state, 'user', None)
if user_info:
user_id = user_info["user_id"]
rate_limit = user_info.get("rate_limit", self.default_rate_limit)
else:
user_id = request.client.host if request.client else "anonymous"
rate_limit = self.default_rate_limit
# Current minute window
current_minute = int(time.time() // 60)
key = f"{user_id}:{current_minute}"
# Increment request count
current_count = self.request_counts.get(key, 0) + 1
self.request_counts[key] = current_count
# Clean up old entries (simple cleanup)
if len(self.request_counts) > 10000:
old_keys = [k for k in self.request_counts.keys()
if int(k.split(':')[1]) < current_minute - 5]
for old_key in old_keys:
del self.request_counts[old_key]
# Check rate limit
if current_count > rate_limit:
return JSONResponse(
status_code=429,
content={
"success": False,
"error": "Rate limit exceeded",
"error_details": [{"message": f"Rate limit of {rate_limit} requests per minute exceeded"}],
"timestamp": datetime.utcnow().isoformat(),
"retry_after": 60 - (int(time.time()) % 60)
},
headers={
"Retry-After": str(60 - (int(time.time()) % 60)),
"X-RateLimit-Limit": str(rate_limit),
"X-RateLimit-Remaining": str(max(0, rate_limit - current_count)),
"X-RateLimit-Reset": str((current_minute + 1) * 60)
}
)
# Add rate limit headers to response
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(rate_limit)
response.headers["X-RateLimit-Remaining"] = str(max(0, rate_limit - current_count))
response.headers["X-RateLimit-Reset"] = str((current_minute + 1) * 60)
return response
class CompressionMiddleware(BaseHTTPMiddleware):
"""Custom compression middleware with configurable settings."""
def __init__(self, app, minimum_size: int = 1000, compression_level: int = 6):
super().__init__(app)
self.minimum_size = minimum_size
self.compression_level = compression_level
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Check if client accepts gzip
accept_encoding = request.headers.get("accept-encoding", "")
if "gzip" not in accept_encoding:
return response
# Check content type and size
content_type = response.headers.get("content-type", "")
if not any(ct in content_type for ct in ["application/json", "text/", "application/javascript"]):
return response
# Get response body
body = b""
async for chunk in response.body_iterator:
body += chunk
# Compress if body is large enough
if len(body) >= self.minimum_size:
import gzip
compressed_body = gzip.compress(body, compresslevel=self.compression_level)
# Create new response with compressed body
from starlette.responses import Response
return Response(
content=compressed_body,
status_code=response.status_code,
headers={
**dict(response.headers),
"content-encoding": "gzip",
"content-length": str(len(compressed_body))
}
)
# Return original response if not compressed
from starlette.responses import Response
return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers)
)