Phase 2 build initial

This commit is contained in:
Claude Code
2025-07-30 09:34:16 +10:00
parent 8f19eaab25
commit a6ee31f237
68 changed files with 18055 additions and 3 deletions

View File

@@ -0,0 +1 @@
"""HCFS API components."""

View File

@@ -0,0 +1,288 @@
"""
Configuration management for HCFS API.
Handles environment-based configuration with validation and defaults.
"""
import os
from typing import List, Optional, Dict, Any
from pydantic import BaseSettings, Field, validator
from pathlib import Path
class DatabaseConfig(BaseSettings):
"""Database configuration settings."""
# SQLite settings
db_path: str = Field(default="hcfs_production.db", description="Path to SQLite database")
vector_db_path: str = Field(default="hcfs_vectors_production.db", description="Path to vector database")
# Connection settings
pool_size: int = Field(default=10, description="Database connection pool size")
max_overflow: int = Field(default=20, description="Maximum connection overflow")
pool_timeout: int = Field(default=30, description="Connection pool timeout in seconds")
# Performance settings
cache_size: int = Field(default=1000, description="Database cache size")
enable_wal_mode: bool = Field(default=True, description="Enable SQLite WAL mode")
synchronous_mode: str = Field(default="NORMAL", description="SQLite synchronous mode")
class Config:
env_prefix = "HCFS_DB_"
class EmbeddingConfig(BaseSettings):
"""Embedding system configuration."""
# Model settings
model_name: str = Field(default="mini", description="Embedding model to use")
cache_size: int = Field(default=2000, description="Embedding cache size")
batch_size: int = Field(default=32, description="Batch processing size")
# Performance settings
max_workers: int = Field(default=4, description="Maximum worker threads")
timeout_seconds: int = Field(default=300, description="Operation timeout")
# Vector database settings
vector_dimension: int = Field(default=384, description="Vector dimension")
similarity_threshold: float = Field(default=0.0, description="Default similarity threshold")
class Config:
env_prefix = "HCFS_EMBEDDING_"
class APIConfig(BaseSettings):
"""API server configuration."""
# Server settings
host: str = Field(default="0.0.0.0", description="Server host")
port: int = Field(default=8000, description="Server port")
workers: int = Field(default=1, description="Number of worker processes")
# Security settings
secret_key: str = Field(default="dev-secret-key-change-in-production", description="JWT secret key")
algorithm: str = Field(default="HS256", description="JWT algorithm")
token_expire_minutes: int = Field(default=30, description="JWT token expiration time")
# CORS settings
cors_origins: List[str] = Field(
default=["http://localhost:3000", "http://localhost:8080"],
description="Allowed CORS origins"
)
cors_credentials: bool = Field(default=True, description="Allow credentials in CORS")
# Rate limiting
rate_limit_requests: int = Field(default=100, description="Requests per minute")
rate_limit_burst: int = Field(default=20, description="Burst requests allowed")
# Feature flags
enable_auth: bool = Field(default=True, description="Enable authentication")
enable_websocket: bool = Field(default=True, description="Enable WebSocket support")
enable_metrics: bool = Field(default=True, description="Enable Prometheus metrics")
enable_docs: bool = Field(default=True, description="Enable API documentation")
class Config:
env_prefix = "HCFS_API_"
class MonitoringConfig(BaseSettings):
"""Monitoring and observability configuration."""
# Logging settings
log_level: str = Field(default="INFO", description="Logging level")
log_format: str = Field(default="json", description="Log format (json/text)")
log_file: Optional[str] = Field(default=None, description="Log file path")
# Metrics settings
metrics_enabled: bool = Field(default=True, description="Enable metrics collection")
metrics_port: int = Field(default=9090, description="Metrics server port")
# Health check settings
health_check_interval: int = Field(default=30, description="Health check interval in seconds")
health_check_timeout: int = Field(default=5, description="Health check timeout")
# Tracing settings
tracing_enabled: bool = Field(default=False, description="Enable distributed tracing")
tracing_sample_rate: float = Field(default=0.1, description="Tracing sample rate")
jaeger_endpoint: Optional[str] = Field(default=None, description="Jaeger endpoint")
class Config:
env_prefix = "HCFS_MONITORING_"
class RedisConfig(BaseSettings):
"""Redis configuration for caching and rate limiting."""
# Connection settings
host: str = Field(default="localhost", description="Redis host")
port: int = Field(default=6379, description="Redis port")
db: int = Field(default=0, description="Redis database number")
password: Optional[str] = Field(default=None, description="Redis password")
# Pool settings
max_connections: int = Field(default=20, description="Maximum Redis connections")
socket_timeout: int = Field(default=5, description="Socket timeout in seconds")
# Cache settings
default_ttl: int = Field(default=3600, description="Default cache TTL in seconds")
key_prefix: str = Field(default="hcfs:", description="Redis key prefix")
class Config:
env_prefix = "HCFS_REDIS_"
class SecurityConfig(BaseSettings):
"""Security configuration."""
# Authentication
require_auth: bool = Field(default=True, description="Require authentication")
api_key_header: str = Field(default="X-API-Key", description="API key header name")
# Rate limiting
rate_limit_enabled: bool = Field(default=True, description="Enable rate limiting")
rate_limit_storage: str = Field(default="memory", description="Rate limit storage (memory/redis)")
# HTTPS settings
force_https: bool = Field(default=False, description="Force HTTPS in production")
hsts_max_age: int = Field(default=31536000, description="HSTS max age")
# Request validation
max_request_size: int = Field(default=10 * 1024 * 1024, description="Maximum request size in bytes")
max_query_params: int = Field(default=100, description="Maximum query parameters")
# Content security
allowed_content_types: List[str] = Field(
default=["application/json", "application/x-www-form-urlencoded", "multipart/form-data"],
description="Allowed content types"
)
class Config:
env_prefix = "HCFS_SECURITY_"
class HCFSConfig(BaseSettings):
"""Main HCFS configuration combining all subsystem configs."""
# Environment
environment: str = Field(default="development", description="Environment (development/staging/production)")
debug: bool = Field(default=False, description="Enable debug mode")
# Application info
app_name: str = Field(default="HCFS API", description="Application name")
app_version: str = Field(default="2.0.0", description="Application version")
app_description: str = Field(default="Context-Aware Hierarchical Context File System API", description="App description")
# Configuration file path
config_file: Optional[str] = Field(default=None, description="Path to configuration file")
# Subsystem configurations
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
api: APIConfig = Field(default_factory=APIConfig)
monitoring: MonitoringConfig = Field(default_factory=MonitoringConfig)
redis: RedisConfig = Field(default_factory=RedisConfig)
security: SecurityConfig = Field(default_factory=SecurityConfig)
class Config:
env_prefix = "HCFS_"
env_file = ".env"
env_file_encoding = "utf-8"
@validator('environment')
def validate_environment(cls, v):
"""Validate environment value."""
allowed = ['development', 'staging', 'production']
if v not in allowed:
raise ValueError(f'Environment must be one of: {allowed}')
return v
@validator('debug')
def validate_debug_in_production(cls, v, values):
"""Ensure debug is disabled in production."""
if values.get('environment') == 'production' and v:
raise ValueError('Debug mode cannot be enabled in production')
return v
def is_production(self) -> bool:
"""Check if running in production environment."""
return self.environment == 'production'
def is_development(self) -> bool:
"""Check if running in development environment."""
return self.environment == 'development'
def get_database_url(self) -> str:
"""Get database URL."""
return f"sqlite:///{self.database.db_path}"
def get_redis_url(self) -> str:
"""Get Redis URL."""
if self.redis.password:
return f"redis://:{self.redis.password}@{self.redis.host}:{self.redis.port}/{self.redis.db}"
return f"redis://{self.redis.host}:{self.redis.port}/{self.redis.db}"
def load_from_file(self, config_path: str) -> None:
"""Load configuration from YAML file."""
import yaml
config_file = Path(config_path)
if not config_file.exists():
raise FileNotFoundError(f"Configuration file not found: {config_path}")
with open(config_file, 'r') as f:
config_data = yaml.safe_load(f)
# Update configuration
for key, value in config_data.items():
if hasattr(self, key):
setattr(self, key, value)
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary."""
return self.dict()
def save_to_file(self, config_path: str) -> None:
"""Save configuration to YAML file."""
import yaml
config_data = self.to_dict()
with open(config_path, 'w') as f:
yaml.dump(config_data, f, default_flow_style=False, indent=2)
# Global configuration instance
config = HCFSConfig()
def get_config() -> HCFSConfig:
"""Get the global configuration instance."""
return config
def load_config(config_path: Optional[str] = None, **overrides) -> HCFSConfig:
"""Load configuration with optional file and overrides."""
global config
# Load from file if provided
if config_path:
config.load_from_file(config_path)
# Apply overrides
for key, value in overrides.items():
if hasattr(config, key):
setattr(config, key, value)
return config
def create_config_template(output_path: str = "hcfs_config.yaml") -> None:
"""Create a configuration template file."""
template_config = HCFSConfig()
template_config.save_to_file(output_path)
print(f"Configuration template created: {output_path}")
if __name__ == "__main__":
# Create configuration template
create_config_template()

View File

@@ -0,0 +1,365 @@
"""
Custom middleware for HCFS API.
Provides authentication, logging, error handling, and security features.
"""
import time
import uuid
import json
from typing import Optional
from datetime import datetime, timedelta
from fastapi import Request, Response, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import jwt
import structlog
logger = structlog.get_logger()
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware for comprehensive request/response logging."""
def __init__(self, app, log_body: bool = False):
super().__init__(app)
self.log_body = log_body
async def dispatch(self, request: Request, call_next):
# Generate request ID
request_id = str(uuid.uuid4())
request.state.request_id = request_id
# Start timing
start_time = time.time()
# Log request
logger.info(
"Request started",
request_id=request_id,
method=request.method,
url=str(request.url),
client_ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
)
# Call the next middleware/endpoint
response = await call_next(request)
# Calculate duration
duration = time.time() - start_time
# Log response
logger.info(
"Request completed",
request_id=request_id,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2),
)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
return response
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
"""Middleware for consistent error handling and formatting."""
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
return response
except HTTPException as e:
# FastAPI HTTPExceptions are handled by FastAPI itself
raise e
except Exception as e:
# Log unexpected errors
request_id = getattr(request.state, 'request_id', 'unknown')
logger.error(
"Unhandled exception",
request_id=request_id,
error=str(e),
error_type=type(e).__name__,
method=request.method,
url=str(request.url),
exc_info=True
)
# Return consistent error response
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "Internal server error",
"error_details": [{"message": "An unexpected error occurred"}],
"timestamp": datetime.utcnow().isoformat(),
"request_id": request_id,
"api_version": "v1"
}
)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Middleware to add security headers."""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Add security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
return response
class JWTAuthenticationManager:
"""JWT-based authentication manager."""
def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 30):
self.secret_key = secret_key
self.algorithm = algorithm
self.token_expire_minutes = token_expire_minutes
def create_access_token(self, data: dict) -> str:
"""Create JWT access token."""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=self.token_expire_minutes)
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
def verify_token(self, token: str) -> Optional[dict]:
"""Verify and decode JWT token."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
except jwt.JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
class APIKeyManager:
"""API key-based authentication manager."""
def __init__(self):
# In production, store these in a database
self.api_keys = {
"dev-key-123": {
"name": "Development Key",
"scopes": ["read", "write"],
"rate_limit": 1000,
"created_at": datetime.utcnow(),
"last_used": None
}
}
def validate_api_key(self, api_key: str) -> Optional[dict]:
"""Validate API key and return key info."""
key_info = self.api_keys.get(api_key)
if key_info:
# Update last used timestamp
key_info["last_used"] = datetime.utcnow()
return key_info
return None
class AuthenticationMiddleware(BaseHTTPMiddleware):
"""Authentication middleware supporting multiple auth methods."""
def __init__(self, app, jwt_manager: JWTAuthenticationManager = None, api_key_manager: APIKeyManager = None):
super().__init__(app)
self.jwt_manager = jwt_manager
self.api_key_manager = api_key_manager or APIKeyManager()
# Paths that don't require authentication
self.public_paths = {
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/favicon.ico"
}
async def dispatch(self, request: Request, call_next):
# Skip authentication for public paths
if any(request.url.path.startswith(path) for path in self.public_paths):
return await call_next(request)
# Extract authentication credentials
auth_header = request.headers.get("Authorization")
api_key_header = request.headers.get("X-API-Key")
user_info = None
# Try JWT authentication first
if auth_header and auth_header.startswith("Bearer ") and self.jwt_manager:
token = auth_header[7:] # Remove "Bearer " prefix
try:
payload = self.jwt_manager.verify_token(token)
user_info = {
"user_id": payload.get("sub"),
"username": payload.get("username"),
"scopes": payload.get("scopes", []),
"auth_method": "jwt"
}
except HTTPException:
pass # Try other auth methods
# Try API key authentication
if not user_info and api_key_header:
key_info = self.api_key_manager.validate_api_key(api_key_header)
if key_info:
user_info = {
"user_id": f"api_key_{api_key_header[:8]}",
"username": key_info["name"],
"scopes": key_info["scopes"],
"auth_method": "api_key",
"rate_limit": key_info["rate_limit"]
}
# If no valid authentication found
if not user_info:
return JSONResponse(
status_code=401,
content={
"success": False,
"error": "Authentication required",
"error_details": [{"message": "Valid API key or JWT token required"}],
"timestamp": datetime.utcnow().isoformat(),
"api_version": "v1"
}
)
# Add user info to request state
request.state.user = user_info
return await call_next(request)
class RateLimitingMiddleware(BaseHTTPMiddleware):
"""Custom rate limiting middleware."""
def __init__(self, app, default_rate_limit: int = 100):
super().__init__(app)
self.default_rate_limit = default_rate_limit
self.request_counts = {} # In production, use Redis
async def dispatch(self, request: Request, call_next):
# Get user identifier
user_info = getattr(request.state, 'user', None)
if user_info:
user_id = user_info["user_id"]
rate_limit = user_info.get("rate_limit", self.default_rate_limit)
else:
user_id = request.client.host if request.client else "anonymous"
rate_limit = self.default_rate_limit
# Current minute window
current_minute = int(time.time() // 60)
key = f"{user_id}:{current_minute}"
# Increment request count
current_count = self.request_counts.get(key, 0) + 1
self.request_counts[key] = current_count
# Clean up old entries (simple cleanup)
if len(self.request_counts) > 10000:
old_keys = [k for k in self.request_counts.keys()
if int(k.split(':')[1]) < current_minute - 5]
for old_key in old_keys:
del self.request_counts[old_key]
# Check rate limit
if current_count > rate_limit:
return JSONResponse(
status_code=429,
content={
"success": False,
"error": "Rate limit exceeded",
"error_details": [{"message": f"Rate limit of {rate_limit} requests per minute exceeded"}],
"timestamp": datetime.utcnow().isoformat(),
"retry_after": 60 - (int(time.time()) % 60)
},
headers={
"Retry-After": str(60 - (int(time.time()) % 60)),
"X-RateLimit-Limit": str(rate_limit),
"X-RateLimit-Remaining": str(max(0, rate_limit - current_count)),
"X-RateLimit-Reset": str((current_minute + 1) * 60)
}
)
# Add rate limit headers to response
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(rate_limit)
response.headers["X-RateLimit-Remaining"] = str(max(0, rate_limit - current_count))
response.headers["X-RateLimit-Reset"] = str((current_minute + 1) * 60)
return response
class CompressionMiddleware(BaseHTTPMiddleware):
"""Custom compression middleware with configurable settings."""
def __init__(self, app, minimum_size: int = 1000, compression_level: int = 6):
super().__init__(app)
self.minimum_size = minimum_size
self.compression_level = compression_level
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Check if client accepts gzip
accept_encoding = request.headers.get("accept-encoding", "")
if "gzip" not in accept_encoding:
return response
# Check content type and size
content_type = response.headers.get("content-type", "")
if not any(ct in content_type for ct in ["application/json", "text/", "application/javascript"]):
return response
# Get response body
body = b""
async for chunk in response.body_iterator:
body += chunk
# Compress if body is large enough
if len(body) >= self.minimum_size:
import gzip
compressed_body = gzip.compress(body, compresslevel=self.compression_level)
# Create new response with compressed body
from starlette.responses import Response
return Response(
content=compressed_body,
status_code=response.status_code,
headers={
**dict(response.headers),
"content-encoding": "gzip",
"content-length": str(len(compressed_body))
}
)
# Return original response if not compressed
from starlette.responses import Response
return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers)
)

