365 lines
13 KiB
Python
365 lines
13 KiB
Python
"""
|
|
Custom middleware for HCFS API.
|
|
|
|
Provides authentication, logging, error handling, and security features.
|
|
"""
|
|
|
|
import time
|
|
import uuid
|
|
import json
|
|
from typing import Optional
|
|
from datetime import datetime, timedelta
|
|
|
|
from fastapi import Request, Response, HTTPException, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from fastapi.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
import jwt
|
|
import structlog
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware for comprehensive request/response logging."""
|
|
|
|
def __init__(self, app, log_body: bool = False):
|
|
super().__init__(app)
|
|
self.log_body = log_body
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Generate request ID
|
|
request_id = str(uuid.uuid4())
|
|
request.state.request_id = request_id
|
|
|
|
# Start timing
|
|
start_time = time.time()
|
|
|
|
# Log request
|
|
logger.info(
|
|
"Request started",
|
|
request_id=request_id,
|
|
method=request.method,
|
|
url=str(request.url),
|
|
client_ip=request.client.host if request.client else None,
|
|
user_agent=request.headers.get("user-agent"),
|
|
)
|
|
|
|
# Call the next middleware/endpoint
|
|
response = await call_next(request)
|
|
|
|
# Calculate duration
|
|
duration = time.time() - start_time
|
|
|
|
# Log response
|
|
logger.info(
|
|
"Request completed",
|
|
request_id=request_id,
|
|
status_code=response.status_code,
|
|
duration_ms=round(duration * 1000, 2),
|
|
)
|
|
|
|
# Add request ID to response headers
|
|
response.headers["X-Request-ID"] = request_id
|
|
|
|
return response
|
|
|
|
|
|
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware for consistent error handling and formatting."""
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
try:
|
|
response = await call_next(request)
|
|
return response
|
|
except HTTPException as e:
|
|
# FastAPI HTTPExceptions are handled by FastAPI itself
|
|
raise e
|
|
except Exception as e:
|
|
# Log unexpected errors
|
|
request_id = getattr(request.state, 'request_id', 'unknown')
|
|
logger.error(
|
|
"Unhandled exception",
|
|
request_id=request_id,
|
|
error=str(e),
|
|
error_type=type(e).__name__,
|
|
method=request.method,
|
|
url=str(request.url),
|
|
exc_info=True
|
|
)
|
|
|
|
# Return consistent error response
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"success": False,
|
|
"error": "Internal server error",
|
|
"error_details": [{"message": "An unexpected error occurred"}],
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"request_id": request_id,
|
|
"api_version": "v1"
|
|
}
|
|
)
|
|
|
|
|
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to add security headers."""
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
response = await call_next(request)
|
|
|
|
# Add security headers
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
|
|
|
return response
|
|
|
|
|
|
class JWTAuthenticationManager:
|
|
"""JWT-based authentication manager."""
|
|
|
|
def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 30):
|
|
self.secret_key = secret_key
|
|
self.algorithm = algorithm
|
|
self.token_expire_minutes = token_expire_minutes
|
|
|
|
def create_access_token(self, data: dict) -> str:
|
|
"""Create JWT access token."""
|
|
to_encode = data.copy()
|
|
expire = datetime.utcnow() + timedelta(minutes=self.token_expire_minutes)
|
|
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
|
|
|
|
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
|
|
|
def verify_token(self, token: str) -> Optional[dict]:
|
|
"""Verify and decode JWT token."""
|
|
try:
|
|
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token has expired",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
except jwt.JWTError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
|
|
class APIKeyManager:
|
|
"""API key-based authentication manager."""
|
|
|
|
def __init__(self):
|
|
# In production, store these in a database
|
|
self.api_keys = {
|
|
"dev-key-123": {
|
|
"name": "Development Key",
|
|
"scopes": ["read", "write"],
|
|
"rate_limit": 1000,
|
|
"created_at": datetime.utcnow(),
|
|
"last_used": None
|
|
}
|
|
}
|
|
|
|
def validate_api_key(self, api_key: str) -> Optional[dict]:
|
|
"""Validate API key and return key info."""
|
|
key_info = self.api_keys.get(api_key)
|
|
if key_info:
|
|
# Update last used timestamp
|
|
key_info["last_used"] = datetime.utcnow()
|
|
return key_info
|
|
return None
|
|
|
|
|
|
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|
"""Authentication middleware supporting multiple auth methods."""
|
|
|
|
def __init__(self, app, jwt_manager: JWTAuthenticationManager = None, api_key_manager: APIKeyManager = None):
|
|
super().__init__(app)
|
|
self.jwt_manager = jwt_manager
|
|
self.api_key_manager = api_key_manager or APIKeyManager()
|
|
|
|
# Paths that don't require authentication
|
|
self.public_paths = {
|
|
"/health",
|
|
"/metrics",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
"/favicon.ico"
|
|
}
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Skip authentication for public paths
|
|
if any(request.url.path.startswith(path) for path in self.public_paths):
|
|
return await call_next(request)
|
|
|
|
# Extract authentication credentials
|
|
auth_header = request.headers.get("Authorization")
|
|
api_key_header = request.headers.get("X-API-Key")
|
|
|
|
user_info = None
|
|
|
|
# Try JWT authentication first
|
|
if auth_header and auth_header.startswith("Bearer ") and self.jwt_manager:
|
|
token = auth_header[7:] # Remove "Bearer " prefix
|
|
try:
|
|
payload = self.jwt_manager.verify_token(token)
|
|
user_info = {
|
|
"user_id": payload.get("sub"),
|
|
"username": payload.get("username"),
|
|
"scopes": payload.get("scopes", []),
|
|
"auth_method": "jwt"
|
|
}
|
|
except HTTPException:
|
|
pass # Try other auth methods
|
|
|
|
# Try API key authentication
|
|
if not user_info and api_key_header:
|
|
key_info = self.api_key_manager.validate_api_key(api_key_header)
|
|
if key_info:
|
|
user_info = {
|
|
"user_id": f"api_key_{api_key_header[:8]}",
|
|
"username": key_info["name"],
|
|
"scopes": key_info["scopes"],
|
|
"auth_method": "api_key",
|
|
"rate_limit": key_info["rate_limit"]
|
|
}
|
|
|
|
# If no valid authentication found
|
|
if not user_info:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={
|
|
"success": False,
|
|
"error": "Authentication required",
|
|
"error_details": [{"message": "Valid API key or JWT token required"}],
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"api_version": "v1"
|
|
}
|
|
)
|
|
|
|
# Add user info to request state
|
|
request.state.user = user_info
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
class RateLimitingMiddleware(BaseHTTPMiddleware):
|
|
"""Custom rate limiting middleware."""
|
|
|
|
def __init__(self, app, default_rate_limit: int = 100):
|
|
super().__init__(app)
|
|
self.default_rate_limit = default_rate_limit
|
|
self.request_counts = {} # In production, use Redis
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Get user identifier
|
|
user_info = getattr(request.state, 'user', None)
|
|
if user_info:
|
|
user_id = user_info["user_id"]
|
|
rate_limit = user_info.get("rate_limit", self.default_rate_limit)
|
|
else:
|
|
user_id = request.client.host if request.client else "anonymous"
|
|
rate_limit = self.default_rate_limit
|
|
|
|
# Current minute window
|
|
current_minute = int(time.time() // 60)
|
|
key = f"{user_id}:{current_minute}"
|
|
|
|
# Increment request count
|
|
current_count = self.request_counts.get(key, 0) + 1
|
|
self.request_counts[key] = current_count
|
|
|
|
# Clean up old entries (simple cleanup)
|
|
if len(self.request_counts) > 10000:
|
|
old_keys = [k for k in self.request_counts.keys()
|
|
if int(k.split(':')[1]) < current_minute - 5]
|
|
for old_key in old_keys:
|
|
del self.request_counts[old_key]
|
|
|
|
# Check rate limit
|
|
if current_count > rate_limit:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"success": False,
|
|
"error": "Rate limit exceeded",
|
|
"error_details": [{"message": f"Rate limit of {rate_limit} requests per minute exceeded"}],
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"retry_after": 60 - (int(time.time()) % 60)
|
|
},
|
|
headers={
|
|
"Retry-After": str(60 - (int(time.time()) % 60)),
|
|
"X-RateLimit-Limit": str(rate_limit),
|
|
"X-RateLimit-Remaining": str(max(0, rate_limit - current_count)),
|
|
"X-RateLimit-Reset": str((current_minute + 1) * 60)
|
|
}
|
|
)
|
|
|
|
# Add rate limit headers to response
|
|
response = await call_next(request)
|
|
response.headers["X-RateLimit-Limit"] = str(rate_limit)
|
|
response.headers["X-RateLimit-Remaining"] = str(max(0, rate_limit - current_count))
|
|
response.headers["X-RateLimit-Reset"] = str((current_minute + 1) * 60)
|
|
|
|
return response
|
|
|
|
|
|
class CompressionMiddleware(BaseHTTPMiddleware):
|
|
"""Custom compression middleware with configurable settings."""
|
|
|
|
def __init__(self, app, minimum_size: int = 1000, compression_level: int = 6):
|
|
super().__init__(app)
|
|
self.minimum_size = minimum_size
|
|
self.compression_level = compression_level
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
response = await call_next(request)
|
|
|
|
# Check if client accepts gzip
|
|
accept_encoding = request.headers.get("accept-encoding", "")
|
|
if "gzip" not in accept_encoding:
|
|
return response
|
|
|
|
# Check content type and size
|
|
content_type = response.headers.get("content-type", "")
|
|
if not any(ct in content_type for ct in ["application/json", "text/", "application/javascript"]):
|
|
return response
|
|
|
|
# Get response body
|
|
body = b""
|
|
async for chunk in response.body_iterator:
|
|
body += chunk
|
|
|
|
# Compress if body is large enough
|
|
if len(body) >= self.minimum_size:
|
|
import gzip
|
|
compressed_body = gzip.compress(body, compresslevel=self.compression_level)
|
|
|
|
# Create new response with compressed body
|
|
from starlette.responses import Response
|
|
return Response(
|
|
content=compressed_body,
|
|
status_code=response.status_code,
|
|
headers={
|
|
**dict(response.headers),
|
|
"content-encoding": "gzip",
|
|
"content-length": str(len(compressed_body))
|
|
}
|
|
)
|
|
|
|
# Return original response if not compressed
|
|
from starlette.responses import Response
|
|
return Response(
|
|
content=body,
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers)
|
|
) |