Phase 2 build initial
This commit is contained in:
365
hcfs-python/hcfs/api/middleware.py
Normal file
365
hcfs-python/hcfs/api/middleware.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
Custom middleware for HCFS API.
|
||||
|
||||
Provides authentication, logging, error handling, and security features.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
import jwt
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for comprehensive request/response logging."""
|
||||
|
||||
def __init__(self, app, log_body: bool = False):
|
||||
super().__init__(app)
|
||||
self.log_body = log_body
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
"Request started",
|
||||
request_id=request_id,
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
client_ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
)
|
||||
|
||||
# Call the next middleware/endpoint
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
"Request completed",
|
||||
request_id=request_id,
|
||||
status_code=response.status_code,
|
||||
duration_ms=round(duration * 1000, 2),
|
||||
)
|
||||
|
||||
# Add request ID to response headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for consistent error handling and formatting."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
except HTTPException as e:
|
||||
# FastAPI HTTPExceptions are handled by FastAPI itself
|
||||
raise e
|
||||
except Exception as e:
|
||||
# Log unexpected errors
|
||||
request_id = getattr(request.state, 'request_id', 'unknown')
|
||||
logger.error(
|
||||
"Unhandled exception",
|
||||
request_id=request_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Return consistent error response
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Internal server error",
|
||||
"error_details": [{"message": "An unexpected error occurred"}],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": request_id,
|
||||
"api_version": "v1"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add security headers."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class JWTAuthenticationManager:
|
||||
"""JWT-based authentication manager."""
|
||||
|
||||
def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 30):
|
||||
self.secret_key = secret_key
|
||||
self.algorithm = algorithm
|
||||
self.token_expire_minutes = token_expire_minutes
|
||||
|
||||
def create_access_token(self, data: dict) -> str:
|
||||
"""Create JWT access token."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=self.token_expire_minutes)
|
||||
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
|
||||
|
||||
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str) -> Optional[dict]:
|
||||
"""Verify and decode JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except jwt.JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
"""API key-based authentication manager."""
|
||||
|
||||
def __init__(self):
|
||||
# In production, store these in a database
|
||||
self.api_keys = {
|
||||
"dev-key-123": {
|
||||
"name": "Development Key",
|
||||
"scopes": ["read", "write"],
|
||||
"rate_limit": 1000,
|
||||
"created_at": datetime.utcnow(),
|
||||
"last_used": None
|
||||
}
|
||||
}
|
||||
|
||||
def validate_api_key(self, api_key: str) -> Optional[dict]:
|
||||
"""Validate API key and return key info."""
|
||||
key_info = self.api_keys.get(api_key)
|
||||
if key_info:
|
||||
# Update last used timestamp
|
||||
key_info["last_used"] = datetime.utcnow()
|
||||
return key_info
|
||||
return None
|
||||
|
||||
|
||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""Authentication middleware supporting multiple auth methods."""
|
||||
|
||||
def __init__(self, app, jwt_manager: JWTAuthenticationManager = None, api_key_manager: APIKeyManager = None):
|
||||
super().__init__(app)
|
||||
self.jwt_manager = jwt_manager
|
||||
self.api_key_manager = api_key_manager or APIKeyManager()
|
||||
|
||||
# Paths that don't require authentication
|
||||
self.public_paths = {
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/favicon.ico"
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip authentication for public paths
|
||||
if any(request.url.path.startswith(path) for path in self.public_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract authentication credentials
|
||||
auth_header = request.headers.get("Authorization")
|
||||
api_key_header = request.headers.get("X-API-Key")
|
||||
|
||||
user_info = None
|
||||
|
||||
# Try JWT authentication first
|
||||
if auth_header and auth_header.startswith("Bearer ") and self.jwt_manager:
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
try:
|
||||
payload = self.jwt_manager.verify_token(token)
|
||||
user_info = {
|
||||
"user_id": payload.get("sub"),
|
||||
"username": payload.get("username"),
|
||||
"scopes": payload.get("scopes", []),
|
||||
"auth_method": "jwt"
|
||||
}
|
||||
except HTTPException:
|
||||
pass # Try other auth methods
|
||||
|
||||
# Try API key authentication
|
||||
if not user_info and api_key_header:
|
||||
key_info = self.api_key_manager.validate_api_key(api_key_header)
|
||||
if key_info:
|
||||
user_info = {
|
||||
"user_id": f"api_key_{api_key_header[:8]}",
|
||||
"username": key_info["name"],
|
||||
"scopes": key_info["scopes"],
|
||||
"auth_method": "api_key",
|
||||
"rate_limit": key_info["rate_limit"]
|
||||
}
|
||||
|
||||
# If no valid authentication found
|
||||
if not user_info:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Authentication required",
|
||||
"error_details": [{"message": "Valid API key or JWT token required"}],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"api_version": "v1"
|
||||
}
|
||||
)
|
||||
|
||||
# Add user info to request state
|
||||
request.state.user = user_info
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RateLimitingMiddleware(BaseHTTPMiddleware):
|
||||
"""Custom rate limiting middleware."""
|
||||
|
||||
def __init__(self, app, default_rate_limit: int = 100):
|
||||
super().__init__(app)
|
||||
self.default_rate_limit = default_rate_limit
|
||||
self.request_counts = {} # In production, use Redis
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Get user identifier
|
||||
user_info = getattr(request.state, 'user', None)
|
||||
if user_info:
|
||||
user_id = user_info["user_id"]
|
||||
rate_limit = user_info.get("rate_limit", self.default_rate_limit)
|
||||
else:
|
||||
user_id = request.client.host if request.client else "anonymous"
|
||||
rate_limit = self.default_rate_limit
|
||||
|
||||
# Current minute window
|
||||
current_minute = int(time.time() // 60)
|
||||
key = f"{user_id}:{current_minute}"
|
||||
|
||||
# Increment request count
|
||||
current_count = self.request_counts.get(key, 0) + 1
|
||||
self.request_counts[key] = current_count
|
||||
|
||||
# Clean up old entries (simple cleanup)
|
||||
if len(self.request_counts) > 10000:
|
||||
old_keys = [k for k in self.request_counts.keys()
|
||||
if int(k.split(':')[1]) < current_minute - 5]
|
||||
for old_key in old_keys:
|
||||
del self.request_counts[old_key]
|
||||
|
||||
# Check rate limit
|
||||
if current_count > rate_limit:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Rate limit exceeded",
|
||||
"error_details": [{"message": f"Rate limit of {rate_limit} requests per minute exceeded"}],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"retry_after": 60 - (int(time.time()) % 60)
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(60 - (int(time.time()) % 60)),
|
||||
"X-RateLimit-Limit": str(rate_limit),
|
||||
"X-RateLimit-Remaining": str(max(0, rate_limit - current_count)),
|
||||
"X-RateLimit-Reset": str((current_minute + 1) * 60)
|
||||
}
|
||||
)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(rate_limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(max(0, rate_limit - current_count))
|
||||
response.headers["X-RateLimit-Reset"] = str((current_minute + 1) * 60)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class CompressionMiddleware(BaseHTTPMiddleware):
|
||||
"""Custom compression middleware with configurable settings."""
|
||||
|
||||
def __init__(self, app, minimum_size: int = 1000, compression_level: int = 6):
|
||||
super().__init__(app)
|
||||
self.minimum_size = minimum_size
|
||||
self.compression_level = compression_level
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
# Check if client accepts gzip
|
||||
accept_encoding = request.headers.get("accept-encoding", "")
|
||||
if "gzip" not in accept_encoding:
|
||||
return response
|
||||
|
||||
# Check content type and size
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if not any(ct in content_type for ct in ["application/json", "text/", "application/javascript"]):
|
||||
return response
|
||||
|
||||
# Get response body
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
|
||||
# Compress if body is large enough
|
||||
if len(body) >= self.minimum_size:
|
||||
import gzip
|
||||
compressed_body = gzip.compress(body, compresslevel=self.compression_level)
|
||||
|
||||
# Create new response with compressed body
|
||||
from starlette.responses import Response
|
||||
return Response(
|
||||
content=compressed_body,
|
||||
status_code=response.status_code,
|
||||
headers={
|
||||
**dict(response.headers),
|
||||
"content-encoding": "gzip",
|
||||
"content-length": str(len(compressed_body))
|
||||
}
|
||||
)
|
||||
|
||||
# Return original response if not compressed
|
||||
from starlette.responses import Response
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
Reference in New Issue
Block a user