View File

@@ -0,0 +1,347 @@
"""
Enhanced API Models for HCFS Production API.
Comprehensive Pydantic models for request/response validation,
API versioning, and enterprise-grade data validation.
"""
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field, validator, ConfigDict
import uuid
class APIVersion(str, Enum):
"""API version enumeration."""
V1 = "v1"
V2 = "v2"
class SearchType(str, Enum):
"""Search type enumeration."""
SEMANTIC = "semantic"
HYBRID = "hybrid"
KEYWORD = "keyword"
SIMILARITY = "similarity"
class SortOrder(str, Enum):
"""Sort order enumeration."""
ASC = "asc"
DESC = "desc"
class ContextStatus(str, Enum):
"""Context status enumeration."""
ACTIVE = "active"
ARCHIVED = "archived"
DRAFT = "draft"
DELETED = "deleted"
# Base Models
class BaseResponse(BaseModel):
"""Base response model with metadata."""
model_config = ConfigDict(from_attributes=True)
success: bool = True
message: Optional[str] = None
timestamp: datetime = Field(default_factory=datetime.utcnow)
request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
api_version: APIVersion = APIVersion.V1
class PaginationParams(BaseModel):
"""Pagination parameters."""
page: int = Field(default=1, ge=1, description="Page number (1-based)")
page_size: int = Field(default=20, ge=1, le=100, description="Items per page")
@property
def offset(self) -> int:
"""Calculate offset from page and page_size."""
return (self.page - 1) * self.page_size
class PaginationMeta(BaseModel):
"""Pagination metadata."""
page: int
page_size: int
total_items: int
total_pages: int
has_next: bool
has_previous: bool
# Context Models
class ContextBase(BaseModel):
"""Base context model with common fields."""
path: str = Field(..., description="Hierarchical path for the context")
content: str = Field(..., description="Main content of the context")
summary: Optional[str] = Field(None, description="Brief summary of the content")
author: Optional[str] = Field(None, description="Author or creator of the context")
tags: Optional[List[str]] = Field(default_factory=list, description="Tags associated with the context")
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata")
status: ContextStatus = Field(default=ContextStatus.ACTIVE, description="Context status")
@validator('path')
def validate_path(cls, v):
"""Validate path format."""
if not v.startswith('/'):
raise ValueError('Path must start with /')
if '//' in v:
raise ValueError('Path cannot contain double slashes')
return v
@validator('content')
def validate_content(cls, v):
"""Validate content is not empty."""
if not v.strip():
raise ValueError('Content cannot be empty')
return v.strip()
class ContextCreate(ContextBase):
"""Model for creating a new context."""
pass
class ContextUpdate(BaseModel):
"""Model for updating an existing context."""
content: Optional[str] = None
summary: Optional[str] = None
author: Optional[str] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
status: Optional[ContextStatus] = None
@validator('content')
def validate_content(cls, v):
"""Validate content if provided."""
if v is not None and not v.strip():
raise ValueError('Content cannot be empty')
return v.strip() if v else v
class ContextResponse(ContextBase):
"""Model for context responses."""
id: int = Field(..., description="Unique context identifier")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
version: int = Field(..., description="Context version number")
embedding_model: Optional[str] = Field(None, description="Embedding model used")
similarity_score: Optional[float] = Field(None, description="Similarity score (for search results)")
class ContextListResponse(BaseResponse):
"""Response model for context list operations."""
data: List[ContextResponse]
pagination: PaginationMeta
class ContextDetailResponse(BaseResponse):
"""Response model for single context operations."""
data: ContextResponse
# Search Models
class SearchRequest(BaseModel):
"""Model for search requests."""
query: str = Field(..., description="Search query text")
search_type: SearchType = Field(default=SearchType.SEMANTIC, description="Type of search to perform")
path_prefix: Optional[str] = Field(None, description="Limit search to paths with this prefix")
top_k: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return")
min_similarity: float = Field(default=0.0, ge=0.0, le=1.0, description="Minimum similarity threshold")
semantic_weight: float = Field(default=0.7, ge=0.0, le=1.0, description="Weight for semantic vs keyword search")
include_content: bool = Field(default=True, description="Whether to include full content in results")
filters: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional search filters")
@validator('query')
def validate_query(cls, v):
"""Validate query is not empty."""
if not v.strip():
raise ValueError('Query cannot be empty')
return v.strip()
class SearchResult(BaseModel):
"""Individual search result."""
context: ContextResponse
score: float = Field(..., description="Relevance score")
highlight: Optional[Dict[str, List[str]]] = Field(None, description="Highlighted matching text")
explanation: Optional[str] = Field(None, description="Explanation of why this result was returned")
class SearchResponse(BaseResponse):
"""Response model for search operations."""
data: List[SearchResult]
query: str
search_type: SearchType
total_results: int
search_time_ms: float
filters_applied: Dict[str, Any]
# Version Models
class VersionResponse(BaseModel):
"""Model for context version information."""
version_id: int
version_number: int
context_id: int
author: str
message: Optional[str]
created_at: datetime
content_hash: str
metadata: Optional[Dict[str, Any]] = None
class VersionListResponse(BaseResponse):
"""Response model for version history."""
data: List[VersionResponse]
context_id: int
total_versions: int
class VersionCreateRequest(BaseModel):
"""Request model for creating a new version."""
message: Optional[str] = Field(None, description="Version commit message")
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Version metadata")
class RollbackRequest(BaseModel):
"""Request model for version rollback."""
target_version: int = Field(..., description="Target version number to rollback to")
message: Optional[str] = Field(None, description="Rollback commit message")
# Analytics Models
class ContextStats(BaseModel):
"""Context statistics model."""
total_contexts: int
contexts_by_status: Dict[ContextStatus, int]
contexts_by_author: Dict[str, int]
average_content_length: float
most_active_paths: List[Dict[str, Union[str, int]]]
recent_activity: List[Dict[str, Any]]
class SearchStats(BaseModel):
"""Search statistics model."""
total_searches: int
searches_by_type: Dict[SearchType, int]
average_response_time_ms: float
popular_queries: List[Dict[str, Union[str, int]]]
search_success_rate: float
class SystemStats(BaseModel):
"""System statistics model."""
uptime_seconds: float
memory_usage_mb: float
active_connections: int
cache_hit_rate: float
embedding_model_info: Dict[str, Any]
database_size_mb: float
class StatsResponse(BaseResponse):
"""Response model for statistics."""
context_stats: ContextStats
search_stats: SearchStats
system_stats: SystemStats
# Batch Operations Models
class BatchContextCreate(BaseModel):
"""Model for batch context creation."""
contexts: List[ContextCreate] = Field(..., max_items=100, description="List of contexts to create")
@validator('contexts')
def validate_contexts_not_empty(cls, v):
"""Validate contexts list is not empty."""
if not v:
raise ValueError('Contexts list cannot be empty')
return v
class BatchOperationResult(BaseModel):
"""Result of batch operation."""
success_count: int
error_count: int
total_items: int
errors: List[Dict[str, Any]] = Field(default_factory=list)
created_ids: List[int] = Field(default_factory=list)
class BatchResponse(BaseResponse):
"""Response model for batch operations."""
data: BatchOperationResult
# WebSocket Models
class WebSocketMessage(BaseModel):
"""WebSocket message model."""
type: str = Field(..., description="Message type")
data: Dict[str, Any] = Field(..., description="Message data")
timestamp: datetime = Field(default_factory=datetime.utcnow)
message_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
class SubscriptionRequest(BaseModel):
"""WebSocket subscription request."""
path_prefix: str = Field(..., description="Path prefix to subscribe to")
event_types: List[str] = Field(default_factory=lambda: ["created", "updated", "deleted"])
filters: Optional[Dict[str, Any]] = Field(default_factory=dict)
# Health Check Models
class HealthStatus(str, Enum):
"""Health status enumeration."""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
class ComponentHealth(BaseModel):
"""Individual component health."""
name: str
status: HealthStatus
response_time_ms: Optional[float] = None
error_message: Optional[str] = None
last_check: datetime
class HealthResponse(BaseModel):
"""System health response."""
status: HealthStatus
timestamp: datetime = Field(default_factory=datetime.utcnow)
version: str
uptime_seconds: float
components: List[ComponentHealth]
# Error Models
class ErrorDetail(BaseModel):
"""Detailed error information."""
field: Optional[str] = None
message: str
error_code: Optional[str] = None
class ErrorResponse(BaseModel):
"""Error response model."""
success: bool = False
error: str
error_details: Optional[List[ErrorDetail]] = None
timestamp: datetime = Field(default_factory=datetime.utcnow)
request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
api_version: APIVersion = APIVersion.V1

