Phase 2 build initial
This commit is contained in:
1
hcfs-python/hcfs/api/__init__.py
Normal file
1
hcfs-python/hcfs/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""HCFS API components."""
|
||||
288
hcfs-python/hcfs/api/config.py
Normal file
288
hcfs-python/hcfs/api/config.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Configuration management for HCFS API.
|
||||
|
||||
Handles environment-based configuration with validation and defaults.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseSettings, Field, validator
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
"""Database configuration settings."""
|
||||
|
||||
# SQLite settings
|
||||
db_path: str = Field(default="hcfs_production.db", description="Path to SQLite database")
|
||||
vector_db_path: str = Field(default="hcfs_vectors_production.db", description="Path to vector database")
|
||||
|
||||
# Connection settings
|
||||
pool_size: int = Field(default=10, description="Database connection pool size")
|
||||
max_overflow: int = Field(default=20, description="Maximum connection overflow")
|
||||
pool_timeout: int = Field(default=30, description="Connection pool timeout in seconds")
|
||||
|
||||
# Performance settings
|
||||
cache_size: int = Field(default=1000, description="Database cache size")
|
||||
enable_wal_mode: bool = Field(default=True, description="Enable SQLite WAL mode")
|
||||
synchronous_mode: str = Field(default="NORMAL", description="SQLite synchronous mode")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_DB_"
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseSettings):
|
||||
"""Embedding system configuration."""
|
||||
|
||||
# Model settings
|
||||
model_name: str = Field(default="mini", description="Embedding model to use")
|
||||
cache_size: int = Field(default=2000, description="Embedding cache size")
|
||||
batch_size: int = Field(default=32, description="Batch processing size")
|
||||
|
||||
# Performance settings
|
||||
max_workers: int = Field(default=4, description="Maximum worker threads")
|
||||
timeout_seconds: int = Field(default=300, description="Operation timeout")
|
||||
|
||||
# Vector database settings
|
||||
vector_dimension: int = Field(default=384, description="Vector dimension")
|
||||
similarity_threshold: float = Field(default=0.0, description="Default similarity threshold")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_EMBEDDING_"
|
||||
|
||||
|
||||
class APIConfig(BaseSettings):
|
||||
"""API server configuration."""
|
||||
|
||||
# Server settings
|
||||
host: str = Field(default="0.0.0.0", description="Server host")
|
||||
port: int = Field(default=8000, description="Server port")
|
||||
workers: int = Field(default=1, description="Number of worker processes")
|
||||
|
||||
# Security settings
|
||||
secret_key: str = Field(default="dev-secret-key-change-in-production", description="JWT secret key")
|
||||
algorithm: str = Field(default="HS256", description="JWT algorithm")
|
||||
token_expire_minutes: int = Field(default=30, description="JWT token expiration time")
|
||||
|
||||
# CORS settings
|
||||
cors_origins: List[str] = Field(
|
||||
default=["http://localhost:3000", "http://localhost:8080"],
|
||||
description="Allowed CORS origins"
|
||||
)
|
||||
cors_credentials: bool = Field(default=True, description="Allow credentials in CORS")
|
||||
|
||||
# Rate limiting
|
||||
rate_limit_requests: int = Field(default=100, description="Requests per minute")
|
||||
rate_limit_burst: int = Field(default=20, description="Burst requests allowed")
|
||||
|
||||
# Feature flags
|
||||
enable_auth: bool = Field(default=True, description="Enable authentication")
|
||||
enable_websocket: bool = Field(default=True, description="Enable WebSocket support")
|
||||
enable_metrics: bool = Field(default=True, description="Enable Prometheus metrics")
|
||||
enable_docs: bool = Field(default=True, description="Enable API documentation")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_API_"
|
||||
|
||||
|
||||
class MonitoringConfig(BaseSettings):
|
||||
"""Monitoring and observability configuration."""
|
||||
|
||||
# Logging settings
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(default="json", description="Log format (json/text)")
|
||||
log_file: Optional[str] = Field(default=None, description="Log file path")
|
||||
|
||||
# Metrics settings
|
||||
metrics_enabled: bool = Field(default=True, description="Enable metrics collection")
|
||||
metrics_port: int = Field(default=9090, description="Metrics server port")
|
||||
|
||||
# Health check settings
|
||||
health_check_interval: int = Field(default=30, description="Health check interval in seconds")
|
||||
health_check_timeout: int = Field(default=5, description="Health check timeout")
|
||||
|
||||
# Tracing settings
|
||||
tracing_enabled: bool = Field(default=False, description="Enable distributed tracing")
|
||||
tracing_sample_rate: float = Field(default=0.1, description="Tracing sample rate")
|
||||
jaeger_endpoint: Optional[str] = Field(default=None, description="Jaeger endpoint")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_MONITORING_"
|
||||
|
||||
|
||||
class RedisConfig(BaseSettings):
|
||||
"""Redis configuration for caching and rate limiting."""
|
||||
|
||||
# Connection settings
|
||||
host: str = Field(default="localhost", description="Redis host")
|
||||
port: int = Field(default=6379, description="Redis port")
|
||||
db: int = Field(default=0, description="Redis database number")
|
||||
password: Optional[str] = Field(default=None, description="Redis password")
|
||||
|
||||
# Pool settings
|
||||
max_connections: int = Field(default=20, description="Maximum Redis connections")
|
||||
socket_timeout: int = Field(default=5, description="Socket timeout in seconds")
|
||||
|
||||
# Cache settings
|
||||
default_ttl: int = Field(default=3600, description="Default cache TTL in seconds")
|
||||
key_prefix: str = Field(default="hcfs:", description="Redis key prefix")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_REDIS_"
|
||||
|
||||
|
||||
class SecurityConfig(BaseSettings):
|
||||
"""Security configuration."""
|
||||
|
||||
# Authentication
|
||||
require_auth: bool = Field(default=True, description="Require authentication")
|
||||
api_key_header: str = Field(default="X-API-Key", description="API key header name")
|
||||
|
||||
# Rate limiting
|
||||
rate_limit_enabled: bool = Field(default=True, description="Enable rate limiting")
|
||||
rate_limit_storage: str = Field(default="memory", description="Rate limit storage (memory/redis)")
|
||||
|
||||
# HTTPS settings
|
||||
force_https: bool = Field(default=False, description="Force HTTPS in production")
|
||||
hsts_max_age: int = Field(default=31536000, description="HSTS max age")
|
||||
|
||||
# Request validation
|
||||
max_request_size: int = Field(default=10 * 1024 * 1024, description="Maximum request size in bytes")
|
||||
max_query_params: int = Field(default=100, description="Maximum query parameters")
|
||||
|
||||
# Content security
|
||||
allowed_content_types: List[str] = Field(
|
||||
default=["application/json", "application/x-www-form-urlencoded", "multipart/form-data"],
|
||||
description="Allowed content types"
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_SECURITY_"
|
||||
|
||||
|
||||
class HCFSConfig(BaseSettings):
|
||||
"""Main HCFS configuration combining all subsystem configs."""
|
||||
|
||||
# Environment
|
||||
environment: str = Field(default="development", description="Environment (development/staging/production)")
|
||||
debug: bool = Field(default=False, description="Enable debug mode")
|
||||
|
||||
# Application info
|
||||
app_name: str = Field(default="HCFS API", description="Application name")
|
||||
app_version: str = Field(default="2.0.0", description="Application version")
|
||||
app_description: str = Field(default="Context-Aware Hierarchical Context File System API", description="App description")
|
||||
|
||||
# Configuration file path
|
||||
config_file: Optional[str] = Field(default=None, description="Path to configuration file")
|
||||
|
||||
# Subsystem configurations
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||||
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
|
||||
api: APIConfig = Field(default_factory=APIConfig)
|
||||
monitoring: MonitoringConfig = Field(default_factory=MonitoringConfig)
|
||||
redis: RedisConfig = Field(default_factory=RedisConfig)
|
||||
security: SecurityConfig = Field(default_factory=SecurityConfig)
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_"
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
@validator('environment')
|
||||
def validate_environment(cls, v):
|
||||
"""Validate environment value."""
|
||||
allowed = ['development', 'staging', 'production']
|
||||
if v not in allowed:
|
||||
raise ValueError(f'Environment must be one of: {allowed}')
|
||||
return v
|
||||
|
||||
@validator('debug')
|
||||
def validate_debug_in_production(cls, v, values):
|
||||
"""Ensure debug is disabled in production."""
|
||||
if values.get('environment') == 'production' and v:
|
||||
raise ValueError('Debug mode cannot be enabled in production')
|
||||
return v
|
||||
|
||||
def is_production(self) -> bool:
|
||||
"""Check if running in production environment."""
|
||||
return self.environment == 'production'
|
||||
|
||||
def is_development(self) -> bool:
|
||||
"""Check if running in development environment."""
|
||||
return self.environment == 'development'
|
||||
|
||||
def get_database_url(self) -> str:
|
||||
"""Get database URL."""
|
||||
return f"sqlite:///{self.database.db_path}"
|
||||
|
||||
def get_redis_url(self) -> str:
|
||||
"""Get Redis URL."""
|
||||
if self.redis.password:
|
||||
return f"redis://:{self.redis.password}@{self.redis.host}:{self.redis.port}/{self.redis.db}"
|
||||
return f"redis://{self.redis.host}:{self.redis.port}/{self.redis.db}"
|
||||
|
||||
def load_from_file(self, config_path: str) -> None:
|
||||
"""Load configuration from YAML file."""
|
||||
import yaml
|
||||
|
||||
config_file = Path(config_path)
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
# Update configuration
|
||||
for key, value in config_data.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert configuration to dictionary."""
|
||||
return self.dict()
|
||||
|
||||
def save_to_file(self, config_path: str) -> None:
|
||||
"""Save configuration to YAML file."""
|
||||
import yaml
|
||||
|
||||
config_data = self.to_dict()
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(config_data, f, default_flow_style=False, indent=2)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
config = HCFSConfig()
|
||||
|
||||
|
||||
def get_config() -> HCFSConfig:
|
||||
"""Get the global configuration instance."""
|
||||
return config
|
||||
|
||||
|
||||
def load_config(config_path: Optional[str] = None, **overrides) -> HCFSConfig:
|
||||
"""Load configuration with optional file and overrides."""
|
||||
global config
|
||||
|
||||
# Load from file if provided
|
||||
if config_path:
|
||||
config.load_from_file(config_path)
|
||||
|
||||
# Apply overrides
|
||||
for key, value in overrides.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_config_template(output_path: str = "hcfs_config.yaml") -> None:
|
||||
"""Create a configuration template file."""
|
||||
template_config = HCFSConfig()
|
||||
template_config.save_to_file(output_path)
|
||||
print(f"Configuration template created: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create configuration template
|
||||
create_config_template()
|
||||
365
hcfs-python/hcfs/api/middleware.py
Normal file
365
hcfs-python/hcfs/api/middleware.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
Custom middleware for HCFS API.
|
||||
|
||||
Provides authentication, logging, error handling, and security features.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
import jwt
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for comprehensive request/response logging."""
|
||||
|
||||
def __init__(self, app, log_body: bool = False):
|
||||
super().__init__(app)
|
||||
self.log_body = log_body
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
"Request started",
|
||||
request_id=request_id,
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
client_ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
)
|
||||
|
||||
# Call the next middleware/endpoint
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
"Request completed",
|
||||
request_id=request_id,
|
||||
status_code=response.status_code,
|
||||
duration_ms=round(duration * 1000, 2),
|
||||
)
|
||||
|
||||
# Add request ID to response headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for consistent error handling and formatting."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
except HTTPException as e:
|
||||
# FastAPI HTTPExceptions are handled by FastAPI itself
|
||||
raise e
|
||||
except Exception as e:
|
||||
# Log unexpected errors
|
||||
request_id = getattr(request.state, 'request_id', 'unknown')
|
||||
logger.error(
|
||||
"Unhandled exception",
|
||||
request_id=request_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Return consistent error response
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Internal server error",
|
||||
"error_details": [{"message": "An unexpected error occurred"}],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": request_id,
|
||||
"api_version": "v1"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add security headers."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class JWTAuthenticationManager:
|
||||
"""JWT-based authentication manager."""
|
||||
|
||||
def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 30):
|
||||
self.secret_key = secret_key
|
||||
self.algorithm = algorithm
|
||||
self.token_expire_minutes = token_expire_minutes
|
||||
|
||||
def create_access_token(self, data: dict) -> str:
|
||||
"""Create JWT access token."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=self.token_expire_minutes)
|
||||
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
|
||||
|
||||
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str) -> Optional[dict]:
|
||||
"""Verify and decode JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except jwt.JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
"""API key-based authentication manager."""
|
||||
|
||||
def __init__(self):
|
||||
# In production, store these in a database
|
||||
self.api_keys = {
|
||||
"dev-key-123": {
|
||||
"name": "Development Key",
|
||||
"scopes": ["read", "write"],
|
||||
"rate_limit": 1000,
|
||||
"created_at": datetime.utcnow(),
|
||||
"last_used": None
|
||||
}
|
||||
}
|
||||
|
||||
def validate_api_key(self, api_key: str) -> Optional[dict]:
|
||||
"""Validate API key and return key info."""
|
||||
key_info = self.api_keys.get(api_key)
|
||||
if key_info:
|
||||
# Update last used timestamp
|
||||
key_info["last_used"] = datetime.utcnow()
|
||||
return key_info
|
||||
return None
|
||||
|
||||
|
||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""Authentication middleware supporting multiple auth methods."""
|
||||
|
||||
def __init__(self, app, jwt_manager: JWTAuthenticationManager = None, api_key_manager: APIKeyManager = None):
|
||||
super().__init__(app)
|
||||
self.jwt_manager = jwt_manager
|
||||
self.api_key_manager = api_key_manager or APIKeyManager()
|
||||
|
||||
# Paths that don't require authentication
|
||||
self.public_paths = {
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/favicon.ico"
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip authentication for public paths
|
||||
if any(request.url.path.startswith(path) for path in self.public_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract authentication credentials
|
||||
auth_header = request.headers.get("Authorization")
|
||||
api_key_header = request.headers.get("X-API-Key")
|
||||
|
||||
user_info = None
|
||||
|
||||
# Try JWT authentication first
|
||||
if auth_header and auth_header.startswith("Bearer ") and self.jwt_manager:
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
try:
|
||||
payload = self.jwt_manager.verify_token(token)
|
||||
user_info = {
|
||||
"user_id": payload.get("sub"),
|
||||
"username": payload.get("username"),
|
||||
"scopes": payload.get("scopes", []),
|
||||
"auth_method": "jwt"
|
||||
}
|
||||
except HTTPException:
|
||||
pass # Try other auth methods
|
||||
|
||||
# Try API key authentication
|
||||
if not user_info and api_key_header:
|
||||
key_info = self.api_key_manager.validate_api_key(api_key_header)
|
||||
if key_info:
|
||||
user_info = {
|
||||
"user_id": f"api_key_{api_key_header[:8]}",
|
||||
"username": key_info["name"],
|
||||
"scopes": key_info["scopes"],
|
||||
"auth_method": "api_key",
|
||||
"rate_limit": key_info["rate_limit"]
|
||||
}
|
||||
|
||||
# If no valid authentication found
|
||||
if not user_info:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Authentication required",
|
||||
"error_details": [{"message": "Valid API key or JWT token required"}],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"api_version": "v1"
|
||||
}
|
||||
)
|
||||
|
||||
# Add user info to request state
|
||||
request.state.user = user_info
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RateLimitingMiddleware(BaseHTTPMiddleware):
|
||||
"""Custom rate limiting middleware."""
|
||||
|
||||
def __init__(self, app, default_rate_limit: int = 100):
|
||||
super().__init__(app)
|
||||
self.default_rate_limit = default_rate_limit
|
||||
self.request_counts = {} # In production, use Redis
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Get user identifier
|
||||
user_info = getattr(request.state, 'user', None)
|
||||
if user_info:
|
||||
user_id = user_info["user_id"]
|
||||
rate_limit = user_info.get("rate_limit", self.default_rate_limit)
|
||||
else:
|
||||
user_id = request.client.host if request.client else "anonymous"
|
||||
rate_limit = self.default_rate_limit
|
||||
|
||||
# Current minute window
|
||||
current_minute = int(time.time() // 60)
|
||||
key = f"{user_id}:{current_minute}"
|
||||
|
||||
# Increment request count
|
||||
current_count = self.request_counts.get(key, 0) + 1
|
||||
self.request_counts[key] = current_count
|
||||
|
||||
# Clean up old entries (simple cleanup)
|
||||
if len(self.request_counts) > 10000:
|
||||
old_keys = [k for k in self.request_counts.keys()
|
||||
if int(k.split(':')[1]) < current_minute - 5]
|
||||
for old_key in old_keys:
|
||||
del self.request_counts[old_key]
|
||||
|
||||
# Check rate limit
|
||||
if current_count > rate_limit:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Rate limit exceeded",
|
||||
"error_details": [{"message": f"Rate limit of {rate_limit} requests per minute exceeded"}],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"retry_after": 60 - (int(time.time()) % 60)
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(60 - (int(time.time()) % 60)),
|
||||
"X-RateLimit-Limit": str(rate_limit),
|
||||
"X-RateLimit-Remaining": str(max(0, rate_limit - current_count)),
|
||||
"X-RateLimit-Reset": str((current_minute + 1) * 60)
|
||||
}
|
||||
)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(rate_limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(max(0, rate_limit - current_count))
|
||||
response.headers["X-RateLimit-Reset"] = str((current_minute + 1) * 60)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class CompressionMiddleware(BaseHTTPMiddleware):
|
||||
"""Custom compression middleware with configurable settings."""
|
||||
|
||||
def __init__(self, app, minimum_size: int = 1000, compression_level: int = 6):
|
||||
super().__init__(app)
|
||||
self.minimum_size = minimum_size
|
||||
self.compression_level = compression_level
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
# Check if client accepts gzip
|
||||
accept_encoding = request.headers.get("accept-encoding", "")
|
||||
if "gzip" not in accept_encoding:
|
||||
return response
|
||||
|
||||
# Check content type and size
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if not any(ct in content_type for ct in ["application/json", "text/", "application/javascript"]):
|
||||
return response
|
||||
|
||||
# Get response body
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
|
||||
# Compress if body is large enough
|
||||
if len(body) >= self.minimum_size:
|
||||
import gzip
|
||||
compressed_body = gzip.compress(body, compresslevel=self.compression_level)
|
||||
|
||||
# Create new response with compressed body
|
||||
from starlette.responses import Response
|
||||
return Response(
|
||||
content=compressed_body,
|
||||
status_code=response.status_code,
|
||||
headers={
|
||||
**dict(response.headers),
|
||||
"content-encoding": "gzip",
|
||||
"content-length": str(len(compressed_body))
|
||||
}
|
||||
)
|
||||
|
||||
# Return original response if not compressed
|
||||
from starlette.responses import Response
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
347
hcfs-python/hcfs/api/models.py
Normal file
347
hcfs-python/hcfs/api/models.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Enhanced API Models for HCFS Production API.
|
||||
|
||||
Comprehensive Pydantic models for request/response validation,
|
||||
API versioning, and enterprise-grade data validation.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, validator, ConfigDict
|
||||
import uuid
|
||||
|
||||
|
||||
class APIVersion(str, Enum):
|
||||
"""API version enumeration."""
|
||||
V1 = "v1"
|
||||
V2 = "v2"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
"""Search type enumeration."""
|
||||
SEMANTIC = "semantic"
|
||||
HYBRID = "hybrid"
|
||||
KEYWORD = "keyword"
|
||||
SIMILARITY = "similarity"
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order enumeration."""
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class ContextStatus(str, Enum):
|
||||
"""Context status enumeration."""
|
||||
ACTIVE = "active"
|
||||
ARCHIVED = "archived"
|
||||
DRAFT = "draft"
|
||||
DELETED = "deleted"
|
||||
|
||||
|
||||
# Base Models
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
"""Base response model with metadata."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
success: bool = True
|
||||
message: Optional[str] = None
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
api_version: APIVersion = APIVersion.V1
|
||||
|
||||
|
||||
class PaginationParams(BaseModel):
|
||||
"""Pagination parameters."""
|
||||
page: int = Field(default=1, ge=1, description="Page number (1-based)")
|
||||
page_size: int = Field(default=20, ge=1, le=100, description="Items per page")
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
"""Calculate offset from page and page_size."""
|
||||
return (self.page - 1) * self.page_size
|
||||
|
||||
|
||||
class PaginationMeta(BaseModel):
|
||||
"""Pagination metadata."""
|
||||
page: int
|
||||
page_size: int
|
||||
total_items: int
|
||||
total_pages: int
|
||||
has_next: bool
|
||||
has_previous: bool
|
||||
|
||||
|
||||
# Context Models
|
||||
|
||||
class ContextBase(BaseModel):
|
||||
"""Base context model with common fields."""
|
||||
path: str = Field(..., description="Hierarchical path for the context")
|
||||
content: str = Field(..., description="Main content of the context")
|
||||
summary: Optional[str] = Field(None, description="Brief summary of the content")
|
||||
author: Optional[str] = Field(None, description="Author or creator of the context")
|
||||
tags: Optional[List[str]] = Field(default_factory=list, description="Tags associated with the context")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata")
|
||||
status: ContextStatus = Field(default=ContextStatus.ACTIVE, description="Context status")
|
||||
|
||||
@validator('path')
|
||||
def validate_path(cls, v):
|
||||
"""Validate path format."""
|
||||
if not v.startswith('/'):
|
||||
raise ValueError('Path must start with /')
|
||||
if '//' in v:
|
||||
raise ValueError('Path cannot contain double slashes')
|
||||
return v
|
||||
|
||||
@validator('content')
|
||||
def validate_content(cls, v):
|
||||
"""Validate content is not empty."""
|
||||
if not v.strip():
|
||||
raise ValueError('Content cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ContextCreate(ContextBase):
|
||||
"""Model for creating a new context."""
|
||||
pass
|
||||
|
||||
|
||||
class ContextUpdate(BaseModel):
|
||||
"""Model for updating an existing context."""
|
||||
content: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
status: Optional[ContextStatus] = None
|
||||
|
||||
@validator('content')
|
||||
def validate_content(cls, v):
|
||||
"""Validate content if provided."""
|
||||
if v is not None and not v.strip():
|
||||
raise ValueError('Content cannot be empty')
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class ContextResponse(ContextBase):
|
||||
"""Model for context responses."""
|
||||
id: int = Field(..., description="Unique context identifier")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
version: int = Field(..., description="Context version number")
|
||||
embedding_model: Optional[str] = Field(None, description="Embedding model used")
|
||||
similarity_score: Optional[float] = Field(None, description="Similarity score (for search results)")
|
||||
|
||||
|
||||
class ContextListResponse(BaseResponse):
|
||||
"""Response model for context list operations."""
|
||||
data: List[ContextResponse]
|
||||
pagination: PaginationMeta
|
||||
|
||||
|
||||
class ContextDetailResponse(BaseResponse):
|
||||
"""Response model for single context operations."""
|
||||
data: ContextResponse
|
||||
|
||||
|
||||
# Search Models
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Model for search requests."""
|
||||
query: str = Field(..., description="Search query text")
|
||||
search_type: SearchType = Field(default=SearchType.SEMANTIC, description="Type of search to perform")
|
||||
path_prefix: Optional[str] = Field(None, description="Limit search to paths with this prefix")
|
||||
top_k: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return")
|
||||
min_similarity: float = Field(default=0.0, ge=0.0, le=1.0, description="Minimum similarity threshold")
|
||||
semantic_weight: float = Field(default=0.7, ge=0.0, le=1.0, description="Weight for semantic vs keyword search")
|
||||
include_content: bool = Field(default=True, description="Whether to include full content in results")
|
||||
filters: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional search filters")
|
||||
|
||||
@validator('query')
|
||||
def validate_query(cls, v):
|
||||
"""Validate query is not empty."""
|
||||
if not v.strip():
|
||||
raise ValueError('Query cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Individual search result."""
|
||||
context: ContextResponse
|
||||
score: float = Field(..., description="Relevance score")
|
||||
highlight: Optional[Dict[str, List[str]]] = Field(None, description="Highlighted matching text")
|
||||
explanation: Optional[str] = Field(None, description="Explanation of why this result was returned")
|
||||
|
||||
|
||||
class SearchResponse(BaseResponse):
|
||||
"""Response model for search operations."""
|
||||
data: List[SearchResult]
|
||||
query: str
|
||||
search_type: SearchType
|
||||
total_results: int
|
||||
search_time_ms: float
|
||||
filters_applied: Dict[str, Any]
|
||||
|
||||
|
||||
# Version Models
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
"""Model for context version information."""
|
||||
version_id: int
|
||||
version_number: int
|
||||
context_id: int
|
||||
author: str
|
||||
message: Optional[str]
|
||||
created_at: datetime
|
||||
content_hash: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class VersionListResponse(BaseResponse):
|
||||
"""Response model for version history."""
|
||||
data: List[VersionResponse]
|
||||
context_id: int
|
||||
total_versions: int
|
||||
|
||||
|
||||
class VersionCreateRequest(BaseModel):
|
||||
"""Request model for creating a new version."""
|
||||
message: Optional[str] = Field(None, description="Version commit message")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Version metadata")
|
||||
|
||||
|
||||
class RollbackRequest(BaseModel):
|
||||
"""Request model for version rollback."""
|
||||
target_version: int = Field(..., description="Target version number to rollback to")
|
||||
message: Optional[str] = Field(None, description="Rollback commit message")
|
||||
|
||||
|
||||
# Analytics Models
|
||||
|
||||
class ContextStats(BaseModel):
|
||||
"""Context statistics model."""
|
||||
total_contexts: int
|
||||
contexts_by_status: Dict[ContextStatus, int]
|
||||
contexts_by_author: Dict[str, int]
|
||||
average_content_length: float
|
||||
most_active_paths: List[Dict[str, Union[str, int]]]
|
||||
recent_activity: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class SearchStats(BaseModel):
|
||||
"""Search statistics model."""
|
||||
total_searches: int
|
||||
searches_by_type: Dict[SearchType, int]
|
||||
average_response_time_ms: float
|
||||
popular_queries: List[Dict[str, Union[str, int]]]
|
||||
search_success_rate: float
|
||||
|
||||
|
||||
class SystemStats(BaseModel):
|
||||
"""System statistics model."""
|
||||
uptime_seconds: float
|
||||
memory_usage_mb: float
|
||||
active_connections: int
|
||||
cache_hit_rate: float
|
||||
embedding_model_info: Dict[str, Any]
|
||||
database_size_mb: float
|
||||
|
||||
|
||||
class StatsResponse(BaseResponse):
|
||||
"""Response model for statistics."""
|
||||
context_stats: ContextStats
|
||||
search_stats: SearchStats
|
||||
system_stats: SystemStats
|
||||
|
||||
|
||||
# Batch Operations Models
|
||||
|
||||
class BatchContextCreate(BaseModel):
|
||||
"""Model for batch context creation."""
|
||||
contexts: List[ContextCreate] = Field(..., max_items=100, description="List of contexts to create")
|
||||
|
||||
@validator('contexts')
|
||||
def validate_contexts_not_empty(cls, v):
|
||||
"""Validate contexts list is not empty."""
|
||||
if not v:
|
||||
raise ValueError('Contexts list cannot be empty')
|
||||
return v
|
||||
|
||||
|
||||
class BatchOperationResult(BaseModel):
|
||||
"""Result of batch operation."""
|
||||
success_count: int
|
||||
error_count: int
|
||||
total_items: int
|
||||
errors: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
created_ids: List[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BatchResponse(BaseResponse):
|
||||
"""Response model for batch operations."""
|
||||
data: BatchOperationResult
|
||||
|
||||
|
||||
# WebSocket Models
|
||||
|
||||
class WebSocketMessage(BaseModel):
|
||||
"""WebSocket message model."""
|
||||
type: str = Field(..., description="Message type")
|
||||
data: Dict[str, Any] = Field(..., description="Message data")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
message_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
|
||||
class SubscriptionRequest(BaseModel):
|
||||
"""WebSocket subscription request."""
|
||||
path_prefix: str = Field(..., description="Path prefix to subscribe to")
|
||||
event_types: List[str] = Field(default_factory=lambda: ["created", "updated", "deleted"])
|
||||
filters: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Health Check Models
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
"""Health status enumeration."""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
|
||||
|
||||
class ComponentHealth(BaseModel):
|
||||
"""Individual component health."""
|
||||
name: str
|
||||
status: HealthStatus
|
||||
response_time_ms: Optional[float] = None
|
||||
error_message: Optional[str] = None
|
||||
last_check: datetime
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""System health response."""
|
||||
status: HealthStatus
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
version: str
|
||||
uptime_seconds: float
|
||||
components: List[ComponentHealth]
|
||||
|
||||
|
||||
# Error Models
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Detailed error information."""
|
||||
field: Optional[str] = None
|
||||
message: str
|
||||
error_code: Optional[str] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response model."""
|
||||
success: bool = False
|
||||
error: str
|
||||
error_details: Optional[List[ErrorDetail]] = None
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
api_version: APIVersion = APIVersion.V1
|
||||
172
hcfs-python/hcfs/api/server.py
Normal file
172
hcfs-python/hcfs/api/server.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
HCFS API Server - FastAPI-based REST API for context operations.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..core.context_db import ContextDatabase, Context
|
||||
from ..core.embeddings import EmbeddingManager
|
||||
|
||||
|
||||
# Pydantic models
|
||||
class ContextCreateRequest(BaseModel):
|
||||
path: str
|
||||
content: str
|
||||
summary: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
|
||||
|
||||
class ContextResponse(BaseModel):
|
||||
id: int
|
||||
path: str
|
||||
content: str
|
||||
summary: Optional[str]
|
||||
author: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
version: int
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
path_prefix: Optional[str] = None
|
||||
top_k: int = 5
|
||||
search_type: str = "hybrid" # "semantic", "hybrid"
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
context: ContextResponse
|
||||
score: float
|
||||
|
||||
|
||||
class ContextAPI:
|
||||
"""HCFS REST API server."""
|
||||
|
||||
def __init__(self, context_db: ContextDatabase, embedding_manager: EmbeddingManager):
|
||||
self.context_db = context_db
|
||||
self.embedding_manager = embedding_manager
|
||||
self.app = FastAPI(
|
||||
title="HCFS Context API",
|
||||
description="Context-Aware Hierarchical Context File System API",
|
||||
version="0.1.0"
|
||||
)
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Setup API routes."""
|
||||
|
||||
@self.app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "service": "hcfs-api"}
|
||||
|
||||
@self.app.post("/context", response_model=ContextResponse)
|
||||
async def create_context(request: ContextCreateRequest):
|
||||
"""Create a new context."""
|
||||
context = Context(
|
||||
id=None,
|
||||
path=request.path,
|
||||
content=request.content,
|
||||
summary=request.summary,
|
||||
author=request.author
|
||||
)
|
||||
|
||||
# Store with embedding
|
||||
context_id = self.embedding_manager.store_context_with_embedding(context)
|
||||
|
||||
# Retrieve the stored context
|
||||
stored_contexts = self.context_db.list_contexts_at_path(request.path)
|
||||
stored_context = next((c for c in stored_contexts if c.id == context_id), None)
|
||||
|
||||
if not stored_context:
|
||||
raise HTTPException(status_code=500, detail="Failed to store context")
|
||||
|
||||
return ContextResponse(**stored_context.__dict__)
|
||||
|
||||
@self.app.get("/context/{path:path}", response_model=List[ContextResponse])
|
||||
async def get_context(path: str, depth: int = 1):
|
||||
"""Get contexts for a path with optional parent inheritance."""
|
||||
contexts = self.context_db.get_context_by_path(f"/{path}", depth=depth)
|
||||
return [ContextResponse(**ctx.__dict__) for ctx in contexts]
|
||||
|
||||
@self.app.get("/context", response_model=List[ContextResponse])
|
||||
async def list_contexts(path: str):
|
||||
"""List all contexts at a specific path."""
|
||||
contexts = self.context_db.list_contexts_at_path(path)
|
||||
return [ContextResponse(**ctx.__dict__) for ctx in contexts]
|
||||
|
||||
@self.app.put("/context/{context_id}")
|
||||
async def update_context(context_id: int, content: str, summary: Optional[str] = None):
|
||||
"""Update an existing context."""
|
||||
success = self.context_db.update_context(context_id, content, summary)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
|
||||
# Update embedding
|
||||
contexts = self.context_db.list_contexts_at_path("") # Get updated context
|
||||
updated_context = next((c for c in contexts if c.id == context_id), None)
|
||||
if updated_context:
|
||||
embedding = self.embedding_manager.generate_embedding(updated_context.content)
|
||||
self.embedding_manager._store_embedding(context_id, embedding)
|
||||
|
||||
return {"message": "Context updated successfully"}
|
||||
|
||||
@self.app.delete("/context/{context_id}")
|
||||
async def delete_context(context_id: int):
|
||||
"""Delete a context."""
|
||||
success = self.context_db.delete_context(context_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
return {"message": "Context deleted successfully"}
|
||||
|
||||
@self.app.post("/search", response_model=List[SearchResult])
|
||||
async def search_contexts(request: SearchRequest):
|
||||
"""Search contexts using semantic or hybrid search."""
|
||||
if request.search_type == "semantic":
|
||||
results = self.embedding_manager.semantic_search(
|
||||
request.query,
|
||||
request.path_prefix,
|
||||
request.top_k
|
||||
)
|
||||
elif request.search_type == "hybrid":
|
||||
results = self.embedding_manager.hybrid_search(
|
||||
request.query,
|
||||
request.path_prefix,
|
||||
request.top_k
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid search_type")
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
context=ContextResponse(**ctx.__dict__),
|
||||
score=score
|
||||
)
|
||||
for ctx, score in results
|
||||
]
|
||||
|
||||
@self.app.get("/similar/{context_id}", response_model=List[SearchResult])
|
||||
async def get_similar_contexts(context_id: int, top_k: int = 5):
|
||||
"""Find contexts similar to a given context."""
|
||||
results = self.embedding_manager.get_similar_contexts(context_id, top_k)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
context=ContextResponse(**ctx.__dict__),
|
||||
score=score
|
||||
)
|
||||
for ctx, score in results
|
||||
]
|
||||
|
||||
|
||||
def create_app(db_path: str = "hcfs_context.db") -> FastAPI:
|
||||
"""Create FastAPI application with HCFS components."""
|
||||
context_db = ContextDatabase(db_path)
|
||||
embedding_manager = EmbeddingManager(context_db)
|
||||
api = ContextAPI(context_db, embedding_manager)
|
||||
return api.app
|
||||
692
hcfs-python/hcfs/api/server_v2.py
Normal file
692
hcfs-python/hcfs/api/server_v2.py
Normal file
@@ -0,0 +1,692 @@
|
||||
"""
|
||||
Production-Grade HCFS API Server v2.0
|
||||
|
||||
Enterprise-ready FastAPI server with comprehensive features:
|
||||
- Full CRUD operations with validation
|
||||
- Advanced search capabilities
|
||||
- Version control and rollback
|
||||
- Batch operations
|
||||
- Real-time WebSocket updates
|
||||
- Authentication and authorization
|
||||
- Rate limiting and monitoring
|
||||
- OpenAPI documentation
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends, status, Request, Query, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.websocket import WebSocket, WebSocketDisconnect
|
||||
import uvicorn
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from prometheus_client import Counter, Histogram, Gauge, generate_latest, CONTENT_TYPE_LATEST
|
||||
import structlog
|
||||
|
||||
# HCFS imports
|
||||
from .models import *
|
||||
from ..core.context_db_optimized_fixed import OptimizedContextDatabase
|
||||
from ..core.embeddings_optimized import OptimizedEmbeddingManager
|
||||
from ..core.context_versioning import VersioningSystem
|
||||
from ..core.context_db import Context
|
||||
|
||||
# Logging setup
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Metrics
|
||||
REQUEST_COUNT = Counter('hcfs_requests_total', 'Total HTTP requests', ['method', 'endpoint', 'status'])
|
||||
REQUEST_DURATION = Histogram('hcfs_request_duration_seconds', 'HTTP request duration')
|
||||
ACTIVE_CONNECTIONS = Gauge('hcfs_active_connections', 'Active WebSocket connections')
|
||||
CONTEXT_COUNT = Gauge('hcfs_contexts_total', 'Total number of contexts')
|
||||
SEARCH_COUNT = Counter('hcfs_searches_total', 'Total searches performed', ['search_type'])
|
||||
|
||||
# Rate limiting
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
|
||||
class HCFSAPIServer:
|
||||
"""Production HCFS API Server."""
|
||||
|
||||
def __init__(self,
|
||||
db_path: str = "hcfs_production.db",
|
||||
vector_db_path: str = "hcfs_vectors_production.db",
|
||||
enable_auth: bool = True,
|
||||
cors_origins: List[str] = None):
|
||||
|
||||
self.db_path = db_path
|
||||
self.vector_db_path = vector_db_path
|
||||
self.enable_auth = enable_auth
|
||||
self.cors_origins = cors_origins or ["http://localhost:3000", "http://localhost:8080"]
|
||||
|
||||
# Initialize core components
|
||||
self.context_db = None
|
||||
self.embedding_manager = None
|
||||
self.versioning_system = None
|
||||
|
||||
# WebSocket connections
|
||||
self.websocket_connections: Dict[str, WebSocket] = {}
|
||||
self.subscriptions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Create FastAPI app
|
||||
self.app = self._create_app()
|
||||
|
||||
async def startup(self):
|
||||
"""Initialize database connections and components."""
|
||||
logger.info("Starting HCFS API Server...")
|
||||
|
||||
# Initialize core components
|
||||
self.context_db = OptimizedContextDatabase(self.db_path, cache_size=1000)
|
||||
self.embedding_manager = OptimizedEmbeddingManager(
|
||||
self.context_db,
|
||||
model_name="mini",
|
||||
vector_db_path=self.vector_db_path,
|
||||
cache_size=2000,
|
||||
batch_size=32
|
||||
)
|
||||
self.versioning_system = VersioningSystem(self.db_path)
|
||||
|
||||
# Update metrics
|
||||
CONTEXT_COUNT.set(len(self.context_db.get_all_contexts()))
|
||||
|
||||
logger.info("HCFS API Server started successfully")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Cleanup resources."""
|
||||
logger.info("Shutting down HCFS API Server...")
|
||||
|
||||
# Close WebSocket connections
|
||||
for connection in self.websocket_connections.values():
|
||||
await connection.close()
|
||||
|
||||
logger.info("HCFS API Server shutdown complete")
|
||||
|
||||
def _create_app(self) -> FastAPI:
|
||||
"""Create and configure FastAPI application."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await self.startup()
|
||||
yield
|
||||
await self.shutdown()
|
||||
|
||||
app = FastAPI(
|
||||
title="HCFS API",
|
||||
description="Context-Aware Hierarchical Context File System API",
|
||||
version="2.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=self.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
# Rate limiting
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Add routes
|
||||
self._add_routes(app)
|
||||
|
||||
# Add middleware for metrics
|
||||
@app.middleware("http")
|
||||
async def metrics_middleware(request: Request, call_next):
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
duration = time.time() - start_time
|
||||
|
||||
REQUEST_COUNT.labels(
|
||||
method=request.method,
|
||||
endpoint=request.url.path,
|
||||
status=response.status_code
|
||||
).inc()
|
||||
REQUEST_DURATION.observe(duration)
|
||||
|
||||
return response
|
||||
|
||||
return app
|
||||
|
||||
def _add_routes(self, app: FastAPI):
|
||||
"""Add all API routes."""
|
||||
|
||||
# Authentication dependency
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
if self.enable_auth:
|
||||
# TODO: Implement actual authentication
|
||||
# For now, just validate token exists
|
||||
if not credentials.credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return {"username": "api_user", "scopes": ["read", "write"]}
|
||||
return {"username": "anonymous", "scopes": ["read", "write"]}
|
||||
|
||||
# Health check
|
||||
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
||||
async def health_check():
|
||||
"""System health check endpoint."""
|
||||
components = []
|
||||
|
||||
# Check database
|
||||
try:
|
||||
self.context_db.get_all_contexts()
|
||||
db_health = ComponentHealth(name="database", status=HealthStatus.HEALTHY, response_time_ms=1.0)
|
||||
except Exception as e:
|
||||
db_health = ComponentHealth(name="database", status=HealthStatus.UNHEALTHY, error_message=str(e))
|
||||
components.append(db_health)
|
||||
|
||||
# Check embedding manager
|
||||
try:
|
||||
stats = self.embedding_manager.get_statistics()
|
||||
emb_health = ComponentHealth(name="embeddings", status=HealthStatus.HEALTHY, response_time_ms=2.0)
|
||||
except Exception as e:
|
||||
emb_health = ComponentHealth(name="embeddings", status=HealthStatus.UNHEALTHY, error_message=str(e))
|
||||
components.append(emb_health)
|
||||
|
||||
# Overall status
|
||||
overall_status = HealthStatus.HEALTHY
|
||||
if any(c.status == HealthStatus.UNHEALTHY for c in components):
|
||||
overall_status = HealthStatus.UNHEALTHY
|
||||
elif any(c.status == HealthStatus.DEGRADED for c in components):
|
||||
overall_status = HealthStatus.DEGRADED
|
||||
|
||||
return HealthResponse(
|
||||
status=overall_status,
|
||||
version="2.0.0",
|
||||
uptime_seconds=time.time(), # Simplified uptime
|
||||
components=components
|
||||
)
|
||||
|
||||
# Metrics endpoint
|
||||
@app.get("/metrics", tags=["System"])
|
||||
async def metrics():
|
||||
"""Prometheus metrics endpoint."""
|
||||
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
|
||||
|
||||
# Context CRUD operations
|
||||
@app.post("/api/v1/contexts", response_model=ContextDetailResponse, tags=["Contexts"])
|
||||
@limiter.limit("100/minute")
|
||||
async def create_context(
|
||||
request: Request,
|
||||
context_data: ContextCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new context with automatic embedding generation."""
|
||||
try:
|
||||
# Create context object
|
||||
context = Context(
|
||||
id=None,
|
||||
path=context_data.path,
|
||||
content=context_data.content,
|
||||
summary=context_data.summary,
|
||||
author=context_data.author or current_user["username"],
|
||||
version=1
|
||||
)
|
||||
|
||||
# Store context
|
||||
context_id = self.context_db.store_context(context)
|
||||
|
||||
# Generate and store embedding in background
|
||||
background_tasks.add_task(self._generate_embedding_async, context_id, context_data.content)
|
||||
|
||||
# Get created context
|
||||
created_context = self.context_db.get_context(context_id)
|
||||
context_response = self._context_to_response(created_context)
|
||||
|
||||
# Update metrics
|
||||
CONTEXT_COUNT.inc()
|
||||
|
||||
# Notify WebSocket subscribers
|
||||
await self._notify_websocket_subscribers("created", context_response)
|
||||
|
||||
return ContextDetailResponse(data=context_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error creating context", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create context: {str(e)}")
|
||||
|
||||
@app.get("/api/v1/contexts/{context_id}", response_model=ContextDetailResponse, tags=["Contexts"])
|
||||
@limiter.limit("200/minute")
|
||||
async def get_context(
|
||||
request: Request,
|
||||
context_id: int,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific context by ID."""
|
||||
try:
|
||||
context = self.context_db.get_context(context_id)
|
||||
if not context:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
|
||||
context_response = self._context_to_response(context)
|
||||
return ContextDetailResponse(data=context_response)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving context", context_id=context_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to retrieve context: {str(e)}")
|
||||
|
||||
@app.get("/api/v1/contexts", response_model=ContextListResponse, tags=["Contexts"])
|
||||
@limiter.limit("100/minute")
|
||||
async def list_contexts(
|
||||
request: Request,
|
||||
pagination: PaginationParams = Depends(),
|
||||
path_prefix: Optional[str] = Query(None, description="Filter by path prefix"),
|
||||
author: Optional[str] = Query(None, description="Filter by author"),
|
||||
status: Optional[ContextStatus] = Query(None, description="Filter by status"),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""List contexts with filtering and pagination."""
|
||||
try:
|
||||
# Get contexts with filters
|
||||
contexts = self.context_db.get_contexts_filtered(
|
||||
path_prefix=path_prefix,
|
||||
author=author,
|
||||
status=status.value if status else None,
|
||||
limit=pagination.page_size,
|
||||
offset=pagination.offset
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = self.context_db.count_contexts(
|
||||
path_prefix=path_prefix,
|
||||
author=author,
|
||||
status=status.value if status else None
|
||||
)
|
||||
|
||||
# Convert to response models
|
||||
context_responses = [self._context_to_response(ctx) for ctx in contexts]
|
||||
|
||||
# Create pagination metadata
|
||||
pagination_meta = PaginationMeta(
|
||||
page=pagination.page,
|
||||
page_size=pagination.page_size,
|
||||
total_items=total_count,
|
||||
total_pages=(total_count + pagination.page_size - 1) // pagination.page_size,
|
||||
has_next=pagination.page * pagination.page_size < total_count,
|
||||
has_previous=pagination.page > 1
|
||||
)
|
||||
|
||||
return ContextListResponse(data=context_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error listing contexts", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list contexts: {str(e)}")
|
||||
|
||||
@app.put("/api/v1/contexts/{context_id}", response_model=ContextDetailResponse, tags=["Contexts"])
|
||||
@limiter.limit("50/minute")
|
||||
async def update_context(
|
||||
request: Request,
|
||||
context_id: int,
|
||||
context_update: ContextUpdate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update an existing context."""
|
||||
try:
|
||||
# Check if context exists
|
||||
existing_context = self.context_db.get_context(context_id)
|
||||
if not existing_context:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
|
||||
# Update context
|
||||
update_data = context_update.dict(exclude_unset=True)
|
||||
if update_data:
|
||||
self.context_db.update_context(context_id, **update_data)
|
||||
|
||||
# If content changed, regenerate embedding
|
||||
if 'content' in update_data:
|
||||
background_tasks.add_task(
|
||||
self._generate_embedding_async,
|
||||
context_id,
|
||||
update_data['content']
|
||||
)
|
||||
|
||||
# Get updated context
|
||||
updated_context = self.context_db.get_context(context_id)
|
||||
context_response = self._context_to_response(updated_context)
|
||||
|
||||
# Notify WebSocket subscribers
|
||||
await self._notify_websocket_subscribers("updated", context_response)
|
||||
|
||||
return ContextDetailResponse(data=context_response)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error updating context", context_id=context_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update context: {str(e)}")
|
||||
|
||||
@app.delete("/api/v1/contexts/{context_id}", tags=["Contexts"])
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_context(
|
||||
request: Request,
|
||||
context_id: int,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a context."""
|
||||
try:
|
||||
# Check if context exists
|
||||
existing_context = self.context_db.get_context(context_id)
|
||||
if not existing_context:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
|
||||
# Delete context
|
||||
success = self.context_db.delete_context(context_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete context")
|
||||
|
||||
# Update metrics
|
||||
CONTEXT_COUNT.dec()
|
||||
|
||||
# Notify WebSocket subscribers
|
||||
await self._notify_websocket_subscribers("deleted", {"id": context_id})
|
||||
|
||||
return {"success": True, "message": "Context deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error deleting context", context_id=context_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete context: {str(e)}")
|
||||
|
||||
# Search endpoints
|
||||
@app.post("/api/v1/search", response_model=SearchResponse, tags=["Search"])
|
||||
@limiter.limit("100/minute")
|
||||
async def search_contexts(
|
||||
request: Request,
|
||||
search_request: SearchRequest,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Advanced context search with multiple search types."""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Perform search based on type
|
||||
if search_request.search_type == SearchType.SEMANTIC:
|
||||
results = self.embedding_manager.semantic_search_optimized(
|
||||
search_request.query,
|
||||
path_prefix=search_request.path_prefix,
|
||||
top_k=search_request.top_k,
|
||||
include_contexts=True
|
||||
)
|
||||
elif search_request.search_type == SearchType.HYBRID:
|
||||
results = self.embedding_manager.hybrid_search_optimized(
|
||||
search_request.query,
|
||||
path_prefix=search_request.path_prefix,
|
||||
top_k=search_request.top_k,
|
||||
semantic_weight=search_request.semantic_weight
|
||||
)
|
||||
else:
|
||||
# Fallback to keyword search
|
||||
contexts = self.context_db.search_contexts(search_request.query)
|
||||
results = [type('Result', (), {'context': ctx, 'score': 1.0})() for ctx in contexts[:search_request.top_k]]
|
||||
|
||||
search_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert results to response format
|
||||
search_results = []
|
||||
for result in results:
|
||||
if hasattr(result, 'context') and result.context:
|
||||
context_response = self._context_to_response(result.context)
|
||||
context_response.similarity_score = getattr(result, 'score', None)
|
||||
|
||||
search_results.append(SearchResult(
|
||||
context=context_response,
|
||||
score=result.score,
|
||||
explanation=f"Matched with {result.score:.3f} similarity"
|
||||
))
|
||||
|
||||
# Update metrics
|
||||
SEARCH_COUNT.labels(search_type=search_request.search_type.value).inc()
|
||||
|
||||
return SearchResponse(
|
||||
data=search_results,
|
||||
query=search_request.query,
|
||||
search_type=search_request.search_type,
|
||||
total_results=len(search_results),
|
||||
search_time_ms=search_time,
|
||||
filters_applied=search_request.filters
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error performing search", query=search_request.query, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
# Batch operations
|
||||
@app.post("/api/v1/contexts/batch", response_model=BatchResponse, tags=["Batch Operations"])
|
||||
@limiter.limit("10/minute")
|
||||
async def batch_create_contexts(
|
||||
request: Request,
|
||||
batch_request: BatchContextCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Create multiple contexts in batch."""
|
||||
try:
|
||||
results = BatchOperationResult(
|
||||
success_count=0,
|
||||
error_count=0,
|
||||
total_items=len(batch_request.contexts)
|
||||
)
|
||||
|
||||
for i, context_data in enumerate(batch_request.contexts):
|
||||
try:
|
||||
context = Context(
|
||||
id=None,
|
||||
path=context_data.path,
|
||||
content=context_data.content,
|
||||
summary=context_data.summary,
|
||||
author=context_data.author or current_user["username"],
|
||||
version=1
|
||||
)
|
||||
|
||||
context_id = self.context_db.store_context(context)
|
||||
results.created_ids.append(context_id)
|
||||
results.success_count += 1
|
||||
|
||||
# Generate embedding in background
|
||||
background_tasks.add_task(
|
||||
self._generate_embedding_async,
|
||||
context_id,
|
||||
context_data.content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results.error_count += 1
|
||||
results.errors.append({
|
||||
"index": i,
|
||||
"path": context_data.path,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Update metrics
|
||||
CONTEXT_COUNT.inc(results.success_count)
|
||||
|
||||
return BatchResponse(data=results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in batch create", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Batch operation failed: {str(e)}")
|
||||
|
||||
# Statistics endpoint
|
||||
@app.get("/api/v1/stats", response_model=StatsResponse, tags=["Analytics"])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_statistics(
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get comprehensive system statistics."""
|
||||
try:
|
||||
# Get embedding manager stats
|
||||
emb_stats = self.embedding_manager.get_statistics()
|
||||
|
||||
# Mock context stats (implement based on your needs)
|
||||
context_stats = ContextStats(
|
||||
total_contexts=emb_stats["database_stats"]["total_embeddings"],
|
||||
contexts_by_status={ContextStatus.ACTIVE: emb_stats["database_stats"]["total_embeddings"]},
|
||||
contexts_by_author={"system": emb_stats["database_stats"]["total_embeddings"]},
|
||||
average_content_length=100.0,
|
||||
most_active_paths=[],
|
||||
recent_activity=[]
|
||||
)
|
||||
|
||||
search_stats = SearchStats(
|
||||
total_searches=100, # Mock data
|
||||
searches_by_type={SearchType.SEMANTIC: 60, SearchType.HYBRID: 40},
|
||||
average_response_time_ms=50.0,
|
||||
popular_queries=[],
|
||||
search_success_rate=0.95
|
||||
)
|
||||
|
||||
system_stats = SystemStats(
|
||||
uptime_seconds=time.time(),
|
||||
memory_usage_mb=100.0,
|
||||
active_connections=len(self.websocket_connections),
|
||||
cache_hit_rate=emb_stats["cache_stats"].get("hit_rate", 0.0),
|
||||
embedding_model_info=emb_stats["current_model"],
|
||||
database_size_mb=10.0
|
||||
)
|
||||
|
||||
return StatsResponse(
|
||||
context_stats=context_stats,
|
||||
search_stats=search_stats,
|
||||
system_stats=system_stats
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting statistics", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
|
||||
|
||||
# WebSocket endpoint
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time updates."""
|
||||
await self._handle_websocket_connection(websocket)
|
||||
|
||||
def _context_to_response(self, context) -> ContextResponse:
|
||||
"""Convert database context to API response model."""
|
||||
return ContextResponse(
|
||||
id=context.id,
|
||||
path=context.path,
|
||||
content=context.content,
|
||||
summary=context.summary,
|
||||
author=context.author or "unknown",
|
||||
tags=[], # TODO: implement tags
|
||||
metadata={}, # TODO: implement metadata
|
||||
status=ContextStatus.ACTIVE, # TODO: implement status
|
||||
created_at=context.created_at,
|
||||
updated_at=context.updated_at,
|
||||
version=context.version
|
||||
)
|
||||
|
||||
async def _generate_embedding_async(self, context_id: int, content: str):
|
||||
"""Generate and store embedding asynchronously."""
|
||||
try:
|
||||
embedding = self.embedding_manager.generate_embedding(content)
|
||||
self.embedding_manager.store_embedding(context_id, embedding)
|
||||
logger.info("Generated embedding for context", context_id=context_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate embedding", context_id=context_id, error=str(e))
|
||||
|
||||
async def _handle_websocket_connection(self, websocket: WebSocket):
|
||||
"""Handle WebSocket connection and subscriptions."""
|
||||
await websocket.accept()
|
||||
connection_id = str(id(websocket))
|
||||
self.websocket_connections[connection_id] = websocket
|
||||
ACTIVE_CONNECTIONS.inc()
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for subscription requests
|
||||
data = await websocket.receive_json()
|
||||
message = WebSocketMessage(**data)
|
||||
|
||||
if message.type == "subscribe":
|
||||
subscription = SubscriptionRequest(**message.data)
|
||||
self.subscriptions[connection_id] = {
|
||||
"path_prefix": subscription.path_prefix,
|
||||
"event_types": subscription.event_types,
|
||||
"filters": subscription.filters
|
||||
}
|
||||
await websocket.send_json({
|
||||
"type": "subscription_confirmed",
|
||||
"data": {"path_prefix": subscription.path_prefix}
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
# Cleanup
|
||||
self.websocket_connections.pop(connection_id, None)
|
||||
self.subscriptions.pop(connection_id, None)
|
||||
ACTIVE_CONNECTIONS.dec()
|
||||
|
||||
async def _notify_websocket_subscribers(self, event_type: str, data: Any):
|
||||
"""Notify WebSocket subscribers of events."""
|
||||
if not self.websocket_connections:
|
||||
return
|
||||
|
||||
# Create notification message
|
||||
notification = WebSocketMessage(
|
||||
type=event_type,
|
||||
data=data.dict() if hasattr(data, 'dict') else data
|
||||
)
|
||||
|
||||
# Send to all relevant subscribers
|
||||
for connection_id, websocket in list(self.websocket_connections.items()):
|
||||
try:
|
||||
subscription = self.subscriptions.get(connection_id)
|
||||
if subscription and event_type in subscription["event_types"]:
|
||||
# Check path filter
|
||||
if hasattr(data, 'path') and subscription["path_prefix"]:
|
||||
if not data.path.startswith(subscription["path_prefix"]):
|
||||
continue
|
||||
|
||||
await websocket.send_json(notification.dict())
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending WebSocket notification",
|
||||
connection_id=connection_id, error=str(e))
|
||||
# Remove failed connection
|
||||
self.websocket_connections.pop(connection_id, None)
|
||||
self.subscriptions.pop(connection_id, None)
|
||||
|
||||
def run(self, host: str = "0.0.0.0", port: int = 8000, **kwargs):
|
||||
"""Run the API server."""
|
||||
uvicorn.run(self.app, host=host, port=port, **kwargs)
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Factory function to create the app."""
|
||||
server = HCFSAPIServer()
|
||||
return server.app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
server = HCFSAPIServer()
|
||||
server.run()
|
||||
Reference in New Issue
Block a user