""" Custom middleware for HCFS API. Provides authentication, logging, error handling, and security features. """ import time import uuid import json from typing import Optional from datetime import datetime, timedelta from fastapi import Request, Response, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse import jwt import structlog logger = structlog.get_logger() class RequestLoggingMiddleware(BaseHTTPMiddleware): """Middleware for comprehensive request/response logging.""" def __init__(self, app, log_body: bool = False): super().__init__(app) self.log_body = log_body async def dispatch(self, request: Request, call_next): # Generate request ID request_id = str(uuid.uuid4()) request.state.request_id = request_id # Start timing start_time = time.time() # Log request logger.info( "Request started", request_id=request_id, method=request.method, url=str(request.url), client_ip=request.client.host if request.client else None, user_agent=request.headers.get("user-agent"), ) # Call the next middleware/endpoint response = await call_next(request) # Calculate duration duration = time.time() - start_time # Log response logger.info( "Request completed", request_id=request_id, status_code=response.status_code, duration_ms=round(duration * 1000, 2), ) # Add request ID to response headers response.headers["X-Request-ID"] = request_id return response class ErrorHandlingMiddleware(BaseHTTPMiddleware): """Middleware for consistent error handling and formatting.""" async def dispatch(self, request: Request, call_next): try: response = await call_next(request) return response except HTTPException as e: # FastAPI HTTPExceptions are handled by FastAPI itself raise e except Exception as e: # Log unexpected errors request_id = getattr(request.state, 'request_id', 'unknown') logger.error( "Unhandled exception", request_id=request_id, error=str(e), error_type=type(e).__name__, method=request.method, url=str(request.url), exc_info=True ) # Return consistent error response return JSONResponse( status_code=500, content={ "success": False, "error": "Internal server error", "error_details": [{"message": "An unexpected error occurred"}], "timestamp": datetime.utcnow().isoformat(), "request_id": request_id, "api_version": "v1" } ) class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Middleware to add security headers.""" async def dispatch(self, request: Request, call_next): response = await call_next(request) # Add security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" return response class JWTAuthenticationManager: """JWT-based authentication manager.""" def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 30): self.secret_key = secret_key self.algorithm = algorithm self.token_expire_minutes = token_expire_minutes def create_access_token(self, data: dict) -> str: """Create JWT access token.""" to_encode = data.copy() expire = datetime.utcnow() + timedelta(minutes=self.token_expire_minutes) to_encode.update({"exp": expire, "iat": datetime.utcnow()}) return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) def verify_token(self, token: str) -> Optional[dict]: """Verify and decode JWT token.""" try: payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) return payload except jwt.ExpiredSignatureError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired", headers={"WWW-Authenticate": "Bearer"}, ) except jwt.JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) class APIKeyManager: """API key-based authentication manager.""" def __init__(self): # In production, store these in a database self.api_keys = { "dev-key-123": { "name": "Development Key", "scopes": ["read", "write"], "rate_limit": 1000, "created_at": datetime.utcnow(), "last_used": None } } def validate_api_key(self, api_key: str) -> Optional[dict]: """Validate API key and return key info.""" key_info = self.api_keys.get(api_key) if key_info: # Update last used timestamp key_info["last_used"] = datetime.utcnow() return key_info return None class AuthenticationMiddleware(BaseHTTPMiddleware): """Authentication middleware supporting multiple auth methods.""" def __init__(self, app, jwt_manager: JWTAuthenticationManager = None, api_key_manager: APIKeyManager = None): super().__init__(app) self.jwt_manager = jwt_manager self.api_key_manager = api_key_manager or APIKeyManager() # Paths that don't require authentication self.public_paths = { "/health", "/metrics", "/docs", "/redoc", "/openapi.json", "/favicon.ico" } async def dispatch(self, request: Request, call_next): # Skip authentication for public paths if any(request.url.path.startswith(path) for path in self.public_paths): return await call_next(request) # Extract authentication credentials auth_header = request.headers.get("Authorization") api_key_header = request.headers.get("X-API-Key") user_info = None # Try JWT authentication first if auth_header and auth_header.startswith("Bearer ") and self.jwt_manager: token = auth_header[7:] # Remove "Bearer " prefix try: payload = self.jwt_manager.verify_token(token) user_info = { "user_id": payload.get("sub"), "username": payload.get("username"), "scopes": payload.get("scopes", []), "auth_method": "jwt" } except HTTPException: pass # Try other auth methods # Try API key authentication if not user_info and api_key_header: key_info = self.api_key_manager.validate_api_key(api_key_header) if key_info: user_info = { "user_id": f"api_key_{api_key_header[:8]}", "username": key_info["name"], "scopes": key_info["scopes"], "auth_method": "api_key", "rate_limit": key_info["rate_limit"] } # If no valid authentication found if not user_info: return JSONResponse( status_code=401, content={ "success": False, "error": "Authentication required", "error_details": [{"message": "Valid API key or JWT token required"}], "timestamp": datetime.utcnow().isoformat(), "api_version": "v1" } ) # Add user info to request state request.state.user = user_info return await call_next(request) class RateLimitingMiddleware(BaseHTTPMiddleware): """Custom rate limiting middleware.""" def __init__(self, app, default_rate_limit: int = 100): super().__init__(app) self.default_rate_limit = default_rate_limit self.request_counts = {} # In production, use Redis async def dispatch(self, request: Request, call_next): # Get user identifier user_info = getattr(request.state, 'user', None) if user_info: user_id = user_info["user_id"] rate_limit = user_info.get("rate_limit", self.default_rate_limit) else: user_id = request.client.host if request.client else "anonymous" rate_limit = self.default_rate_limit # Current minute window current_minute = int(time.time() // 60) key = f"{user_id}:{current_minute}" # Increment request count current_count = self.request_counts.get(key, 0) + 1 self.request_counts[key] = current_count # Clean up old entries (simple cleanup) if len(self.request_counts) > 10000: old_keys = [k for k in self.request_counts.keys() if int(k.split(':')[1]) < current_minute - 5] for old_key in old_keys: del self.request_counts[old_key] # Check rate limit if current_count > rate_limit: return JSONResponse( status_code=429, content={ "success": False, "error": "Rate limit exceeded", "error_details": [{"message": f"Rate limit of {rate_limit} requests per minute exceeded"}], "timestamp": datetime.utcnow().isoformat(), "retry_after": 60 - (int(time.time()) % 60) }, headers={ "Retry-After": str(60 - (int(time.time()) % 60)), "X-RateLimit-Limit": str(rate_limit), "X-RateLimit-Remaining": str(max(0, rate_limit - current_count)), "X-RateLimit-Reset": str((current_minute + 1) * 60) } ) # Add rate limit headers to response response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(rate_limit) response.headers["X-RateLimit-Remaining"] = str(max(0, rate_limit - current_count)) response.headers["X-RateLimit-Reset"] = str((current_minute + 1) * 60) return response class CompressionMiddleware(BaseHTTPMiddleware): """Custom compression middleware with configurable settings.""" def __init__(self, app, minimum_size: int = 1000, compression_level: int = 6): super().__init__(app) self.minimum_size = minimum_size self.compression_level = compression_level async def dispatch(self, request: Request, call_next): response = await call_next(request) # Check if client accepts gzip accept_encoding = request.headers.get("accept-encoding", "") if "gzip" not in accept_encoding: return response # Check content type and size content_type = response.headers.get("content-type", "") if not any(ct in content_type for ct in ["application/json", "text/", "application/javascript"]): return response # Get response body body = b"" async for chunk in response.body_iterator: body += chunk # Compress if body is large enough if len(body) >= self.minimum_size: import gzip compressed_body = gzip.compress(body, compresslevel=self.compression_level) # Create new response with compressed body from starlette.responses import Response return Response( content=compressed_body, status_code=response.status_code, headers={ **dict(response.headers), "content-encoding": "gzip", "content-length": str(len(compressed_body)) } ) # Return original response if not compressed from starlette.responses import Response return Response( content=body, status_code=response.status_code, headers=dict(response.headers) )