View File

@@ -0,0 +1,172 @@
"""
HCFS API Server - FastAPI-based REST API for context operations.
"""
from typing import List, Optional
from datetime import datetime
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, Depends
from fastapi.responses import JSONResponse
from ..core.context_db import ContextDatabase, Context
from ..core.embeddings import EmbeddingManager
# Pydantic models
class ContextCreateRequest(BaseModel):
path: str
content: str
summary: Optional[str] = None
author: Optional[str] = None
class ContextResponse(BaseModel):
id: int
path: str
content: str
summary: Optional[str]
author: Optional[str]
created_at: datetime
updated_at: datetime
version: int
class SearchRequest(BaseModel):
query: str
path_prefix: Optional[str] = None
top_k: int = 5
search_type: str = "hybrid" # "semantic", "hybrid"
class SearchResult(BaseModel):
context: ContextResponse
score: float
class ContextAPI:
"""HCFS REST API server."""
def __init__(self, context_db: ContextDatabase, embedding_manager: EmbeddingManager):
self.context_db = context_db
self.embedding_manager = embedding_manager
self.app = FastAPI(
title="HCFS Context API",
description="Context-Aware Hierarchical Context File System API",
version="0.1.0"
)
self._setup_routes()
def _setup_routes(self):
"""Setup API routes."""
@self.app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "service": "hcfs-api"}
@self.app.post("/context", response_model=ContextResponse)
async def create_context(request: ContextCreateRequest):
"""Create a new context."""
context = Context(
id=None,
path=request.path,
content=request.content,
summary=request.summary,
author=request.author
)
# Store with embedding
context_id = self.embedding_manager.store_context_with_embedding(context)
# Retrieve the stored context
stored_contexts = self.context_db.list_contexts_at_path(request.path)
stored_context = next((c for c in stored_contexts if c.id == context_id), None)
if not stored_context:
raise HTTPException(status_code=500, detail="Failed to store context")
return ContextResponse(**stored_context.__dict__)
@self.app.get("/context/{path:path}", response_model=List[ContextResponse])
async def get_context(path: str, depth: int = 1):
"""Get contexts for a path with optional parent inheritance."""
contexts = self.context_db.get_context_by_path(f"/{path}", depth=depth)
return [ContextResponse(**ctx.__dict__) for ctx in contexts]
@self.app.get("/context", response_model=List[ContextResponse])
async def list_contexts(path: str):
"""List all contexts at a specific path."""
contexts = self.context_db.list_contexts_at_path(path)
return [ContextResponse(**ctx.__dict__) for ctx in contexts]
@self.app.put("/context/{context_id}")
async def update_context(context_id: int, content: str, summary: Optional[str] = None):
"""Update an existing context."""
success = self.context_db.update_context(context_id, content, summary)
if not success:
raise HTTPException(status_code=404, detail="Context not found")
# Update embedding
contexts = self.context_db.list_contexts_at_path("") # Get updated context
updated_context = next((c for c in contexts if c.id == context_id), None)
if updated_context:
embedding = self.embedding_manager.generate_embedding(updated_context.content)
self.embedding_manager._store_embedding(context_id, embedding)
return {"message": "Context updated successfully"}
@self.app.delete("/context/{context_id}")
async def delete_context(context_id: int):
"""Delete a context."""
success = self.context_db.delete_context(context_id)
if not success:
raise HTTPException(status_code=404, detail="Context not found")
return {"message": "Context deleted successfully"}
@self.app.post("/search", response_model=List[SearchResult])
async def search_contexts(request: SearchRequest):
"""Search contexts using semantic or hybrid search."""
if request.search_type == "semantic":
results = self.embedding_manager.semantic_search(
request.query,
request.path_prefix,
request.top_k
)
elif request.search_type == "hybrid":
results = self.embedding_manager.hybrid_search(
request.query,
request.path_prefix,
request.top_k
)
else:
raise HTTPException(status_code=400, detail="Invalid search_type")
return [
SearchResult(
context=ContextResponse(**ctx.__dict__),
score=score
)
for ctx, score in results
]
@self.app.get("/similar/{context_id}", response_model=List[SearchResult])
async def get_similar_contexts(context_id: int, top_k: int = 5):
"""Find contexts similar to a given context."""
results = self.embedding_manager.get_similar_contexts(context_id, top_k)
return [
SearchResult(
context=ContextResponse(**ctx.__dict__),
score=score
)
for ctx, score in results
]
def create_app(db_path: str = "hcfs_context.db") -> FastAPI:
"""Create FastAPI application with HCFS components."""
context_db = ContextDatabase(db_path)
embedding_manager = EmbeddingManager(context_db)
api = ContextAPI(context_db, embedding_manager)
return api.app

