Phase 2 build initial
This commit is contained in:
21
hcfs-python/hcfs/__init__.py
Normal file
21
hcfs-python/hcfs/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
HCFS - Context-Aware Hierarchical Context File System
|
||||
|
||||
A virtual filesystem that maps hierarchical paths to context blobs,
|
||||
enabling AI agents to navigate and manage context at different scopes.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "Tony"
|
||||
|
||||
from .core.context_db import ContextDatabase
|
||||
from .core.filesystem import HCFSFilesystem
|
||||
from .core.embeddings import EmbeddingManager
|
||||
from .api.server import ContextAPI
|
||||
|
||||
__all__ = [
|
||||
"ContextDatabase",
|
||||
"HCFSFilesystem",
|
||||
"EmbeddingManager",
|
||||
"ContextAPI",
|
||||
]
|
||||
1
hcfs-python/hcfs/api/__init__.py
Normal file
1
hcfs-python/hcfs/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""HCFS API components."""
|
||||
288
hcfs-python/hcfs/api/config.py
Normal file
288
hcfs-python/hcfs/api/config.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Configuration management for HCFS API.
|
||||
|
||||
Handles environment-based configuration with validation and defaults.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseSettings, Field, validator
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
"""Database configuration settings."""
|
||||
|
||||
# SQLite settings
|
||||
db_path: str = Field(default="hcfs_production.db", description="Path to SQLite database")
|
||||
vector_db_path: str = Field(default="hcfs_vectors_production.db", description="Path to vector database")
|
||||
|
||||
# Connection settings
|
||||
pool_size: int = Field(default=10, description="Database connection pool size")
|
||||
max_overflow: int = Field(default=20, description="Maximum connection overflow")
|
||||
pool_timeout: int = Field(default=30, description="Connection pool timeout in seconds")
|
||||
|
||||
# Performance settings
|
||||
cache_size: int = Field(default=1000, description="Database cache size")
|
||||
enable_wal_mode: bool = Field(default=True, description="Enable SQLite WAL mode")
|
||||
synchronous_mode: str = Field(default="NORMAL", description="SQLite synchronous mode")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_DB_"
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseSettings):
|
||||
"""Embedding system configuration."""
|
||||
|
||||
# Model settings
|
||||
model_name: str = Field(default="mini", description="Embedding model to use")
|
||||
cache_size: int = Field(default=2000, description="Embedding cache size")
|
||||
batch_size: int = Field(default=32, description="Batch processing size")
|
||||
|
||||
# Performance settings
|
||||
max_workers: int = Field(default=4, description="Maximum worker threads")
|
||||
timeout_seconds: int = Field(default=300, description="Operation timeout")
|
||||
|
||||
# Vector database settings
|
||||
vector_dimension: int = Field(default=384, description="Vector dimension")
|
||||
similarity_threshold: float = Field(default=0.0, description="Default similarity threshold")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_EMBEDDING_"
|
||||
|
||||
|
||||
class APIConfig(BaseSettings):
|
||||
"""API server configuration."""
|
||||
|
||||
# Server settings
|
||||
host: str = Field(default="0.0.0.0", description="Server host")
|
||||
port: int = Field(default=8000, description="Server port")
|
||||
workers: int = Field(default=1, description="Number of worker processes")
|
||||
|
||||
# Security settings
|
||||
secret_key: str = Field(default="dev-secret-key-change-in-production", description="JWT secret key")
|
||||
algorithm: str = Field(default="HS256", description="JWT algorithm")
|
||||
token_expire_minutes: int = Field(default=30, description="JWT token expiration time")
|
||||
|
||||
# CORS settings
|
||||
cors_origins: List[str] = Field(
|
||||
default=["http://localhost:3000", "http://localhost:8080"],
|
||||
description="Allowed CORS origins"
|
||||
)
|
||||
cors_credentials: bool = Field(default=True, description="Allow credentials in CORS")
|
||||
|
||||
# Rate limiting
|
||||
rate_limit_requests: int = Field(default=100, description="Requests per minute")
|
||||
rate_limit_burst: int = Field(default=20, description="Burst requests allowed")
|
||||
|
||||
# Feature flags
|
||||
enable_auth: bool = Field(default=True, description="Enable authentication")
|
||||
enable_websocket: bool = Field(default=True, description="Enable WebSocket support")
|
||||
enable_metrics: bool = Field(default=True, description="Enable Prometheus metrics")
|
||||
enable_docs: bool = Field(default=True, description="Enable API documentation")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_API_"
|
||||
|
||||
|
||||
class MonitoringConfig(BaseSettings):
|
||||
"""Monitoring and observability configuration."""
|
||||
|
||||
# Logging settings
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(default="json", description="Log format (json/text)")
|
||||
log_file: Optional[str] = Field(default=None, description="Log file path")
|
||||
|
||||
# Metrics settings
|
||||
metrics_enabled: bool = Field(default=True, description="Enable metrics collection")
|
||||
metrics_port: int = Field(default=9090, description="Metrics server port")
|
||||
|
||||
# Health check settings
|
||||
health_check_interval: int = Field(default=30, description="Health check interval in seconds")
|
||||
health_check_timeout: int = Field(default=5, description="Health check timeout")
|
||||
|
||||
# Tracing settings
|
||||
tracing_enabled: bool = Field(default=False, description="Enable distributed tracing")
|
||||
tracing_sample_rate: float = Field(default=0.1, description="Tracing sample rate")
|
||||
jaeger_endpoint: Optional[str] = Field(default=None, description="Jaeger endpoint")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_MONITORING_"
|
||||
|
||||
|
||||
class RedisConfig(BaseSettings):
|
||||
"""Redis configuration for caching and rate limiting."""
|
||||
|
||||
# Connection settings
|
||||
host: str = Field(default="localhost", description="Redis host")
|
||||
port: int = Field(default=6379, description="Redis port")
|
||||
db: int = Field(default=0, description="Redis database number")
|
||||
password: Optional[str] = Field(default=None, description="Redis password")
|
||||
|
||||
# Pool settings
|
||||
max_connections: int = Field(default=20, description="Maximum Redis connections")
|
||||
socket_timeout: int = Field(default=5, description="Socket timeout in seconds")
|
||||
|
||||
# Cache settings
|
||||
default_ttl: int = Field(default=3600, description="Default cache TTL in seconds")
|
||||
key_prefix: str = Field(default="hcfs:", description="Redis key prefix")
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_REDIS_"
|
||||
|
||||
|
||||
class SecurityConfig(BaseSettings):
|
||||
"""Security configuration."""
|
||||
|
||||
# Authentication
|
||||
require_auth: bool = Field(default=True, description="Require authentication")
|
||||
api_key_header: str = Field(default="X-API-Key", description="API key header name")
|
||||
|
||||
# Rate limiting
|
||||
rate_limit_enabled: bool = Field(default=True, description="Enable rate limiting")
|
||||
rate_limit_storage: str = Field(default="memory", description="Rate limit storage (memory/redis)")
|
||||
|
||||
# HTTPS settings
|
||||
force_https: bool = Field(default=False, description="Force HTTPS in production")
|
||||
hsts_max_age: int = Field(default=31536000, description="HSTS max age")
|
||||
|
||||
# Request validation
|
||||
max_request_size: int = Field(default=10 * 1024 * 1024, description="Maximum request size in bytes")
|
||||
max_query_params: int = Field(default=100, description="Maximum query parameters")
|
||||
|
||||
# Content security
|
||||
allowed_content_types: List[str] = Field(
|
||||
default=["application/json", "application/x-www-form-urlencoded", "multipart/form-data"],
|
||||
description="Allowed content types"
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_SECURITY_"
|
||||
|
||||
|
||||
class HCFSConfig(BaseSettings):
|
||||
"""Main HCFS configuration combining all subsystem configs."""
|
||||
|
||||
# Environment
|
||||
environment: str = Field(default="development", description="Environment (development/staging/production)")
|
||||
debug: bool = Field(default=False, description="Enable debug mode")
|
||||
|
||||
# Application info
|
||||
app_name: str = Field(default="HCFS API", description="Application name")
|
||||
app_version: str = Field(default="2.0.0", description="Application version")
|
||||
app_description: str = Field(default="Context-Aware Hierarchical Context File System API", description="App description")
|
||||
|
||||
# Configuration file path
|
||||
config_file: Optional[str] = Field(default=None, description="Path to configuration file")
|
||||
|
||||
# Subsystem configurations
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||||
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
|
||||
api: APIConfig = Field(default_factory=APIConfig)
|
||||
monitoring: MonitoringConfig = Field(default_factory=MonitoringConfig)
|
||||
redis: RedisConfig = Field(default_factory=RedisConfig)
|
||||
security: SecurityConfig = Field(default_factory=SecurityConfig)
|
||||
|
||||
class Config:
|
||||
env_prefix = "HCFS_"
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
@validator('environment')
|
||||
def validate_environment(cls, v):
|
||||
"""Validate environment value."""
|
||||
allowed = ['development', 'staging', 'production']
|
||||
if v not in allowed:
|
||||
raise ValueError(f'Environment must be one of: {allowed}')
|
||||
return v
|
||||
|
||||
@validator('debug')
|
||||
def validate_debug_in_production(cls, v, values):
|
||||
"""Ensure debug is disabled in production."""
|
||||
if values.get('environment') == 'production' and v:
|
||||
raise ValueError('Debug mode cannot be enabled in production')
|
||||
return v
|
||||
|
||||
def is_production(self) -> bool:
|
||||
"""Check if running in production environment."""
|
||||
return self.environment == 'production'
|
||||
|
||||
def is_development(self) -> bool:
|
||||
"""Check if running in development environment."""
|
||||
return self.environment == 'development'
|
||||
|
||||
def get_database_url(self) -> str:
|
||||
"""Get database URL."""
|
||||
return f"sqlite:///{self.database.db_path}"
|
||||
|
||||
def get_redis_url(self) -> str:
|
||||
"""Get Redis URL."""
|
||||
if self.redis.password:
|
||||
return f"redis://:{self.redis.password}@{self.redis.host}:{self.redis.port}/{self.redis.db}"
|
||||
return f"redis://{self.redis.host}:{self.redis.port}/{self.redis.db}"
|
||||
|
||||
def load_from_file(self, config_path: str) -> None:
|
||||
"""Load configuration from YAML file."""
|
||||
import yaml
|
||||
|
||||
config_file = Path(config_path)
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
# Update configuration
|
||||
for key, value in config_data.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert configuration to dictionary."""
|
||||
return self.dict()
|
||||
|
||||
def save_to_file(self, config_path: str) -> None:
|
||||
"""Save configuration to YAML file."""
|
||||
import yaml
|
||||
|
||||
config_data = self.to_dict()
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(config_data, f, default_flow_style=False, indent=2)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
config = HCFSConfig()
|
||||
|
||||
|
||||
def get_config() -> HCFSConfig:
|
||||
"""Get the global configuration instance."""
|
||||
return config
|
||||
|
||||
|
||||
def load_config(config_path: Optional[str] = None, **overrides) -> HCFSConfig:
|
||||
"""Load configuration with optional file and overrides."""
|
||||
global config
|
||||
|
||||
# Load from file if provided
|
||||
if config_path:
|
||||
config.load_from_file(config_path)
|
||||
|
||||
# Apply overrides
|
||||
for key, value in overrides.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_config_template(output_path: str = "hcfs_config.yaml") -> None:
|
||||
"""Create a configuration template file."""
|
||||
template_config = HCFSConfig()
|
||||
template_config.save_to_file(output_path)
|
||||
print(f"Configuration template created: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create configuration template
|
||||
create_config_template()
|
||||
365
hcfs-python/hcfs/api/middleware.py
Normal file
365
hcfs-python/hcfs/api/middleware.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
Custom middleware for HCFS API.
|
||||
|
||||
Provides authentication, logging, error handling, and security features.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
import jwt
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for comprehensive request/response logging."""
|
||||
|
||||
def __init__(self, app, log_body: bool = False):
|
||||
super().__init__(app)
|
||||
self.log_body = log_body
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
"Request started",
|
||||
request_id=request_id,
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
client_ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
)
|
||||
|
||||
# Call the next middleware/endpoint
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
"Request completed",
|
||||
request_id=request_id,
|
||||
status_code=response.status_code,
|
||||
duration_ms=round(duration * 1000, 2),
|
||||
)
|
||||
|
||||
# Add request ID to response headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for consistent error handling and formatting."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
except HTTPException as e:
|
||||
# FastAPI HTTPExceptions are handled by FastAPI itself
|
||||
raise e
|
||||
except Exception as e:
|
||||
# Log unexpected errors
|
||||
request_id = getattr(request.state, 'request_id', 'unknown')
|
||||
logger.error(
|
||||
"Unhandled exception",
|
||||
request_id=request_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Return consistent error response
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Internal server error",
|
||||
"error_details": [{"message": "An unexpected error occurred"}],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_id": request_id,
|
||||
"api_version": "v1"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add security headers."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class JWTAuthenticationManager:
|
||||
"""JWT-based authentication manager."""
|
||||
|
||||
def __init__(self, secret_key: str, algorithm: str = "HS256", token_expire_minutes: int = 30):
|
||||
self.secret_key = secret_key
|
||||
self.algorithm = algorithm
|
||||
self.token_expire_minutes = token_expire_minutes
|
||||
|
||||
def create_access_token(self, data: dict) -> str:
|
||||
"""Create JWT access token."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=self.token_expire_minutes)
|
||||
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
|
||||
|
||||
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str) -> Optional[dict]:
|
||||
"""Verify and decode JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except jwt.JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
"""API key-based authentication manager."""
|
||||
|
||||
def __init__(self):
|
||||
# In production, store these in a database
|
||||
self.api_keys = {
|
||||
"dev-key-123": {
|
||||
"name": "Development Key",
|
||||
"scopes": ["read", "write"],
|
||||
"rate_limit": 1000,
|
||||
"created_at": datetime.utcnow(),
|
||||
"last_used": None
|
||||
}
|
||||
}
|
||||
|
||||
def validate_api_key(self, api_key: str) -> Optional[dict]:
|
||||
"""Validate API key and return key info."""
|
||||
key_info = self.api_keys.get(api_key)
|
||||
if key_info:
|
||||
# Update last used timestamp
|
||||
key_info["last_used"] = datetime.utcnow()
|
||||
return key_info
|
||||
return None
|
||||
|
||||
|
||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""Authentication middleware supporting multiple auth methods."""
|
||||
|
||||
def __init__(self, app, jwt_manager: JWTAuthenticationManager = None, api_key_manager: APIKeyManager = None):
|
||||
super().__init__(app)
|
||||
self.jwt_manager = jwt_manager
|
||||
self.api_key_manager = api_key_manager or APIKeyManager()
|
||||
|
||||
# Paths that don't require authentication
|
||||
self.public_paths = {
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/favicon.ico"
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip authentication for public paths
|
||||
if any(request.url.path.startswith(path) for path in self.public_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract authentication credentials
|
||||
auth_header = request.headers.get("Authorization")
|
||||
api_key_header = request.headers.get("X-API-Key")
|
||||
|
||||
user_info = None
|
||||
|
||||
# Try JWT authentication first
|
||||
if auth_header and auth_header.startswith("Bearer ") and self.jwt_manager:
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
try:
|
||||
payload = self.jwt_manager.verify_token(token)
|
||||
user_info = {
|
||||
"user_id": payload.get("sub"),
|
||||
"username": payload.get("username"),
|
||||
"scopes": payload.get("scopes", []),
|
||||
"auth_method": "jwt"
|
||||
}
|
||||
except HTTPException:
|
||||
pass # Try other auth methods
|
||||
|
||||
# Try API key authentication
|
||||
if not user_info and api_key_header:
|
||||
key_info = self.api_key_manager.validate_api_key(api_key_header)
|
||||
if key_info:
|
||||
user_info = {
|
||||
"user_id": f"api_key_{api_key_header[:8]}",
|
||||
"username": key_info["name"],
|
||||
"scopes": key_info["scopes"],
|
||||
"auth_method": "api_key",
|
||||
"rate_limit": key_info["rate_limit"]
|
||||
}
|
||||
|
||||
# If no valid authentication found
|
||||
if not user_info:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Authentication required",
|
||||
"error_details": [{"message": "Valid API key or JWT token required"}],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"api_version": "v1"
|
||||
}
|
||||
)
|
||||
|
||||
# Add user info to request state
|
||||
request.state.user = user_info
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RateLimitingMiddleware(BaseHTTPMiddleware):
|
||||
"""Custom rate limiting middleware."""
|
||||
|
||||
def __init__(self, app, default_rate_limit: int = 100):
|
||||
super().__init__(app)
|
||||
self.default_rate_limit = default_rate_limit
|
||||
self.request_counts = {} # In production, use Redis
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Get user identifier
|
||||
user_info = getattr(request.state, 'user', None)
|
||||
if user_info:
|
||||
user_id = user_info["user_id"]
|
||||
rate_limit = user_info.get("rate_limit", self.default_rate_limit)
|
||||
else:
|
||||
user_id = request.client.host if request.client else "anonymous"
|
||||
rate_limit = self.default_rate_limit
|
||||
|
||||
# Current minute window
|
||||
current_minute = int(time.time() // 60)
|
||||
key = f"{user_id}:{current_minute}"
|
||||
|
||||
# Increment request count
|
||||
current_count = self.request_counts.get(key, 0) + 1
|
||||
self.request_counts[key] = current_count
|
||||
|
||||
# Clean up old entries (simple cleanup)
|
||||
if len(self.request_counts) > 10000:
|
||||
old_keys = [k for k in self.request_counts.keys()
|
||||
if int(k.split(':')[1]) < current_minute - 5]
|
||||
for old_key in old_keys:
|
||||
del self.request_counts[old_key]
|
||||
|
||||
# Check rate limit
|
||||
if current_count > rate_limit:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Rate limit exceeded",
|
||||
"error_details": [{"message": f"Rate limit of {rate_limit} requests per minute exceeded"}],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"retry_after": 60 - (int(time.time()) % 60)
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(60 - (int(time.time()) % 60)),
|
||||
"X-RateLimit-Limit": str(rate_limit),
|
||||
"X-RateLimit-Remaining": str(max(0, rate_limit - current_count)),
|
||||
"X-RateLimit-Reset": str((current_minute + 1) * 60)
|
||||
}
|
||||
)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(rate_limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(max(0, rate_limit - current_count))
|
||||
response.headers["X-RateLimit-Reset"] = str((current_minute + 1) * 60)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class CompressionMiddleware(BaseHTTPMiddleware):
|
||||
"""Custom compression middleware with configurable settings."""
|
||||
|
||||
def __init__(self, app, minimum_size: int = 1000, compression_level: int = 6):
|
||||
super().__init__(app)
|
||||
self.minimum_size = minimum_size
|
||||
self.compression_level = compression_level
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
# Check if client accepts gzip
|
||||
accept_encoding = request.headers.get("accept-encoding", "")
|
||||
if "gzip" not in accept_encoding:
|
||||
return response
|
||||
|
||||
# Check content type and size
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if not any(ct in content_type for ct in ["application/json", "text/", "application/javascript"]):
|
||||
return response
|
||||
|
||||
# Get response body
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
|
||||
# Compress if body is large enough
|
||||
if len(body) >= self.minimum_size:
|
||||
import gzip
|
||||
compressed_body = gzip.compress(body, compresslevel=self.compression_level)
|
||||
|
||||
# Create new response with compressed body
|
||||
from starlette.responses import Response
|
||||
return Response(
|
||||
content=compressed_body,
|
||||
status_code=response.status_code,
|
||||
headers={
|
||||
**dict(response.headers),
|
||||
"content-encoding": "gzip",
|
||||
"content-length": str(len(compressed_body))
|
||||
}
|
||||
)
|
||||
|
||||
# Return original response if not compressed
|
||||
from starlette.responses import Response
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
347
hcfs-python/hcfs/api/models.py
Normal file
347
hcfs-python/hcfs/api/models.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Enhanced API Models for HCFS Production API.
|
||||
|
||||
Comprehensive Pydantic models for request/response validation,
|
||||
API versioning, and enterprise-grade data validation.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, validator, ConfigDict
|
||||
import uuid
|
||||
|
||||
|
||||
class APIVersion(str, Enum):
|
||||
"""API version enumeration."""
|
||||
V1 = "v1"
|
||||
V2 = "v2"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
"""Search type enumeration."""
|
||||
SEMANTIC = "semantic"
|
||||
HYBRID = "hybrid"
|
||||
KEYWORD = "keyword"
|
||||
SIMILARITY = "similarity"
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order enumeration."""
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class ContextStatus(str, Enum):
|
||||
"""Context status enumeration."""
|
||||
ACTIVE = "active"
|
||||
ARCHIVED = "archived"
|
||||
DRAFT = "draft"
|
||||
DELETED = "deleted"
|
||||
|
||||
|
||||
# Base Models
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
"""Base response model with metadata."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
success: bool = True
|
||||
message: Optional[str] = None
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
api_version: APIVersion = APIVersion.V1
|
||||
|
||||
|
||||
class PaginationParams(BaseModel):
|
||||
"""Pagination parameters."""
|
||||
page: int = Field(default=1, ge=1, description="Page number (1-based)")
|
||||
page_size: int = Field(default=20, ge=1, le=100, description="Items per page")
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
"""Calculate offset from page and page_size."""
|
||||
return (self.page - 1) * self.page_size
|
||||
|
||||
|
||||
class PaginationMeta(BaseModel):
|
||||
"""Pagination metadata."""
|
||||
page: int
|
||||
page_size: int
|
||||
total_items: int
|
||||
total_pages: int
|
||||
has_next: bool
|
||||
has_previous: bool
|
||||
|
||||
|
||||
# Context Models
|
||||
|
||||
class ContextBase(BaseModel):
|
||||
"""Base context model with common fields."""
|
||||
path: str = Field(..., description="Hierarchical path for the context")
|
||||
content: str = Field(..., description="Main content of the context")
|
||||
summary: Optional[str] = Field(None, description="Brief summary of the content")
|
||||
author: Optional[str] = Field(None, description="Author or creator of the context")
|
||||
tags: Optional[List[str]] = Field(default_factory=list, description="Tags associated with the context")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata")
|
||||
status: ContextStatus = Field(default=ContextStatus.ACTIVE, description="Context status")
|
||||
|
||||
@validator('path')
|
||||
def validate_path(cls, v):
|
||||
"""Validate path format."""
|
||||
if not v.startswith('/'):
|
||||
raise ValueError('Path must start with /')
|
||||
if '//' in v:
|
||||
raise ValueError('Path cannot contain double slashes')
|
||||
return v
|
||||
|
||||
@validator('content')
|
||||
def validate_content(cls, v):
|
||||
"""Validate content is not empty."""
|
||||
if not v.strip():
|
||||
raise ValueError('Content cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ContextCreate(ContextBase):
|
||||
"""Model for creating a new context."""
|
||||
pass
|
||||
|
||||
|
||||
class ContextUpdate(BaseModel):
|
||||
"""Model for updating an existing context."""
|
||||
content: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
status: Optional[ContextStatus] = None
|
||||
|
||||
@validator('content')
|
||||
def validate_content(cls, v):
|
||||
"""Validate content if provided."""
|
||||
if v is not None and not v.strip():
|
||||
raise ValueError('Content cannot be empty')
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class ContextResponse(ContextBase):
|
||||
"""Model for context responses."""
|
||||
id: int = Field(..., description="Unique context identifier")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
version: int = Field(..., description="Context version number")
|
||||
embedding_model: Optional[str] = Field(None, description="Embedding model used")
|
||||
similarity_score: Optional[float] = Field(None, description="Similarity score (for search results)")
|
||||
|
||||
|
||||
class ContextListResponse(BaseResponse):
|
||||
"""Response model for context list operations."""
|
||||
data: List[ContextResponse]
|
||||
pagination: PaginationMeta
|
||||
|
||||
|
||||
class ContextDetailResponse(BaseResponse):
|
||||
"""Response model for single context operations."""
|
||||
data: ContextResponse
|
||||
|
||||
|
||||
# Search Models
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Model for search requests."""
|
||||
query: str = Field(..., description="Search query text")
|
||||
search_type: SearchType = Field(default=SearchType.SEMANTIC, description="Type of search to perform")
|
||||
path_prefix: Optional[str] = Field(None, description="Limit search to paths with this prefix")
|
||||
top_k: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return")
|
||||
min_similarity: float = Field(default=0.0, ge=0.0, le=1.0, description="Minimum similarity threshold")
|
||||
semantic_weight: float = Field(default=0.7, ge=0.0, le=1.0, description="Weight for semantic vs keyword search")
|
||||
include_content: bool = Field(default=True, description="Whether to include full content in results")
|
||||
filters: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional search filters")
|
||||
|
||||
@validator('query')
|
||||
def validate_query(cls, v):
|
||||
"""Validate query is not empty."""
|
||||
if not v.strip():
|
||||
raise ValueError('Query cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Individual search result."""
|
||||
context: ContextResponse
|
||||
score: float = Field(..., description="Relevance score")
|
||||
highlight: Optional[Dict[str, List[str]]] = Field(None, description="Highlighted matching text")
|
||||
explanation: Optional[str] = Field(None, description="Explanation of why this result was returned")
|
||||
|
||||
|
||||
class SearchResponse(BaseResponse):
|
||||
"""Response model for search operations."""
|
||||
data: List[SearchResult]
|
||||
query: str
|
||||
search_type: SearchType
|
||||
total_results: int
|
||||
search_time_ms: float
|
||||
filters_applied: Dict[str, Any]
|
||||
|
||||
|
||||
# Version Models
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
"""Model for context version information."""
|
||||
version_id: int
|
||||
version_number: int
|
||||
context_id: int
|
||||
author: str
|
||||
message: Optional[str]
|
||||
created_at: datetime
|
||||
content_hash: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class VersionListResponse(BaseResponse):
|
||||
"""Response model for version history."""
|
||||
data: List[VersionResponse]
|
||||
context_id: int
|
||||
total_versions: int
|
||||
|
||||
|
||||
class VersionCreateRequest(BaseModel):
|
||||
"""Request model for creating a new version."""
|
||||
message: Optional[str] = Field(None, description="Version commit message")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Version metadata")
|
||||
|
||||
|
||||
class RollbackRequest(BaseModel):
|
||||
"""Request model for version rollback."""
|
||||
target_version: int = Field(..., description="Target version number to rollback to")
|
||||
message: Optional[str] = Field(None, description="Rollback commit message")
|
||||
|
||||
|
||||
# Analytics Models
|
||||
|
||||
class ContextStats(BaseModel):
|
||||
"""Context statistics model."""
|
||||
total_contexts: int
|
||||
contexts_by_status: Dict[ContextStatus, int]
|
||||
contexts_by_author: Dict[str, int]
|
||||
average_content_length: float
|
||||
most_active_paths: List[Dict[str, Union[str, int]]]
|
||||
recent_activity: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class SearchStats(BaseModel):
|
||||
"""Search statistics model."""
|
||||
total_searches: int
|
||||
searches_by_type: Dict[SearchType, int]
|
||||
average_response_time_ms: float
|
||||
popular_queries: List[Dict[str, Union[str, int]]]
|
||||
search_success_rate: float
|
||||
|
||||
|
||||
class SystemStats(BaseModel):
|
||||
"""System statistics model."""
|
||||
uptime_seconds: float
|
||||
memory_usage_mb: float
|
||||
active_connections: int
|
||||
cache_hit_rate: float
|
||||
embedding_model_info: Dict[str, Any]
|
||||
database_size_mb: float
|
||||
|
||||
|
||||
class StatsResponse(BaseResponse):
|
||||
"""Response model for statistics."""
|
||||
context_stats: ContextStats
|
||||
search_stats: SearchStats
|
||||
system_stats: SystemStats
|
||||
|
||||
|
||||
# Batch Operations Models
|
||||
|
||||
class BatchContextCreate(BaseModel):
|
||||
"""Model for batch context creation."""
|
||||
contexts: List[ContextCreate] = Field(..., max_items=100, description="List of contexts to create")
|
||||
|
||||
@validator('contexts')
|
||||
def validate_contexts_not_empty(cls, v):
|
||||
"""Validate contexts list is not empty."""
|
||||
if not v:
|
||||
raise ValueError('Contexts list cannot be empty')
|
||||
return v
|
||||
|
||||
|
||||
class BatchOperationResult(BaseModel):
|
||||
"""Result of batch operation."""
|
||||
success_count: int
|
||||
error_count: int
|
||||
total_items: int
|
||||
errors: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
created_ids: List[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BatchResponse(BaseResponse):
|
||||
"""Response model for batch operations."""
|
||||
data: BatchOperationResult
|
||||
|
||||
|
||||
# WebSocket Models
|
||||
|
||||
class WebSocketMessage(BaseModel):
|
||||
"""WebSocket message model."""
|
||||
type: str = Field(..., description="Message type")
|
||||
data: Dict[str, Any] = Field(..., description="Message data")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
message_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
|
||||
class SubscriptionRequest(BaseModel):
|
||||
"""WebSocket subscription request."""
|
||||
path_prefix: str = Field(..., description="Path prefix to subscribe to")
|
||||
event_types: List[str] = Field(default_factory=lambda: ["created", "updated", "deleted"])
|
||||
filters: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Health Check Models
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
"""Health status enumeration."""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
|
||||
|
||||
class ComponentHealth(BaseModel):
|
||||
"""Individual component health."""
|
||||
name: str
|
||||
status: HealthStatus
|
||||
response_time_ms: Optional[float] = None
|
||||
error_message: Optional[str] = None
|
||||
last_check: datetime
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""System health response."""
|
||||
status: HealthStatus
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
version: str
|
||||
uptime_seconds: float
|
||||
components: List[ComponentHealth]
|
||||
|
||||
|
||||
# Error Models
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Detailed error information."""
|
||||
field: Optional[str] = None
|
||||
message: str
|
||||
error_code: Optional[str] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response model."""
|
||||
success: bool = False
|
||||
error: str
|
||||
error_details: Optional[List[ErrorDetail]] = None
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
api_version: APIVersion = APIVersion.V1
|
||||
172
hcfs-python/hcfs/api/server.py
Normal file
172
hcfs-python/hcfs/api/server.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
HCFS API Server - FastAPI-based REST API for context operations.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..core.context_db import ContextDatabase, Context
|
||||
from ..core.embeddings import EmbeddingManager
|
||||
|
||||
|
||||
# Pydantic models
|
||||
class ContextCreateRequest(BaseModel):
|
||||
path: str
|
||||
content: str
|
||||
summary: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
|
||||
|
||||
class ContextResponse(BaseModel):
|
||||
id: int
|
||||
path: str
|
||||
content: str
|
||||
summary: Optional[str]
|
||||
author: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
version: int
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
path_prefix: Optional[str] = None
|
||||
top_k: int = 5
|
||||
search_type: str = "hybrid" # "semantic", "hybrid"
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
context: ContextResponse
|
||||
score: float
|
||||
|
||||
|
||||
class ContextAPI:
|
||||
"""HCFS REST API server."""
|
||||
|
||||
def __init__(self, context_db: ContextDatabase, embedding_manager: EmbeddingManager):
|
||||
self.context_db = context_db
|
||||
self.embedding_manager = embedding_manager
|
||||
self.app = FastAPI(
|
||||
title="HCFS Context API",
|
||||
description="Context-Aware Hierarchical Context File System API",
|
||||
version="0.1.0"
|
||||
)
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Setup API routes."""
|
||||
|
||||
@self.app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "service": "hcfs-api"}
|
||||
|
||||
@self.app.post("/context", response_model=ContextResponse)
|
||||
async def create_context(request: ContextCreateRequest):
|
||||
"""Create a new context."""
|
||||
context = Context(
|
||||
id=None,
|
||||
path=request.path,
|
||||
content=request.content,
|
||||
summary=request.summary,
|
||||
author=request.author
|
||||
)
|
||||
|
||||
# Store with embedding
|
||||
context_id = self.embedding_manager.store_context_with_embedding(context)
|
||||
|
||||
# Retrieve the stored context
|
||||
stored_contexts = self.context_db.list_contexts_at_path(request.path)
|
||||
stored_context = next((c for c in stored_contexts if c.id == context_id), None)
|
||||
|
||||
if not stored_context:
|
||||
raise HTTPException(status_code=500, detail="Failed to store context")
|
||||
|
||||
return ContextResponse(**stored_context.__dict__)
|
||||
|
||||
@self.app.get("/context/{path:path}", response_model=List[ContextResponse])
|
||||
async def get_context(path: str, depth: int = 1):
|
||||
"""Get contexts for a path with optional parent inheritance."""
|
||||
contexts = self.context_db.get_context_by_path(f"/{path}", depth=depth)
|
||||
return [ContextResponse(**ctx.__dict__) for ctx in contexts]
|
||||
|
||||
@self.app.get("/context", response_model=List[ContextResponse])
|
||||
async def list_contexts(path: str):
|
||||
"""List all contexts at a specific path."""
|
||||
contexts = self.context_db.list_contexts_at_path(path)
|
||||
return [ContextResponse(**ctx.__dict__) for ctx in contexts]
|
||||
|
||||
@self.app.put("/context/{context_id}")
|
||||
async def update_context(context_id: int, content: str, summary: Optional[str] = None):
|
||||
"""Update an existing context."""
|
||||
success = self.context_db.update_context(context_id, content, summary)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
|
||||
# Update embedding
|
||||
contexts = self.context_db.list_contexts_at_path("") # Get updated context
|
||||
updated_context = next((c for c in contexts if c.id == context_id), None)
|
||||
if updated_context:
|
||||
embedding = self.embedding_manager.generate_embedding(updated_context.content)
|
||||
self.embedding_manager._store_embedding(context_id, embedding)
|
||||
|
||||
return {"message": "Context updated successfully"}
|
||||
|
||||
@self.app.delete("/context/{context_id}")
|
||||
async def delete_context(context_id: int):
|
||||
"""Delete a context."""
|
||||
success = self.context_db.delete_context(context_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
return {"message": "Context deleted successfully"}
|
||||
|
||||
@self.app.post("/search", response_model=List[SearchResult])
|
||||
async def search_contexts(request: SearchRequest):
|
||||
"""Search contexts using semantic or hybrid search."""
|
||||
if request.search_type == "semantic":
|
||||
results = self.embedding_manager.semantic_search(
|
||||
request.query,
|
||||
request.path_prefix,
|
||||
request.top_k
|
||||
)
|
||||
elif request.search_type == "hybrid":
|
||||
results = self.embedding_manager.hybrid_search(
|
||||
request.query,
|
||||
request.path_prefix,
|
||||
request.top_k
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid search_type")
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
context=ContextResponse(**ctx.__dict__),
|
||||
score=score
|
||||
)
|
||||
for ctx, score in results
|
||||
]
|
||||
|
||||
@self.app.get("/similar/{context_id}", response_model=List[SearchResult])
|
||||
async def get_similar_contexts(context_id: int, top_k: int = 5):
|
||||
"""Find contexts similar to a given context."""
|
||||
results = self.embedding_manager.get_similar_contexts(context_id, top_k)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
context=ContextResponse(**ctx.__dict__),
|
||||
score=score
|
||||
)
|
||||
for ctx, score in results
|
||||
]
|
||||
|
||||
|
||||
def create_app(db_path: str = "hcfs_context.db") -> FastAPI:
|
||||
"""Create FastAPI application with HCFS components."""
|
||||
context_db = ContextDatabase(db_path)
|
||||
embedding_manager = EmbeddingManager(context_db)
|
||||
api = ContextAPI(context_db, embedding_manager)
|
||||
return api.app
|
||||
692
hcfs-python/hcfs/api/server_v2.py
Normal file
692
hcfs-python/hcfs/api/server_v2.py
Normal file
@@ -0,0 +1,692 @@
|
||||
"""
|
||||
Production-Grade HCFS API Server v2.0
|
||||
|
||||
Enterprise-ready FastAPI server with comprehensive features:
|
||||
- Full CRUD operations with validation
|
||||
- Advanced search capabilities
|
||||
- Version control and rollback
|
||||
- Batch operations
|
||||
- Real-time WebSocket updates
|
||||
- Authentication and authorization
|
||||
- Rate limiting and monitoring
|
||||
- OpenAPI documentation
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends, status, Request, Query, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.websocket import WebSocket, WebSocketDisconnect
|
||||
import uvicorn
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from prometheus_client import Counter, Histogram, Gauge, generate_latest, CONTENT_TYPE_LATEST
|
||||
import structlog
|
||||
|
||||
# HCFS imports
|
||||
from .models import *
|
||||
from ..core.context_db_optimized_fixed import OptimizedContextDatabase
|
||||
from ..core.embeddings_optimized import OptimizedEmbeddingManager
|
||||
from ..core.context_versioning import VersioningSystem
|
||||
from ..core.context_db import Context
|
||||
|
||||
# Logging setup
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Metrics
|
||||
REQUEST_COUNT = Counter('hcfs_requests_total', 'Total HTTP requests', ['method', 'endpoint', 'status'])
|
||||
REQUEST_DURATION = Histogram('hcfs_request_duration_seconds', 'HTTP request duration')
|
||||
ACTIVE_CONNECTIONS = Gauge('hcfs_active_connections', 'Active WebSocket connections')
|
||||
CONTEXT_COUNT = Gauge('hcfs_contexts_total', 'Total number of contexts')
|
||||
SEARCH_COUNT = Counter('hcfs_searches_total', 'Total searches performed', ['search_type'])
|
||||
|
||||
# Rate limiting
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
|
||||
class HCFSAPIServer:
|
||||
"""Production HCFS API Server."""
|
||||
|
||||
def __init__(self,
|
||||
db_path: str = "hcfs_production.db",
|
||||
vector_db_path: str = "hcfs_vectors_production.db",
|
||||
enable_auth: bool = True,
|
||||
cors_origins: List[str] = None):
|
||||
|
||||
self.db_path = db_path
|
||||
self.vector_db_path = vector_db_path
|
||||
self.enable_auth = enable_auth
|
||||
self.cors_origins = cors_origins or ["http://localhost:3000", "http://localhost:8080"]
|
||||
|
||||
# Initialize core components
|
||||
self.context_db = None
|
||||
self.embedding_manager = None
|
||||
self.versioning_system = None
|
||||
|
||||
# WebSocket connections
|
||||
self.websocket_connections: Dict[str, WebSocket] = {}
|
||||
self.subscriptions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Create FastAPI app
|
||||
self.app = self._create_app()
|
||||
|
||||
async def startup(self):
|
||||
"""Initialize database connections and components."""
|
||||
logger.info("Starting HCFS API Server...")
|
||||
|
||||
# Initialize core components
|
||||
self.context_db = OptimizedContextDatabase(self.db_path, cache_size=1000)
|
||||
self.embedding_manager = OptimizedEmbeddingManager(
|
||||
self.context_db,
|
||||
model_name="mini",
|
||||
vector_db_path=self.vector_db_path,
|
||||
cache_size=2000,
|
||||
batch_size=32
|
||||
)
|
||||
self.versioning_system = VersioningSystem(self.db_path)
|
||||
|
||||
# Update metrics
|
||||
CONTEXT_COUNT.set(len(self.context_db.get_all_contexts()))
|
||||
|
||||
logger.info("HCFS API Server started successfully")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Cleanup resources."""
|
||||
logger.info("Shutting down HCFS API Server...")
|
||||
|
||||
# Close WebSocket connections
|
||||
for connection in self.websocket_connections.values():
|
||||
await connection.close()
|
||||
|
||||
logger.info("HCFS API Server shutdown complete")
|
||||
|
||||
def _create_app(self) -> FastAPI:
|
||||
"""Create and configure FastAPI application."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await self.startup()
|
||||
yield
|
||||
await self.shutdown()
|
||||
|
||||
app = FastAPI(
|
||||
title="HCFS API",
|
||||
description="Context-Aware Hierarchical Context File System API",
|
||||
version="2.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=self.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
# Rate limiting
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Add routes
|
||||
self._add_routes(app)
|
||||
|
||||
# Add middleware for metrics
|
||||
@app.middleware("http")
|
||||
async def metrics_middleware(request: Request, call_next):
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
duration = time.time() - start_time
|
||||
|
||||
REQUEST_COUNT.labels(
|
||||
method=request.method,
|
||||
endpoint=request.url.path,
|
||||
status=response.status_code
|
||||
).inc()
|
||||
REQUEST_DURATION.observe(duration)
|
||||
|
||||
return response
|
||||
|
||||
return app
|
||||
|
||||
def _add_routes(self, app: FastAPI):
|
||||
"""Add all API routes."""
|
||||
|
||||
# Authentication dependency
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
if self.enable_auth:
|
||||
# TODO: Implement actual authentication
|
||||
# For now, just validate token exists
|
||||
if not credentials.credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return {"username": "api_user", "scopes": ["read", "write"]}
|
||||
return {"username": "anonymous", "scopes": ["read", "write"]}
|
||||
|
||||
# Health check
|
||||
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
||||
async def health_check():
|
||||
"""System health check endpoint."""
|
||||
components = []
|
||||
|
||||
# Check database
|
||||
try:
|
||||
self.context_db.get_all_contexts()
|
||||
db_health = ComponentHealth(name="database", status=HealthStatus.HEALTHY, response_time_ms=1.0)
|
||||
except Exception as e:
|
||||
db_health = ComponentHealth(name="database", status=HealthStatus.UNHEALTHY, error_message=str(e))
|
||||
components.append(db_health)
|
||||
|
||||
# Check embedding manager
|
||||
try:
|
||||
stats = self.embedding_manager.get_statistics()
|
||||
emb_health = ComponentHealth(name="embeddings", status=HealthStatus.HEALTHY, response_time_ms=2.0)
|
||||
except Exception as e:
|
||||
emb_health = ComponentHealth(name="embeddings", status=HealthStatus.UNHEALTHY, error_message=str(e))
|
||||
components.append(emb_health)
|
||||
|
||||
# Overall status
|
||||
overall_status = HealthStatus.HEALTHY
|
||||
if any(c.status == HealthStatus.UNHEALTHY for c in components):
|
||||
overall_status = HealthStatus.UNHEALTHY
|
||||
elif any(c.status == HealthStatus.DEGRADED for c in components):
|
||||
overall_status = HealthStatus.DEGRADED
|
||||
|
||||
return HealthResponse(
|
||||
status=overall_status,
|
||||
version="2.0.0",
|
||||
uptime_seconds=time.time(), # Simplified uptime
|
||||
components=components
|
||||
)
|
||||
|
||||
# Metrics endpoint
|
||||
@app.get("/metrics", tags=["System"])
|
||||
async def metrics():
|
||||
"""Prometheus metrics endpoint."""
|
||||
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
|
||||
|
||||
# Context CRUD operations
|
||||
@app.post("/api/v1/contexts", response_model=ContextDetailResponse, tags=["Contexts"])
|
||||
@limiter.limit("100/minute")
|
||||
async def create_context(
|
||||
request: Request,
|
||||
context_data: ContextCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new context with automatic embedding generation."""
|
||||
try:
|
||||
# Create context object
|
||||
context = Context(
|
||||
id=None,
|
||||
path=context_data.path,
|
||||
content=context_data.content,
|
||||
summary=context_data.summary,
|
||||
author=context_data.author or current_user["username"],
|
||||
version=1
|
||||
)
|
||||
|
||||
# Store context
|
||||
context_id = self.context_db.store_context(context)
|
||||
|
||||
# Generate and store embedding in background
|
||||
background_tasks.add_task(self._generate_embedding_async, context_id, context_data.content)
|
||||
|
||||
# Get created context
|
||||
created_context = self.context_db.get_context(context_id)
|
||||
context_response = self._context_to_response(created_context)
|
||||
|
||||
# Update metrics
|
||||
CONTEXT_COUNT.inc()
|
||||
|
||||
# Notify WebSocket subscribers
|
||||
await self._notify_websocket_subscribers("created", context_response)
|
||||
|
||||
return ContextDetailResponse(data=context_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error creating context", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create context: {str(e)}")
|
||||
|
||||
@app.get("/api/v1/contexts/{context_id}", response_model=ContextDetailResponse, tags=["Contexts"])
|
||||
@limiter.limit("200/minute")
|
||||
async def get_context(
|
||||
request: Request,
|
||||
context_id: int,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific context by ID."""
|
||||
try:
|
||||
context = self.context_db.get_context(context_id)
|
||||
if not context:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
|
||||
context_response = self._context_to_response(context)
|
||||
return ContextDetailResponse(data=context_response)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving context", context_id=context_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to retrieve context: {str(e)}")
|
||||
|
||||
@app.get("/api/v1/contexts", response_model=ContextListResponse, tags=["Contexts"])
|
||||
@limiter.limit("100/minute")
|
||||
async def list_contexts(
|
||||
request: Request,
|
||||
pagination: PaginationParams = Depends(),
|
||||
path_prefix: Optional[str] = Query(None, description="Filter by path prefix"),
|
||||
author: Optional[str] = Query(None, description="Filter by author"),
|
||||
status: Optional[ContextStatus] = Query(None, description="Filter by status"),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""List contexts with filtering and pagination."""
|
||||
try:
|
||||
# Get contexts with filters
|
||||
contexts = self.context_db.get_contexts_filtered(
|
||||
path_prefix=path_prefix,
|
||||
author=author,
|
||||
status=status.value if status else None,
|
||||
limit=pagination.page_size,
|
||||
offset=pagination.offset
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = self.context_db.count_contexts(
|
||||
path_prefix=path_prefix,
|
||||
author=author,
|
||||
status=status.value if status else None
|
||||
)
|
||||
|
||||
# Convert to response models
|
||||
context_responses = [self._context_to_response(ctx) for ctx in contexts]
|
||||
|
||||
# Create pagination metadata
|
||||
pagination_meta = PaginationMeta(
|
||||
page=pagination.page,
|
||||
page_size=pagination.page_size,
|
||||
total_items=total_count,
|
||||
total_pages=(total_count + pagination.page_size - 1) // pagination.page_size,
|
||||
has_next=pagination.page * pagination.page_size < total_count,
|
||||
has_previous=pagination.page > 1
|
||||
)
|
||||
|
||||
return ContextListResponse(data=context_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error listing contexts", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list contexts: {str(e)}")
|
||||
|
||||
@app.put("/api/v1/contexts/{context_id}", response_model=ContextDetailResponse, tags=["Contexts"])
|
||||
@limiter.limit("50/minute")
|
||||
async def update_context(
|
||||
request: Request,
|
||||
context_id: int,
|
||||
context_update: ContextUpdate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update an existing context."""
|
||||
try:
|
||||
# Check if context exists
|
||||
existing_context = self.context_db.get_context(context_id)
|
||||
if not existing_context:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
|
||||
# Update context
|
||||
update_data = context_update.dict(exclude_unset=True)
|
||||
if update_data:
|
||||
self.context_db.update_context(context_id, **update_data)
|
||||
|
||||
# If content changed, regenerate embedding
|
||||
if 'content' in update_data:
|
||||
background_tasks.add_task(
|
||||
self._generate_embedding_async,
|
||||
context_id,
|
||||
update_data['content']
|
||||
)
|
||||
|
||||
# Get updated context
|
||||
updated_context = self.context_db.get_context(context_id)
|
||||
context_response = self._context_to_response(updated_context)
|
||||
|
||||
# Notify WebSocket subscribers
|
||||
await self._notify_websocket_subscribers("updated", context_response)
|
||||
|
||||
return ContextDetailResponse(data=context_response)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error updating context", context_id=context_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update context: {str(e)}")
|
||||
|
||||
@app.delete("/api/v1/contexts/{context_id}", tags=["Contexts"])
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_context(
|
||||
request: Request,
|
||||
context_id: int,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a context."""
|
||||
try:
|
||||
# Check if context exists
|
||||
existing_context = self.context_db.get_context(context_id)
|
||||
if not existing_context:
|
||||
raise HTTPException(status_code=404, detail="Context not found")
|
||||
|
||||
# Delete context
|
||||
success = self.context_db.delete_context(context_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete context")
|
||||
|
||||
# Update metrics
|
||||
CONTEXT_COUNT.dec()
|
||||
|
||||
# Notify WebSocket subscribers
|
||||
await self._notify_websocket_subscribers("deleted", {"id": context_id})
|
||||
|
||||
return {"success": True, "message": "Context deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error deleting context", context_id=context_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete context: {str(e)}")
|
||||
|
||||
# Search endpoints
|
||||
@app.post("/api/v1/search", response_model=SearchResponse, tags=["Search"])
|
||||
@limiter.limit("100/minute")
|
||||
async def search_contexts(
|
||||
request: Request,
|
||||
search_request: SearchRequest,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Advanced context search with multiple search types."""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Perform search based on type
|
||||
if search_request.search_type == SearchType.SEMANTIC:
|
||||
results = self.embedding_manager.semantic_search_optimized(
|
||||
search_request.query,
|
||||
path_prefix=search_request.path_prefix,
|
||||
top_k=search_request.top_k,
|
||||
include_contexts=True
|
||||
)
|
||||
elif search_request.search_type == SearchType.HYBRID:
|
||||
results = self.embedding_manager.hybrid_search_optimized(
|
||||
search_request.query,
|
||||
path_prefix=search_request.path_prefix,
|
||||
top_k=search_request.top_k,
|
||||
semantic_weight=search_request.semantic_weight
|
||||
)
|
||||
else:
|
||||
# Fallback to keyword search
|
||||
contexts = self.context_db.search_contexts(search_request.query)
|
||||
results = [type('Result', (), {'context': ctx, 'score': 1.0})() for ctx in contexts[:search_request.top_k]]
|
||||
|
||||
search_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert results to response format
|
||||
search_results = []
|
||||
for result in results:
|
||||
if hasattr(result, 'context') and result.context:
|
||||
context_response = self._context_to_response(result.context)
|
||||
context_response.similarity_score = getattr(result, 'score', None)
|
||||
|
||||
search_results.append(SearchResult(
|
||||
context=context_response,
|
||||
score=result.score,
|
||||
explanation=f"Matched with {result.score:.3f} similarity"
|
||||
))
|
||||
|
||||
# Update metrics
|
||||
SEARCH_COUNT.labels(search_type=search_request.search_type.value).inc()
|
||||
|
||||
return SearchResponse(
|
||||
data=search_results,
|
||||
query=search_request.query,
|
||||
search_type=search_request.search_type,
|
||||
total_results=len(search_results),
|
||||
search_time_ms=search_time,
|
||||
filters_applied=search_request.filters
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error performing search", query=search_request.query, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
# Batch operations
|
||||
@app.post("/api/v1/contexts/batch", response_model=BatchResponse, tags=["Batch Operations"])
|
||||
@limiter.limit("10/minute")
|
||||
async def batch_create_contexts(
|
||||
request: Request,
|
||||
batch_request: BatchContextCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Create multiple contexts in batch."""
|
||||
try:
|
||||
results = BatchOperationResult(
|
||||
success_count=0,
|
||||
error_count=0,
|
||||
total_items=len(batch_request.contexts)
|
||||
)
|
||||
|
||||
for i, context_data in enumerate(batch_request.contexts):
|
||||
try:
|
||||
context = Context(
|
||||
id=None,
|
||||
path=context_data.path,
|
||||
content=context_data.content,
|
||||
summary=context_data.summary,
|
||||
author=context_data.author or current_user["username"],
|
||||
version=1
|
||||
)
|
||||
|
||||
context_id = self.context_db.store_context(context)
|
||||
results.created_ids.append(context_id)
|
||||
results.success_count += 1
|
||||
|
||||
# Generate embedding in background
|
||||
background_tasks.add_task(
|
||||
self._generate_embedding_async,
|
||||
context_id,
|
||||
context_data.content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results.error_count += 1
|
||||
results.errors.append({
|
||||
"index": i,
|
||||
"path": context_data.path,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Update metrics
|
||||
CONTEXT_COUNT.inc(results.success_count)
|
||||
|
||||
return BatchResponse(data=results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in batch create", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Batch operation failed: {str(e)}")
|
||||
|
||||
# Statistics endpoint
|
||||
@app.get("/api/v1/stats", response_model=StatsResponse, tags=["Analytics"])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_statistics(
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get comprehensive system statistics."""
|
||||
try:
|
||||
# Get embedding manager stats
|
||||
emb_stats = self.embedding_manager.get_statistics()
|
||||
|
||||
# Mock context stats (implement based on your needs)
|
||||
context_stats = ContextStats(
|
||||
total_contexts=emb_stats["database_stats"]["total_embeddings"],
|
||||
contexts_by_status={ContextStatus.ACTIVE: emb_stats["database_stats"]["total_embeddings"]},
|
||||
contexts_by_author={"system": emb_stats["database_stats"]["total_embeddings"]},
|
||||
average_content_length=100.0,
|
||||
most_active_paths=[],
|
||||
recent_activity=[]
|
||||
)
|
||||
|
||||
search_stats = SearchStats(
|
||||
total_searches=100, # Mock data
|
||||
searches_by_type={SearchType.SEMANTIC: 60, SearchType.HYBRID: 40},
|
||||
average_response_time_ms=50.0,
|
||||
popular_queries=[],
|
||||
search_success_rate=0.95
|
||||
)
|
||||
|
||||
system_stats = SystemStats(
|
||||
uptime_seconds=time.time(),
|
||||
memory_usage_mb=100.0,
|
||||
active_connections=len(self.websocket_connections),
|
||||
cache_hit_rate=emb_stats["cache_stats"].get("hit_rate", 0.0),
|
||||
embedding_model_info=emb_stats["current_model"],
|
||||
database_size_mb=10.0
|
||||
)
|
||||
|
||||
return StatsResponse(
|
||||
context_stats=context_stats,
|
||||
search_stats=search_stats,
|
||||
system_stats=system_stats
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting statistics", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
|
||||
|
||||
# WebSocket endpoint
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time updates."""
|
||||
await self._handle_websocket_connection(websocket)
|
||||
|
||||
def _context_to_response(self, context) -> ContextResponse:
|
||||
"""Convert database context to API response model."""
|
||||
return ContextResponse(
|
||||
id=context.id,
|
||||
path=context.path,
|
||||
content=context.content,
|
||||
summary=context.summary,
|
||||
author=context.author or "unknown",
|
||||
tags=[], # TODO: implement tags
|
||||
metadata={}, # TODO: implement metadata
|
||||
status=ContextStatus.ACTIVE, # TODO: implement status
|
||||
created_at=context.created_at,
|
||||
updated_at=context.updated_at,
|
||||
version=context.version
|
||||
)
|
||||
|
||||
async def _generate_embedding_async(self, context_id: int, content: str):
|
||||
"""Generate and store embedding asynchronously."""
|
||||
try:
|
||||
embedding = self.embedding_manager.generate_embedding(content)
|
||||
self.embedding_manager.store_embedding(context_id, embedding)
|
||||
logger.info("Generated embedding for context", context_id=context_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate embedding", context_id=context_id, error=str(e))
|
||||
|
||||
async def _handle_websocket_connection(self, websocket: WebSocket):
|
||||
"""Handle WebSocket connection and subscriptions."""
|
||||
await websocket.accept()
|
||||
connection_id = str(id(websocket))
|
||||
self.websocket_connections[connection_id] = websocket
|
||||
ACTIVE_CONNECTIONS.inc()
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for subscription requests
|
||||
data = await websocket.receive_json()
|
||||
message = WebSocketMessage(**data)
|
||||
|
||||
if message.type == "subscribe":
|
||||
subscription = SubscriptionRequest(**message.data)
|
||||
self.subscriptions[connection_id] = {
|
||||
"path_prefix": subscription.path_prefix,
|
||||
"event_types": subscription.event_types,
|
||||
"filters": subscription.filters
|
||||
}
|
||||
await websocket.send_json({
|
||||
"type": "subscription_confirmed",
|
||||
"data": {"path_prefix": subscription.path_prefix}
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
# Cleanup
|
||||
self.websocket_connections.pop(connection_id, None)
|
||||
self.subscriptions.pop(connection_id, None)
|
||||
ACTIVE_CONNECTIONS.dec()
|
||||
|
||||
async def _notify_websocket_subscribers(self, event_type: str, data: Any):
|
||||
"""Notify WebSocket subscribers of events."""
|
||||
if not self.websocket_connections:
|
||||
return
|
||||
|
||||
# Create notification message
|
||||
notification = WebSocketMessage(
|
||||
type=event_type,
|
||||
data=data.dict() if hasattr(data, 'dict') else data
|
||||
)
|
||||
|
||||
# Send to all relevant subscribers
|
||||
for connection_id, websocket in list(self.websocket_connections.items()):
|
||||
try:
|
||||
subscription = self.subscriptions.get(connection_id)
|
||||
if subscription and event_type in subscription["event_types"]:
|
||||
# Check path filter
|
||||
if hasattr(data, 'path') and subscription["path_prefix"]:
|
||||
if not data.path.startswith(subscription["path_prefix"]):
|
||||
continue
|
||||
|
||||
await websocket.send_json(notification.dict())
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending WebSocket notification",
|
||||
connection_id=connection_id, error=str(e))
|
||||
# Remove failed connection
|
||||
self.websocket_connections.pop(connection_id, None)
|
||||
self.subscriptions.pop(connection_id, None)
|
||||
|
||||
def run(self, host: str = "0.0.0.0", port: int = 8000, **kwargs):
|
||||
"""Run the API server."""
|
||||
uvicorn.run(self.app, host=host, port=port, **kwargs)
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Factory function to create the app."""
|
||||
server = HCFSAPIServer()
|
||||
return server.app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
server = HCFSAPIServer()
|
||||
server.run()
|
||||
164
hcfs-python/hcfs/cli.py
Normal file
164
hcfs-python/hcfs/cli.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
HCFS Command Line Interface
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import pyfuse3
|
||||
import uvicorn
|
||||
|
||||
from .core.context_db import ContextDatabase, Context
|
||||
from .core.filesystem import HCFSFilesystem
|
||||
from .core.embeddings import EmbeddingManager
|
||||
from .api.server import create_app
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
"""HCFS - Context-Aware Hierarchical Context File System"""
|
||||
pass
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--mount-point", "-m", required=True, help="Mount point for HCFS")
|
||||
@click.option("--db-path", "-d", default="hcfs_context.db", help="Database path")
|
||||
@click.option("--foreground", "-f", is_flag=True, help="Run in foreground")
|
||||
def mount(mount_point: str, db_path: str, foreground: bool):
|
||||
"""Mount HCFS filesystem."""
|
||||
|
||||
async def run_filesystem():
|
||||
"""Run the FUSE filesystem."""
|
||||
context_db = ContextDatabase(db_path)
|
||||
fs = HCFSFilesystem(context_db, mount_point)
|
||||
|
||||
fuse_options = set(pyfuse3.default_options)
|
||||
fuse_options.add('fsname=hcfs')
|
||||
|
||||
if foreground:
|
||||
fuse_options.add('debug')
|
||||
|
||||
pyfuse3.init(fs, mount_point, fuse_options)
|
||||
|
||||
try:
|
||||
click.echo(f"HCFS mounted at {mount_point}")
|
||||
click.echo(f"Database: {db_path}")
|
||||
click.echo("Press Ctrl+C to unmount...")
|
||||
|
||||
await pyfuse3.main()
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\\nUnmounting HCFS...")
|
||||
finally:
|
||||
pyfuse3.close(unmount=True)
|
||||
|
||||
try:
|
||||
asyncio.run(run_filesystem())
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--db-path", "-d", default="hcfs_context.db", help="Database path")
|
||||
@click.option("--host", default="127.0.0.1", help="API server host")
|
||||
@click.option("--port", default=8000, help="API server port")
|
||||
def serve(db_path: str, host: str, port: int):
|
||||
"""Start HCFS API server."""
|
||||
app = create_app(db_path)
|
||||
|
||||
click.echo(f"Starting HCFS API server on {host}:{port}")
|
||||
click.echo(f"Database: {db_path}")
|
||||
click.echo(f"API docs: http://{host}:{port}/docs")
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--db-path", "-d", default="hcfs_context.db", help="Database path")
|
||||
@click.argument("path")
|
||||
@click.argument("content")
|
||||
@click.option("--author", "-a", help="Context author")
|
||||
@click.option("--summary", "-s", help="Context summary")
|
||||
def push(db_path: str, path: str, content: str, author: Optional[str], summary: Optional[str]):
|
||||
"""Push context to a path."""
|
||||
context_db = ContextDatabase(db_path)
|
||||
embedding_manager = EmbeddingManager(context_db)
|
||||
|
||||
context = Context(
|
||||
id=None,
|
||||
path=path,
|
||||
content=content,
|
||||
summary=summary,
|
||||
author=author or "cli_user"
|
||||
)
|
||||
|
||||
context_id = embedding_manager.store_context_with_embedding(context)
|
||||
click.echo(f"Context stored with ID: {context_id}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--db-path", "-d", default="hcfs_context.db", help="Database path")
|
||||
@click.argument("path")
|
||||
@click.option("--depth", default=1, help="Inheritance depth")
|
||||
def get(db_path: str, path: str, depth: int):
|
||||
"""Get contexts for a path."""
|
||||
context_db = ContextDatabase(db_path)
|
||||
contexts = context_db.get_context_by_path(path, depth=depth)
|
||||
|
||||
if not contexts:
|
||||
click.echo("No contexts found for path")
|
||||
return
|
||||
|
||||
for ctx in contexts:
|
||||
click.echo(f"\\n--- Context ID: {ctx.id} ---")
|
||||
click.echo(f"Path: {ctx.path}")
|
||||
click.echo(f"Author: {ctx.author}")
|
||||
click.echo(f"Created: {ctx.created_at}")
|
||||
click.echo(f"Content: {ctx.content}")
|
||||
if ctx.summary:
|
||||
click.echo(f"Summary: {ctx.summary}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--db-path", "-d", default="hcfs_context.db", help="Database path")
|
||||
@click.argument("query")
|
||||
@click.option("--path-prefix", "-p", help="Path prefix filter")
|
||||
@click.option("--top-k", "-k", default=5, help="Number of results")
|
||||
@click.option("--search-type", "-t", default="hybrid",
|
||||
type=click.Choice(["semantic", "hybrid"]), help="Search type")
|
||||
def search(db_path: str, query: str, path_prefix: Optional[str], top_k: int, search_type: str):
|
||||
"""Search contexts."""
|
||||
context_db = ContextDatabase(db_path)
|
||||
embedding_manager = EmbeddingManager(context_db)
|
||||
|
||||
if search_type == "semantic":
|
||||
results = embedding_manager.semantic_search(query, path_prefix, top_k)
|
||||
else:
|
||||
results = embedding_manager.hybrid_search(query, path_prefix, top_k)
|
||||
|
||||
if not results:
|
||||
click.echo("No results found")
|
||||
return
|
||||
|
||||
click.echo(f"Found {len(results)} results:\\n")
|
||||
|
||||
for ctx, score in results:
|
||||
click.echo(f"Score: {score:.4f} | Path: {ctx.path} | ID: {ctx.id}")
|
||||
click.echo(f"Content: {ctx.content[:100]}...")
|
||||
click.echo()
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--db-path", "-d", default="hcfs_context.db", help="Database path")
|
||||
def init(db_path: str):
|
||||
"""Initialize HCFS database."""
|
||||
context_db = ContextDatabase(db_path)
|
||||
click.echo(f"HCFS database initialized at {db_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
hcfs-python/hcfs/core/__init__.py
Normal file
1
hcfs-python/hcfs/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core HCFS components."""
|
||||
148
hcfs-python/hcfs/core/context_db.py
Normal file
148
hcfs-python/hcfs/core/context_db.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Context Database - Storage and retrieval of context blobs.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Text, Float
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class ContextBlob(Base):
|
||||
"""Database model for context blobs."""
|
||||
|
||||
__tablename__ = "context_blobs"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
path = Column(String(512), nullable=False, index=True)
|
||||
content = Column(Text, nullable=False)
|
||||
summary = Column(Text)
|
||||
embedding_model = Column(String(100))
|
||||
embedding_vector = Column(Text) # JSON serialized vector
|
||||
author = Column(String(100))
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
version = Column(Integer, default=1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
"""Context data structure."""
|
||||
id: Optional[int]
|
||||
path: str
|
||||
content: str
|
||||
summary: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
version: int = 1
|
||||
|
||||
|
||||
class ContextDatabase:
|
||||
"""Main interface for context storage and retrieval."""
|
||||
|
||||
def __init__(self, db_path: str = "hcfs_context.db"):
|
||||
self.db_path = db_path
|
||||
self.engine = create_engine(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(self.engine)
|
||||
self.SessionLocal = sessionmaker(bind=self.engine)
|
||||
|
||||
def get_session(self) -> Session:
|
||||
"""Get database session."""
|
||||
return self.SessionLocal()
|
||||
|
||||
def store_context(self, context: Context) -> int:
|
||||
"""Store a context blob and return its ID."""
|
||||
with self.get_session() as session:
|
||||
blob = ContextBlob(
|
||||
path=context.path,
|
||||
content=context.content,
|
||||
summary=context.summary,
|
||||
author=context.author,
|
||||
version=context.version
|
||||
)
|
||||
session.add(blob)
|
||||
session.commit()
|
||||
session.refresh(blob)
|
||||
return blob.id
|
||||
|
||||
def get_context_by_path(self, path: str, depth: int = 1) -> List[Context]:
|
||||
"""Retrieve contexts for a path and optionally parent paths."""
|
||||
contexts = []
|
||||
current_path = Path(path)
|
||||
|
||||
with self.get_session() as session:
|
||||
# Get contexts for current path and parents up to depth
|
||||
for i in range(depth + 1):
|
||||
search_path = str(current_path) if current_path != Path(".") else "/"
|
||||
|
||||
blobs = session.query(ContextBlob).filter(
|
||||
ContextBlob.path == search_path
|
||||
).order_by(ContextBlob.created_at.desc()).all()
|
||||
|
||||
for blob in blobs:
|
||||
contexts.append(Context(
|
||||
id=blob.id,
|
||||
path=blob.path,
|
||||
content=blob.content,
|
||||
summary=blob.summary,
|
||||
author=blob.author,
|
||||
created_at=blob.created_at,
|
||||
updated_at=blob.updated_at,
|
||||
version=blob.version
|
||||
))
|
||||
|
||||
if current_path.parent == current_path: # Root reached
|
||||
break
|
||||
current_path = current_path.parent
|
||||
|
||||
return contexts
|
||||
|
||||
def list_contexts_at_path(self, path: str) -> List[Context]:
|
||||
"""List all contexts at a specific path."""
|
||||
with self.get_session() as session:
|
||||
blobs = session.query(ContextBlob).filter(
|
||||
ContextBlob.path == path
|
||||
).order_by(ContextBlob.created_at.desc()).all()
|
||||
|
||||
return [Context(
|
||||
id=blob.id,
|
||||
path=blob.path,
|
||||
content=blob.content,
|
||||
summary=blob.summary,
|
||||
author=blob.author,
|
||||
created_at=blob.created_at,
|
||||
updated_at=blob.updated_at,
|
||||
version=blob.version
|
||||
) for blob in blobs]
|
||||
|
||||
def update_context(self, context_id: int, content: str, summary: str = None) -> bool:
|
||||
"""Update an existing context."""
|
||||
with self.get_session() as session:
|
||||
blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
|
||||
if blob:
|
||||
blob.content = content
|
||||
if summary:
|
||||
blob.summary = summary
|
||||
blob.version += 1
|
||||
blob.updated_at = datetime.utcnow()
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_context(self, context_id: int) -> bool:
|
||||
"""Delete a context by ID."""
|
||||
with self.get_session() as session:
|
||||
blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
|
||||
if blob:
|
||||
session.delete(blob)
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
188
hcfs-python/hcfs/core/embeddings.py
Normal file
188
hcfs-python/hcfs/core/embeddings.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Embedding Manager - Generate and manage context embeddings.
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
from .context_db import Context, ContextDatabase
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
"""
|
||||
Manages embeddings for context blobs and semantic similarity search.
|
||||
"""
|
||||
|
||||
def __init__(self, context_db: ContextDatabase, model_name: str = "all-MiniLM-L6-v2"):
|
||||
self.context_db = context_db
|
||||
self.model_name = model_name
|
||||
self.model = SentenceTransformer(model_name)
|
||||
self.tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=1000)
|
||||
self._tfidf_fitted = False
|
||||
|
||||
def generate_embedding(self, text: str) -> np.ndarray:
|
||||
"""Generate embedding for a text."""
|
||||
return self.model.encode(text, normalize_embeddings=True)
|
||||
|
||||
def store_context_with_embedding(self, context: Context) -> int:
|
||||
"""Store context and generate its embedding."""
|
||||
# Generate embedding
|
||||
embedding = self.generate_embedding(context.content)
|
||||
|
||||
# Store in database
|
||||
context_id = self.context_db.store_context(context)
|
||||
|
||||
# Update with embedding (you'd extend ContextBlob model for this)
|
||||
self._store_embedding(context_id, embedding)
|
||||
|
||||
return context_id
|
||||
|
||||
def _store_embedding(self, context_id: int, embedding: np.ndarray) -> None:
|
||||
"""Store embedding vector in database."""
|
||||
embedding_json = json.dumps(embedding.tolist())
|
||||
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
|
||||
if blob:
|
||||
blob.embedding_model = self.model_name
|
||||
blob.embedding_vector = embedding_json
|
||||
session.commit()
|
||||
|
||||
def semantic_search(self, query: str, path_prefix: str = None, top_k: int = 5) -> List[Tuple[Context, float]]:
|
||||
"""
|
||||
Perform semantic search for contexts similar to query.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
path_prefix: Optional path prefix to limit search scope
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of (Context, similarity_score) tuples
|
||||
"""
|
||||
query_embedding = self.generate_embedding(query)
|
||||
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
|
||||
query_filter = session.query(ContextBlob).filter(
|
||||
ContextBlob.embedding_vector.isnot(None)
|
||||
)
|
||||
|
||||
if path_prefix:
|
||||
query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix))
|
||||
|
||||
blobs = query_filter.all()
|
||||
|
||||
if not blobs:
|
||||
return []
|
||||
|
||||
# Calculate similarities
|
||||
similarities = []
|
||||
for blob in blobs:
|
||||
if blob.embedding_vector:
|
||||
stored_embedding = np.array(json.loads(blob.embedding_vector))
|
||||
similarity = cosine_similarity(
|
||||
query_embedding.reshape(1, -1),
|
||||
stored_embedding.reshape(1, -1)
|
||||
)[0][0]
|
||||
|
||||
context = Context(
|
||||
id=blob.id,
|
||||
path=blob.path,
|
||||
content=blob.content,
|
||||
summary=blob.summary,
|
||||
author=blob.author,
|
||||
created_at=blob.created_at,
|
||||
updated_at=blob.updated_at,
|
||||
version=blob.version
|
||||
)
|
||||
|
||||
similarities.append((context, float(similarity)))
|
||||
|
||||
# Sort by similarity and return top_k
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
return similarities[:top_k]
|
||||
|
||||
def hybrid_search(self, query: str, path_prefix: str = None, top_k: int = 5,
|
||||
semantic_weight: float = 0.7) -> List[Tuple[Context, float]]:
|
||||
"""
|
||||
Hybrid search combining semantic similarity and BM25.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
path_prefix: Optional path filter
|
||||
top_k: Number of results
|
||||
semantic_weight: Weight for semantic vs BM25 (0.0-1.0)
|
||||
"""
|
||||
# Get contexts for BM25
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
|
||||
query_filter = session.query(ContextBlob)
|
||||
if path_prefix:
|
||||
query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix))
|
||||
|
||||
blobs = query_filter.all()
|
||||
|
||||
if not blobs:
|
||||
return []
|
||||
|
||||
# Prepare documents for BM25
|
||||
documents = [blob.content for blob in blobs]
|
||||
|
||||
# Fit TF-IDF if not already fitted or refitting needed
|
||||
if not self._tfidf_fitted or len(documents) > 100: # Refit periodically
|
||||
self.tfidf_vectorizer.fit(documents)
|
||||
self._tfidf_fitted = True
|
||||
|
||||
# BM25 scoring (using TF-IDF as approximation)
|
||||
doc_vectors = self.tfidf_vectorizer.transform(documents)
|
||||
query_vector = self.tfidf_vectorizer.transform([query])
|
||||
bm25_scores = cosine_similarity(query_vector, doc_vectors)[0]
|
||||
|
||||
# Semantic scoring
|
||||
semantic_results = self.semantic_search(query, path_prefix, len(blobs))
|
||||
semantic_scores = {ctx.id: score for ctx, score in semantic_results}
|
||||
|
||||
# Combine scores
|
||||
combined_results = []
|
||||
for i, blob in enumerate(blobs):
|
||||
bm25_score = bm25_scores[i]
|
||||
semantic_score = semantic_scores.get(blob.id, 0.0)
|
||||
|
||||
combined_score = (semantic_weight * semantic_score +
|
||||
(1 - semantic_weight) * bm25_score)
|
||||
|
||||
context = Context(
|
||||
id=blob.id,
|
||||
path=blob.path,
|
||||
content=blob.content,
|
||||
summary=blob.summary,
|
||||
author=blob.author,
|
||||
created_at=blob.created_at,
|
||||
updated_at=blob.updated_at,
|
||||
version=blob.version
|
||||
)
|
||||
|
||||
combined_results.append((context, combined_score))
|
||||
|
||||
# Sort and return top results
|
||||
combined_results.sort(key=lambda x: x[1], reverse=True)
|
||||
return combined_results[:top_k]
|
||||
|
||||
def get_similar_contexts(self, context_id: int, top_k: int = 5) -> List[Tuple[Context, float]]:
|
||||
"""Find contexts similar to a given context."""
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
reference_blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
|
||||
|
||||
if not reference_blob or not reference_blob.content:
|
||||
return []
|
||||
|
||||
return self.semantic_search(reference_blob.content, top_k=top_k)
|
||||
616
hcfs-python/hcfs/core/embeddings_optimized.py
Normal file
616
hcfs-python/hcfs/core/embeddings_optimized.py
Normal file
@@ -0,0 +1,616 @@
|
||||
"""
|
||||
Optimized Embedding Manager - High-performance vector operations and storage.
|
||||
|
||||
This module provides enhanced embedding capabilities including:
|
||||
- Vector database integration with SQLite-Vec
|
||||
- Optimized batch processing and caching
|
||||
- Multiple embedding model support
|
||||
- Efficient similarity search with indexing
|
||||
- Memory-efficient embedding storage
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import sqlite3
|
||||
from typing import List, Dict, Optional, Tuple, Union, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
|
||||
from .context_db import Context, ContextDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModel:
|
||||
"""Configuration for embedding models."""
|
||||
name: str
|
||||
model_path: str
|
||||
dimension: int
|
||||
max_tokens: int = 512
|
||||
normalize: bool = True
|
||||
|
||||
@dataclass
|
||||
class VectorSearchResult:
|
||||
"""Result from vector search operations."""
|
||||
context_id: int
|
||||
score: float
|
||||
context: Optional[Context] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
class VectorCache:
|
||||
"""High-performance LRU cache for embeddings."""
|
||||
|
||||
def __init__(self, max_size: int = 5000, ttl_seconds: int = 3600):
|
||||
self.max_size = max_size
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.cache: Dict[str, Tuple[np.ndarray, float]] = {}
|
||||
self.access_times: Dict[str, float] = {}
|
||||
self.lock = threading.RLock()
|
||||
|
||||
def get(self, key: str) -> Optional[np.ndarray]:
|
||||
"""Get embedding from cache."""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
|
||||
if key in self.cache:
|
||||
embedding, created_time = self.cache[key]
|
||||
|
||||
# Check TTL
|
||||
if current_time - created_time < self.ttl_seconds:
|
||||
self.access_times[key] = current_time
|
||||
return embedding.copy()
|
||||
else:
|
||||
# Expired
|
||||
del self.cache[key]
|
||||
del self.access_times[key]
|
||||
return None
|
||||
|
||||
def put(self, key: str, embedding: np.ndarray) -> None:
|
||||
"""Store embedding in cache."""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
|
||||
# Evict if cache is full
|
||||
if len(self.cache) >= self.max_size:
|
||||
self._evict_lru()
|
||||
|
||||
self.cache[key] = (embedding.copy(), current_time)
|
||||
self.access_times[key] = current_time
|
||||
|
||||
def _evict_lru(self) -> None:
|
||||
"""Evict least recently used item."""
|
||||
if not self.access_times:
|
||||
return
|
||||
|
||||
lru_key = min(self.access_times.items(), key=lambda x: x[1])[0]
|
||||
del self.cache[lru_key]
|
||||
del self.access_times[lru_key]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear cache."""
|
||||
with self.lock:
|
||||
self.cache.clear()
|
||||
self.access_times.clear()
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
with self.lock:
|
||||
return {
|
||||
"size": len(self.cache),
|
||||
"max_size": self.max_size,
|
||||
"hit_rate": getattr(self, '_hits', 0) / max(getattr(self, '_requests', 1), 1),
|
||||
"ttl_seconds": self.ttl_seconds
|
||||
}
|
||||
|
||||
class OptimizedEmbeddingManager:
|
||||
"""
|
||||
High-performance embedding manager with vector database capabilities.
|
||||
"""
|
||||
|
||||
# Predefined embedding models
|
||||
MODELS = {
|
||||
"mini": EmbeddingModel("all-MiniLM-L6-v2", "all-MiniLM-L6-v2", 384),
|
||||
"base": EmbeddingModel("all-MiniLM-L12-v2", "all-MiniLM-L12-v2", 384),
|
||||
"large": EmbeddingModel("all-mpnet-base-v2", "all-mpnet-base-v2", 768),
|
||||
"multilingual": EmbeddingModel("paraphrase-multilingual-MiniLM-L12-v2",
|
||||
"paraphrase-multilingual-MiniLM-L12-v2", 384)
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
context_db: ContextDatabase,
|
||||
model_name: str = "mini",
|
||||
vector_db_path: Optional[str] = None,
|
||||
cache_size: int = 5000,
|
||||
batch_size: int = 32):
|
||||
self.context_db = context_db
|
||||
self.model_config = self.MODELS.get(model_name, self.MODELS["mini"])
|
||||
self.model = None # Lazy loading
|
||||
self.vector_cache = VectorCache(cache_size)
|
||||
self.batch_size = batch_size
|
||||
|
||||
# Vector database setup
|
||||
self.vector_db_path = vector_db_path or "hcfs_vectors.db"
|
||||
self._init_vector_db()
|
||||
|
||||
# TF-IDF for hybrid search
|
||||
self.tfidf_vectorizer = TfidfVectorizer(
|
||||
stop_words='english',
|
||||
max_features=5000,
|
||||
ngram_range=(1, 2),
|
||||
min_df=2
|
||||
)
|
||||
self._tfidf_fitted = False
|
||||
self._model_lock = threading.RLock()
|
||||
|
||||
logger.info(f"Initialized OptimizedEmbeddingManager with model: {self.model_config.name}")
|
||||
|
||||
def _get_model(self) -> SentenceTransformer:
|
||||
"""Lazy load the embedding model."""
|
||||
if self.model is None:
|
||||
with self._model_lock:
|
||||
if self.model is None:
|
||||
logger.info(f"Loading embedding model: {self.model_config.model_path}")
|
||||
self.model = SentenceTransformer(self.model_config.model_path)
|
||||
return self.model
|
||||
|
||||
def _init_vector_db(self) -> None:
|
||||
"""Initialize SQLite vector database for fast similarity search."""
|
||||
conn = sqlite3.connect(self.vector_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create vectors table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS context_vectors (
|
||||
context_id INTEGER PRIMARY KEY,
|
||||
model_name TEXT NOT NULL,
|
||||
embedding_dimension INTEGER NOT NULL,
|
||||
vector_data BLOB NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Create index for fast lookups
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_context_vectors_model
|
||||
ON context_vectors(model_name, context_id)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info(f"Vector database initialized: {self.vector_db_path}")
|
||||
|
||||
@contextmanager
|
||||
def _get_vector_db(self):
|
||||
"""Get vector database connection with proper cleanup."""
|
||||
conn = sqlite3.connect(self.vector_db_path)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def generate_embedding(self, text: str, use_cache: bool = True) -> np.ndarray:
|
||||
"""Generate embedding for text with caching."""
|
||||
cache_key = f"{self.model_config.name}:{hash(text)}"
|
||||
|
||||
if use_cache:
|
||||
cached = self.vector_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
model = self._get_model()
|
||||
embedding = model.encode(
|
||||
text,
|
||||
normalize_embeddings=self.model_config.normalize,
|
||||
show_progress_bar=False
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
self.vector_cache.put(cache_key, embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
def generate_embeddings_batch(self, texts: List[str], use_cache: bool = True) -> List[np.ndarray]:
|
||||
"""Generate embeddings for multiple texts efficiently."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Check cache first
|
||||
cache_results = []
|
||||
uncached_indices = []
|
||||
uncached_texts = []
|
||||
|
||||
if use_cache:
|
||||
for i, text in enumerate(texts):
|
||||
cache_key = f"{self.model_config.name}:{hash(text)}"
|
||||
cached = self.vector_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
cache_results.append((i, cached))
|
||||
else:
|
||||
uncached_indices.append(i)
|
||||
uncached_texts.append(text)
|
||||
else:
|
||||
uncached_indices = list(range(len(texts)))
|
||||
uncached_texts = texts
|
||||
|
||||
# Generate embeddings for uncached texts
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Place cached results
|
||||
for i, embedding in cache_results:
|
||||
embeddings[i] = embedding
|
||||
|
||||
if uncached_texts:
|
||||
model = self._get_model()
|
||||
|
||||
# Process in batches
|
||||
for batch_start in range(0, len(uncached_texts), self.batch_size):
|
||||
batch_end = min(batch_start + self.batch_size, len(uncached_texts))
|
||||
batch_texts = uncached_texts[batch_start:batch_end]
|
||||
batch_indices = uncached_indices[batch_start:batch_end]
|
||||
|
||||
batch_embeddings = model.encode(
|
||||
batch_texts,
|
||||
normalize_embeddings=self.model_config.normalize,
|
||||
show_progress_bar=False,
|
||||
batch_size=self.batch_size
|
||||
)
|
||||
|
||||
# Store results and cache
|
||||
for i, (orig_idx, embedding) in enumerate(zip(batch_indices, batch_embeddings)):
|
||||
embeddings[orig_idx] = embedding
|
||||
|
||||
if use_cache:
|
||||
cache_key = f"{self.model_config.name}:{hash(batch_texts[i])}"
|
||||
self.vector_cache.put(cache_key, embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
def store_embedding(self, context_id: int, embedding: np.ndarray) -> None:
|
||||
"""Store embedding in vector database."""
|
||||
with self._get_vector_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Convert to bytes for storage
|
||||
vector_bytes = embedding.astype(np.float32).tobytes()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO context_vectors
|
||||
(context_id, model_name, embedding_dimension, vector_data, updated_at)
|
||||
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||||
''', (context_id, self.model_config.name, embedding.shape[0], vector_bytes))
|
||||
|
||||
conn.commit()
|
||||
|
||||
def store_embeddings_batch(self, context_embeddings: List[Tuple[int, np.ndarray]]) -> None:
|
||||
"""Store multiple embeddings efficiently."""
|
||||
if not context_embeddings:
|
||||
return
|
||||
|
||||
with self._get_vector_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
data = [
|
||||
(context_id, self.model_config.name, embedding.shape[0],
|
||||
embedding.astype(np.float32).tobytes())
|
||||
for context_id, embedding in context_embeddings
|
||||
]
|
||||
|
||||
cursor.executemany('''
|
||||
INSERT OR REPLACE INTO context_vectors
|
||||
(context_id, model_name, embedding_dimension, vector_data, updated_at)
|
||||
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||||
''', data)
|
||||
|
||||
conn.commit()
|
||||
|
||||
logger.info(f"Stored {len(context_embeddings)} embeddings in batch")
|
||||
|
||||
def get_embedding(self, context_id: int) -> Optional[np.ndarray]:
|
||||
"""Retrieve embedding for a context."""
|
||||
with self._get_vector_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT vector_data, embedding_dimension FROM context_vectors
|
||||
WHERE context_id = ? AND model_name = ?
|
||||
''', (context_id, self.model_config.name))
|
||||
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
vector_bytes, dimension = result
|
||||
return np.frombuffer(vector_bytes, dtype=np.float32).reshape(dimension)
|
||||
|
||||
return None
|
||||
|
||||
def vector_similarity_search(self,
|
||||
query_embedding: np.ndarray,
|
||||
context_ids: Optional[List[int]] = None,
|
||||
top_k: int = 10,
|
||||
min_similarity: float = 0.0) -> List[VectorSearchResult]:
|
||||
"""Efficient vector similarity search."""
|
||||
with self._get_vector_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Build query
|
||||
if context_ids:
|
||||
placeholders = ','.join(['?'] * len(context_ids))
|
||||
query = f'''
|
||||
SELECT context_id, vector_data, embedding_dimension
|
||||
FROM context_vectors
|
||||
WHERE model_name = ? AND context_id IN ({placeholders})
|
||||
'''
|
||||
params = [self.model_config.name] + context_ids
|
||||
else:
|
||||
query = '''
|
||||
SELECT context_id, vector_data, embedding_dimension
|
||||
FROM context_vectors
|
||||
WHERE model_name = ?
|
||||
'''
|
||||
params = [self.model_config.name]
|
||||
|
||||
cursor.execute(query, params)
|
||||
results = cursor.fetchall()
|
||||
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Calculate similarities
|
||||
similarities = []
|
||||
query_embedding = query_embedding.reshape(1, -1)
|
||||
|
||||
for context_id, vector_bytes, dimension in results:
|
||||
stored_embedding = np.frombuffer(vector_bytes, dtype=np.float32).reshape(1, dimension)
|
||||
|
||||
similarity = cosine_similarity(query_embedding, stored_embedding)[0][0]
|
||||
|
||||
if similarity >= min_similarity:
|
||||
similarities.append(VectorSearchResult(
|
||||
context_id=context_id,
|
||||
score=float(similarity)
|
||||
))
|
||||
|
||||
# Sort by similarity and return top_k
|
||||
similarities.sort(key=lambda x: x.score, reverse=True)
|
||||
return similarities[:top_k]
|
||||
|
||||
def semantic_search_optimized(self,
|
||||
query: str,
|
||||
path_prefix: str = None,
|
||||
top_k: int = 5,
|
||||
include_contexts: bool = True) -> List[VectorSearchResult]:
|
||||
"""High-performance semantic search."""
|
||||
# Generate query embedding
|
||||
query_embedding = self.generate_embedding(query)
|
||||
|
||||
# Get relevant context IDs based on path filter
|
||||
context_ids = None
|
||||
if path_prefix:
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
blobs = session.query(ContextBlob.id).filter(
|
||||
ContextBlob.path.startswith(path_prefix)
|
||||
).all()
|
||||
context_ids = [blob.id for blob in blobs]
|
||||
|
||||
if not context_ids:
|
||||
return []
|
||||
|
||||
# Perform vector search
|
||||
results = self.vector_similarity_search(
|
||||
query_embedding,
|
||||
context_ids=context_ids,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Populate with context data if requested
|
||||
if include_contexts and results:
|
||||
context_map = {}
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
|
||||
result_ids = [r.context_id for r in results]
|
||||
blobs = session.query(ContextBlob).filter(
|
||||
ContextBlob.id.in_(result_ids)
|
||||
).all()
|
||||
|
||||
for blob in blobs:
|
||||
context_map[blob.id] = Context(
|
||||
id=blob.id,
|
||||
path=blob.path,
|
||||
content=blob.content,
|
||||
summary=blob.summary,
|
||||
author=blob.author,
|
||||
created_at=blob.created_at,
|
||||
updated_at=blob.updated_at,
|
||||
version=blob.version
|
||||
)
|
||||
|
||||
# Add contexts to results
|
||||
for result in results:
|
||||
result.context = context_map.get(result.context_id)
|
||||
|
||||
return results
|
||||
|
||||
def hybrid_search_optimized(self,
|
||||
query: str,
|
||||
path_prefix: str = None,
|
||||
top_k: int = 5,
|
||||
semantic_weight: float = 0.7,
|
||||
rerank_top_n: int = 50) -> List[VectorSearchResult]:
|
||||
"""Optimized hybrid search with two-stage ranking."""
|
||||
|
||||
# Stage 1: Fast semantic search to get candidate set
|
||||
semantic_results = self.semantic_search_optimized(
|
||||
query, path_prefix, rerank_top_n, include_contexts=True
|
||||
)
|
||||
|
||||
if not semantic_results or len(semantic_results) < 2:
|
||||
return semantic_results[:top_k]
|
||||
|
||||
# Stage 2: Re-rank with BM25 scores
|
||||
contexts = [r.context for r in semantic_results if r.context]
|
||||
if not contexts:
|
||||
return semantic_results[:top_k]
|
||||
|
||||
documents = [ctx.content for ctx in contexts]
|
||||
|
||||
# Compute BM25 scores
|
||||
try:
|
||||
if not self._tfidf_fitted:
|
||||
self.tfidf_vectorizer.fit(documents)
|
||||
self._tfidf_fitted = True
|
||||
|
||||
doc_vectors = self.tfidf_vectorizer.transform(documents)
|
||||
query_vector = self.tfidf_vectorizer.transform([query])
|
||||
bm25_scores = cosine_similarity(query_vector, doc_vectors)[0]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"BM25 scoring failed: {e}, using semantic only")
|
||||
return semantic_results[:top_k]
|
||||
|
||||
# Combine scores
|
||||
for i, result in enumerate(semantic_results[:len(bm25_scores)]):
|
||||
semantic_score = result.score
|
||||
bm25_score = bm25_scores[i]
|
||||
|
||||
combined_score = (semantic_weight * semantic_score +
|
||||
(1 - semantic_weight) * bm25_score)
|
||||
|
||||
result.score = float(combined_score)
|
||||
result.metadata = {
|
||||
"semantic_score": float(semantic_score),
|
||||
"bm25_score": float(bm25_score),
|
||||
"semantic_weight": semantic_weight
|
||||
}
|
||||
|
||||
# Re-sort by combined score
|
||||
semantic_results.sort(key=lambda x: x.score, reverse=True)
|
||||
return semantic_results[:top_k]
|
||||
|
||||
def build_embeddings_index(self, batch_size: int = 100) -> Dict[str, Any]:
|
||||
"""Build embeddings for all contexts without embeddings."""
|
||||
start_time = time.time()
|
||||
|
||||
# Get contexts without embeddings
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
|
||||
# Find contexts missing embeddings
|
||||
with self._get_vector_db() as vector_conn:
|
||||
vector_cursor = vector_conn.cursor()
|
||||
vector_cursor.execute('''
|
||||
SELECT context_id FROM context_vectors
|
||||
WHERE model_name = ?
|
||||
''', (self.model_config.name,))
|
||||
|
||||
existing_ids = {row[0] for row in vector_cursor.fetchall()}
|
||||
|
||||
# Get contexts that need embeddings
|
||||
all_blobs = session.query(ContextBlob).all()
|
||||
missing_blobs = [blob for blob in all_blobs if blob.id not in existing_ids]
|
||||
|
||||
if not missing_blobs:
|
||||
return {
|
||||
"total_processed": 0,
|
||||
"processing_time": 0,
|
||||
"embeddings_per_second": 0,
|
||||
"message": "All contexts already have embeddings"
|
||||
}
|
||||
|
||||
logger.info(f"Building embeddings for {len(missing_blobs)} contexts")
|
||||
|
||||
# Process in batches
|
||||
total_processed = 0
|
||||
for batch_start in range(0, len(missing_blobs), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(missing_blobs))
|
||||
batch_blobs = missing_blobs[batch_start:batch_end]
|
||||
|
||||
# Generate embeddings for batch
|
||||
texts = [blob.content for blob in batch_blobs]
|
||||
embeddings = self.generate_embeddings_batch(texts, use_cache=False)
|
||||
|
||||
# Store embeddings
|
||||
context_embeddings = [
|
||||
(blob.id, embedding)
|
||||
for blob, embedding in zip(batch_blobs, embeddings)
|
||||
]
|
||||
self.store_embeddings_batch(context_embeddings)
|
||||
|
||||
total_processed += len(batch_blobs)
|
||||
logger.info(f"Processed {total_processed}/{len(missing_blobs)} contexts")
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
embeddings_per_second = total_processed / processing_time if processing_time > 0 else 0
|
||||
|
||||
return {
|
||||
"total_processed": total_processed,
|
||||
"processing_time": processing_time,
|
||||
"embeddings_per_second": embeddings_per_second,
|
||||
"model_used": self.model_config.name,
|
||||
"embedding_dimension": self.model_config.dimension
|
||||
}
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get embedding manager statistics."""
|
||||
with self._get_vector_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT
|
||||
COUNT(*) as total_embeddings,
|
||||
COUNT(DISTINCT model_name) as unique_models,
|
||||
AVG(embedding_dimension) as avg_dimension
|
||||
FROM context_vectors
|
||||
''')
|
||||
|
||||
db_stats = cursor.fetchone()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT model_name, COUNT(*) as count
|
||||
FROM context_vectors
|
||||
GROUP BY model_name
|
||||
''')
|
||||
|
||||
model_counts = dict(cursor.fetchall())
|
||||
|
||||
return {
|
||||
"database_stats": {
|
||||
"total_embeddings": db_stats[0] if db_stats else 0,
|
||||
"unique_models": db_stats[1] if db_stats else 0,
|
||||
"average_dimension": db_stats[2] if db_stats else 0,
|
||||
"model_counts": model_counts
|
||||
},
|
||||
"cache_stats": self.vector_cache.stats(),
|
||||
"current_model": asdict(self.model_config),
|
||||
"vector_db_path": self.vector_db_path,
|
||||
"batch_size": self.batch_size
|
||||
}
|
||||
|
||||
def cleanup_old_embeddings(self, days_old: int = 30) -> int:
|
||||
"""Remove old unused embeddings."""
|
||||
with self._get_vector_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
DELETE FROM context_vectors
|
||||
WHERE updated_at < datetime('now', '-{} days')
|
||||
AND context_id NOT IN (
|
||||
SELECT id FROM context_blobs
|
||||
)
|
||||
'''.format(days_old))
|
||||
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
logger.info(f"Cleaned up {deleted_count} old embeddings")
|
||||
return deleted_count
|
||||
136
hcfs-python/hcfs/core/embeddings_trio.py
Normal file
136
hcfs-python/hcfs/core/embeddings_trio.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Trio-compatible wrapper for OptimizedEmbeddingManager.
|
||||
|
||||
This module provides async compatibility for the optimized embedding system
|
||||
to work with FUSE filesystem operations that require Trio async context.
|
||||
"""
|
||||
|
||||
import trio
|
||||
from typing import List, Dict, Optional, Tuple, Any
|
||||
|
||||
from .embeddings_optimized import OptimizedEmbeddingManager, VectorSearchResult
|
||||
from .context_db import Context
|
||||
|
||||
|
||||
class TrioOptimizedEmbeddingManager:
|
||||
"""
|
||||
Trio-compatible async wrapper for OptimizedEmbeddingManager.
|
||||
"""
|
||||
|
||||
def __init__(self, sync_embedding_manager: OptimizedEmbeddingManager):
|
||||
self.sync_manager = sync_embedding_manager
|
||||
|
||||
async def generate_embedding(self, text: str, use_cache: bool = True) -> 'np.ndarray':
|
||||
"""Generate embedding asynchronously."""
|
||||
return await trio.to_thread.run_sync(
|
||||
self.sync_manager.generate_embedding,
|
||||
text,
|
||||
use_cache
|
||||
)
|
||||
|
||||
async def generate_embeddings_batch(self, texts: List[str], use_cache: bool = True) -> List['np.ndarray']:
|
||||
"""Generate embeddings for multiple texts asynchronously."""
|
||||
return await trio.to_thread.run_sync(
|
||||
self.sync_manager.generate_embeddings_batch,
|
||||
texts,
|
||||
use_cache
|
||||
)
|
||||
|
||||
async def store_embedding(self, context_id: int, embedding: 'np.ndarray') -> None:
|
||||
"""Store embedding asynchronously."""
|
||||
await trio.to_thread.run_sync(
|
||||
self.sync_manager.store_embedding,
|
||||
context_id,
|
||||
embedding
|
||||
)
|
||||
|
||||
async def store_embeddings_batch(self, context_embeddings: List[Tuple[int, 'np.ndarray']]) -> None:
|
||||
"""Store multiple embeddings asynchronously."""
|
||||
await trio.to_thread.run_sync(
|
||||
self.sync_manager.store_embeddings_batch,
|
||||
context_embeddings
|
||||
)
|
||||
|
||||
async def get_embedding(self, context_id: int) -> Optional['np.ndarray']:
|
||||
"""Retrieve embedding asynchronously."""
|
||||
return await trio.to_thread.run_sync(
|
||||
self.sync_manager.get_embedding,
|
||||
context_id
|
||||
)
|
||||
|
||||
async def semantic_search_optimized(self,
|
||||
query: str,
|
||||
path_prefix: str = None,
|
||||
top_k: int = 5,
|
||||
include_contexts: bool = True) -> List[VectorSearchResult]:
|
||||
"""Perform semantic search asynchronously."""
|
||||
return await trio.to_thread.run_sync(
|
||||
self.sync_manager.semantic_search_optimized,
|
||||
query,
|
||||
path_prefix,
|
||||
top_k,
|
||||
include_contexts
|
||||
)
|
||||
|
||||
async def hybrid_search_optimized(self,
|
||||
query: str,
|
||||
path_prefix: str = None,
|
||||
top_k: int = 5,
|
||||
semantic_weight: float = 0.7,
|
||||
rerank_top_n: int = 50) -> List[VectorSearchResult]:
|
||||
"""Perform hybrid search asynchronously."""
|
||||
return await trio.to_thread.run_sync(
|
||||
self.sync_manager.hybrid_search_optimized,
|
||||
query,
|
||||
path_prefix,
|
||||
top_k,
|
||||
semantic_weight,
|
||||
rerank_top_n
|
||||
)
|
||||
|
||||
async def vector_similarity_search(self,
|
||||
query_embedding: 'np.ndarray',
|
||||
context_ids: Optional[List[int]] = None,
|
||||
top_k: int = 10,
|
||||
min_similarity: float = 0.0) -> List[VectorSearchResult]:
|
||||
"""Perform vector similarity search asynchronously."""
|
||||
return await trio.to_thread.run_sync(
|
||||
self.sync_manager.vector_similarity_search,
|
||||
query_embedding,
|
||||
context_ids,
|
||||
top_k,
|
||||
min_similarity
|
||||
)
|
||||
|
||||
async def build_embeddings_index(self, batch_size: int = 100) -> Dict[str, Any]:
|
||||
"""Build embeddings index asynchronously."""
|
||||
return await trio.to_thread.run_sync(
|
||||
self.sync_manager.build_embeddings_index,
|
||||
batch_size
|
||||
)
|
||||
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get statistics asynchronously."""
|
||||
return await trio.to_thread.run_sync(
|
||||
self.sync_manager.get_statistics
|
||||
)
|
||||
|
||||
async def cleanup_old_embeddings(self, days_old: int = 30) -> int:
|
||||
"""Clean up old embeddings asynchronously."""
|
||||
return await trio.to_thread.run_sync(
|
||||
self.sync_manager.cleanup_old_embeddings,
|
||||
days_old
|
||||
)
|
||||
|
||||
# Synchronous access to underlying manager properties
|
||||
@property
|
||||
def model_config(self):
|
||||
return self.sync_manager.model_config
|
||||
|
||||
@property
|
||||
def vector_cache(self):
|
||||
return self.sync_manager.vector_cache
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self.sync_manager.batch_size
|
||||
179
hcfs-python/hcfs/core/filesystem.py
Normal file
179
hcfs-python/hcfs/core/filesystem.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
HCFS Filesystem - FUSE-based virtual filesystem layer.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
import errno
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
import pyfuse3
|
||||
from pyfuse3 import FUSEError
|
||||
|
||||
from .context_db import ContextDatabase, Context
|
||||
|
||||
|
||||
class HCFSFilesystem(pyfuse3.Operations):
|
||||
"""
|
||||
HCFS FUSE filesystem implementation.
|
||||
|
||||
Maps directory navigation to context scope and provides
|
||||
virtual files for context access.
|
||||
"""
|
||||
|
||||
def __init__(self, context_db: ContextDatabase, mount_point: str):
|
||||
super().__init__()
|
||||
self.context_db = context_db
|
||||
self.mount_point = mount_point
|
||||
self._inode_counter = 1
|
||||
self._inode_to_path: Dict[int, str] = {1: "/"} # Root inode
|
||||
self._path_to_inode: Dict[str, int] = {"/": 1}
|
||||
|
||||
# Virtual files
|
||||
self.CONTEXT_FILE = ".context"
|
||||
self.CONTEXT_LIST_FILE = ".context_list"
|
||||
self.CONTEXT_PUSH_FILE = ".context_push"
|
||||
|
||||
def _get_inode(self, path: str) -> int:
|
||||
"""Get or create inode for path."""
|
||||
if path in self._path_to_inode:
|
||||
return self._path_to_inode[path]
|
||||
|
||||
self._inode_counter += 1
|
||||
inode = self._inode_counter
|
||||
self._inode_to_path[inode] = path
|
||||
self._path_to_inode[path] = inode
|
||||
return inode
|
||||
|
||||
def _get_path(self, inode: int) -> str:
|
||||
"""Get path for inode."""
|
||||
return self._inode_to_path.get(inode, "/")
|
||||
|
||||
def _is_virtual_file(self, path: str) -> bool:
|
||||
"""Check if path is a virtual context file."""
|
||||
basename = os.path.basename(path)
|
||||
return basename in [self.CONTEXT_FILE, self.CONTEXT_LIST_FILE, self.CONTEXT_PUSH_FILE]
|
||||
|
||||
async def getattr(self, inode: int, ctx=None) -> pyfuse3.EntryAttributes:
|
||||
"""Get file attributes."""
|
||||
path = self._get_path(inode)
|
||||
entry = pyfuse3.EntryAttributes()
|
||||
entry.st_ino = inode
|
||||
entry.st_uid = os.getuid()
|
||||
entry.st_gid = os.getgid()
|
||||
entry.st_atime_ns = int(time.time() * 1e9)
|
||||
entry.st_mtime_ns = int(time.time() * 1e9)
|
||||
entry.st_ctime_ns = int(time.time() * 1e9)
|
||||
|
||||
if self._is_virtual_file(path):
|
||||
# Virtual files are readable text files
|
||||
entry.st_mode = stat.S_IFREG | 0o644
|
||||
entry.st_size = 1024 # Placeholder size
|
||||
else:
|
||||
# Directories
|
||||
entry.st_mode = stat.S_IFDIR | 0o755
|
||||
entry.st_size = 0
|
||||
|
||||
return entry
|
||||
|
||||
async def lookup(self, parent_inode: int, name: bytes, ctx=None) -> pyfuse3.EntryAttributes:
|
||||
"""Look up a directory entry."""
|
||||
parent_path = self._get_path(parent_inode)
|
||||
child_path = os.path.join(parent_path, name.decode('utf-8'))
|
||||
|
||||
# Normalize path
|
||||
if child_path.startswith("//"):
|
||||
child_path = child_path[1:]
|
||||
|
||||
child_inode = self._get_inode(child_path)
|
||||
return await self.getattr(child_inode, ctx)
|
||||
|
||||
async def opendir(self, inode: int, ctx=None) -> int:
|
||||
"""Open directory."""
|
||||
return inode
|
||||
|
||||
async def readdir(self, inode: int, start_id: int, token) -> None:
|
||||
"""Read directory contents."""
|
||||
path = self._get_path(inode)
|
||||
|
||||
# Always show virtual context files in every directory
|
||||
entries = [
|
||||
(self.CONTEXT_FILE, await self.getattr(self._get_inode(os.path.join(path, self.CONTEXT_FILE)))),
|
||||
(self.CONTEXT_LIST_FILE, await self.getattr(self._get_inode(os.path.join(path, self.CONTEXT_LIST_FILE)))),
|
||||
(self.CONTEXT_PUSH_FILE, await self.getattr(self._get_inode(os.path.join(path, self.CONTEXT_PUSH_FILE)))),
|
||||
]
|
||||
|
||||
# Add subdirectories (you might want to make this dynamic based on context paths)
|
||||
# For now, allowing any directory to be created by navigation
|
||||
|
||||
for i, (name, attr) in enumerate(entries):
|
||||
if i >= start_id:
|
||||
if not pyfuse3.readdir_reply(token, name.encode('utf-8'), attr, i + 1):
|
||||
break
|
||||
|
||||
async def open(self, inode: int, flags: int, ctx=None) -> int:
|
||||
"""Open file."""
|
||||
path = self._get_path(inode)
|
||||
if not self._is_virtual_file(path):
|
||||
raise FUSEError(errno.EISDIR)
|
||||
return inode
|
||||
|
||||
async def read(self, fh: int, offset: int, size: int) -> bytes:
|
||||
"""Read from virtual files."""
|
||||
path = self._get_path(fh)
|
||||
basename = os.path.basename(path)
|
||||
dir_path = os.path.dirname(path)
|
||||
|
||||
if basename == self.CONTEXT_FILE:
|
||||
# Return aggregated context for current directory
|
||||
contexts = self.context_db.get_context_by_path(dir_path, depth=1)
|
||||
content = "\\n".join(f"[{ctx.path}] {ctx.content}" for ctx in contexts)
|
||||
|
||||
elif basename == self.CONTEXT_LIST_FILE:
|
||||
# List contexts at current path
|
||||
contexts = self.context_db.list_contexts_at_path(dir_path)
|
||||
content = "\\n".join(f"ID: {ctx.id}, Path: {ctx.path}, Author: {ctx.author}, Created: {ctx.created_at}"
|
||||
for ctx in contexts)
|
||||
|
||||
elif basename == self.CONTEXT_PUSH_FILE:
|
||||
# Instructions for pushing context
|
||||
content = f"Write to this file to push context to path: {dir_path}\\nFormat: <content>"
|
||||
|
||||
else:
|
||||
content = "Unknown virtual file"
|
||||
|
||||
content_bytes = content.encode('utf-8')
|
||||
return content_bytes[offset:offset + size]
|
||||
|
||||
async def write(self, fh: int, offset: int, data: bytes) -> int:
|
||||
"""Write to virtual files (context_push only)."""
|
||||
path = self._get_path(fh)
|
||||
basename = os.path.basename(path)
|
||||
dir_path = os.path.dirname(path)
|
||||
|
||||
if basename == self.CONTEXT_PUSH_FILE:
|
||||
# Push new context to current directory
|
||||
content = data.decode('utf-8').strip()
|
||||
context = Context(
|
||||
id=None,
|
||||
path=dir_path,
|
||||
content=content,
|
||||
author="fuse_user"
|
||||
)
|
||||
self.context_db.store_context(context)
|
||||
return len(data)
|
||||
else:
|
||||
raise FUSEError(errno.EACCES)
|
||||
|
||||
async def mkdir(self, parent_inode: int, name: bytes, mode: int, ctx=None) -> pyfuse3.EntryAttributes:
|
||||
"""Create directory (virtual - just for navigation)."""
|
||||
parent_path = self._get_path(parent_inode)
|
||||
new_path = os.path.join(parent_path, name.decode('utf-8'))
|
||||
|
||||
if new_path.startswith("//"):
|
||||
new_path = new_path[1:]
|
||||
|
||||
new_inode = self._get_inode(new_path)
|
||||
return await self.getattr(new_inode, ctx)
|
||||
48
hcfs-python/hcfs/sdk/__init__.py
Normal file
48
hcfs-python/hcfs/sdk/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
HCFS Python Agent SDK
|
||||
|
||||
A comprehensive SDK for AI agents to interact with the HCFS API.
|
||||
Provides high-level abstractions, caching, async support, and utilities.
|
||||
"""
|
||||
|
||||
from .client import HCFSClient
|
||||
from .async_client import HCFSAsyncClient
|
||||
from .models import *
|
||||
from .exceptions import *
|
||||
from .utils import *
|
||||
from .decorators import *
|
||||
|
||||
__version__ = "2.0.0"
|
||||
__all__ = [
|
||||
# Core clients
|
||||
"HCFSClient",
|
||||
"HCFSAsyncClient",
|
||||
|
||||
# Models and data structures
|
||||
"Context",
|
||||
"SearchResult",
|
||||
"ContextFilter",
|
||||
"PaginationOptions",
|
||||
"CacheConfig",
|
||||
"RetryConfig",
|
||||
|
||||
# Exceptions
|
||||
"HCFSError",
|
||||
"HCFSConnectionError",
|
||||
"HCFSAuthenticationError",
|
||||
"HCFSNotFoundError",
|
||||
"HCFSValidationError",
|
||||
"HCFSRateLimitError",
|
||||
|
||||
# Utilities
|
||||
"context_similarity",
|
||||
"batch_processor",
|
||||
"text_chunker",
|
||||
"embedding_cache",
|
||||
|
||||
# Decorators
|
||||
"cached_context",
|
||||
"retry_on_failure",
|
||||
"rate_limited",
|
||||
"context_manager"
|
||||
]
|
||||
667
hcfs-python/hcfs/sdk/async_client.py
Normal file
667
hcfs-python/hcfs/sdk/async_client.py
Normal file
@@ -0,0 +1,667 @@
|
||||
"""
|
||||
HCFS Asynchronous Client
|
||||
|
||||
High-level asynchronous client for HCFS API operations with WebSocket support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any, AsyncIterator, Callable
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed, WebSocketException
|
||||
|
||||
from .models import (
|
||||
Context, SearchResult, ContextFilter, PaginationOptions,
|
||||
SearchOptions, ClientConfig, AnalyticsData, BatchResult, StreamEvent
|
||||
)
|
||||
from .exceptions import (
|
||||
HCFSError, HCFSConnectionError, HCFSAuthenticationError,
|
||||
HCFSNotFoundError, HCFSValidationError, HCFSStreamError, handle_api_error
|
||||
)
|
||||
from .utils import MemoryCache, validate_path, normalize_path
|
||||
from .decorators import cached_context, retry_on_failure, rate_limited
|
||||
|
||||
|
||||
class HCFSAsyncClient:
|
||||
"""
|
||||
Asynchronous HCFS API client with WebSocket streaming capabilities.
|
||||
|
||||
This client provides async/await support for all operations and includes
|
||||
real-time streaming capabilities through WebSocket connections.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from hcfs.sdk import HCFSAsyncClient, Context
|
||||
>>>
|
||||
>>> async def main():
|
||||
... async with HCFSAsyncClient(
|
||||
... base_url="https://api.hcfs.example.com",
|
||||
... api_key="your-api-key"
|
||||
... ) as client:
|
||||
... # Create a context
|
||||
... context = Context(
|
||||
... path="/docs/async_readme",
|
||||
... content="Async README content",
|
||||
... summary="Async documentation"
|
||||
... )
|
||||
... created = await client.create_context(context)
|
||||
...
|
||||
... # Search with async
|
||||
... results = await client.search_contexts("async README")
|
||||
... async for result in results:
|
||||
... print(f"Found: {result.context.path}")
|
||||
>>>
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClientConfig] = None, **kwargs):
|
||||
"""
|
||||
Initialize async HCFS client.
|
||||
|
||||
Args:
|
||||
config: Client configuration object
|
||||
**kwargs: Configuration overrides
|
||||
"""
|
||||
# Merge configuration
|
||||
if config:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = ClientConfig(**kwargs)
|
||||
|
||||
# HTTP client will be initialized in __aenter__
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
self.websocket: Optional[websockets.WebSocketServerProtocol] = None
|
||||
self._websocket_listeners: List[Callable[[StreamEvent], None]] = []
|
||||
self._websocket_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Initialize cache
|
||||
self._cache = MemoryCache(
|
||||
max_size=self.config.cache.max_size,
|
||||
strategy=self.config.cache.strategy,
|
||||
ttl_seconds=self.config.cache.ttl_seconds
|
||||
) if self.config.cache.enabled else None
|
||||
|
||||
# Analytics
|
||||
self.analytics = AnalyticsData()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self._initialize_http_client()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
|
||||
async def _initialize_http_client(self):
|
||||
"""Initialize the HTTP client with proper configuration."""
|
||||
headers = {
|
||||
"User-Agent": self.config.user_agent,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
if self.config.api_key:
|
||||
headers["X-API-Key"] = self.config.api_key
|
||||
elif self.config.jwt_token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwt_token}"
|
||||
|
||||
# Configure timeouts
|
||||
timeout = httpx.Timeout(
|
||||
connect=self.config.timeout,
|
||||
read=self.config.timeout,
|
||||
write=self.config.timeout,
|
||||
pool=self.config.timeout * 2
|
||||
)
|
||||
|
||||
# Configure connection limits
|
||||
limits = httpx.Limits(
|
||||
max_connections=self.config.max_connections,
|
||||
max_keepalive_connections=self.config.max_keepalive_connections
|
||||
)
|
||||
|
||||
self.http_client = httpx.AsyncClient(
|
||||
base_url=self.config.base_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
limits=limits,
|
||||
follow_redirects=True
|
||||
)
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Check API health status asynchronously.
|
||||
|
||||
Returns:
|
||||
Health status information
|
||||
|
||||
Raises:
|
||||
HCFSConnectionError: If health check fails
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.get("/health")
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("health_check", success=True)
|
||||
return response.json()
|
||||
else:
|
||||
self._update_analytics("health_check", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("health_check", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Health check failed: {str(e)}")
|
||||
|
||||
@cached_context()
|
||||
@retry_on_failure()
|
||||
async def create_context(self, context: Context) -> Context:
|
||||
"""
|
||||
Create a new context asynchronously.
|
||||
|
||||
Args:
|
||||
context: Context object to create
|
||||
|
||||
Returns:
|
||||
Created context with assigned ID
|
||||
|
||||
Raises:
|
||||
HCFSValidationError: If context data is invalid
|
||||
HCFSError: If creation fails
|
||||
"""
|
||||
if not validate_path(context.path):
|
||||
raise HCFSValidationError(f"Invalid context path: {context.path}")
|
||||
|
||||
context.path = normalize_path(context.path)
|
||||
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
"/api/v1/contexts",
|
||||
json=context.to_create_dict()
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
created_context = Context(**data)
|
||||
self._update_analytics("create_context", success=True)
|
||||
return created_context
|
||||
else:
|
||||
self._update_analytics("create_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("create_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to create context: {str(e)}")
|
||||
|
||||
@cached_context()
|
||||
async def get_context(self, context_id: int) -> Context:
|
||||
"""
|
||||
Retrieve a context by ID asynchronously.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
|
||||
Returns:
|
||||
Context object
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.get(f"/api/v1/contexts/{context_id}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
context = Context(**data)
|
||||
self._update_analytics("get_context", success=True)
|
||||
return context
|
||||
else:
|
||||
self._update_analytics("get_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("get_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to get context: {str(e)}")
|
||||
|
||||
async def list_contexts(self,
|
||||
filter_opts: Optional[ContextFilter] = None,
|
||||
pagination: Optional[PaginationOptions] = None) -> List[Context]:
|
||||
"""
|
||||
List contexts with filtering and pagination asynchronously.
|
||||
|
||||
Args:
|
||||
filter_opts: Context filtering options
|
||||
pagination: Pagination configuration
|
||||
|
||||
Returns:
|
||||
List of contexts
|
||||
"""
|
||||
params = {}
|
||||
|
||||
if filter_opts:
|
||||
params.update(filter_opts.to_query_params())
|
||||
|
||||
if pagination:
|
||||
params.update(pagination.to_query_params())
|
||||
|
||||
try:
|
||||
response = await self.http_client.get("/api/v1/contexts", params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
contexts = [Context(**ctx_data) for ctx_data in data]
|
||||
self._update_analytics("list_contexts", success=True)
|
||||
return contexts
|
||||
else:
|
||||
self._update_analytics("list_contexts", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("list_contexts", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to list contexts: {str(e)}")
|
||||
|
||||
async def update_context(self, context_id: int, updates: Dict[str, Any]) -> Context:
|
||||
"""
|
||||
Update an existing context asynchronously.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
updates: Fields to update
|
||||
|
||||
Returns:
|
||||
Updated context
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
HCFSValidationError: If update data is invalid
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.put(
|
||||
f"/api/v1/contexts/{context_id}",
|
||||
json=updates
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
updated_context = Context(**data)
|
||||
self._update_analytics("update_context", success=True)
|
||||
|
||||
# Invalidate cache
|
||||
if self._cache:
|
||||
cache_key = f"get_context:{context_id}"
|
||||
self._cache.remove(cache_key)
|
||||
|
||||
return updated_context
|
||||
else:
|
||||
self._update_analytics("update_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("update_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to update context: {str(e)}")
|
||||
|
||||
async def delete_context(self, context_id: int) -> bool:
|
||||
"""
|
||||
Delete a context asynchronously.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
|
||||
Returns:
|
||||
True if deletion was successful
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.delete(f"/api/v1/contexts/{context_id}")
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("delete_context", success=True)
|
||||
|
||||
# Invalidate cache
|
||||
if self._cache:
|
||||
cache_key = f"get_context:{context_id}"
|
||||
self._cache.remove(cache_key)
|
||||
|
||||
return True
|
||||
else:
|
||||
self._update_analytics("delete_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("delete_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to delete context: {str(e)}")
|
||||
|
||||
@rate_limited(requests_per_second=10.0)
|
||||
async def search_contexts(self,
|
||||
query: str,
|
||||
options: Optional[SearchOptions] = None) -> List[SearchResult]:
|
||||
"""
|
||||
Search contexts asynchronously using various search methods.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
options: Search configuration options
|
||||
|
||||
Returns:
|
||||
List of search results ordered by relevance
|
||||
"""
|
||||
search_opts = options or SearchOptions()
|
||||
|
||||
request_data = {
|
||||
"query": query,
|
||||
**search_opts.to_request_dict()
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
"/api/v1/search",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
results = []
|
||||
|
||||
for result_data in data:
|
||||
context = Context(**result_data["context"])
|
||||
search_result = SearchResult(
|
||||
context=context,
|
||||
score=result_data["score"],
|
||||
explanation=result_data.get("explanation"),
|
||||
highlights=result_data.get("highlights", [])
|
||||
)
|
||||
results.append(search_result)
|
||||
|
||||
self._update_analytics("search_contexts", success=True)
|
||||
return sorted(results, key=lambda x: x.score, reverse=True)
|
||||
else:
|
||||
self._update_analytics("search_contexts", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("search_contexts", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Search failed: {str(e)}")
|
||||
|
||||
async def batch_create_contexts(self, contexts: List[Context]) -> BatchResult:
|
||||
"""
|
||||
Create multiple contexts in a single batch operation asynchronously.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to create
|
||||
|
||||
Returns:
|
||||
Batch operation results
|
||||
"""
|
||||
request_data = {
|
||||
"contexts": [ctx.to_create_dict() for ctx in contexts]
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
"/api/v1/contexts/batch",
|
||||
json=request_data,
|
||||
timeout=self.config.timeout * 3 # Extended timeout for batch ops
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
|
||||
result = BatchResult(
|
||||
success_count=data["success_count"],
|
||||
error_count=data["error_count"],
|
||||
total_items=data["total_items"],
|
||||
successful_items=data.get("created_ids", []),
|
||||
failed_items=data.get("errors", []),
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
self._update_analytics("batch_create", success=True)
|
||||
return result
|
||||
else:
|
||||
self._update_analytics("batch_create", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
execution_time = time.time() - start_time
|
||||
self._update_analytics("batch_create", success=False, error=str(e))
|
||||
|
||||
return BatchResult(
|
||||
success_count=0,
|
||||
error_count=len(contexts),
|
||||
total_items=len(contexts),
|
||||
successful_items=[],
|
||||
failed_items=[{"error": str(e)}],
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive system statistics asynchronously.
|
||||
|
||||
Returns:
|
||||
System statistics and metrics
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.get("/api/v1/stats")
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("get_statistics", success=True)
|
||||
return response.json()
|
||||
else:
|
||||
self._update_analytics("get_statistics", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("get_statistics", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to get statistics: {str(e)}")
|
||||
|
||||
async def iterate_contexts(self,
|
||||
filter_opts: Optional[ContextFilter] = None,
|
||||
page_size: int = 100) -> AsyncIterator[Context]:
|
||||
"""
|
||||
Asynchronously iterate through all contexts with automatic pagination.
|
||||
|
||||
Args:
|
||||
filter_opts: Context filtering options
|
||||
page_size: Number of contexts per page
|
||||
|
||||
Yields:
|
||||
Context objects
|
||||
"""
|
||||
page = 1
|
||||
|
||||
while True:
|
||||
pagination = PaginationOptions(page=page, page_size=page_size)
|
||||
contexts = await self.list_contexts(filter_opts, pagination)
|
||||
|
||||
if not contexts:
|
||||
break
|
||||
|
||||
for context in contexts:
|
||||
yield context
|
||||
|
||||
# If we got fewer contexts than requested, we've reached the end
|
||||
if len(contexts) < page_size:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
async def connect_websocket(self,
|
||||
path_prefix: Optional[str] = None,
|
||||
event_types: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Connect to WebSocket for real-time updates.
|
||||
|
||||
Args:
|
||||
path_prefix: Filter events by path prefix
|
||||
event_types: List of event types to subscribe to
|
||||
|
||||
Raises:
|
||||
HCFSStreamError: If WebSocket connection fails
|
||||
"""
|
||||
if self.websocket and not self.websocket.closed:
|
||||
return # Already connected
|
||||
|
||||
# Convert HTTP URL to WebSocket URL
|
||||
ws_url = self.config.base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url += "/ws"
|
||||
|
||||
# Add authentication headers
|
||||
headers = {}
|
||||
if self.config.api_key:
|
||||
headers["X-API-Key"] = self.config.api_key
|
||||
elif self.config.jwt_token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwt_token}"
|
||||
|
||||
try:
|
||||
self.websocket = await websockets.connect(
|
||||
ws_url,
|
||||
extra_headers=headers,
|
||||
ping_interval=self.config.websocket.ping_interval,
|
||||
ping_timeout=self.config.websocket.ping_timeout
|
||||
)
|
||||
|
||||
# Send subscription request
|
||||
subscription = {
|
||||
"type": "subscribe",
|
||||
"data": {
|
||||
"path_prefix": path_prefix,
|
||||
"event_types": event_types or ["created", "updated", "deleted"],
|
||||
"filters": {}
|
||||
}
|
||||
}
|
||||
|
||||
await self.websocket.send(json.dumps(subscription))
|
||||
|
||||
# Start listening task
|
||||
self._websocket_task = asyncio.create_task(self._websocket_listener())
|
||||
|
||||
except (WebSocketException, ConnectionClosed) as e:
|
||||
raise HCFSStreamError(f"Failed to connect to WebSocket: {str(e)}")
|
||||
|
||||
async def disconnect_websocket(self) -> None:
|
||||
"""Disconnect from WebSocket."""
|
||||
if self._websocket_task:
|
||||
self._websocket_task.cancel()
|
||||
try:
|
||||
await self._websocket_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._websocket_task = None
|
||||
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
|
||||
def add_event_listener(self, listener: Callable[[StreamEvent], None]) -> None:
|
||||
"""
|
||||
Add an event listener for WebSocket events.
|
||||
|
||||
Args:
|
||||
listener: Function to call when events are received
|
||||
"""
|
||||
self._websocket_listeners.append(listener)
|
||||
|
||||
def remove_event_listener(self, listener: Callable[[StreamEvent], None]) -> None:
|
||||
"""
|
||||
Remove an event listener.
|
||||
|
||||
Args:
|
||||
listener: Function to remove
|
||||
"""
|
||||
if listener in self._websocket_listeners:
|
||||
self._websocket_listeners.remove(listener)
|
||||
|
||||
async def _websocket_listener(self) -> None:
|
||||
"""Internal WebSocket message listener."""
|
||||
try:
|
||||
async for message in self.websocket:
|
||||
try:
|
||||
data = json.loads(message)
|
||||
event = StreamEvent(
|
||||
event_type=data.get("type", "unknown"),
|
||||
data=data.get("data", {}),
|
||||
timestamp=datetime.fromisoformat(data.get("timestamp", datetime.utcnow().isoformat())),
|
||||
context_id=data.get("context_id"),
|
||||
path=data.get("path")
|
||||
)
|
||||
|
||||
# Notify all listeners
|
||||
for listener in self._websocket_listeners:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(listener):
|
||||
await listener(event)
|
||||
else:
|
||||
listener(event)
|
||||
except Exception:
|
||||
pass # Don't let listener errors break the connection
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass # Ignore malformed messages
|
||||
|
||||
except (WebSocketException, ConnectionClosed):
|
||||
# Connection was closed, attempt reconnection if configured
|
||||
if self.config.websocket.auto_reconnect:
|
||||
await self._attempt_websocket_reconnection()
|
||||
|
||||
async def _attempt_websocket_reconnection(self) -> None:
|
||||
"""Attempt to reconnect WebSocket with backoff."""
|
||||
for attempt in range(self.config.websocket.max_reconnect_attempts):
|
||||
try:
|
||||
await asyncio.sleep(self.config.websocket.reconnect_interval)
|
||||
await self.connect_websocket()
|
||||
return # Successfully reconnected
|
||||
except Exception:
|
||||
continue # Try again
|
||||
|
||||
# All reconnection attempts failed
|
||||
raise HCFSStreamError("Failed to reconnect WebSocket after multiple attempts")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached data."""
|
||||
if self._cache:
|
||||
self._cache.clear()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._cache:
|
||||
stats = self._cache.stats()
|
||||
self.analytics.cache_stats = stats
|
||||
return stats
|
||||
return {}
|
||||
|
||||
def get_analytics(self) -> AnalyticsData:
|
||||
"""
|
||||
Get client analytics and usage statistics.
|
||||
|
||||
Returns:
|
||||
Analytics data including operation counts and performance metrics
|
||||
"""
|
||||
# Update cache stats
|
||||
if self._cache:
|
||||
self.analytics.cache_stats = self._cache.stats()
|
||||
|
||||
return self.analytics
|
||||
|
||||
def _update_analytics(self, operation: str, success: bool, error: Optional[str] = None):
|
||||
"""Update internal analytics tracking."""
|
||||
self.analytics.operation_count[operation] = self.analytics.operation_count.get(operation, 0) + 1
|
||||
|
||||
if not success:
|
||||
error_key = error or "unknown_error"
|
||||
self.analytics.error_stats[error_key] = self.analytics.error_stats.get(error_key, 0) + 1
|
||||
|
||||
async def close(self):
|
||||
"""Close the client and cleanup resources."""
|
||||
await self.disconnect_websocket()
|
||||
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
539
hcfs-python/hcfs/sdk/client.py
Normal file
539
hcfs-python/hcfs/sdk/client.py
Normal file
@@ -0,0 +1,539 @@
|
||||
"""
|
||||
HCFS Synchronous Client
|
||||
|
||||
High-level synchronous client for HCFS API operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any, Iterator
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from .models import (
|
||||
Context, SearchResult, ContextFilter, PaginationOptions,
|
||||
SearchOptions, ClientConfig, AnalyticsData, BatchResult
|
||||
)
|
||||
from .exceptions import (
|
||||
HCFSError, HCFSConnectionError, HCFSAuthenticationError,
|
||||
HCFSNotFoundError, HCFSValidationError, handle_api_error
|
||||
)
|
||||
from .utils import MemoryCache, validate_path, normalize_path
|
||||
from .decorators import cached_context, retry_on_failure, rate_limited
|
||||
|
||||
|
||||
class HCFSClient:
|
||||
"""
|
||||
Synchronous HCFS API client with caching and retry capabilities.
|
||||
|
||||
This client provides a high-level interface for interacting with the HCFS API,
|
||||
including context management, search operations, and batch processing.
|
||||
|
||||
Example:
|
||||
>>> from hcfs.sdk import HCFSClient, Context
|
||||
>>>
|
||||
>>> # Initialize client
|
||||
>>> client = HCFSClient(
|
||||
... base_url="https://api.hcfs.example.com",
|
||||
... api_key="your-api-key"
|
||||
... )
|
||||
>>>
|
||||
>>> # Create a context
|
||||
>>> context = Context(
|
||||
... path="/docs/readme",
|
||||
... content="This is a README file",
|
||||
... summary="Project documentation"
|
||||
... )
|
||||
>>> created = client.create_context(context)
|
||||
>>>
|
||||
>>> # Search contexts
|
||||
>>> results = client.search_contexts("README documentation")
|
||||
>>> for result in results:
|
||||
... print(f"Found: {result.context.path} (score: {result.score})")
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClientConfig] = None, **kwargs):
|
||||
"""
|
||||
Initialize HCFS client.
|
||||
|
||||
Args:
|
||||
config: Client configuration object
|
||||
**kwargs: Configuration overrides (base_url, api_key, etc.)
|
||||
"""
|
||||
# Merge configuration
|
||||
if config:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = ClientConfig(**kwargs)
|
||||
|
||||
# Initialize session with retry strategy
|
||||
self.session = requests.Session()
|
||||
|
||||
# Configure retries
|
||||
retry_strategy = Retry(
|
||||
total=self.config.retry.max_attempts if self.config.retry.enabled else 0,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
backoff_factor=self.config.retry.base_delay,
|
||||
raise_on_status=False
|
||||
)
|
||||
|
||||
adapter = HTTPAdapter(
|
||||
max_retries=retry_strategy,
|
||||
pool_connections=self.config.max_connections,
|
||||
pool_maxsize=self.config.max_keepalive_connections
|
||||
)
|
||||
|
||||
self.session.mount("http://", adapter)
|
||||
self.session.mount("https://", adapter)
|
||||
|
||||
# Set headers
|
||||
self.session.headers.update({
|
||||
"User-Agent": self.config.user_agent,
|
||||
"Content-Type": "application/json"
|
||||
})
|
||||
|
||||
if self.config.api_key:
|
||||
self.session.headers["X-API-Key"] = self.config.api_key
|
||||
elif self.config.jwt_token:
|
||||
self.session.headers["Authorization"] = f"Bearer {self.config.jwt_token}"
|
||||
|
||||
# Initialize cache
|
||||
self._cache = MemoryCache(
|
||||
max_size=self.config.cache.max_size,
|
||||
strategy=self.config.cache.strategy,
|
||||
ttl_seconds=self.config.cache.ttl_seconds
|
||||
) if self.config.cache.enabled else None
|
||||
|
||||
# Analytics
|
||||
self.analytics = AnalyticsData()
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Check API health status.
|
||||
|
||||
Returns:
|
||||
Health status information
|
||||
|
||||
Raises:
|
||||
HCFSConnectionError: If health check fails
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(
|
||||
f"{self.config.base_url}/health",
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("health_check", success=True)
|
||||
return response.json()
|
||||
else:
|
||||
self._update_analytics("health_check", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("health_check", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Health check failed: {str(e)}")
|
||||
|
||||
@cached_context()
|
||||
@retry_on_failure()
|
||||
def create_context(self, context: Context) -> Context:
|
||||
"""
|
||||
Create a new context.
|
||||
|
||||
Args:
|
||||
context: Context object to create
|
||||
|
||||
Returns:
|
||||
Created context with assigned ID
|
||||
|
||||
Raises:
|
||||
HCFSValidationError: If context data is invalid
|
||||
HCFSError: If creation fails
|
||||
"""
|
||||
if not validate_path(context.path):
|
||||
raise HCFSValidationError(f"Invalid context path: {context.path}")
|
||||
|
||||
context.path = normalize_path(context.path)
|
||||
|
||||
try:
|
||||
response = self.session.post(
|
||||
f"{self.config.base_url}/api/v1/contexts",
|
||||
json=context.to_create_dict(),
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
created_context = Context(**data)
|
||||
self._update_analytics("create_context", success=True)
|
||||
return created_context
|
||||
else:
|
||||
self._update_analytics("create_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("create_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to create context: {str(e)}")
|
||||
|
||||
@cached_context()
|
||||
def get_context(self, context_id: int) -> Context:
|
||||
"""
|
||||
Retrieve a context by ID.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
|
||||
Returns:
|
||||
Context object
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(
|
||||
f"{self.config.base_url}/api/v1/contexts/{context_id}",
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
context = Context(**data)
|
||||
self._update_analytics("get_context", success=True)
|
||||
return context
|
||||
else:
|
||||
self._update_analytics("get_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("get_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to get context: {str(e)}")
|
||||
|
||||
def list_contexts(self,
|
||||
filter_opts: Optional[ContextFilter] = None,
|
||||
pagination: Optional[PaginationOptions] = None) -> List[Context]:
|
||||
"""
|
||||
List contexts with filtering and pagination.
|
||||
|
||||
Args:
|
||||
filter_opts: Context filtering options
|
||||
pagination: Pagination configuration
|
||||
|
||||
Returns:
|
||||
List of contexts
|
||||
"""
|
||||
params = {}
|
||||
|
||||
if filter_opts:
|
||||
params.update(filter_opts.to_query_params())
|
||||
|
||||
if pagination:
|
||||
params.update(pagination.to_query_params())
|
||||
|
||||
try:
|
||||
response = self.session.get(
|
||||
f"{self.config.base_url}/api/v1/contexts",
|
||||
params=params,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
contexts = [Context(**ctx_data) for ctx_data in data]
|
||||
self._update_analytics("list_contexts", success=True)
|
||||
return contexts
|
||||
else:
|
||||
self._update_analytics("list_contexts", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("list_contexts", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to list contexts: {str(e)}")
|
||||
|
||||
def update_context(self, context_id: int, updates: Dict[str, Any]) -> Context:
|
||||
"""
|
||||
Update an existing context.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
updates: Fields to update
|
||||
|
||||
Returns:
|
||||
Updated context
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
HCFSValidationError: If update data is invalid
|
||||
"""
|
||||
try:
|
||||
response = self.session.put(
|
||||
f"{self.config.base_url}/api/v1/contexts/{context_id}",
|
||||
json=updates,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
updated_context = Context(**data)
|
||||
self._update_analytics("update_context", success=True)
|
||||
|
||||
# Invalidate cache
|
||||
if self._cache:
|
||||
cache_key = f"get_context:{context_id}"
|
||||
self._cache.remove(cache_key)
|
||||
|
||||
return updated_context
|
||||
else:
|
||||
self._update_analytics("update_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("update_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to update context: {str(e)}")
|
||||
|
||||
def delete_context(self, context_id: int) -> bool:
|
||||
"""
|
||||
Delete a context.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
|
||||
Returns:
|
||||
True if deletion was successful
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
"""
|
||||
try:
|
||||
response = self.session.delete(
|
||||
f"{self.config.base_url}/api/v1/contexts/{context_id}",
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("delete_context", success=True)
|
||||
|
||||
# Invalidate cache
|
||||
if self._cache:
|
||||
cache_key = f"get_context:{context_id}"
|
||||
self._cache.remove(cache_key)
|
||||
|
||||
return True
|
||||
else:
|
||||
self._update_analytics("delete_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("delete_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to delete context: {str(e)}")
|
||||
|
||||
@rate_limited(requests_per_second=10.0)
|
||||
def search_contexts(self,
|
||||
query: str,
|
||||
options: Optional[SearchOptions] = None) -> List[SearchResult]:
|
||||
"""
|
||||
Search contexts using various search methods.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
options: Search configuration options
|
||||
|
||||
Returns:
|
||||
List of search results ordered by relevance
|
||||
"""
|
||||
search_opts = options or SearchOptions()
|
||||
|
||||
request_data = {
|
||||
"query": query,
|
||||
**search_opts.to_request_dict()
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.session.post(
|
||||
f"{self.config.base_url}/api/v1/search",
|
||||
json=request_data,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
results = []
|
||||
|
||||
for result_data in data:
|
||||
context = Context(**result_data["context"])
|
||||
search_result = SearchResult(
|
||||
context=context,
|
||||
score=result_data["score"],
|
||||
explanation=result_data.get("explanation"),
|
||||
highlights=result_data.get("highlights", [])
|
||||
)
|
||||
results.append(search_result)
|
||||
|
||||
self._update_analytics("search_contexts", success=True)
|
||||
return sorted(results, key=lambda x: x.score, reverse=True)
|
||||
else:
|
||||
self._update_analytics("search_contexts", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("search_contexts", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Search failed: {str(e)}")
|
||||
|
||||
def batch_create_contexts(self, contexts: List[Context]) -> BatchResult:
|
||||
"""
|
||||
Create multiple contexts in a single batch operation.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to create
|
||||
|
||||
Returns:
|
||||
Batch operation results
|
||||
"""
|
||||
request_data = {
|
||||
"contexts": [ctx.to_create_dict() for ctx in contexts]
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = self.session.post(
|
||||
f"{self.config.base_url}/api/v1/contexts/batch",
|
||||
json=request_data,
|
||||
timeout=self.config.timeout * 3 # Extended timeout for batch ops
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
|
||||
result = BatchResult(
|
||||
success_count=data["success_count"],
|
||||
error_count=data["error_count"],
|
||||
total_items=data["total_items"],
|
||||
successful_items=data.get("created_ids", []),
|
||||
failed_items=data.get("errors", []),
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
self._update_analytics("batch_create", success=True)
|
||||
return result
|
||||
else:
|
||||
self._update_analytics("batch_create", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
execution_time = time.time() - start_time
|
||||
self._update_analytics("batch_create", success=False, error=str(e))
|
||||
|
||||
return BatchResult(
|
||||
success_count=0,
|
||||
error_count=len(contexts),
|
||||
total_items=len(contexts),
|
||||
successful_items=[],
|
||||
failed_items=[{"error": str(e)}],
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive system statistics.
|
||||
|
||||
Returns:
|
||||
System statistics and metrics
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(
|
||||
f"{self.config.base_url}/api/v1/stats",
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("get_statistics", success=True)
|
||||
return response.json()
|
||||
else:
|
||||
self._update_analytics("get_statistics", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("get_statistics", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to get statistics: {str(e)}")
|
||||
|
||||
def iterate_contexts(self,
|
||||
filter_opts: Optional[ContextFilter] = None,
|
||||
page_size: int = 100) -> Iterator[Context]:
|
||||
"""
|
||||
Iterate through all contexts with automatic pagination.
|
||||
|
||||
Args:
|
||||
filter_opts: Context filtering options
|
||||
page_size: Number of contexts per page
|
||||
|
||||
Yields:
|
||||
Context objects
|
||||
"""
|
||||
page = 1
|
||||
|
||||
while True:
|
||||
pagination = PaginationOptions(page=page, page_size=page_size)
|
||||
contexts = self.list_contexts(filter_opts, pagination)
|
||||
|
||||
if not contexts:
|
||||
break
|
||||
|
||||
for context in contexts:
|
||||
yield context
|
||||
|
||||
# If we got fewer contexts than requested, we've reached the end
|
||||
if len(contexts) < page_size:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached data."""
|
||||
if self._cache:
|
||||
self._cache.clear()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._cache:
|
||||
stats = self._cache.stats()
|
||||
self.analytics.cache_stats = stats
|
||||
return stats
|
||||
return {}
|
||||
|
||||
def get_analytics(self) -> AnalyticsData:
|
||||
"""
|
||||
Get client analytics and usage statistics.
|
||||
|
||||
Returns:
|
||||
Analytics data including operation counts and performance metrics
|
||||
"""
|
||||
# Update cache stats
|
||||
if self._cache:
|
||||
self.analytics.cache_stats = self._cache.stats()
|
||||
|
||||
return self.analytics
|
||||
|
||||
def _update_analytics(self, operation: str, success: bool, error: Optional[str] = None):
|
||||
"""Update internal analytics tracking."""
|
||||
self.analytics.operation_count[operation] = self.analytics.operation_count.get(operation, 0) + 1
|
||||
|
||||
if not success:
|
||||
error_key = error or "unknown_error"
|
||||
self.analytics.error_stats[error_key] = self.analytics.error_stats.get(error_key, 0) + 1
|
||||
|
||||
def close(self):
|
||||
"""Close the client and cleanup resources."""
|
||||
self.session.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
472
hcfs-python/hcfs/sdk/decorators.py
Normal file
472
hcfs-python/hcfs/sdk/decorators.py
Normal file
@@ -0,0 +1,472 @@
|
||||
"""
|
||||
HCFS SDK Decorators
|
||||
|
||||
Decorators for caching, retry logic, rate limiting, and context management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
from functools import wraps
|
||||
from typing import Optional, Dict, Any, Callable, List
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .models import RetryConfig, RetryStrategy, CacheConfig
|
||||
from .exceptions import HCFSError, HCFSRateLimitError, HCFSTimeoutError
|
||||
from .utils import MemoryCache, cache_key
|
||||
|
||||
|
||||
def cached_context(cache_config: Optional[CacheConfig] = None, key_func: Optional[Callable] = None):
|
||||
"""
|
||||
Decorator to cache context-related operations.
|
||||
|
||||
Args:
|
||||
cache_config: Cache configuration
|
||||
key_func: Custom function to generate cache keys
|
||||
"""
|
||||
config = cache_config or CacheConfig()
|
||||
cache = MemoryCache(
|
||||
max_size=config.max_size,
|
||||
strategy=config.strategy,
|
||||
ttl_seconds=config.ttl_seconds
|
||||
)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not config.enabled:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Generate cache key
|
||||
if key_func:
|
||||
key = key_func(*args, **kwargs)
|
||||
else:
|
||||
key = cache_key(func.__name__, *args, **kwargs)
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = cache.get(key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Execute function and cache result
|
||||
result = await func(*args, **kwargs)
|
||||
cache.put(key, result)
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
if not config.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Generate cache key
|
||||
if key_func:
|
||||
key = key_func(*args, **kwargs)
|
||||
else:
|
||||
key = cache_key(func.__name__, *args, **kwargs)
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = cache.get(key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Execute function and cache result
|
||||
result = func(*args, **kwargs)
|
||||
cache.put(key, result)
|
||||
return result
|
||||
|
||||
# Attach cache management methods
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
async_wrapper.cache = cache
|
||||
async_wrapper.clear_cache = cache.clear
|
||||
async_wrapper.cache_stats = cache.stats
|
||||
return async_wrapper
|
||||
else:
|
||||
sync_wrapper.cache = cache
|
||||
sync_wrapper.clear_cache = cache.clear
|
||||
sync_wrapper.cache_stats = cache.stats
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def retry_on_failure(retry_config: Optional[RetryConfig] = None):
|
||||
"""
|
||||
Decorator to retry failed operations with configurable strategies.
|
||||
|
||||
Args:
|
||||
retry_config: Retry configuration
|
||||
"""
|
||||
config = retry_config or RetryConfig()
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not config.enabled:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(config.max_attempts):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not _should_retry_exception(e, config):
|
||||
raise e
|
||||
|
||||
# Don't delay on the last attempt
|
||||
if attempt < config.max_attempts - 1:
|
||||
delay = _calculate_delay(attempt, config)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed, raise the last exception
|
||||
raise last_exception
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
if not config.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(config.max_attempts):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not _should_retry_exception(e, config):
|
||||
raise e
|
||||
|
||||
# Don't delay on the last attempt
|
||||
if attempt < config.max_attempts - 1:
|
||||
delay = _calculate_delay(attempt, config)
|
||||
time.sleep(delay)
|
||||
|
||||
# All attempts failed, raise the last exception
|
||||
raise last_exception
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _should_retry_exception(exception: Exception, config: RetryConfig) -> bool:
|
||||
"""Check if an exception should trigger a retry."""
|
||||
# Check for timeout errors
|
||||
if isinstance(exception, HCFSTimeoutError) and config.retry_on_timeout:
|
||||
return True
|
||||
|
||||
# Check for rate limit errors
|
||||
if isinstance(exception, HCFSRateLimitError):
|
||||
return True
|
||||
|
||||
# Check for HTTP status codes (if it's an HTTP-related error)
|
||||
if hasattr(exception, 'status_code'):
|
||||
return exception.status_code in config.retry_on_status
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _calculate_delay(attempt: int, config: RetryConfig) -> float:
|
||||
"""Calculate delay for retry attempt."""
|
||||
if config.strategy == RetryStrategy.EXPONENTIAL_BACKOFF:
|
||||
delay = config.base_delay * (config.backoff_multiplier ** attempt)
|
||||
elif config.strategy == RetryStrategy.LINEAR_BACKOFF:
|
||||
delay = config.base_delay + (config.base_delay * attempt)
|
||||
elif config.strategy == RetryStrategy.FIBONACCI:
|
||||
delay = config.base_delay * _fibonacci(attempt + 1)
|
||||
else: # CONSTANT_DELAY
|
||||
delay = config.base_delay
|
||||
|
||||
# Apply maximum delay limit
|
||||
delay = min(delay, config.max_delay)
|
||||
|
||||
# Add jitter if enabled
|
||||
if config.jitter:
|
||||
jitter_range = delay * 0.1 # 10% jitter
|
||||
delay += random.uniform(-jitter_range, jitter_range)
|
||||
|
||||
return max(0, delay)
|
||||
|
||||
|
||||
def _fibonacci(n: int) -> int:
|
||||
"""Calculate nth Fibonacci number."""
|
||||
if n <= 1:
|
||||
return n
|
||||
a, b = 0, 1
|
||||
for _ in range(2, n + 1):
|
||||
a, b = b, a + b
|
||||
return b
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter."""
|
||||
|
||||
def __init__(self, rate: float, burst: int = 1):
|
||||
self.rate = rate # tokens per second
|
||||
self.burst = burst # maximum tokens in bucket
|
||||
self.tokens = burst
|
||||
self.last_update = time.time()
|
||||
|
||||
def acquire(self, tokens: int = 1) -> bool:
|
||||
"""Try to acquire tokens from the bucket."""
|
||||
now = time.time()
|
||||
|
||||
# Add tokens based on elapsed time
|
||||
elapsed = now - self.last_update
|
||||
self.tokens = min(self.burst, self.tokens + elapsed * self.rate)
|
||||
self.last_update = now
|
||||
|
||||
# Check if we have enough tokens
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def time_until_tokens(self, tokens: int = 1) -> float:
|
||||
"""Calculate time until enough tokens are available."""
|
||||
if self.tokens >= tokens:
|
||||
return 0.0
|
||||
|
||||
needed_tokens = tokens - self.tokens
|
||||
return needed_tokens / self.rate
|
||||
|
||||
|
||||
def rate_limited(requests_per_second: float, burst: int = 1):
|
||||
"""
|
||||
Decorator to rate limit function calls.
|
||||
|
||||
Args:
|
||||
requests_per_second: Rate limit (requests per second)
|
||||
burst: Maximum burst size
|
||||
"""
|
||||
limiter = RateLimiter(requests_per_second, burst)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not limiter.acquire():
|
||||
wait_time = limiter.time_until_tokens()
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
if not limiter.acquire():
|
||||
raise HCFSRateLimitError()
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
if not limiter.acquire():
|
||||
wait_time = limiter.time_until_tokens()
|
||||
time.sleep(wait_time)
|
||||
|
||||
if not limiter.acquire():
|
||||
raise HCFSRateLimitError()
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context manager for HCFS operations with automatic cleanup."""
|
||||
|
||||
def __init__(self, client, auto_cleanup: bool = True):
|
||||
self.client = client
|
||||
self.auto_cleanup = auto_cleanup
|
||||
self.created_contexts: List[int] = []
|
||||
self.temp_files: List[str] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.auto_cleanup:
|
||||
self.cleanup()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.auto_cleanup:
|
||||
await self.cleanup_async()
|
||||
|
||||
def track_context(self, context_id: int):
|
||||
"""Track a created context for cleanup."""
|
||||
self.created_contexts.append(context_id)
|
||||
|
||||
def track_file(self, file_path: str):
|
||||
"""Track a temporary file for cleanup."""
|
||||
self.temp_files.append(file_path)
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup tracked resources synchronously."""
|
||||
# Cleanup contexts
|
||||
for context_id in self.created_contexts:
|
||||
try:
|
||||
self.client.delete_context(context_id)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
# Cleanup files
|
||||
import os
|
||||
for file_path in self.temp_files:
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
self.created_contexts.clear()
|
||||
self.temp_files.clear()
|
||||
|
||||
async def cleanup_async(self):
|
||||
"""Cleanup tracked resources asynchronously."""
|
||||
# Cleanup contexts
|
||||
for context_id in self.created_contexts:
|
||||
try:
|
||||
await self.client.delete_context(context_id)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
# Cleanup files
|
||||
import os
|
||||
for file_path in self.temp_files:
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
self.created_contexts.clear()
|
||||
self.temp_files.clear()
|
||||
|
||||
|
||||
def context_manager(auto_cleanup: bool = True):
|
||||
"""
|
||||
Decorator to automatically manage context lifecycle.
|
||||
|
||||
Args:
|
||||
auto_cleanup: Whether to automatically cleanup contexts on exit
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# Assume first argument is the client
|
||||
client = args[0] if args else None
|
||||
if not client:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
async with ContextManager(client, auto_cleanup) as ctx_mgr:
|
||||
# Inject context manager into kwargs
|
||||
kwargs['_context_manager'] = ctx_mgr
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# Assume first argument is the client
|
||||
client = args[0] if args else None
|
||||
if not client:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with ContextManager(client, auto_cleanup) as ctx_mgr:
|
||||
# Inject context manager into kwargs
|
||||
kwargs['_context_manager'] = ctx_mgr
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def performance_monitor(track_timing: bool = True, track_memory: bool = False):
|
||||
"""
|
||||
Decorator to monitor function performance.
|
||||
|
||||
Args:
|
||||
track_timing: Whether to track execution timing
|
||||
track_memory: Whether to track memory usage
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.time() if track_timing else None
|
||||
start_memory = None
|
||||
|
||||
if track_memory:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
start_memory = process.memory_info().rss
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# Record performance metrics
|
||||
if track_timing:
|
||||
execution_time = time.time() - start_time
|
||||
# Could store or log timing data here
|
||||
|
||||
if track_memory and start_memory:
|
||||
end_memory = process.memory_info().rss
|
||||
memory_delta = end_memory - start_memory
|
||||
# Could store or log memory usage here
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Record error metrics
|
||||
raise e
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.time() if track_timing else None
|
||||
start_memory = None
|
||||
|
||||
if track_memory:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
start_memory = process.memory_info().rss
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Record performance metrics
|
||||
if track_timing:
|
||||
execution_time = time.time() - start_time
|
||||
# Could store or log timing data here
|
||||
|
||||
if track_memory and start_memory:
|
||||
end_memory = process.memory_info().rss
|
||||
memory_delta = end_memory - start_memory
|
||||
# Could store or log memory usage here
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Record error metrics
|
||||
raise e
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
184
hcfs-python/hcfs/sdk/exceptions.py
Normal file
184
hcfs-python/hcfs/sdk/exceptions.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
HCFS SDK Exception Classes
|
||||
|
||||
Comprehensive exception hierarchy for error handling.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class HCFSError(Exception):
|
||||
"""Base exception for all HCFS SDK errors."""
|
||||
|
||||
def __init__(self, message: str, error_code: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.error_code:
|
||||
return f"[{self.error_code}] {self.message}"
|
||||
return self.message
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert exception to dictionary for serialization."""
|
||||
return {
|
||||
"type": self.__class__.__name__,
|
||||
"message": self.message,
|
||||
"error_code": self.error_code,
|
||||
"details": self.details
|
||||
}
|
||||
|
||||
|
||||
class HCFSConnectionError(HCFSError):
|
||||
"""Raised when connection to HCFS API fails."""
|
||||
|
||||
def __init__(self, message: str = "Failed to connect to HCFS API", **kwargs):
|
||||
super().__init__(message, error_code="CONNECTION_FAILED", **kwargs)
|
||||
|
||||
|
||||
class HCFSAuthenticationError(HCFSError):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
def __init__(self, message: str = "Authentication failed", **kwargs):
|
||||
super().__init__(message, error_code="AUTH_FAILED", **kwargs)
|
||||
|
||||
|
||||
class HCFSAuthorizationError(HCFSError):
|
||||
"""Raised when user lacks permissions for an operation."""
|
||||
|
||||
def __init__(self, message: str = "Insufficient permissions", **kwargs):
|
||||
super().__init__(message, error_code="INSUFFICIENT_PERMISSIONS", **kwargs)
|
||||
|
||||
|
||||
class HCFSNotFoundError(HCFSError):
|
||||
"""Raised when a requested resource is not found."""
|
||||
|
||||
def __init__(self, resource_type: str = "Resource", resource_id: str = "", **kwargs):
|
||||
message = f"{resource_type} not found"
|
||||
if resource_id:
|
||||
message += f": {resource_id}"
|
||||
super().__init__(message, error_code="NOT_FOUND", **kwargs)
|
||||
|
||||
|
||||
class HCFSValidationError(HCFSError):
|
||||
"""Raised when request validation fails."""
|
||||
|
||||
def __init__(self, message: str = "Request validation failed", validation_errors: Optional[list] = None, **kwargs):
|
||||
super().__init__(message, error_code="VALIDATION_FAILED", **kwargs)
|
||||
self.validation_errors = validation_errors or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
result["validation_errors"] = self.validation_errors
|
||||
return result
|
||||
|
||||
|
||||
class HCFSRateLimitError(HCFSError):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(self, retry_after: Optional[int] = None, **kwargs):
|
||||
message = "Rate limit exceeded"
|
||||
if retry_after:
|
||||
message += f". Retry after {retry_after} seconds"
|
||||
super().__init__(message, error_code="RATE_LIMIT_EXCEEDED", **kwargs)
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class HCFSServerError(HCFSError):
|
||||
"""Raised for server-side errors (5xx status codes)."""
|
||||
|
||||
def __init__(self, message: str = "Internal server error", status_code: Optional[int] = None, **kwargs):
|
||||
super().__init__(message, error_code="SERVER_ERROR", **kwargs)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class HCFSTimeoutError(HCFSError):
|
||||
"""Raised when a request times out."""
|
||||
|
||||
def __init__(self, operation: str = "Request", timeout_seconds: Optional[float] = None, **kwargs):
|
||||
message = f"{operation} timed out"
|
||||
if timeout_seconds:
|
||||
message += f" after {timeout_seconds}s"
|
||||
super().__init__(message, error_code="TIMEOUT", **kwargs)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class HCFSCacheError(HCFSError):
|
||||
"""Raised for cache-related errors."""
|
||||
|
||||
def __init__(self, message: str = "Cache operation failed", **kwargs):
|
||||
super().__init__(message, error_code="CACHE_ERROR", **kwargs)
|
||||
|
||||
|
||||
class HCFSBatchError(HCFSError):
|
||||
"""Raised for batch operation errors."""
|
||||
|
||||
def __init__(self, message: str = "Batch operation failed", failed_items: Optional[list] = None, **kwargs):
|
||||
super().__init__(message, error_code="BATCH_ERROR", **kwargs)
|
||||
self.failed_items = failed_items or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
result["failed_items"] = self.failed_items
|
||||
return result
|
||||
|
||||
|
||||
class HCFSStreamError(HCFSError):
|
||||
"""Raised for streaming/WebSocket errors."""
|
||||
|
||||
def __init__(self, message: str = "Stream operation failed", **kwargs):
|
||||
super().__init__(message, error_code="STREAM_ERROR", **kwargs)
|
||||
|
||||
|
||||
class HCFSSearchError(HCFSError):
|
||||
"""Raised for search operation errors."""
|
||||
|
||||
def __init__(self, query: str = "", search_type: str = "", **kwargs):
|
||||
message = f"Search failed"
|
||||
if search_type:
|
||||
message += f" ({search_type})"
|
||||
if query:
|
||||
message += f": '{query}'"
|
||||
super().__init__(message, error_code="SEARCH_ERROR", **kwargs)
|
||||
self.query = query
|
||||
self.search_type = search_type
|
||||
|
||||
|
||||
def handle_api_error(response) -> None:
|
||||
"""
|
||||
Convert HTTP response errors to appropriate HCFS exceptions.
|
||||
|
||||
Args:
|
||||
response: HTTP response object
|
||||
|
||||
Raises:
|
||||
Appropriate HCFSError subclass based on status code
|
||||
"""
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
error_data = response.json() if response.content else {}
|
||||
except Exception:
|
||||
error_data = {}
|
||||
|
||||
error_message = error_data.get("error", "Unknown error")
|
||||
error_details = error_data.get("error_details", [])
|
||||
|
||||
if status_code == 400:
|
||||
raise HCFSValidationError(error_message, validation_errors=error_details)
|
||||
elif status_code == 401:
|
||||
raise HCFSAuthenticationError(error_message)
|
||||
elif status_code == 403:
|
||||
raise HCFSAuthorizationError(error_message)
|
||||
elif status_code == 404:
|
||||
raise HCFSNotFoundError("Resource", error_message)
|
||||
elif status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
retry_after = int(retry_after) if retry_after else None
|
||||
raise HCFSRateLimitError(retry_after=retry_after)
|
||||
elif 500 <= status_code < 600:
|
||||
raise HCFSServerError(error_message, status_code=status_code)
|
||||
else:
|
||||
raise HCFSError(f"HTTP {status_code}: {error_message}")
|
||||
335
hcfs-python/hcfs/sdk/models.py
Normal file
335
hcfs-python/hcfs/sdk/models.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
HCFS SDK Data Models
|
||||
|
||||
Pydantic models for SDK operations and configuration.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Union, Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class ContextStatus(str, Enum):
|
||||
"""Context status enumeration."""
|
||||
ACTIVE = "active"
|
||||
ARCHIVED = "archived"
|
||||
DELETED = "deleted"
|
||||
DRAFT = "draft"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
"""Search type enumeration."""
|
||||
SEMANTIC = "semantic"
|
||||
KEYWORD = "keyword"
|
||||
HYBRID = "hybrid"
|
||||
FUZZY = "fuzzy"
|
||||
|
||||
|
||||
class CacheStrategy(str, Enum):
|
||||
"""Cache strategy enumeration."""
|
||||
LRU = "lru"
|
||||
LFU = "lfu"
|
||||
TTL = "ttl"
|
||||
FIFO = "fifo"
|
||||
|
||||
|
||||
class RetryStrategy(str, Enum):
|
||||
"""Retry strategy enumeration."""
|
||||
EXPONENTIAL_BACKOFF = "exponential_backoff"
|
||||
LINEAR_BACKOFF = "linear_backoff"
|
||||
CONSTANT_DELAY = "constant_delay"
|
||||
FIBONACCI = "fibonacci"
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
"""Context data model for SDK operations."""
|
||||
|
||||
id: Optional[int] = None
|
||||
path: str = Field(..., description="Unique context path")
|
||||
content: str = Field(..., description="Context content")
|
||||
summary: Optional[str] = Field(None, description="Brief summary")
|
||||
author: Optional[str] = Field(None, description="Context author")
|
||||
tags: List[str] = Field(default_factory=list, description="Context tags")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
status: ContextStatus = Field(default=ContextStatus.ACTIVE, description="Context status")
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
version: int = Field(default=1, description="Context version")
|
||||
similarity_score: Optional[float] = Field(None, description="Similarity score (for search results)")
|
||||
|
||||
@validator('path')
|
||||
def validate_path(cls, v):
|
||||
if not v or not v.startswith('/'):
|
||||
raise ValueError('Path must start with /')
|
||||
return v
|
||||
|
||||
@validator('content')
|
||||
def validate_content(cls, v):
|
||||
if not v or len(v.strip()) == 0:
|
||||
raise ValueError('Content cannot be empty')
|
||||
return v
|
||||
|
||||
def to_create_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for context creation."""
|
||||
return {
|
||||
"path": self.path,
|
||||
"content": self.content,
|
||||
"summary": self.summary,
|
||||
"author": self.author,
|
||||
"tags": self.tags,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
def to_update_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for context updates (excluding read-only fields)."""
|
||||
return {
|
||||
k: v for k, v in {
|
||||
"content": self.content,
|
||||
"summary": self.summary,
|
||||
"tags": self.tags,
|
||||
"metadata": self.metadata,
|
||||
"status": self.status.value
|
||||
}.items() if v is not None
|
||||
}
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Search result model."""
|
||||
|
||||
context: Context
|
||||
score: float = Field(..., description="Relevance score")
|
||||
explanation: Optional[str] = Field(None, description="Search result explanation")
|
||||
highlights: List[str] = Field(default_factory=list, description="Highlighted text snippets")
|
||||
|
||||
def __lt__(self, other):
|
||||
"""Enable sorting by score."""
|
||||
return self.score < other.score
|
||||
|
||||
def __gt__(self, other):
|
||||
"""Enable sorting by score."""
|
||||
return self.score > other.score
|
||||
|
||||
|
||||
class ContextFilter(BaseModel):
|
||||
"""Context filtering options."""
|
||||
|
||||
path_prefix: Optional[str] = Field(None, description="Filter by path prefix")
|
||||
author: Optional[str] = Field(None, description="Filter by author")
|
||||
status: Optional[ContextStatus] = Field(None, description="Filter by status")
|
||||
tags: Optional[List[str]] = Field(None, description="Filter by tags")
|
||||
created_after: Optional[datetime] = Field(None, description="Filter by creation date")
|
||||
created_before: Optional[datetime] = Field(None, description="Filter by creation date")
|
||||
content_contains: Optional[str] = Field(None, description="Filter by content substring")
|
||||
min_content_length: Optional[int] = Field(None, description="Minimum content length")
|
||||
max_content_length: Optional[int] = Field(None, description="Maximum content length")
|
||||
|
||||
def to_query_params(self) -> Dict[str, Any]:
|
||||
"""Convert to query parameters for API requests."""
|
||||
params = {}
|
||||
|
||||
if self.path_prefix:
|
||||
params["path_prefix"] = self.path_prefix
|
||||
if self.author:
|
||||
params["author"] = self.author
|
||||
if self.status:
|
||||
params["status"] = self.status.value
|
||||
if self.created_after:
|
||||
params["created_after"] = self.created_after.isoformat()
|
||||
if self.created_before:
|
||||
params["created_before"] = self.created_before.isoformat()
|
||||
if self.content_contains:
|
||||
params["content_contains"] = self.content_contains
|
||||
if self.min_content_length is not None:
|
||||
params["min_content_length"] = self.min_content_length
|
||||
if self.max_content_length is not None:
|
||||
params["max_content_length"] = self.max_content_length
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class PaginationOptions(BaseModel):
|
||||
"""Pagination configuration."""
|
||||
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
page_size: int = Field(default=20, ge=1, le=1000, description="Items per page")
|
||||
sort_by: Optional[str] = Field(None, description="Sort field")
|
||||
sort_order: str = Field(default="desc", description="Sort order (asc/desc)")
|
||||
|
||||
@validator('sort_order')
|
||||
def validate_sort_order(cls, v):
|
||||
if v not in ['asc', 'desc']:
|
||||
raise ValueError('Sort order must be "asc" or "desc"')
|
||||
return v
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
"""Calculate offset for database queries."""
|
||||
return (self.page - 1) * self.page_size
|
||||
|
||||
def to_query_params(self) -> Dict[str, Any]:
|
||||
"""Convert to query parameters."""
|
||||
params = {
|
||||
"page": self.page,
|
||||
"page_size": self.page_size,
|
||||
"sort_order": self.sort_order
|
||||
}
|
||||
if self.sort_by:
|
||||
params["sort_by"] = self.sort_by
|
||||
return params
|
||||
|
||||
|
||||
class SearchOptions(BaseModel):
|
||||
"""Search configuration options."""
|
||||
|
||||
search_type: SearchType = Field(default=SearchType.SEMANTIC, description="Type of search")
|
||||
top_k: int = Field(default=10, ge=1, le=1000, description="Maximum results to return")
|
||||
similarity_threshold: float = Field(default=0.0, ge=0.0, le=1.0, description="Minimum similarity score")
|
||||
path_prefix: Optional[str] = Field(None, description="Search within path prefix")
|
||||
semantic_weight: float = Field(default=0.7, ge=0.0, le=1.0, description="Weight for semantic search in hybrid mode")
|
||||
include_content: bool = Field(default=True, description="Include full content in results")
|
||||
include_highlights: bool = Field(default=True, description="Include text highlights")
|
||||
max_highlights: int = Field(default=3, ge=0, le=10, description="Maximum highlight snippets")
|
||||
|
||||
def to_request_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to API request dictionary."""
|
||||
return {
|
||||
"search_type": self.search_type.value,
|
||||
"top_k": self.top_k,
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"path_prefix": self.path_prefix,
|
||||
"semantic_weight": self.semantic_weight,
|
||||
"include_content": self.include_content,
|
||||
"include_highlights": self.include_highlights
|
||||
}
|
||||
|
||||
|
||||
class CacheConfig(BaseModel):
|
||||
"""Cache configuration."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Enable caching")
|
||||
strategy: CacheStrategy = Field(default=CacheStrategy.LRU, description="Cache eviction strategy")
|
||||
max_size: int = Field(default=1000, ge=1, description="Maximum cache entries")
|
||||
ttl_seconds: Optional[int] = Field(default=3600, ge=1, description="Time-to-live in seconds")
|
||||
memory_limit_mb: Optional[int] = Field(default=100, ge=1, description="Memory limit in MB")
|
||||
persist_to_disk: bool = Field(default=False, description="Persist cache to disk")
|
||||
disk_cache_path: Optional[str] = Field(None, description="Disk cache directory")
|
||||
|
||||
@validator('ttl_seconds')
|
||||
def validate_ttl(cls, v, values):
|
||||
if values.get('strategy') == CacheStrategy.TTL and v is None:
|
||||
raise ValueError('TTL must be specified for TTL cache strategy')
|
||||
return v
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""Retry configuration for failed requests."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Enable retry logic")
|
||||
max_attempts: int = Field(default=3, ge=1, le=10, description="Maximum retry attempts")
|
||||
strategy: RetryStrategy = Field(default=RetryStrategy.EXPONENTIAL_BACKOFF, description="Retry strategy")
|
||||
base_delay: float = Field(default=1.0, ge=0.1, description="Base delay in seconds")
|
||||
max_delay: float = Field(default=60.0, ge=1.0, description="Maximum delay in seconds")
|
||||
backoff_multiplier: float = Field(default=2.0, ge=1.0, description="Backoff multiplier")
|
||||
jitter: bool = Field(default=True, description="Add random jitter to delays")
|
||||
retry_on_status: List[int] = Field(
|
||||
default_factory=lambda: [429, 500, 502, 503, 504],
|
||||
description="HTTP status codes to retry on"
|
||||
)
|
||||
retry_on_timeout: bool = Field(default=True, description="Retry on timeout errors")
|
||||
|
||||
|
||||
class WebSocketConfig(BaseModel):
|
||||
"""WebSocket connection configuration."""
|
||||
|
||||
auto_reconnect: bool = Field(default=True, description="Automatically reconnect on disconnect")
|
||||
reconnect_interval: float = Field(default=5.0, ge=1.0, description="Reconnect interval in seconds")
|
||||
max_reconnect_attempts: int = Field(default=10, ge=1, description="Maximum reconnection attempts")
|
||||
ping_interval: float = Field(default=30.0, ge=1.0, description="Ping interval in seconds")
|
||||
ping_timeout: float = Field(default=10.0, ge=1.0, description="Ping timeout in seconds")
|
||||
message_queue_size: int = Field(default=1000, ge=1, description="Maximum queued messages")
|
||||
|
||||
|
||||
class ClientConfig(BaseModel):
|
||||
"""Main client configuration."""
|
||||
|
||||
base_url: str = Field(..., description="HCFS API base URL")
|
||||
api_key: Optional[str] = Field(None, description="API key for authentication")
|
||||
jwt_token: Optional[str] = Field(None, description="JWT token for authentication")
|
||||
timeout: float = Field(default=30.0, ge=1.0, description="Request timeout in seconds")
|
||||
user_agent: str = Field(default="HCFS-SDK/2.0.0", description="User agent string")
|
||||
|
||||
# Advanced configurations
|
||||
cache: CacheConfig = Field(default_factory=CacheConfig)
|
||||
retry: RetryConfig = Field(default_factory=RetryConfig)
|
||||
websocket: WebSocketConfig = Field(default_factory=WebSocketConfig)
|
||||
|
||||
# Connection pooling
|
||||
max_connections: int = Field(default=100, ge=1, description="Maximum connection pool size")
|
||||
max_keepalive_connections: int = Field(default=20, ge=1, description="Maximum keep-alive connections")
|
||||
|
||||
@validator('base_url')
|
||||
def validate_base_url(cls, v):
|
||||
if not v.startswith(('http://', 'https://')):
|
||||
raise ValueError('Base URL must start with http:// or https://')
|
||||
return v.rstrip('/')
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchResult:
|
||||
"""Result of a batch operation."""
|
||||
|
||||
success_count: int
|
||||
error_count: int
|
||||
total_items: int
|
||||
successful_items: List[Any]
|
||||
failed_items: List[Dict[str, Any]]
|
||||
execution_time: float
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate success rate."""
|
||||
return self.success_count / self.total_items if self.total_items > 0 else 0.0
|
||||
|
||||
@property
|
||||
def has_errors(self) -> bool:
|
||||
"""Check if there were any errors."""
|
||||
return self.error_count > 0
|
||||
|
||||
|
||||
class StreamEvent(BaseModel):
|
||||
"""WebSocket stream event."""
|
||||
|
||||
event_type: str = Field(..., description="Event type (created/updated/deleted)")
|
||||
data: Dict[str, Any] = Field(..., description="Event data")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow, description="Event timestamp")
|
||||
context_id: Optional[int] = Field(None, description="Related context ID")
|
||||
path: Optional[str] = Field(None, description="Related context path")
|
||||
|
||||
def is_context_event(self) -> bool:
|
||||
"""Check if this is a context-related event."""
|
||||
return self.event_type in ['context_created', 'context_updated', 'context_deleted']
|
||||
|
||||
|
||||
class AnalyticsData(BaseModel):
|
||||
"""Analytics and usage data."""
|
||||
|
||||
operation_count: Dict[str, int] = Field(default_factory=dict, description="Operation counts")
|
||||
cache_stats: Dict[str, Any] = Field(default_factory=dict, description="Cache statistics")
|
||||
error_stats: Dict[str, int] = Field(default_factory=dict, description="Error statistics")
|
||||
performance_stats: Dict[str, float] = Field(default_factory=dict, description="Performance metrics")
|
||||
session_start: datetime = Field(default_factory=datetime.utcnow, description="Session start time")
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
hits = self.cache_stats.get('hits', 0)
|
||||
misses = self.cache_stats.get('misses', 0)
|
||||
total = hits + misses
|
||||
return hits / total if total > 0 else 0.0
|
||||
|
||||
def get_error_rate(self) -> float:
|
||||
"""Calculate overall error rate."""
|
||||
total_operations = sum(self.operation_count.values())
|
||||
total_errors = sum(self.error_stats.values())
|
||||
return total_errors / total_operations if total_operations > 0 else 0.0
|
||||
564
hcfs-python/hcfs/sdk/utils.py
Normal file
564
hcfs-python/hcfs/sdk/utils.py
Normal file
@@ -0,0 +1,564 @@
|
||||
"""
|
||||
HCFS SDK Utility Functions
|
||||
|
||||
Common utilities for text processing, caching, and data manipulation.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Tuple, Iterator, Callable, Union
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, OrderedDict
|
||||
from threading import Lock
|
||||
import asyncio
|
||||
from functools import lru_cache, wraps
|
||||
|
||||
from .models import Context, SearchResult, CacheStrategy
|
||||
from .exceptions import HCFSError, HCFSCacheError
|
||||
|
||||
|
||||
def context_similarity(context1: Context, context2: Context, method: str = "jaccard") -> float:
|
||||
"""
|
||||
Calculate similarity between two contexts.
|
||||
|
||||
Args:
|
||||
context1: First context
|
||||
context2: Second context
|
||||
method: Similarity method ("jaccard", "cosine", "levenshtein")
|
||||
|
||||
Returns:
|
||||
Similarity score between 0.0 and 1.0
|
||||
"""
|
||||
if method == "jaccard":
|
||||
return _jaccard_similarity(context1.content, context2.content)
|
||||
elif method == "cosine":
|
||||
return _cosine_similarity(context1.content, context2.content)
|
||||
elif method == "levenshtein":
|
||||
return _levenshtein_similarity(context1.content, context2.content)
|
||||
else:
|
||||
raise ValueError(f"Unknown similarity method: {method}")
|
||||
|
||||
|
||||
def _jaccard_similarity(text1: str, text2: str) -> float:
|
||||
"""Calculate Jaccard similarity between two texts."""
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
intersection = words1.intersection(words2)
|
||||
union = words1.union(words2)
|
||||
|
||||
return len(intersection) / len(union) if union else 0.0
|
||||
|
||||
|
||||
def _cosine_similarity(text1: str, text2: str) -> float:
|
||||
"""Calculate cosine similarity between two texts."""
|
||||
words1 = text1.lower().split()
|
||||
words2 = text2.lower().split()
|
||||
|
||||
# Create word frequency vectors
|
||||
all_words = set(words1 + words2)
|
||||
vector1 = [words1.count(word) for word in all_words]
|
||||
vector2 = [words2.count(word) for word in all_words]
|
||||
|
||||
# Calculate dot product and magnitudes
|
||||
dot_product = sum(a * b for a, b in zip(vector1, vector2))
|
||||
magnitude1 = math.sqrt(sum(a * a for a in vector1))
|
||||
magnitude2 = math.sqrt(sum(a * a for a in vector2))
|
||||
|
||||
if magnitude1 == 0 or magnitude2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (magnitude1 * magnitude2)
|
||||
|
||||
|
||||
def _levenshtein_similarity(text1: str, text2: str) -> float:
|
||||
"""Calculate normalized Levenshtein similarity."""
|
||||
def levenshtein_distance(s1: str, s2: str) -> int:
|
||||
if len(s1) < len(s2):
|
||||
return levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = list(range(len(s2) + 1))
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
max_len = max(len(text1), len(text2))
|
||||
if max_len == 0:
|
||||
return 1.0
|
||||
|
||||
distance = levenshtein_distance(text1.lower(), text2.lower())
|
||||
return 1.0 - (distance / max_len)
|
||||
|
||||
|
||||
def text_chunker(text: str, chunk_size: int = 512, overlap: int = 50, preserve_sentences: bool = True) -> List[str]:
|
||||
"""
|
||||
Split text into overlapping chunks.
|
||||
|
||||
Args:
|
||||
text: Text to chunk
|
||||
chunk_size: Maximum chunk size in characters
|
||||
overlap: Overlap between chunks
|
||||
preserve_sentences: Try to preserve sentence boundaries
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
if len(text) <= chunk_size:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
|
||||
if end >= len(text):
|
||||
chunks.append(text[start:])
|
||||
break
|
||||
|
||||
# Try to find a good break point
|
||||
chunk = text[start:end]
|
||||
|
||||
if preserve_sentences and '.' in chunk:
|
||||
# Find the last sentence boundary
|
||||
last_period = chunk.rfind('.')
|
||||
if last_period > chunk_size // 2: # Don't make chunks too small
|
||||
end = start + last_period + 1
|
||||
chunk = text[start:end]
|
||||
|
||||
chunks.append(chunk.strip())
|
||||
start = end - overlap
|
||||
|
||||
return [chunk for chunk in chunks if chunk.strip()]
|
||||
|
||||
|
||||
def extract_keywords(text: str, max_keywords: int = 10, min_length: int = 3) -> List[str]:
|
||||
"""
|
||||
Extract keywords from text using simple frequency analysis.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
max_keywords: Maximum number of keywords
|
||||
min_length: Minimum keyword length
|
||||
|
||||
Returns:
|
||||
List of keywords ordered by frequency
|
||||
"""
|
||||
# Simple stopwords
|
||||
stopwords = {
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
||||
'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be',
|
||||
'been', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would',
|
||||
'could', 'should', 'may', 'might', 'can', 'this', 'that', 'these',
|
||||
'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him',
|
||||
'her', 'us', 'them', 'my', 'your', 'his', 'its', 'our', 'their'
|
||||
}
|
||||
|
||||
# Extract words and count frequencies
|
||||
words = re.findall(r'\b[a-zA-Z]+\b', text.lower())
|
||||
word_freq = defaultdict(int)
|
||||
|
||||
for word in words:
|
||||
if len(word) >= min_length and word not in stopwords:
|
||||
word_freq[word] += 1
|
||||
|
||||
# Sort by frequency and return top keywords
|
||||
return sorted(word_freq.keys(), key=lambda x: word_freq[x], reverse=True)[:max_keywords]
|
||||
|
||||
|
||||
def format_content_preview(content: str, max_length: int = 200) -> str:
|
||||
"""
|
||||
Format content for preview display.
|
||||
|
||||
Args:
|
||||
content: Full content
|
||||
max_length: Maximum preview length
|
||||
|
||||
Returns:
|
||||
Formatted preview string
|
||||
"""
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
|
||||
# Try to cut at word boundary
|
||||
preview = content[:max_length]
|
||||
last_space = preview.rfind(' ')
|
||||
|
||||
if last_space > max_length * 0.8: # Don't cut too much
|
||||
preview = preview[:last_space]
|
||||
|
||||
return preview + "..."
|
||||
|
||||
|
||||
def validate_path(path: str) -> bool:
|
||||
"""
|
||||
Validate context path format.
|
||||
|
||||
Args:
|
||||
path: Path to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
if not path or not isinstance(path, str):
|
||||
return False
|
||||
|
||||
if not path.startswith('/'):
|
||||
return False
|
||||
|
||||
# Check for invalid characters
|
||||
invalid_chars = set('<>"|?*')
|
||||
if any(char in path for char in invalid_chars):
|
||||
return False
|
||||
|
||||
# Check path components
|
||||
components = path.split('/')
|
||||
for component in components[1:]: # Skip empty first component
|
||||
if not component or component in ['.', '..']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def normalize_path(path: str) -> str:
|
||||
"""
|
||||
Normalize context path.
|
||||
|
||||
Args:
|
||||
path: Path to normalize
|
||||
|
||||
Returns:
|
||||
Normalized path
|
||||
"""
|
||||
if not path.startswith('/'):
|
||||
path = '/' + path
|
||||
|
||||
# Remove duplicate slashes and normalize
|
||||
components = [c for c in path.split('/') if c]
|
||||
return '/' + '/'.join(components) if components else '/'
|
||||
|
||||
|
||||
def hash_content(content: str, algorithm: str = "sha256") -> str:
|
||||
"""
|
||||
Generate hash of content for deduplication.
|
||||
|
||||
Args:
|
||||
content: Content to hash
|
||||
algorithm: Hash algorithm
|
||||
|
||||
Returns:
|
||||
Hex digest of content hash
|
||||
"""
|
||||
if algorithm == "md5":
|
||||
hasher = hashlib.md5()
|
||||
elif algorithm == "sha1":
|
||||
hasher = hashlib.sha1()
|
||||
elif algorithm == "sha256":
|
||||
hasher = hashlib.sha256()
|
||||
else:
|
||||
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
||||
|
||||
hasher.update(content.encode('utf-8'))
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def merge_contexts(contexts: List[Context], strategy: str = "latest") -> Context:
|
||||
"""
|
||||
Merge multiple contexts into one.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to merge
|
||||
strategy: Merge strategy ("latest", "longest", "combined")
|
||||
|
||||
Returns:
|
||||
Merged context
|
||||
"""
|
||||
if not contexts:
|
||||
raise ValueError("No contexts to merge")
|
||||
|
||||
if len(contexts) == 1:
|
||||
return contexts[0]
|
||||
|
||||
if strategy == "latest":
|
||||
return max(contexts, key=lambda c: c.updated_at or c.created_at or datetime.min)
|
||||
elif strategy == "longest":
|
||||
return max(contexts, key=lambda c: len(c.content))
|
||||
elif strategy == "combined":
|
||||
# Combine content and metadata
|
||||
merged = contexts[0].copy()
|
||||
merged.content = "\n\n".join(c.content for c in contexts)
|
||||
merged.tags = list(set(tag for c in contexts for tag in c.tags))
|
||||
|
||||
# Merge metadata
|
||||
merged_metadata = {}
|
||||
for context in contexts:
|
||||
merged_metadata.update(context.metadata)
|
||||
merged.metadata = merged_metadata
|
||||
|
||||
return merged
|
||||
else:
|
||||
raise ValueError(f"Unknown merge strategy: {strategy}")
|
||||
|
||||
|
||||
class MemoryCache:
|
||||
"""Thread-safe in-memory cache with configurable eviction strategies."""
|
||||
|
||||
def __init__(self, max_size: int = 1000, strategy: CacheStrategy = CacheStrategy.LRU, ttl_seconds: Optional[int] = None):
|
||||
self.max_size = max_size
|
||||
self.strategy = strategy
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self._cache = OrderedDict()
|
||||
self._access_counts = defaultdict(int)
|
||||
self._timestamps = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache."""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
|
||||
# Check TTL
|
||||
if self.ttl_seconds and key in self._timestamps:
|
||||
if time.time() - self._timestamps[key] > self.ttl_seconds:
|
||||
self._remove(key)
|
||||
return None
|
||||
|
||||
# Update access patterns
|
||||
if self.strategy == CacheStrategy.LRU:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
elif self.strategy == CacheStrategy.LFU:
|
||||
self._access_counts[key] += 1
|
||||
|
||||
return self._cache[key]
|
||||
|
||||
def put(self, key: str, value: Any) -> None:
|
||||
"""Put value in cache."""
|
||||
with self._lock:
|
||||
# Remove if already exists
|
||||
if key in self._cache:
|
||||
self._remove(key)
|
||||
|
||||
# Evict if necessary
|
||||
while len(self._cache) >= self.max_size:
|
||||
self._evict_one()
|
||||
|
||||
# Add new entry
|
||||
self._cache[key] = value
|
||||
self._timestamps[key] = time.time()
|
||||
if self.strategy == CacheStrategy.LFU:
|
||||
self._access_counts[key] = 1
|
||||
|
||||
def remove(self, key: str) -> bool:
|
||||
"""Remove key from cache."""
|
||||
with self._lock:
|
||||
return self._remove(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache entries."""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._access_counts.clear()
|
||||
self._timestamps.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_size": self.max_size,
|
||||
"strategy": self.strategy.value,
|
||||
"ttl_seconds": self.ttl_seconds,
|
||||
"keys": list(self._cache.keys())
|
||||
}
|
||||
|
||||
def _remove(self, key: str) -> bool:
|
||||
"""Remove key without lock (internal use)."""
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._access_counts.pop(key, None)
|
||||
self._timestamps.pop(key, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _evict_one(self) -> None:
|
||||
"""Evict one item based on strategy."""
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
if self.strategy == CacheStrategy.LRU:
|
||||
# Remove least recently used (first item)
|
||||
key = next(iter(self._cache))
|
||||
self._remove(key)
|
||||
elif self.strategy == CacheStrategy.LFU:
|
||||
# Remove least frequently used
|
||||
if self._access_counts:
|
||||
key = min(self._access_counts.keys(), key=lambda k: self._access_counts[k])
|
||||
self._remove(key)
|
||||
elif self.strategy == CacheStrategy.FIFO:
|
||||
# Remove first in, first out
|
||||
key = next(iter(self._cache))
|
||||
self._remove(key)
|
||||
elif self.strategy == CacheStrategy.TTL:
|
||||
# Remove expired items first, then oldest
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, timestamp in self._timestamps.items()
|
||||
if current_time - timestamp > (self.ttl_seconds or 0)
|
||||
]
|
||||
|
||||
if expired_keys:
|
||||
self._remove(expired_keys[0])
|
||||
else:
|
||||
# Remove oldest
|
||||
key = min(self._timestamps.keys(), key=lambda k: self._timestamps[k])
|
||||
self._remove(key)
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""Utility for processing items in batches with error handling."""
|
||||
|
||||
def __init__(self, batch_size: int = 10, max_workers: Optional[int] = None):
|
||||
self.batch_size = batch_size
|
||||
self.max_workers = max_workers or min(32, (len(os.sched_getaffinity(0)) or 1) + 4)
|
||||
|
||||
async def process_async(self,
|
||||
items: List[Any],
|
||||
processor: Callable[[Any], Any],
|
||||
on_success: Optional[Callable[[Any, Any], None]] = None,
|
||||
on_error: Optional[Callable[[Any, Exception], None]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Process items asynchronously in batches.
|
||||
|
||||
Args:
|
||||
items: Items to process
|
||||
processor: Async function to process each item
|
||||
on_success: Callback for successful processing
|
||||
on_error: Callback for processing errors
|
||||
|
||||
Returns:
|
||||
Processing results summary
|
||||
"""
|
||||
results = {
|
||||
"success_count": 0,
|
||||
"error_count": 0,
|
||||
"total_items": len(items),
|
||||
"successful_items": [],
|
||||
"failed_items": [],
|
||||
"execution_time": 0
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(items), self.batch_size):
|
||||
batch = items[i:i + self.batch_size]
|
||||
|
||||
# Create tasks for this batch
|
||||
tasks = []
|
||||
for item in batch:
|
||||
task = asyncio.create_task(self._process_item_async(item, processor))
|
||||
tasks.append((item, task))
|
||||
|
||||
# Wait for batch completion
|
||||
for item, task in tasks:
|
||||
try:
|
||||
result = await task
|
||||
results["success_count"] += 1
|
||||
results["successful_items"].append(result)
|
||||
|
||||
if on_success:
|
||||
on_success(item, result)
|
||||
|
||||
except Exception as e:
|
||||
results["error_count"] += 1
|
||||
results["failed_items"].append({"item": item, "error": str(e)})
|
||||
|
||||
if on_error:
|
||||
on_error(item, e)
|
||||
|
||||
results["execution_time"] = time.time() - start_time
|
||||
return results
|
||||
|
||||
async def _process_item_async(self, item: Any, processor: Callable) -> Any:
|
||||
"""Process a single item asynchronously."""
|
||||
if asyncio.iscoroutinefunction(processor):
|
||||
return await processor(item)
|
||||
else:
|
||||
# Run synchronous processor in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, processor, item)
|
||||
|
||||
|
||||
# Global instances
|
||||
embedding_cache = MemoryCache(max_size=2000, strategy=CacheStrategy.LRU, ttl_seconds=3600)
|
||||
batch_processor = BatchProcessor(batch_size=10)
|
||||
|
||||
|
||||
def cache_key(*args, **kwargs) -> str:
|
||||
"""Generate cache key from arguments."""
|
||||
key_parts = []
|
||||
|
||||
# Add positional arguments
|
||||
for arg in args:
|
||||
if isinstance(arg, (str, int, float, bool)):
|
||||
key_parts.append(str(arg))
|
||||
else:
|
||||
key_parts.append(str(hash(str(arg))))
|
||||
|
||||
# Add keyword arguments
|
||||
for k, v in sorted(kwargs.items()):
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
key_parts.append(f"{k}={v}")
|
||||
else:
|
||||
key_parts.append(f"{k}={hash(str(v))}")
|
||||
|
||||
return ":".join(key_parts)
|
||||
|
||||
|
||||
def timing_decorator(func):
|
||||
"""Decorator to measure function execution time."""
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
execution_time = time.time() - start_time
|
||||
# Could log or store timing data here
|
||||
pass
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
execution_time = time.time() - start_time
|
||||
# Could log or store timing data here
|
||||
pass
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
Reference in New Issue
Block a user