Phase 2 build initial
This commit is contained in:
		
							
								
								
									
										365
									
								
								hcfs-python/hcfs/api/middleware.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										365
									
								
								hcfs-python/hcfs/api/middleware.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,365 @@ | ||||
| """ | ||||
| Custom middleware for HCFS API. | ||||
|  | ||||
| Provides authentication, logging, error handling, and security features. | ||||
| """ | ||||
|  | ||||
| import time | ||||
| import uuid | ||||
| import json | ||||
| from typing import Optional | ||||
| from datetime import datetime, timedelta | ||||
|  | ||||
| from fastapi import Request, Response, HTTPException, status | ||||
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | ||||
| from fastapi.middleware.base import BaseHTTPMiddleware | ||||
| from starlette.responses import JSONResponse | ||||
| import jwt | ||||
| import structlog | ||||
|  | ||||
| logger = structlog.get_logger() | ||||
|  | ||||
| class RequestLoggingMiddleware(BaseHTTPMiddleware): | ||||
|     """Middleware for comprehensive request/response logging.""" | ||||
|      | ||||
|     def __init__(self, app, log_body: bool = False): | ||||
|         super().__init__(app) | ||||
|         self.log_body = log_body | ||||
|      | ||||
|     async def dispatch(self, request: Request, call_next): | ||||
|         # Generate request ID | ||||
|         request_id = str(uuid.uuid4()) | ||||
|         request.state.request_id = request_id | ||||
|          | ||||
|         # Start timing | ||||
|         start_time = time.time() | ||||
|          | ||||
|         # Log request | ||||
|         logger.info( | ||||
|             "Request started", | ||||
|             request_id=request_id, | ||||
|             method=request.method, | ||||
|             url=str(request.url), | ||||
|             client_ip=request.client.host if request.client else None, | ||||
|             user_agent=request.headers.get("user-agent"), | ||||
|         ) | ||||
|          | ||||
|         # Call the next middleware/endpoint | ||||
|         response = await call_next(request) | ||||
|          | ||||
|         # Calculate duration | ||||
|         duration = time.time() - start_time | ||||
|          | ||||
|         # Log response | ||||
|         logger.info( | ||||
|             "Request completed", | ||||
|             request_id=request_id, | ||||
|             status_code=response.status_code, | ||||
|             duration_ms=round(duration * 1000, 2), | ||||
|         ) | ||||
|          | ||||
|         # Add request ID to response headers | ||||
|         response.headers["X-Request-ID"] = request_id | ||||
|          | ||||
|         return response | ||||
|  | ||||
|  | ||||
| class ErrorHandlingMiddleware(BaseHTTPMiddleware): | ||||
|     """Middleware for consistent error handling and formatting.""" | ||||
|      | ||||
|     async def dispatch(self, request: Request, call_next): | ||||
|         try: | ||||
|             response = await call_next(request) | ||||
|             return response | ||||
|         except HTTPException as e: | ||||
|             # FastAPI HTTPExceptions are handled by FastAPI itself | ||||
|             raise e | ||||
|         except Exception as e: | ||||
|             # Log unexpected errors | ||||
|             request_id = getattr(request.state, 'request_id', 'unknown') | ||||
|             logger.error( | ||||
|                 "Unhandled exception", | ||||
|                 request_id=request_id, | ||||
|                 error=str(e), | ||||
|                 error_type=type(e).__name__, | ||||
|                 method=request.method, | ||||
|                 url=str(request.url), | ||||
|                 exc_info=True | ||||
|             ) | ||||
|              | ||||
|             # Return consistent error response | ||||
|             return JSONResponse( | ||||
|                 status_code=500, | ||||
|                 content={ | ||||
|                     "success": False, | ||||
|                     "error": "Internal server error", | ||||
|                     "error_details": [{"message": "An unexpected error occurred"}], | ||||
|                     "timestamp": datetime.utcnow().isoformat(), | ||||
|                     "request_id": request_id, | ||||
|                     "api_version": "v1" | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|  | ||||
| class SecurityHeadersMiddleware(BaseHTTPMiddleware): | ||||
|     """Middleware to add security headers.""" | ||||
|      | ||||
|     async def dispatch(self, request: Request, call_next): | ||||
|         response = await call_next(request) | ||||
|          | ||||
|         # Add security headers | ||||
|         response.headers["X-Content-Type-Options"] = "nosniff" | ||||
|         response.headers["X-Frame-Options"] = "DENY" | ||||
|         response.headers["X-XSS-Protection"] = "1; mode=block" | ||||
|         response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" | ||||
|         response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" | ||||
|         response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" | ||||
|          | ||||
|         return response | ||||
|  | ||||
|  | ||||
| class JWTAuthenticationManager: | ||||
|     """JWT-based authentication manager.""" | ||||
|      | ||||
|     def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 30): | ||||
|         self.secret_key = secret_key | ||||
|         self.algorithm = algorithm | ||||
|         self.token_expire_minutes = token_expire_minutes | ||||
|      | ||||
|     def create_access_token(self, data: dict) -> str: | ||||
|         """Create JWT access token.""" | ||||
|         to_encode = data.copy() | ||||
|         expire = datetime.utcnow() + timedelta(minutes=self.token_expire_minutes) | ||||
|         to_encode.update({"exp": expire, "iat": datetime.utcnow()}) | ||||
|          | ||||
|         return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) | ||||
|      | ||||
|     def verify_token(self, token: str) -> Optional[dict]: | ||||
|         """Verify and decode JWT token.""" | ||||
|         try: | ||||
|             payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) | ||||
|             return payload | ||||
|         except jwt.ExpiredSignatureError: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                 detail="Token has expired", | ||||
|                 headers={"WWW-Authenticate": "Bearer"}, | ||||
|             ) | ||||
|         except jwt.JWTError: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                 detail="Could not validate credentials", | ||||
|                 headers={"WWW-Authenticate": "Bearer"}, | ||||
|             ) | ||||
|  | ||||
|  | ||||
| class APIKeyManager: | ||||
|     """API key-based authentication manager.""" | ||||
|      | ||||
|     def __init__(self): | ||||
|         # In production, store these in a database | ||||
|         self.api_keys = { | ||||
|             "dev-key-123": { | ||||
|                 "name": "Development Key", | ||||
|                 "scopes": ["read", "write"], | ||||
|                 "rate_limit": 1000, | ||||
|                 "created_at": datetime.utcnow(), | ||||
|                 "last_used": None | ||||
|             } | ||||
|         } | ||||
|      | ||||
|     def validate_api_key(self, api_key: str) -> Optional[dict]: | ||||
|         """Validate API key and return key info.""" | ||||
|         key_info = self.api_keys.get(api_key) | ||||
|         if key_info: | ||||
|             # Update last used timestamp | ||||
|             key_info["last_used"] = datetime.utcnow() | ||||
|             return key_info | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class AuthenticationMiddleware(BaseHTTPMiddleware): | ||||
|     """Authentication middleware supporting multiple auth methods.""" | ||||
|      | ||||
|     def __init__(self, app, jwt_manager: JWTAuthenticationManager = None, api_key_manager: APIKeyManager = None): | ||||
|         super().__init__(app) | ||||
|         self.jwt_manager = jwt_manager | ||||
|         self.api_key_manager = api_key_manager or APIKeyManager() | ||||
|          | ||||
|         # Paths that don't require authentication | ||||
|         self.public_paths = { | ||||
|             "/health", | ||||
|             "/metrics", | ||||
|             "/docs", | ||||
|             "/redoc", | ||||
|             "/openapi.json", | ||||
|             "/favicon.ico" | ||||
|         } | ||||
|      | ||||
|     async def dispatch(self, request: Request, call_next): | ||||
|         # Skip authentication for public paths | ||||
|         if any(request.url.path.startswith(path) for path in self.public_paths): | ||||
|             return await call_next(request) | ||||
|          | ||||
|         # Extract authentication credentials | ||||
|         auth_header = request.headers.get("Authorization") | ||||
|         api_key_header = request.headers.get("X-API-Key") | ||||
|          | ||||
|         user_info = None | ||||
|          | ||||
|         # Try JWT authentication first | ||||
|         if auth_header and auth_header.startswith("Bearer ") and self.jwt_manager: | ||||
|             token = auth_header[7:]  # Remove "Bearer " prefix | ||||
|             try: | ||||
|                 payload = self.jwt_manager.verify_token(token) | ||||
|                 user_info = { | ||||
|                     "user_id": payload.get("sub"), | ||||
|                     "username": payload.get("username"), | ||||
|                     "scopes": payload.get("scopes", []), | ||||
|                     "auth_method": "jwt" | ||||
|                 } | ||||
|             except HTTPException: | ||||
|                 pass  # Try other auth methods | ||||
|          | ||||
|         # Try API key authentication | ||||
|         if not user_info and api_key_header: | ||||
|             key_info = self.api_key_manager.validate_api_key(api_key_header) | ||||
|             if key_info: | ||||
|                 user_info = { | ||||
|                     "user_id": f"api_key_{api_key_header[:8]}", | ||||
|                     "username": key_info["name"], | ||||
|                     "scopes": key_info["scopes"], | ||||
|                     "auth_method": "api_key", | ||||
|                     "rate_limit": key_info["rate_limit"] | ||||
|                 } | ||||
|          | ||||
|         # If no valid authentication found | ||||
|         if not user_info: | ||||
|             return JSONResponse( | ||||
|                 status_code=401, | ||||
|                 content={ | ||||
|                     "success": False, | ||||
|                     "error": "Authentication required", | ||||
|                     "error_details": [{"message": "Valid API key or JWT token required"}], | ||||
|                     "timestamp": datetime.utcnow().isoformat(), | ||||
|                     "api_version": "v1" | ||||
|                 } | ||||
|             ) | ||||
|          | ||||
|         # Add user info to request state | ||||
|         request.state.user = user_info | ||||
|          | ||||
|         return await call_next(request) | ||||
|  | ||||
|  | ||||
| class RateLimitingMiddleware(BaseHTTPMiddleware): | ||||
|     """Custom rate limiting middleware.""" | ||||
|      | ||||
|     def __init__(self, app, default_rate_limit: int = 100): | ||||
|         super().__init__(app) | ||||
|         self.default_rate_limit = default_rate_limit | ||||
|         self.request_counts = {}  # In production, use Redis | ||||
|      | ||||
|     async def dispatch(self, request: Request, call_next): | ||||
|         # Get user identifier | ||||
|         user_info = getattr(request.state, 'user', None) | ||||
|         if user_info: | ||||
|             user_id = user_info["user_id"] | ||||
|             rate_limit = user_info.get("rate_limit", self.default_rate_limit) | ||||
|         else: | ||||
|             user_id = request.client.host if request.client else "anonymous" | ||||
|             rate_limit = self.default_rate_limit | ||||
|          | ||||
|         # Current minute window | ||||
|         current_minute = int(time.time() // 60) | ||||
|         key = f"{user_id}:{current_minute}" | ||||
|          | ||||
|         # Increment request count | ||||
|         current_count = self.request_counts.get(key, 0) + 1 | ||||
|         self.request_counts[key] = current_count | ||||
|          | ||||
|         # Clean up old entries (simple cleanup) | ||||
|         if len(self.request_counts) > 10000: | ||||
|             old_keys = [k for k in self.request_counts.keys()  | ||||
|                        if int(k.split(':')[1]) < current_minute - 5] | ||||
|             for old_key in old_keys: | ||||
|                 del self.request_counts[old_key] | ||||
|          | ||||
|         # Check rate limit | ||||
|         if current_count > rate_limit: | ||||
|             return JSONResponse( | ||||
|                 status_code=429, | ||||
|                 content={ | ||||
|                     "success": False, | ||||
|                     "error": "Rate limit exceeded", | ||||
|                     "error_details": [{"message": f"Rate limit of {rate_limit} requests per minute exceeded"}], | ||||
|                     "timestamp": datetime.utcnow().isoformat(), | ||||
|                     "retry_after": 60 - (int(time.time()) % 60) | ||||
|                 }, | ||||
|                 headers={ | ||||
|                     "Retry-After": str(60 - (int(time.time()) % 60)), | ||||
|                     "X-RateLimit-Limit": str(rate_limit), | ||||
|                     "X-RateLimit-Remaining": str(max(0, rate_limit - current_count)), | ||||
|                     "X-RateLimit-Reset": str((current_minute + 1) * 60) | ||||
|                 } | ||||
|             ) | ||||
|          | ||||
|         # Add rate limit headers to response | ||||
|         response = await call_next(request) | ||||
|         response.headers["X-RateLimit-Limit"] = str(rate_limit) | ||||
|         response.headers["X-RateLimit-Remaining"] = str(max(0, rate_limit - current_count)) | ||||
|         response.headers["X-RateLimit-Reset"] = str((current_minute + 1) * 60) | ||||
|          | ||||
|         return response | ||||
|  | ||||
|  | ||||
| class CompressionMiddleware(BaseHTTPMiddleware): | ||||
|     """Custom compression middleware with configurable settings.""" | ||||
|      | ||||
|     def __init__(self, app, minimum_size: int = 1000, compression_level: int = 6): | ||||
|         super().__init__(app) | ||||
|         self.minimum_size = minimum_size | ||||
|         self.compression_level = compression_level | ||||
|      | ||||
|     async def dispatch(self, request: Request, call_next): | ||||
|         response = await call_next(request) | ||||
|          | ||||
|         # Check if client accepts gzip | ||||
|         accept_encoding = request.headers.get("accept-encoding", "") | ||||
|         if "gzip" not in accept_encoding: | ||||
|             return response | ||||
|          | ||||
|         # Check content type and size | ||||
|         content_type = response.headers.get("content-type", "") | ||||
|         if not any(ct in content_type for ct in ["application/json", "text/", "application/javascript"]): | ||||
|             return response | ||||
|          | ||||
|         # Get response body | ||||
|         body = b"" | ||||
|         async for chunk in response.body_iterator: | ||||
|             body += chunk | ||||
|          | ||||
|         # Compress if body is large enough | ||||
|         if len(body) >= self.minimum_size: | ||||
|             import gzip | ||||
|             compressed_body = gzip.compress(body, compresslevel=self.compression_level) | ||||
|              | ||||
|             # Create new response with compressed body | ||||
|             from starlette.responses import Response | ||||
|             return Response( | ||||
|                 content=compressed_body, | ||||
|                 status_code=response.status_code, | ||||
|                 headers={ | ||||
|                     **dict(response.headers), | ||||
|                     "content-encoding": "gzip", | ||||
|                     "content-length": str(len(compressed_body)) | ||||
|                 } | ||||
|             ) | ||||
|          | ||||
|         # Return original response if not compressed | ||||
|         from starlette.responses import Response | ||||
|         return Response( | ||||
|             content=body, | ||||
|             status_code=response.status_code, | ||||
|             headers=dict(response.headers) | ||||
|         ) | ||||
		Reference in New Issue
	
	Block a user
	 Claude Code
					Claude Code