View File

@@ -0,0 +1,692 @@
"""
Production-Grade HCFS API Server v2.0
Enterprise-ready FastAPI server with comprehensive features:
- Full CRUD operations with validation
- Advanced search capabilities
- Version control and rollback
- Batch operations
- Real-time WebSocket updates
- Authentication and authorization
- Rate limiting and monitoring
- OpenAPI documentation
"""
import time
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta
from fastapi import FastAPI, HTTPException, Depends, status, Request, Query, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.websocket import WebSocket, WebSocketDisconnect
import uvicorn
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from prometheus_client import Counter, Histogram, Gauge, generate_latest, CONTENT_TYPE_LATEST
import structlog
# HCFS imports
from .models import *
from ..core.context_db_optimized_fixed import OptimizedContextDatabase
from ..core.embeddings_optimized import OptimizedEmbeddingManager
from ..core.context_versioning import VersioningSystem
from ..core.context_db import Context
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = structlog.get_logger()
# Metrics
REQUEST_COUNT = Counter('hcfs_requests_total', 'Total HTTP requests', ['method', 'endpoint', 'status'])
REQUEST_DURATION = Histogram('hcfs_request_duration_seconds', 'HTTP request duration')
ACTIVE_CONNECTIONS = Gauge('hcfs_active_connections', 'Active WebSocket connections')
CONTEXT_COUNT = Gauge('hcfs_contexts_total', 'Total number of contexts')
SEARCH_COUNT = Counter('hcfs_searches_total', 'Total searches performed', ['search_type'])
# Rate limiting
limiter = Limiter(key_func=get_remote_address)
# Security
security = HTTPBearer()
class HCFSAPIServer:
"""Production HCFS API Server."""
def __init__(self,
db_path: str = "hcfs_production.db",
vector_db_path: str = "hcfs_vectors_production.db",
enable_auth: bool = True,
cors_origins: List[str] = None):
self.db_path = db_path
self.vector_db_path = vector_db_path
self.enable_auth = enable_auth
self.cors_origins = cors_origins or ["http://localhost:3000", "http://localhost:8080"]
# Initialize core components
self.context_db = None
self.embedding_manager = None
self.versioning_system = None
# WebSocket connections
self.websocket_connections: Dict[str, WebSocket] = {}
self.subscriptions: Dict[str, Dict[str, Any]] = {}
# Create FastAPI app
self.app = self._create_app()
async def startup(self):
"""Initialize database connections and components."""
logger.info("Starting HCFS API Server...")
# Initialize core components
self.context_db = OptimizedContextDatabase(self.db_path, cache_size=1000)
self.embedding_manager = OptimizedEmbeddingManager(
self.context_db,
model_name="mini",
vector_db_path=self.vector_db_path,
cache_size=2000,
batch_size=32
)
self.versioning_system = VersioningSystem(self.db_path)
# Update metrics
CONTEXT_COUNT.set(len(self.context_db.get_all_contexts()))
logger.info("HCFS API Server started successfully")
async def shutdown(self):
"""Cleanup resources."""
logger.info("Shutting down HCFS API Server...")
# Close WebSocket connections
for connection in self.websocket_connections.values():
await connection.close()
logger.info("HCFS API Server shutdown complete")
def _create_app(self) -> FastAPI:
"""Create and configure FastAPI application."""
@asynccontextmanager
async def lifespan(app: FastAPI):
await self.startup()
yield
await self.shutdown()
app = FastAPI(
title="HCFS API",
description="Context-Aware Hierarchical Context File System API",
version="2.0.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
lifespan=lifespan
)
# Middleware
app.add_middleware(
CORSMiddleware,
allow_origins=self.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Rate limiting
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Add routes
self._add_routes(app)
# Add middleware for metrics
@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
duration = time.time() - start_time
REQUEST_COUNT.labels(
method=request.method,
endpoint=request.url.path,
status=response.status_code
).inc()
REQUEST_DURATION.observe(duration)
return response
return app
def _add_routes(self, app: FastAPI):
"""Add all API routes."""
# Authentication dependency
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
if self.enable_auth:
# TODO: Implement actual authentication
# For now, just validate token exists
if not credentials.credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return {"username": "api_user", "scopes": ["read", "write"]}
return {"username": "anonymous", "scopes": ["read", "write"]}
# Health check
@app.get("/health", response_model=HealthResponse, tags=["System"])
async def health_check():
"""System health check endpoint."""
components = []
# Check database
try:
self.context_db.get_all_contexts()
db_health = ComponentHealth(name="database", status=HealthStatus.HEALTHY, response_time_ms=1.0)
except Exception as e:
db_health = ComponentHealth(name="database", status=HealthStatus.UNHEALTHY, error_message=str(e))
components.append(db_health)
# Check embedding manager
try:
stats = self.embedding_manager.get_statistics()
emb_health = ComponentHealth(name="embeddings", status=HealthStatus.HEALTHY, response_time_ms=2.0)
except Exception as e:
emb_health = ComponentHealth(name="embeddings", status=HealthStatus.UNHEALTHY, error_message=str(e))
components.append(emb_health)
# Overall status
overall_status = HealthStatus.HEALTHY
if any(c.status == HealthStatus.UNHEALTHY for c in components):
overall_status = HealthStatus.UNHEALTHY
elif any(c.status == HealthStatus.DEGRADED for c in components):
overall_status = HealthStatus.DEGRADED
return HealthResponse(
status=overall_status,
version="2.0.0",
uptime_seconds=time.time(), # Simplified uptime
components=components
)
# Metrics endpoint
@app.get("/metrics", tags=["System"])
async def metrics():
"""Prometheus metrics endpoint."""
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
# Context CRUD operations
@app.post("/api/v1/contexts", response_model=ContextDetailResponse, tags=["Contexts"])
@limiter.limit("100/minute")
async def create_context(
request: Request,
context_data: ContextCreate,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
"""Create a new context with automatic embedding generation."""
try:
# Create context object
context = Context(
id=None,
path=context_data.path,
content=context_data.content,
summary=context_data.summary,
author=context_data.author or current_user["username"],
version=1
)
# Store context
context_id = self.context_db.store_context(context)
# Generate and store embedding in background
background_tasks.add_task(self._generate_embedding_async, context_id, context_data.content)
# Get created context
created_context = self.context_db.get_context(context_id)
context_response = self._context_to_response(created_context)
# Update metrics
CONTEXT_COUNT.inc()
# Notify WebSocket subscribers
await self._notify_websocket_subscribers("created", context_response)
return ContextDetailResponse(data=context_response)
except Exception as e:
logger.error("Error creating context", error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to create context: {str(e)}")
@app.get("/api/v1/contexts/{context_id}", response_model=ContextDetailResponse, tags=["Contexts"])
@limiter.limit("200/minute")
async def get_context(
request: Request,
context_id: int,
current_user: dict = Depends(get_current_user)
):
"""Get a specific context by ID."""
try:
context = self.context_db.get_context(context_id)
if not context:
raise HTTPException(status_code=404, detail="Context not found")
context_response = self._context_to_response(context)
return ContextDetailResponse(data=context_response)
except HTTPException:
raise
except Exception as e:
logger.error("Error retrieving context", context_id=context_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to retrieve context: {str(e)}")
@app.get("/api/v1/contexts", response_model=ContextListResponse, tags=["Contexts"])
@limiter.limit("100/minute")
async def list_contexts(
request: Request,
pagination: PaginationParams = Depends(),
path_prefix: Optional[str] = Query(None, description="Filter by path prefix"),
author: Optional[str] = Query(None, description="Filter by author"),
status: Optional[ContextStatus] = Query(None, description="Filter by status"),
current_user: dict = Depends(get_current_user)
):
"""List contexts with filtering and pagination."""
try:
# Get contexts with filters
contexts = self.context_db.get_contexts_filtered(
path_prefix=path_prefix,
author=author,
status=status.value if status else None,
limit=pagination.page_size,
offset=pagination.offset
)
# Get total count for pagination
total_count = self.context_db.count_contexts(
path_prefix=path_prefix,
author=author,
status=status.value if status else None
)
# Convert to response models
context_responses = [self._context_to_response(ctx) for ctx in contexts]
# Create pagination metadata
pagination_meta = PaginationMeta(
page=pagination.page,
page_size=pagination.page_size,
total_items=total_count,
total_pages=(total_count + pagination.page_size - 1) // pagination.page_size,
has_next=pagination.page * pagination.page_size < total_count,
has_previous=pagination.page > 1
)
return ContextListResponse(data=context_responses, pagination=pagination_meta)
except Exception as e:
logger.error("Error listing contexts", error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to list contexts: {str(e)}")
@app.put("/api/v1/contexts/{context_id}", response_model=ContextDetailResponse, tags=["Contexts"])
@limiter.limit("50/minute")
async def update_context(
request: Request,
context_id: int,
context_update: ContextUpdate,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
"""Update an existing context."""
try:
# Check if context exists
existing_context = self.context_db.get_context(context_id)
if not existing_context:
raise HTTPException(status_code=404, detail="Context not found")
# Update context
update_data = context_update.dict(exclude_unset=True)
if update_data:
self.context_db.update_context(context_id, **update_data)
# If content changed, regenerate embedding
if 'content' in update_data:
background_tasks.add_task(
self._generate_embedding_async,
context_id,
update_data['content']
)
# Get updated context
updated_context = self.context_db.get_context(context_id)
context_response = self._context_to_response(updated_context)
# Notify WebSocket subscribers
await self._notify_websocket_subscribers("updated", context_response)
return ContextDetailResponse(data=context_response)
except HTTPException:
raise
except Exception as e:
logger.error("Error updating context", context_id=context_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to update context: {str(e)}")
@app.delete("/api/v1/contexts/{context_id}", tags=["Contexts"])
@limiter.limit("30/minute")
async def delete_context(
request: Request,
context_id: int,
current_user: dict = Depends(get_current_user)
):
"""Delete a context."""
try:
# Check if context exists
existing_context = self.context_db.get_context(context_id)
if not existing_context:
raise HTTPException(status_code=404, detail="Context not found")
# Delete context
success = self.context_db.delete_context(context_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete context")
# Update metrics
CONTEXT_COUNT.dec()
# Notify WebSocket subscribers
await self._notify_websocket_subscribers("deleted", {"id": context_id})
return {"success": True, "message": "Context deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error("Error deleting context", context_id=context_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to delete context: {str(e)}")
# Search endpoints
@app.post("/api/v1/search", response_model=SearchResponse, tags=["Search"])
@limiter.limit("100/minute")
async def search_contexts(
request: Request,
search_request: SearchRequest,
current_user: dict = Depends(get_current_user)
):
"""Advanced context search with multiple search types."""
try:
start_time = time.time()
# Perform search based on type
if search_request.search_type == SearchType.SEMANTIC:
results = self.embedding_manager.semantic_search_optimized(
search_request.query,
path_prefix=search_request.path_prefix,
top_k=search_request.top_k,
include_contexts=True
)
elif search_request.search_type == SearchType.HYBRID:
results = self.embedding_manager.hybrid_search_optimized(
search_request.query,
path_prefix=search_request.path_prefix,
top_k=search_request.top_k,
semantic_weight=search_request.semantic_weight
)
else:
# Fallback to keyword search
contexts = self.context_db.search_contexts(search_request.query)
results = [type('Result', (), {'context': ctx, 'score': 1.0})() for ctx in contexts[:search_request.top_k]]
search_time = (time.time() - start_time) * 1000
# Convert results to response format
search_results = []
for result in results:
if hasattr(result, 'context') and result.context:
context_response = self._context_to_response(result.context)
context_response.similarity_score = getattr(result, 'score', None)
search_results.append(SearchResult(
context=context_response,
score=result.score,
explanation=f"Matched with {result.score:.3f} similarity"
))
# Update metrics
SEARCH_COUNT.labels(search_type=search_request.search_type.value).inc()
return SearchResponse(
data=search_results,
query=search_request.query,
search_type=search_request.search_type,
total_results=len(search_results),
search_time_ms=search_time,
filters_applied=search_request.filters
)
except Exception as e:
logger.error("Error performing search", query=search_request.query, error=str(e))
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
# Batch operations
@app.post("/api/v1/contexts/batch", response_model=BatchResponse, tags=["Batch Operations"])
@limiter.limit("10/minute")
async def batch_create_contexts(
request: Request,
batch_request: BatchContextCreate,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
"""Create multiple contexts in batch."""
try:
results = BatchOperationResult(
success_count=0,
error_count=0,
total_items=len(batch_request.contexts)
)
for i, context_data in enumerate(batch_request.contexts):
try:
context = Context(
id=None,
path=context_data.path,
content=context_data.content,
summary=context_data.summary,
author=context_data.author or current_user["username"],
version=1
)
context_id = self.context_db.store_context(context)
results.created_ids.append(context_id)
results.success_count += 1
# Generate embedding in background
background_tasks.add_task(
self._generate_embedding_async,
context_id,
context_data.content
)
except Exception as e:
results.error_count += 1
results.errors.append({
"index": i,
"path": context_data.path,
"error": str(e)
})
# Update metrics
CONTEXT_COUNT.inc(results.success_count)
return BatchResponse(data=results)
except Exception as e:
logger.error("Error in batch create", error=str(e))
raise HTTPException(status_code=500, detail=f"Batch operation failed: {str(e)}")
# Statistics endpoint
@app.get("/api/v1/stats", response_model=StatsResponse, tags=["Analytics"])
@limiter.limit("30/minute")
async def get_statistics(
request: Request,
current_user: dict = Depends(get_current_user)
):
"""Get comprehensive system statistics."""
try:
# Get embedding manager stats
emb_stats = self.embedding_manager.get_statistics()
# Mock context stats (implement based on your needs)
context_stats = ContextStats(
total_contexts=emb_stats["database_stats"]["total_embeddings"],
contexts_by_status={ContextStatus.ACTIVE: emb_stats["database_stats"]["total_embeddings"]},
contexts_by_author={"system": emb_stats["database_stats"]["total_embeddings"]},
average_content_length=100.0,
most_active_paths=[],
recent_activity=[]
)
search_stats = SearchStats(
total_searches=100, # Mock data
searches_by_type={SearchType.SEMANTIC: 60, SearchType.HYBRID: 40},
average_response_time_ms=50.0,
popular_queries=[],
search_success_rate=0.95
)
system_stats = SystemStats(
uptime_seconds=time.time(),
memory_usage_mb=100.0,
active_connections=len(self.websocket_connections),
cache_hit_rate=emb_stats["cache_stats"].get("hit_rate", 0.0),
embedding_model_info=emb_stats["current_model"],
database_size_mb=10.0
)
return StatsResponse(
context_stats=context_stats,
search_stats=search_stats,
system_stats=system_stats
)
except Exception as e:
logger.error("Error getting statistics", error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
# WebSocket endpoint
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time updates."""
await self._handle_websocket_connection(websocket)
def _context_to_response(self, context) -> ContextResponse:
"""Convert database context to API response model."""
return ContextResponse(
id=context.id,
path=context.path,
content=context.content,
summary=context.summary,
author=context.author or "unknown",
tags=[], # TODO: implement tags
metadata={}, # TODO: implement metadata
status=ContextStatus.ACTIVE, # TODO: implement status
created_at=context.created_at,
updated_at=context.updated_at,
version=context.version
)
async def _generate_embedding_async(self, context_id: int, content: str):
"""Generate and store embedding asynchronously."""
try:
embedding = self.embedding_manager.generate_embedding(content)
self.embedding_manager.store_embedding(context_id, embedding)
logger.info("Generated embedding for context", context_id=context_id)
except Exception as e:
logger.error("Failed to generate embedding", context_id=context_id, error=str(e))
async def _handle_websocket_connection(self, websocket: WebSocket):
"""Handle WebSocket connection and subscriptions."""
await websocket.accept()
connection_id = str(id(websocket))
self.websocket_connections[connection_id] = websocket
ACTIVE_CONNECTIONS.inc()
try:
while True:
# Wait for subscription requests
data = await websocket.receive_json()
message = WebSocketMessage(**data)
if message.type == "subscribe":
subscription = SubscriptionRequest(**message.data)
self.subscriptions[connection_id] = {
"path_prefix": subscription.path_prefix,
"event_types": subscription.event_types,
"filters": subscription.filters
}
await websocket.send_json({
"type": "subscription_confirmed",
"data": {"path_prefix": subscription.path_prefix}
})
except WebSocketDisconnect:
pass
finally:
# Cleanup
self.websocket_connections.pop(connection_id, None)
self.subscriptions.pop(connection_id, None)
ACTIVE_CONNECTIONS.dec()
async def _notify_websocket_subscribers(self, event_type: str, data: Any):
"""Notify WebSocket subscribers of events."""
if not self.websocket_connections:
return
# Create notification message
notification = WebSocketMessage(
type=event_type,
data=data.dict() if hasattr(data, 'dict') else data
)
# Send to all relevant subscribers
for connection_id, websocket in list(self.websocket_connections.items()):
try:
subscription = self.subscriptions.get(connection_id)
if subscription and event_type in subscription["event_types"]:
# Check path filter
if hasattr(data, 'path') and subscription["path_prefix"]:
if not data.path.startswith(subscription["path_prefix"]):
continue
await websocket.send_json(notification.dict())
except Exception as e:
logger.error("Error sending WebSocket notification",
connection_id=connection_id, error=str(e))
# Remove failed connection
self.websocket_connections.pop(connection_id, None)
self.subscriptions.pop(connection_id, None)
def run(self, host: str = "0.0.0.0", port: int = 8000, **kwargs):
"""Run the API server."""
uvicorn.run(self.app, host=host, port=port, **kwargs)
def create_app() -> FastAPI:
"""Factory function to create the app."""
server = HCFSAPIServer()
return server.app
if __name__ == "__main__":
server = HCFSAPIServer()
server.run()