Phase 2 build initial
This commit is contained in:
48
hcfs-python/hcfs/sdk/__init__.py
Normal file
48
hcfs-python/hcfs/sdk/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
HCFS Python Agent SDK
|
||||
|
||||
A comprehensive SDK for AI agents to interact with the HCFS API.
|
||||
Provides high-level abstractions, caching, async support, and utilities.
|
||||
"""
|
||||
|
||||
from .client import HCFSClient
|
||||
from .async_client import HCFSAsyncClient
|
||||
from .models import *
|
||||
from .exceptions import *
|
||||
from .utils import *
|
||||
from .decorators import *
|
||||
|
||||
__version__ = "2.0.0"
|
||||
__all__ = [
|
||||
# Core clients
|
||||
"HCFSClient",
|
||||
"HCFSAsyncClient",
|
||||
|
||||
# Models and data structures
|
||||
"Context",
|
||||
"SearchResult",
|
||||
"ContextFilter",
|
||||
"PaginationOptions",
|
||||
"CacheConfig",
|
||||
"RetryConfig",
|
||||
|
||||
# Exceptions
|
||||
"HCFSError",
|
||||
"HCFSConnectionError",
|
||||
"HCFSAuthenticationError",
|
||||
"HCFSNotFoundError",
|
||||
"HCFSValidationError",
|
||||
"HCFSRateLimitError",
|
||||
|
||||
# Utilities
|
||||
"context_similarity",
|
||||
"batch_processor",
|
||||
"text_chunker",
|
||||
"embedding_cache",
|
||||
|
||||
# Decorators
|
||||
"cached_context",
|
||||
"retry_on_failure",
|
||||
"rate_limited",
|
||||
"context_manager"
|
||||
]
|
||||
667
hcfs-python/hcfs/sdk/async_client.py
Normal file
667
hcfs-python/hcfs/sdk/async_client.py
Normal file
@@ -0,0 +1,667 @@
|
||||
"""
|
||||
HCFS Asynchronous Client
|
||||
|
||||
High-level asynchronous client for HCFS API operations with WebSocket support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any, AsyncIterator, Callable
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed, WebSocketException
|
||||
|
||||
from .models import (
|
||||
Context, SearchResult, ContextFilter, PaginationOptions,
|
||||
SearchOptions, ClientConfig, AnalyticsData, BatchResult, StreamEvent
|
||||
)
|
||||
from .exceptions import (
|
||||
HCFSError, HCFSConnectionError, HCFSAuthenticationError,
|
||||
HCFSNotFoundError, HCFSValidationError, HCFSStreamError, handle_api_error
|
||||
)
|
||||
from .utils import MemoryCache, validate_path, normalize_path
|
||||
from .decorators import cached_context, retry_on_failure, rate_limited
|
||||
|
||||
|
||||
class HCFSAsyncClient:
|
||||
"""
|
||||
Asynchronous HCFS API client with WebSocket streaming capabilities.
|
||||
|
||||
This client provides async/await support for all operations and includes
|
||||
real-time streaming capabilities through WebSocket connections.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from hcfs.sdk import HCFSAsyncClient, Context
|
||||
>>>
|
||||
>>> async def main():
|
||||
... async with HCFSAsyncClient(
|
||||
... base_url="https://api.hcfs.example.com",
|
||||
... api_key="your-api-key"
|
||||
... ) as client:
|
||||
... # Create a context
|
||||
... context = Context(
|
||||
... path="/docs/async_readme",
|
||||
... content="Async README content",
|
||||
... summary="Async documentation"
|
||||
... )
|
||||
... created = await client.create_context(context)
|
||||
...
|
||||
... # Search with async
|
||||
... results = await client.search_contexts("async README")
|
||||
... async for result in results:
|
||||
... print(f"Found: {result.context.path}")
|
||||
>>>
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClientConfig] = None, **kwargs):
|
||||
"""
|
||||
Initialize async HCFS client.
|
||||
|
||||
Args:
|
||||
config: Client configuration object
|
||||
**kwargs: Configuration overrides
|
||||
"""
|
||||
# Merge configuration
|
||||
if config:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = ClientConfig(**kwargs)
|
||||
|
||||
# HTTP client will be initialized in __aenter__
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
self.websocket: Optional[websockets.WebSocketServerProtocol] = None
|
||||
self._websocket_listeners: List[Callable[[StreamEvent], None]] = []
|
||||
self._websocket_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Initialize cache
|
||||
self._cache = MemoryCache(
|
||||
max_size=self.config.cache.max_size,
|
||||
strategy=self.config.cache.strategy,
|
||||
ttl_seconds=self.config.cache.ttl_seconds
|
||||
) if self.config.cache.enabled else None
|
||||
|
||||
# Analytics
|
||||
self.analytics = AnalyticsData()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self._initialize_http_client()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
|
||||
async def _initialize_http_client(self):
|
||||
"""Initialize the HTTP client with proper configuration."""
|
||||
headers = {
|
||||
"User-Agent": self.config.user_agent,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
if self.config.api_key:
|
||||
headers["X-API-Key"] = self.config.api_key
|
||||
elif self.config.jwt_token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwt_token}"
|
||||
|
||||
# Configure timeouts
|
||||
timeout = httpx.Timeout(
|
||||
connect=self.config.timeout,
|
||||
read=self.config.timeout,
|
||||
write=self.config.timeout,
|
||||
pool=self.config.timeout * 2
|
||||
)
|
||||
|
||||
# Configure connection limits
|
||||
limits = httpx.Limits(
|
||||
max_connections=self.config.max_connections,
|
||||
max_keepalive_connections=self.config.max_keepalive_connections
|
||||
)
|
||||
|
||||
self.http_client = httpx.AsyncClient(
|
||||
base_url=self.config.base_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
limits=limits,
|
||||
follow_redirects=True
|
||||
)
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Check API health status asynchronously.
|
||||
|
||||
Returns:
|
||||
Health status information
|
||||
|
||||
Raises:
|
||||
HCFSConnectionError: If health check fails
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.get("/health")
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("health_check", success=True)
|
||||
return response.json()
|
||||
else:
|
||||
self._update_analytics("health_check", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("health_check", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Health check failed: {str(e)}")
|
||||
|
||||
@cached_context()
|
||||
@retry_on_failure()
|
||||
async def create_context(self, context: Context) -> Context:
|
||||
"""
|
||||
Create a new context asynchronously.
|
||||
|
||||
Args:
|
||||
context: Context object to create
|
||||
|
||||
Returns:
|
||||
Created context with assigned ID
|
||||
|
||||
Raises:
|
||||
HCFSValidationError: If context data is invalid
|
||||
HCFSError: If creation fails
|
||||
"""
|
||||
if not validate_path(context.path):
|
||||
raise HCFSValidationError(f"Invalid context path: {context.path}")
|
||||
|
||||
context.path = normalize_path(context.path)
|
||||
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
"/api/v1/contexts",
|
||||
json=context.to_create_dict()
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
created_context = Context(**data)
|
||||
self._update_analytics("create_context", success=True)
|
||||
return created_context
|
||||
else:
|
||||
self._update_analytics("create_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("create_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to create context: {str(e)}")
|
||||
|
||||
@cached_context()
|
||||
async def get_context(self, context_id: int) -> Context:
|
||||
"""
|
||||
Retrieve a context by ID asynchronously.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
|
||||
Returns:
|
||||
Context object
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.get(f"/api/v1/contexts/{context_id}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
context = Context(**data)
|
||||
self._update_analytics("get_context", success=True)
|
||||
return context
|
||||
else:
|
||||
self._update_analytics("get_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("get_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to get context: {str(e)}")
|
||||
|
||||
async def list_contexts(self,
|
||||
filter_opts: Optional[ContextFilter] = None,
|
||||
pagination: Optional[PaginationOptions] = None) -> List[Context]:
|
||||
"""
|
||||
List contexts with filtering and pagination asynchronously.
|
||||
|
||||
Args:
|
||||
filter_opts: Context filtering options
|
||||
pagination: Pagination configuration
|
||||
|
||||
Returns:
|
||||
List of contexts
|
||||
"""
|
||||
params = {}
|
||||
|
||||
if filter_opts:
|
||||
params.update(filter_opts.to_query_params())
|
||||
|
||||
if pagination:
|
||||
params.update(pagination.to_query_params())
|
||||
|
||||
try:
|
||||
response = await self.http_client.get("/api/v1/contexts", params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
contexts = [Context(**ctx_data) for ctx_data in data]
|
||||
self._update_analytics("list_contexts", success=True)
|
||||
return contexts
|
||||
else:
|
||||
self._update_analytics("list_contexts", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("list_contexts", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to list contexts: {str(e)}")
|
||||
|
||||
async def update_context(self, context_id: int, updates: Dict[str, Any]) -> Context:
|
||||
"""
|
||||
Update an existing context asynchronously.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
updates: Fields to update
|
||||
|
||||
Returns:
|
||||
Updated context
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
HCFSValidationError: If update data is invalid
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.put(
|
||||
f"/api/v1/contexts/{context_id}",
|
||||
json=updates
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
updated_context = Context(**data)
|
||||
self._update_analytics("update_context", success=True)
|
||||
|
||||
# Invalidate cache
|
||||
if self._cache:
|
||||
cache_key = f"get_context:{context_id}"
|
||||
self._cache.remove(cache_key)
|
||||
|
||||
return updated_context
|
||||
else:
|
||||
self._update_analytics("update_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("update_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to update context: {str(e)}")
|
||||
|
||||
async def delete_context(self, context_id: int) -> bool:
|
||||
"""
|
||||
Delete a context asynchronously.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
|
||||
Returns:
|
||||
True if deletion was successful
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.delete(f"/api/v1/contexts/{context_id}")
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("delete_context", success=True)
|
||||
|
||||
# Invalidate cache
|
||||
if self._cache:
|
||||
cache_key = f"get_context:{context_id}"
|
||||
self._cache.remove(cache_key)
|
||||
|
||||
return True
|
||||
else:
|
||||
self._update_analytics("delete_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("delete_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to delete context: {str(e)}")
|
||||
|
||||
@rate_limited(requests_per_second=10.0)
|
||||
async def search_contexts(self,
|
||||
query: str,
|
||||
options: Optional[SearchOptions] = None) -> List[SearchResult]:
|
||||
"""
|
||||
Search contexts asynchronously using various search methods.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
options: Search configuration options
|
||||
|
||||
Returns:
|
||||
List of search results ordered by relevance
|
||||
"""
|
||||
search_opts = options or SearchOptions()
|
||||
|
||||
request_data = {
|
||||
"query": query,
|
||||
**search_opts.to_request_dict()
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
"/api/v1/search",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
results = []
|
||||
|
||||
for result_data in data:
|
||||
context = Context(**result_data["context"])
|
||||
search_result = SearchResult(
|
||||
context=context,
|
||||
score=result_data["score"],
|
||||
explanation=result_data.get("explanation"),
|
||||
highlights=result_data.get("highlights", [])
|
||||
)
|
||||
results.append(search_result)
|
||||
|
||||
self._update_analytics("search_contexts", success=True)
|
||||
return sorted(results, key=lambda x: x.score, reverse=True)
|
||||
else:
|
||||
self._update_analytics("search_contexts", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("search_contexts", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Search failed: {str(e)}")
|
||||
|
||||
async def batch_create_contexts(self, contexts: List[Context]) -> BatchResult:
|
||||
"""
|
||||
Create multiple contexts in a single batch operation asynchronously.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to create
|
||||
|
||||
Returns:
|
||||
Batch operation results
|
||||
"""
|
||||
request_data = {
|
||||
"contexts": [ctx.to_create_dict() for ctx in contexts]
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
"/api/v1/contexts/batch",
|
||||
json=request_data,
|
||||
timeout=self.config.timeout * 3 # Extended timeout for batch ops
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
|
||||
result = BatchResult(
|
||||
success_count=data["success_count"],
|
||||
error_count=data["error_count"],
|
||||
total_items=data["total_items"],
|
||||
successful_items=data.get("created_ids", []),
|
||||
failed_items=data.get("errors", []),
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
self._update_analytics("batch_create", success=True)
|
||||
return result
|
||||
else:
|
||||
self._update_analytics("batch_create", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
execution_time = time.time() - start_time
|
||||
self._update_analytics("batch_create", success=False, error=str(e))
|
||||
|
||||
return BatchResult(
|
||||
success_count=0,
|
||||
error_count=len(contexts),
|
||||
total_items=len(contexts),
|
||||
successful_items=[],
|
||||
failed_items=[{"error": str(e)}],
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive system statistics asynchronously.
|
||||
|
||||
Returns:
|
||||
System statistics and metrics
|
||||
"""
|
||||
try:
|
||||
response = await self.http_client.get("/api/v1/stats")
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("get_statistics", success=True)
|
||||
return response.json()
|
||||
else:
|
||||
self._update_analytics("get_statistics", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
self._update_analytics("get_statistics", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to get statistics: {str(e)}")
|
||||
|
||||
async def iterate_contexts(self,
|
||||
filter_opts: Optional[ContextFilter] = None,
|
||||
page_size: int = 100) -> AsyncIterator[Context]:
|
||||
"""
|
||||
Asynchronously iterate through all contexts with automatic pagination.
|
||||
|
||||
Args:
|
||||
filter_opts: Context filtering options
|
||||
page_size: Number of contexts per page
|
||||
|
||||
Yields:
|
||||
Context objects
|
||||
"""
|
||||
page = 1
|
||||
|
||||
while True:
|
||||
pagination = PaginationOptions(page=page, page_size=page_size)
|
||||
contexts = await self.list_contexts(filter_opts, pagination)
|
||||
|
||||
if not contexts:
|
||||
break
|
||||
|
||||
for context in contexts:
|
||||
yield context
|
||||
|
||||
# If we got fewer contexts than requested, we've reached the end
|
||||
if len(contexts) < page_size:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
async def connect_websocket(self,
|
||||
path_prefix: Optional[str] = None,
|
||||
event_types: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Connect to WebSocket for real-time updates.
|
||||
|
||||
Args:
|
||||
path_prefix: Filter events by path prefix
|
||||
event_types: List of event types to subscribe to
|
||||
|
||||
Raises:
|
||||
HCFSStreamError: If WebSocket connection fails
|
||||
"""
|
||||
if self.websocket and not self.websocket.closed:
|
||||
return # Already connected
|
||||
|
||||
# Convert HTTP URL to WebSocket URL
|
||||
ws_url = self.config.base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url += "/ws"
|
||||
|
||||
# Add authentication headers
|
||||
headers = {}
|
||||
if self.config.api_key:
|
||||
headers["X-API-Key"] = self.config.api_key
|
||||
elif self.config.jwt_token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwt_token}"
|
||||
|
||||
try:
|
||||
self.websocket = await websockets.connect(
|
||||
ws_url,
|
||||
extra_headers=headers,
|
||||
ping_interval=self.config.websocket.ping_interval,
|
||||
ping_timeout=self.config.websocket.ping_timeout
|
||||
)
|
||||
|
||||
# Send subscription request
|
||||
subscription = {
|
||||
"type": "subscribe",
|
||||
"data": {
|
||||
"path_prefix": path_prefix,
|
||||
"event_types": event_types or ["created", "updated", "deleted"],
|
||||
"filters": {}
|
||||
}
|
||||
}
|
||||
|
||||
await self.websocket.send(json.dumps(subscription))
|
||||
|
||||
# Start listening task
|
||||
self._websocket_task = asyncio.create_task(self._websocket_listener())
|
||||
|
||||
except (WebSocketException, ConnectionClosed) as e:
|
||||
raise HCFSStreamError(f"Failed to connect to WebSocket: {str(e)}")
|
||||
|
||||
async def disconnect_websocket(self) -> None:
|
||||
"""Disconnect from WebSocket."""
|
||||
if self._websocket_task:
|
||||
self._websocket_task.cancel()
|
||||
try:
|
||||
await self._websocket_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._websocket_task = None
|
||||
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
|
||||
def add_event_listener(self, listener: Callable[[StreamEvent], None]) -> None:
|
||||
"""
|
||||
Add an event listener for WebSocket events.
|
||||
|
||||
Args:
|
||||
listener: Function to call when events are received
|
||||
"""
|
||||
self._websocket_listeners.append(listener)
|
||||
|
||||
def remove_event_listener(self, listener: Callable[[StreamEvent], None]) -> None:
|
||||
"""
|
||||
Remove an event listener.
|
||||
|
||||
Args:
|
||||
listener: Function to remove
|
||||
"""
|
||||
if listener in self._websocket_listeners:
|
||||
self._websocket_listeners.remove(listener)
|
||||
|
||||
async def _websocket_listener(self) -> None:
|
||||
"""Internal WebSocket message listener."""
|
||||
try:
|
||||
async for message in self.websocket:
|
||||
try:
|
||||
data = json.loads(message)
|
||||
event = StreamEvent(
|
||||
event_type=data.get("type", "unknown"),
|
||||
data=data.get("data", {}),
|
||||
timestamp=datetime.fromisoformat(data.get("timestamp", datetime.utcnow().isoformat())),
|
||||
context_id=data.get("context_id"),
|
||||
path=data.get("path")
|
||||
)
|
||||
|
||||
# Notify all listeners
|
||||
for listener in self._websocket_listeners:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(listener):
|
||||
await listener(event)
|
||||
else:
|
||||
listener(event)
|
||||
except Exception:
|
||||
pass # Don't let listener errors break the connection
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass # Ignore malformed messages
|
||||
|
||||
except (WebSocketException, ConnectionClosed):
|
||||
# Connection was closed, attempt reconnection if configured
|
||||
if self.config.websocket.auto_reconnect:
|
||||
await self._attempt_websocket_reconnection()
|
||||
|
||||
async def _attempt_websocket_reconnection(self) -> None:
|
||||
"""Attempt to reconnect WebSocket with backoff."""
|
||||
for attempt in range(self.config.websocket.max_reconnect_attempts):
|
||||
try:
|
||||
await asyncio.sleep(self.config.websocket.reconnect_interval)
|
||||
await self.connect_websocket()
|
||||
return # Successfully reconnected
|
||||
except Exception:
|
||||
continue # Try again
|
||||
|
||||
# All reconnection attempts failed
|
||||
raise HCFSStreamError("Failed to reconnect WebSocket after multiple attempts")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached data."""
|
||||
if self._cache:
|
||||
self._cache.clear()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._cache:
|
||||
stats = self._cache.stats()
|
||||
self.analytics.cache_stats = stats
|
||||
return stats
|
||||
return {}
|
||||
|
||||
def get_analytics(self) -> AnalyticsData:
|
||||
"""
|
||||
Get client analytics and usage statistics.
|
||||
|
||||
Returns:
|
||||
Analytics data including operation counts and performance metrics
|
||||
"""
|
||||
# Update cache stats
|
||||
if self._cache:
|
||||
self.analytics.cache_stats = self._cache.stats()
|
||||
|
||||
return self.analytics
|
||||
|
||||
def _update_analytics(self, operation: str, success: bool, error: Optional[str] = None):
|
||||
"""Update internal analytics tracking."""
|
||||
self.analytics.operation_count[operation] = self.analytics.operation_count.get(operation, 0) + 1
|
||||
|
||||
if not success:
|
||||
error_key = error or "unknown_error"
|
||||
self.analytics.error_stats[error_key] = self.analytics.error_stats.get(error_key, 0) + 1
|
||||
|
||||
async def close(self):
|
||||
"""Close the client and cleanup resources."""
|
||||
await self.disconnect_websocket()
|
||||
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
539
hcfs-python/hcfs/sdk/client.py
Normal file
539
hcfs-python/hcfs/sdk/client.py
Normal file
@@ -0,0 +1,539 @@
|
||||
"""
|
||||
HCFS Synchronous Client
|
||||
|
||||
High-level synchronous client for HCFS API operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any, Iterator
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from .models import (
|
||||
Context, SearchResult, ContextFilter, PaginationOptions,
|
||||
SearchOptions, ClientConfig, AnalyticsData, BatchResult
|
||||
)
|
||||
from .exceptions import (
|
||||
HCFSError, HCFSConnectionError, HCFSAuthenticationError,
|
||||
HCFSNotFoundError, HCFSValidationError, handle_api_error
|
||||
)
|
||||
from .utils import MemoryCache, validate_path, normalize_path
|
||||
from .decorators import cached_context, retry_on_failure, rate_limited
|
||||
|
||||
|
||||
class HCFSClient:
|
||||
"""
|
||||
Synchronous HCFS API client with caching and retry capabilities.
|
||||
|
||||
This client provides a high-level interface for interacting with the HCFS API,
|
||||
including context management, search operations, and batch processing.
|
||||
|
||||
Example:
|
||||
>>> from hcfs.sdk import HCFSClient, Context
|
||||
>>>
|
||||
>>> # Initialize client
|
||||
>>> client = HCFSClient(
|
||||
... base_url="https://api.hcfs.example.com",
|
||||
... api_key="your-api-key"
|
||||
... )
|
||||
>>>
|
||||
>>> # Create a context
|
||||
>>> context = Context(
|
||||
... path="/docs/readme",
|
||||
... content="This is a README file",
|
||||
... summary="Project documentation"
|
||||
... )
|
||||
>>> created = client.create_context(context)
|
||||
>>>
|
||||
>>> # Search contexts
|
||||
>>> results = client.search_contexts("README documentation")
|
||||
>>> for result in results:
|
||||
... print(f"Found: {result.context.path} (score: {result.score})")
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClientConfig] = None, **kwargs):
|
||||
"""
|
||||
Initialize HCFS client.
|
||||
|
||||
Args:
|
||||
config: Client configuration object
|
||||
**kwargs: Configuration overrides (base_url, api_key, etc.)
|
||||
"""
|
||||
# Merge configuration
|
||||
if config:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = ClientConfig(**kwargs)
|
||||
|
||||
# Initialize session with retry strategy
|
||||
self.session = requests.Session()
|
||||
|
||||
# Configure retries
|
||||
retry_strategy = Retry(
|
||||
total=self.config.retry.max_attempts if self.config.retry.enabled else 0,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
backoff_factor=self.config.retry.base_delay,
|
||||
raise_on_status=False
|
||||
)
|
||||
|
||||
adapter = HTTPAdapter(
|
||||
max_retries=retry_strategy,
|
||||
pool_connections=self.config.max_connections,
|
||||
pool_maxsize=self.config.max_keepalive_connections
|
||||
)
|
||||
|
||||
self.session.mount("http://", adapter)
|
||||
self.session.mount("https://", adapter)
|
||||
|
||||
# Set headers
|
||||
self.session.headers.update({
|
||||
"User-Agent": self.config.user_agent,
|
||||
"Content-Type": "application/json"
|
||||
})
|
||||
|
||||
if self.config.api_key:
|
||||
self.session.headers["X-API-Key"] = self.config.api_key
|
||||
elif self.config.jwt_token:
|
||||
self.session.headers["Authorization"] = f"Bearer {self.config.jwt_token}"
|
||||
|
||||
# Initialize cache
|
||||
self._cache = MemoryCache(
|
||||
max_size=self.config.cache.max_size,
|
||||
strategy=self.config.cache.strategy,
|
||||
ttl_seconds=self.config.cache.ttl_seconds
|
||||
) if self.config.cache.enabled else None
|
||||
|
||||
# Analytics
|
||||
self.analytics = AnalyticsData()
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Check API health status.
|
||||
|
||||
Returns:
|
||||
Health status information
|
||||
|
||||
Raises:
|
||||
HCFSConnectionError: If health check fails
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(
|
||||
f"{self.config.base_url}/health",
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("health_check", success=True)
|
||||
return response.json()
|
||||
else:
|
||||
self._update_analytics("health_check", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("health_check", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Health check failed: {str(e)}")
|
||||
|
||||
@cached_context()
|
||||
@retry_on_failure()
|
||||
def create_context(self, context: Context) -> Context:
|
||||
"""
|
||||
Create a new context.
|
||||
|
||||
Args:
|
||||
context: Context object to create
|
||||
|
||||
Returns:
|
||||
Created context with assigned ID
|
||||
|
||||
Raises:
|
||||
HCFSValidationError: If context data is invalid
|
||||
HCFSError: If creation fails
|
||||
"""
|
||||
if not validate_path(context.path):
|
||||
raise HCFSValidationError(f"Invalid context path: {context.path}")
|
||||
|
||||
context.path = normalize_path(context.path)
|
||||
|
||||
try:
|
||||
response = self.session.post(
|
||||
f"{self.config.base_url}/api/v1/contexts",
|
||||
json=context.to_create_dict(),
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
created_context = Context(**data)
|
||||
self._update_analytics("create_context", success=True)
|
||||
return created_context
|
||||
else:
|
||||
self._update_analytics("create_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("create_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to create context: {str(e)}")
|
||||
|
||||
@cached_context()
|
||||
def get_context(self, context_id: int) -> Context:
|
||||
"""
|
||||
Retrieve a context by ID.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
|
||||
Returns:
|
||||
Context object
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(
|
||||
f"{self.config.base_url}/api/v1/contexts/{context_id}",
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
context = Context(**data)
|
||||
self._update_analytics("get_context", success=True)
|
||||
return context
|
||||
else:
|
||||
self._update_analytics("get_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("get_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to get context: {str(e)}")
|
||||
|
||||
def list_contexts(self,
|
||||
filter_opts: Optional[ContextFilter] = None,
|
||||
pagination: Optional[PaginationOptions] = None) -> List[Context]:
|
||||
"""
|
||||
List contexts with filtering and pagination.
|
||||
|
||||
Args:
|
||||
filter_opts: Context filtering options
|
||||
pagination: Pagination configuration
|
||||
|
||||
Returns:
|
||||
List of contexts
|
||||
"""
|
||||
params = {}
|
||||
|
||||
if filter_opts:
|
||||
params.update(filter_opts.to_query_params())
|
||||
|
||||
if pagination:
|
||||
params.update(pagination.to_query_params())
|
||||
|
||||
try:
|
||||
response = self.session.get(
|
||||
f"{self.config.base_url}/api/v1/contexts",
|
||||
params=params,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
contexts = [Context(**ctx_data) for ctx_data in data]
|
||||
self._update_analytics("list_contexts", success=True)
|
||||
return contexts
|
||||
else:
|
||||
self._update_analytics("list_contexts", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("list_contexts", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to list contexts: {str(e)}")
|
||||
|
||||
def update_context(self, context_id: int, updates: Dict[str, Any]) -> Context:
|
||||
"""
|
||||
Update an existing context.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
updates: Fields to update
|
||||
|
||||
Returns:
|
||||
Updated context
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
HCFSValidationError: If update data is invalid
|
||||
"""
|
||||
try:
|
||||
response = self.session.put(
|
||||
f"{self.config.base_url}/api/v1/contexts/{context_id}",
|
||||
json=updates,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
updated_context = Context(**data)
|
||||
self._update_analytics("update_context", success=True)
|
||||
|
||||
# Invalidate cache
|
||||
if self._cache:
|
||||
cache_key = f"get_context:{context_id}"
|
||||
self._cache.remove(cache_key)
|
||||
|
||||
return updated_context
|
||||
else:
|
||||
self._update_analytics("update_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("update_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to update context: {str(e)}")
|
||||
|
||||
def delete_context(self, context_id: int) -> bool:
|
||||
"""
|
||||
Delete a context.
|
||||
|
||||
Args:
|
||||
context_id: Context identifier
|
||||
|
||||
Returns:
|
||||
True if deletion was successful
|
||||
|
||||
Raises:
|
||||
HCFSNotFoundError: If context doesn't exist
|
||||
"""
|
||||
try:
|
||||
response = self.session.delete(
|
||||
f"{self.config.base_url}/api/v1/contexts/{context_id}",
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("delete_context", success=True)
|
||||
|
||||
# Invalidate cache
|
||||
if self._cache:
|
||||
cache_key = f"get_context:{context_id}"
|
||||
self._cache.remove(cache_key)
|
||||
|
||||
return True
|
||||
else:
|
||||
self._update_analytics("delete_context", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("delete_context", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to delete context: {str(e)}")
|
||||
|
||||
@rate_limited(requests_per_second=10.0)
|
||||
def search_contexts(self,
|
||||
query: str,
|
||||
options: Optional[SearchOptions] = None) -> List[SearchResult]:
|
||||
"""
|
||||
Search contexts using various search methods.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
options: Search configuration options
|
||||
|
||||
Returns:
|
||||
List of search results ordered by relevance
|
||||
"""
|
||||
search_opts = options or SearchOptions()
|
||||
|
||||
request_data = {
|
||||
"query": query,
|
||||
**search_opts.to_request_dict()
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.session.post(
|
||||
f"{self.config.base_url}/api/v1/search",
|
||||
json=request_data,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
results = []
|
||||
|
||||
for result_data in data:
|
||||
context = Context(**result_data["context"])
|
||||
search_result = SearchResult(
|
||||
context=context,
|
||||
score=result_data["score"],
|
||||
explanation=result_data.get("explanation"),
|
||||
highlights=result_data.get("highlights", [])
|
||||
)
|
||||
results.append(search_result)
|
||||
|
||||
self._update_analytics("search_contexts", success=True)
|
||||
return sorted(results, key=lambda x: x.score, reverse=True)
|
||||
else:
|
||||
self._update_analytics("search_contexts", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("search_contexts", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Search failed: {str(e)}")
|
||||
|
||||
def batch_create_contexts(self, contexts: List[Context]) -> BatchResult:
|
||||
"""
|
||||
Create multiple contexts in a single batch operation.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to create
|
||||
|
||||
Returns:
|
||||
Batch operation results
|
||||
"""
|
||||
request_data = {
|
||||
"contexts": [ctx.to_create_dict() for ctx in contexts]
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = self.session.post(
|
||||
f"{self.config.base_url}/api/v1/contexts/batch",
|
||||
json=request_data,
|
||||
timeout=self.config.timeout * 3 # Extended timeout for batch ops
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()["data"]
|
||||
|
||||
result = BatchResult(
|
||||
success_count=data["success_count"],
|
||||
error_count=data["error_count"],
|
||||
total_items=data["total_items"],
|
||||
successful_items=data.get("created_ids", []),
|
||||
failed_items=data.get("errors", []),
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
self._update_analytics("batch_create", success=True)
|
||||
return result
|
||||
else:
|
||||
self._update_analytics("batch_create", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
execution_time = time.time() - start_time
|
||||
self._update_analytics("batch_create", success=False, error=str(e))
|
||||
|
||||
return BatchResult(
|
||||
success_count=0,
|
||||
error_count=len(contexts),
|
||||
total_items=len(contexts),
|
||||
successful_items=[],
|
||||
failed_items=[{"error": str(e)}],
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive system statistics.
|
||||
|
||||
Returns:
|
||||
System statistics and metrics
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(
|
||||
f"{self.config.base_url}/api/v1/stats",
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self._update_analytics("get_statistics", success=True)
|
||||
return response.json()
|
||||
else:
|
||||
self._update_analytics("get_statistics", success=False)
|
||||
handle_api_error(response)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._update_analytics("get_statistics", success=False, error=str(e))
|
||||
raise HCFSConnectionError(f"Failed to get statistics: {str(e)}")
|
||||
|
||||
def iterate_contexts(self,
|
||||
filter_opts: Optional[ContextFilter] = None,
|
||||
page_size: int = 100) -> Iterator[Context]:
|
||||
"""
|
||||
Iterate through all contexts with automatic pagination.
|
||||
|
||||
Args:
|
||||
filter_opts: Context filtering options
|
||||
page_size: Number of contexts per page
|
||||
|
||||
Yields:
|
||||
Context objects
|
||||
"""
|
||||
page = 1
|
||||
|
||||
while True:
|
||||
pagination = PaginationOptions(page=page, page_size=page_size)
|
||||
contexts = self.list_contexts(filter_opts, pagination)
|
||||
|
||||
if not contexts:
|
||||
break
|
||||
|
||||
for context in contexts:
|
||||
yield context
|
||||
|
||||
# If we got fewer contexts than requested, we've reached the end
|
||||
if len(contexts) < page_size:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached data."""
|
||||
if self._cache:
|
||||
self._cache.clear()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._cache:
|
||||
stats = self._cache.stats()
|
||||
self.analytics.cache_stats = stats
|
||||
return stats
|
||||
return {}
|
||||
|
||||
def get_analytics(self) -> AnalyticsData:
|
||||
"""
|
||||
Get client analytics and usage statistics.
|
||||
|
||||
Returns:
|
||||
Analytics data including operation counts and performance metrics
|
||||
"""
|
||||
# Update cache stats
|
||||
if self._cache:
|
||||
self.analytics.cache_stats = self._cache.stats()
|
||||
|
||||
return self.analytics
|
||||
|
||||
def _update_analytics(self, operation: str, success: bool, error: Optional[str] = None):
|
||||
"""Update internal analytics tracking."""
|
||||
self.analytics.operation_count[operation] = self.analytics.operation_count.get(operation, 0) + 1
|
||||
|
||||
if not success:
|
||||
error_key = error or "unknown_error"
|
||||
self.analytics.error_stats[error_key] = self.analytics.error_stats.get(error_key, 0) + 1
|
||||
|
||||
def close(self):
|
||||
"""Close the client and cleanup resources."""
|
||||
self.session.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
472
hcfs-python/hcfs/sdk/decorators.py
Normal file
472
hcfs-python/hcfs/sdk/decorators.py
Normal file
@@ -0,0 +1,472 @@
|
||||
"""
|
||||
HCFS SDK Decorators
|
||||
|
||||
Decorators for caching, retry logic, rate limiting, and context management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
from functools import wraps
|
||||
from typing import Optional, Dict, Any, Callable, List
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .models import RetryConfig, RetryStrategy, CacheConfig
|
||||
from .exceptions import HCFSError, HCFSRateLimitError, HCFSTimeoutError
|
||||
from .utils import MemoryCache, cache_key
|
||||
|
||||
|
||||
def cached_context(cache_config: Optional[CacheConfig] = None, key_func: Optional[Callable] = None):
|
||||
"""
|
||||
Decorator to cache context-related operations.
|
||||
|
||||
Args:
|
||||
cache_config: Cache configuration
|
||||
key_func: Custom function to generate cache keys
|
||||
"""
|
||||
config = cache_config or CacheConfig()
|
||||
cache = MemoryCache(
|
||||
max_size=config.max_size,
|
||||
strategy=config.strategy,
|
||||
ttl_seconds=config.ttl_seconds
|
||||
)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not config.enabled:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Generate cache key
|
||||
if key_func:
|
||||
key = key_func(*args, **kwargs)
|
||||
else:
|
||||
key = cache_key(func.__name__, *args, **kwargs)
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = cache.get(key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Execute function and cache result
|
||||
result = await func(*args, **kwargs)
|
||||
cache.put(key, result)
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
if not config.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Generate cache key
|
||||
if key_func:
|
||||
key = key_func(*args, **kwargs)
|
||||
else:
|
||||
key = cache_key(func.__name__, *args, **kwargs)
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = cache.get(key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Execute function and cache result
|
||||
result = func(*args, **kwargs)
|
||||
cache.put(key, result)
|
||||
return result
|
||||
|
||||
# Attach cache management methods
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
async_wrapper.cache = cache
|
||||
async_wrapper.clear_cache = cache.clear
|
||||
async_wrapper.cache_stats = cache.stats
|
||||
return async_wrapper
|
||||
else:
|
||||
sync_wrapper.cache = cache
|
||||
sync_wrapper.clear_cache = cache.clear
|
||||
sync_wrapper.cache_stats = cache.stats
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def retry_on_failure(retry_config: Optional[RetryConfig] = None):
|
||||
"""
|
||||
Decorator to retry failed operations with configurable strategies.
|
||||
|
||||
Args:
|
||||
retry_config: Retry configuration
|
||||
"""
|
||||
config = retry_config or RetryConfig()
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not config.enabled:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(config.max_attempts):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not _should_retry_exception(e, config):
|
||||
raise e
|
||||
|
||||
# Don't delay on the last attempt
|
||||
if attempt < config.max_attempts - 1:
|
||||
delay = _calculate_delay(attempt, config)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed, raise the last exception
|
||||
raise last_exception
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
if not config.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(config.max_attempts):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not _should_retry_exception(e, config):
|
||||
raise e
|
||||
|
||||
# Don't delay on the last attempt
|
||||
if attempt < config.max_attempts - 1:
|
||||
delay = _calculate_delay(attempt, config)
|
||||
time.sleep(delay)
|
||||
|
||||
# All attempts failed, raise the last exception
|
||||
raise last_exception
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _should_retry_exception(exception: Exception, config: RetryConfig) -> bool:
|
||||
"""Check if an exception should trigger a retry."""
|
||||
# Check for timeout errors
|
||||
if isinstance(exception, HCFSTimeoutError) and config.retry_on_timeout:
|
||||
return True
|
||||
|
||||
# Check for rate limit errors
|
||||
if isinstance(exception, HCFSRateLimitError):
|
||||
return True
|
||||
|
||||
# Check for HTTP status codes (if it's an HTTP-related error)
|
||||
if hasattr(exception, 'status_code'):
|
||||
return exception.status_code in config.retry_on_status
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _calculate_delay(attempt: int, config: RetryConfig) -> float:
|
||||
"""Calculate delay for retry attempt."""
|
||||
if config.strategy == RetryStrategy.EXPONENTIAL_BACKOFF:
|
||||
delay = config.base_delay * (config.backoff_multiplier ** attempt)
|
||||
elif config.strategy == RetryStrategy.LINEAR_BACKOFF:
|
||||
delay = config.base_delay + (config.base_delay * attempt)
|
||||
elif config.strategy == RetryStrategy.FIBONACCI:
|
||||
delay = config.base_delay * _fibonacci(attempt + 1)
|
||||
else: # CONSTANT_DELAY
|
||||
delay = config.base_delay
|
||||
|
||||
# Apply maximum delay limit
|
||||
delay = min(delay, config.max_delay)
|
||||
|
||||
# Add jitter if enabled
|
||||
if config.jitter:
|
||||
jitter_range = delay * 0.1 # 10% jitter
|
||||
delay += random.uniform(-jitter_range, jitter_range)
|
||||
|
||||
return max(0, delay)
|
||||
|
||||
|
||||
def _fibonacci(n: int) -> int:
|
||||
"""Calculate nth Fibonacci number."""
|
||||
if n <= 1:
|
||||
return n
|
||||
a, b = 0, 1
|
||||
for _ in range(2, n + 1):
|
||||
a, b = b, a + b
|
||||
return b
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter."""
|
||||
|
||||
def __init__(self, rate: float, burst: int = 1):
|
||||
self.rate = rate # tokens per second
|
||||
self.burst = burst # maximum tokens in bucket
|
||||
self.tokens = burst
|
||||
self.last_update = time.time()
|
||||
|
||||
def acquire(self, tokens: int = 1) -> bool:
|
||||
"""Try to acquire tokens from the bucket."""
|
||||
now = time.time()
|
||||
|
||||
# Add tokens based on elapsed time
|
||||
elapsed = now - self.last_update
|
||||
self.tokens = min(self.burst, self.tokens + elapsed * self.rate)
|
||||
self.last_update = now
|
||||
|
||||
# Check if we have enough tokens
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def time_until_tokens(self, tokens: int = 1) -> float:
|
||||
"""Calculate time until enough tokens are available."""
|
||||
if self.tokens >= tokens:
|
||||
return 0.0
|
||||
|
||||
needed_tokens = tokens - self.tokens
|
||||
return needed_tokens / self.rate
|
||||
|
||||
|
||||
def rate_limited(requests_per_second: float, burst: int = 1):
|
||||
"""
|
||||
Decorator to rate limit function calls.
|
||||
|
||||
Args:
|
||||
requests_per_second: Rate limit (requests per second)
|
||||
burst: Maximum burst size
|
||||
"""
|
||||
limiter = RateLimiter(requests_per_second, burst)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not limiter.acquire():
|
||||
wait_time = limiter.time_until_tokens()
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
if not limiter.acquire():
|
||||
raise HCFSRateLimitError()
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
if not limiter.acquire():
|
||||
wait_time = limiter.time_until_tokens()
|
||||
time.sleep(wait_time)
|
||||
|
||||
if not limiter.acquire():
|
||||
raise HCFSRateLimitError()
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context manager for HCFS operations with automatic cleanup."""
|
||||
|
||||
def __init__(self, client, auto_cleanup: bool = True):
|
||||
self.client = client
|
||||
self.auto_cleanup = auto_cleanup
|
||||
self.created_contexts: List[int] = []
|
||||
self.temp_files: List[str] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.auto_cleanup:
|
||||
self.cleanup()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.auto_cleanup:
|
||||
await self.cleanup_async()
|
||||
|
||||
def track_context(self, context_id: int):
|
||||
"""Track a created context for cleanup."""
|
||||
self.created_contexts.append(context_id)
|
||||
|
||||
def track_file(self, file_path: str):
|
||||
"""Track a temporary file for cleanup."""
|
||||
self.temp_files.append(file_path)
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup tracked resources synchronously."""
|
||||
# Cleanup contexts
|
||||
for context_id in self.created_contexts:
|
||||
try:
|
||||
self.client.delete_context(context_id)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
# Cleanup files
|
||||
import os
|
||||
for file_path in self.temp_files:
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
self.created_contexts.clear()
|
||||
self.temp_files.clear()
|
||||
|
||||
async def cleanup_async(self):
|
||||
"""Cleanup tracked resources asynchronously."""
|
||||
# Cleanup contexts
|
||||
for context_id in self.created_contexts:
|
||||
try:
|
||||
await self.client.delete_context(context_id)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
# Cleanup files
|
||||
import os
|
||||
for file_path in self.temp_files:
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
self.created_contexts.clear()
|
||||
self.temp_files.clear()
|
||||
|
||||
|
||||
def context_manager(auto_cleanup: bool = True):
|
||||
"""
|
||||
Decorator to automatically manage context lifecycle.
|
||||
|
||||
Args:
|
||||
auto_cleanup: Whether to automatically cleanup contexts on exit
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# Assume first argument is the client
|
||||
client = args[0] if args else None
|
||||
if not client:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
async with ContextManager(client, auto_cleanup) as ctx_mgr:
|
||||
# Inject context manager into kwargs
|
||||
kwargs['_context_manager'] = ctx_mgr
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# Assume first argument is the client
|
||||
client = args[0] if args else None
|
||||
if not client:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with ContextManager(client, auto_cleanup) as ctx_mgr:
|
||||
# Inject context manager into kwargs
|
||||
kwargs['_context_manager'] = ctx_mgr
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def performance_monitor(track_timing: bool = True, track_memory: bool = False):
|
||||
"""
|
||||
Decorator to monitor function performance.
|
||||
|
||||
Args:
|
||||
track_timing: Whether to track execution timing
|
||||
track_memory: Whether to track memory usage
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.time() if track_timing else None
|
||||
start_memory = None
|
||||
|
||||
if track_memory:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
start_memory = process.memory_info().rss
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# Record performance metrics
|
||||
if track_timing:
|
||||
execution_time = time.time() - start_time
|
||||
# Could store or log timing data here
|
||||
|
||||
if track_memory and start_memory:
|
||||
end_memory = process.memory_info().rss
|
||||
memory_delta = end_memory - start_memory
|
||||
# Could store or log memory usage here
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Record error metrics
|
||||
raise e
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.time() if track_timing else None
|
||||
start_memory = None
|
||||
|
||||
if track_memory:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
start_memory = process.memory_info().rss
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Record performance metrics
|
||||
if track_timing:
|
||||
execution_time = time.time() - start_time
|
||||
# Could store or log timing data here
|
||||
|
||||
if track_memory and start_memory:
|
||||
end_memory = process.memory_info().rss
|
||||
memory_delta = end_memory - start_memory
|
||||
# Could store or log memory usage here
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Record error metrics
|
||||
raise e
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
184
hcfs-python/hcfs/sdk/exceptions.py
Normal file
184
hcfs-python/hcfs/sdk/exceptions.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
HCFS SDK Exception Classes
|
||||
|
||||
Comprehensive exception hierarchy for error handling.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class HCFSError(Exception):
|
||||
"""Base exception for all HCFS SDK errors."""
|
||||
|
||||
def __init__(self, message: str, error_code: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.error_code:
|
||||
return f"[{self.error_code}] {self.message}"
|
||||
return self.message
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert exception to dictionary for serialization."""
|
||||
return {
|
||||
"type": self.__class__.__name__,
|
||||
"message": self.message,
|
||||
"error_code": self.error_code,
|
||||
"details": self.details
|
||||
}
|
||||
|
||||
|
||||
class HCFSConnectionError(HCFSError):
|
||||
"""Raised when connection to HCFS API fails."""
|
||||
|
||||
def __init__(self, message: str = "Failed to connect to HCFS API", **kwargs):
|
||||
super().__init__(message, error_code="CONNECTION_FAILED", **kwargs)
|
||||
|
||||
|
||||
class HCFSAuthenticationError(HCFSError):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
def __init__(self, message: str = "Authentication failed", **kwargs):
|
||||
super().__init__(message, error_code="AUTH_FAILED", **kwargs)
|
||||
|
||||
|
||||
class HCFSAuthorizationError(HCFSError):
|
||||
"""Raised when user lacks permissions for an operation."""
|
||||
|
||||
def __init__(self, message: str = "Insufficient permissions", **kwargs):
|
||||
super().__init__(message, error_code="INSUFFICIENT_PERMISSIONS", **kwargs)
|
||||
|
||||
|
||||
class HCFSNotFoundError(HCFSError):
|
||||
"""Raised when a requested resource is not found."""
|
||||
|
||||
def __init__(self, resource_type: str = "Resource", resource_id: str = "", **kwargs):
|
||||
message = f"{resource_type} not found"
|
||||
if resource_id:
|
||||
message += f": {resource_id}"
|
||||
super().__init__(message, error_code="NOT_FOUND", **kwargs)
|
||||
|
||||
|
||||
class HCFSValidationError(HCFSError):
|
||||
"""Raised when request validation fails."""
|
||||
|
||||
def __init__(self, message: str = "Request validation failed", validation_errors: Optional[list] = None, **kwargs):
|
||||
super().__init__(message, error_code="VALIDATION_FAILED", **kwargs)
|
||||
self.validation_errors = validation_errors or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
result["validation_errors"] = self.validation_errors
|
||||
return result
|
||||
|
||||
|
||||
class HCFSRateLimitError(HCFSError):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(self, retry_after: Optional[int] = None, **kwargs):
|
||||
message = "Rate limit exceeded"
|
||||
if retry_after:
|
||||
message += f". Retry after {retry_after} seconds"
|
||||
super().__init__(message, error_code="RATE_LIMIT_EXCEEDED", **kwargs)
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class HCFSServerError(HCFSError):
|
||||
"""Raised for server-side errors (5xx status codes)."""
|
||||
|
||||
def __init__(self, message: str = "Internal server error", status_code: Optional[int] = None, **kwargs):
|
||||
super().__init__(message, error_code="SERVER_ERROR", **kwargs)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class HCFSTimeoutError(HCFSError):
|
||||
"""Raised when a request times out."""
|
||||
|
||||
def __init__(self, operation: str = "Request", timeout_seconds: Optional[float] = None, **kwargs):
|
||||
message = f"{operation} timed out"
|
||||
if timeout_seconds:
|
||||
message += f" after {timeout_seconds}s"
|
||||
super().__init__(message, error_code="TIMEOUT", **kwargs)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
|
||||
class HCFSCacheError(HCFSError):
|
||||
"""Raised for cache-related errors."""
|
||||
|
||||
def __init__(self, message: str = "Cache operation failed", **kwargs):
|
||||
super().__init__(message, error_code="CACHE_ERROR", **kwargs)
|
||||
|
||||
|
||||
class HCFSBatchError(HCFSError):
|
||||
"""Raised for batch operation errors."""
|
||||
|
||||
def __init__(self, message: str = "Batch operation failed", failed_items: Optional[list] = None, **kwargs):
|
||||
super().__init__(message, error_code="BATCH_ERROR", **kwargs)
|
||||
self.failed_items = failed_items or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
result["failed_items"] = self.failed_items
|
||||
return result
|
||||
|
||||
|
||||
class HCFSStreamError(HCFSError):
|
||||
"""Raised for streaming/WebSocket errors."""
|
||||
|
||||
def __init__(self, message: str = "Stream operation failed", **kwargs):
|
||||
super().__init__(message, error_code="STREAM_ERROR", **kwargs)
|
||||
|
||||
|
||||
class HCFSSearchError(HCFSError):
|
||||
"""Raised for search operation errors."""
|
||||
|
||||
def __init__(self, query: str = "", search_type: str = "", **kwargs):
|
||||
message = f"Search failed"
|
||||
if search_type:
|
||||
message += f" ({search_type})"
|
||||
if query:
|
||||
message += f": '{query}'"
|
||||
super().__init__(message, error_code="SEARCH_ERROR", **kwargs)
|
||||
self.query = query
|
||||
self.search_type = search_type
|
||||
|
||||
|
||||
def handle_api_error(response) -> None:
|
||||
"""
|
||||
Convert HTTP response errors to appropriate HCFS exceptions.
|
||||
|
||||
Args:
|
||||
response: HTTP response object
|
||||
|
||||
Raises:
|
||||
Appropriate HCFSError subclass based on status code
|
||||
"""
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
error_data = response.json() if response.content else {}
|
||||
except Exception:
|
||||
error_data = {}
|
||||
|
||||
error_message = error_data.get("error", "Unknown error")
|
||||
error_details = error_data.get("error_details", [])
|
||||
|
||||
if status_code == 400:
|
||||
raise HCFSValidationError(error_message, validation_errors=error_details)
|
||||
elif status_code == 401:
|
||||
raise HCFSAuthenticationError(error_message)
|
||||
elif status_code == 403:
|
||||
raise HCFSAuthorizationError(error_message)
|
||||
elif status_code == 404:
|
||||
raise HCFSNotFoundError("Resource", error_message)
|
||||
elif status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
retry_after = int(retry_after) if retry_after else None
|
||||
raise HCFSRateLimitError(retry_after=retry_after)
|
||||
elif 500 <= status_code < 600:
|
||||
raise HCFSServerError(error_message, status_code=status_code)
|
||||
else:
|
||||
raise HCFSError(f"HTTP {status_code}: {error_message}")
|
||||
335
hcfs-python/hcfs/sdk/models.py
Normal file
335
hcfs-python/hcfs/sdk/models.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
HCFS SDK Data Models
|
||||
|
||||
Pydantic models for SDK operations and configuration.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Union, Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class ContextStatus(str, Enum):
|
||||
"""Context status enumeration."""
|
||||
ACTIVE = "active"
|
||||
ARCHIVED = "archived"
|
||||
DELETED = "deleted"
|
||||
DRAFT = "draft"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
"""Search type enumeration."""
|
||||
SEMANTIC = "semantic"
|
||||
KEYWORD = "keyword"
|
||||
HYBRID = "hybrid"
|
||||
FUZZY = "fuzzy"
|
||||
|
||||
|
||||
class CacheStrategy(str, Enum):
|
||||
"""Cache strategy enumeration."""
|
||||
LRU = "lru"
|
||||
LFU = "lfu"
|
||||
TTL = "ttl"
|
||||
FIFO = "fifo"
|
||||
|
||||
|
||||
class RetryStrategy(str, Enum):
|
||||
"""Retry strategy enumeration."""
|
||||
EXPONENTIAL_BACKOFF = "exponential_backoff"
|
||||
LINEAR_BACKOFF = "linear_backoff"
|
||||
CONSTANT_DELAY = "constant_delay"
|
||||
FIBONACCI = "fibonacci"
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
"""Context data model for SDK operations."""
|
||||
|
||||
id: Optional[int] = None
|
||||
path: str = Field(..., description="Unique context path")
|
||||
content: str = Field(..., description="Context content")
|
||||
summary: Optional[str] = Field(None, description="Brief summary")
|
||||
author: Optional[str] = Field(None, description="Context author")
|
||||
tags: List[str] = Field(default_factory=list, description="Context tags")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
status: ContextStatus = Field(default=ContextStatus.ACTIVE, description="Context status")
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
version: int = Field(default=1, description="Context version")
|
||||
similarity_score: Optional[float] = Field(None, description="Similarity score (for search results)")
|
||||
|
||||
@validator('path')
|
||||
def validate_path(cls, v):
|
||||
if not v or not v.startswith('/'):
|
||||
raise ValueError('Path must start with /')
|
||||
return v
|
||||
|
||||
@validator('content')
|
||||
def validate_content(cls, v):
|
||||
if not v or len(v.strip()) == 0:
|
||||
raise ValueError('Content cannot be empty')
|
||||
return v
|
||||
|
||||
def to_create_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for context creation."""
|
||||
return {
|
||||
"path": self.path,
|
||||
"content": self.content,
|
||||
"summary": self.summary,
|
||||
"author": self.author,
|
||||
"tags": self.tags,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
def to_update_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for context updates (excluding read-only fields)."""
|
||||
return {
|
||||
k: v for k, v in {
|
||||
"content": self.content,
|
||||
"summary": self.summary,
|
||||
"tags": self.tags,
|
||||
"metadata": self.metadata,
|
||||
"status": self.status.value
|
||||
}.items() if v is not None
|
||||
}
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Search result model."""
|
||||
|
||||
context: Context
|
||||
score: float = Field(..., description="Relevance score")
|
||||
explanation: Optional[str] = Field(None, description="Search result explanation")
|
||||
highlights: List[str] = Field(default_factory=list, description="Highlighted text snippets")
|
||||
|
||||
def __lt__(self, other):
|
||||
"""Enable sorting by score."""
|
||||
return self.score < other.score
|
||||
|
||||
def __gt__(self, other):
|
||||
"""Enable sorting by score."""
|
||||
return self.score > other.score
|
||||
|
||||
|
||||
class ContextFilter(BaseModel):
|
||||
"""Context filtering options."""
|
||||
|
||||
path_prefix: Optional[str] = Field(None, description="Filter by path prefix")
|
||||
author: Optional[str] = Field(None, description="Filter by author")
|
||||
status: Optional[ContextStatus] = Field(None, description="Filter by status")
|
||||
tags: Optional[List[str]] = Field(None, description="Filter by tags")
|
||||
created_after: Optional[datetime] = Field(None, description="Filter by creation date")
|
||||
created_before: Optional[datetime] = Field(None, description="Filter by creation date")
|
||||
content_contains: Optional[str] = Field(None, description="Filter by content substring")
|
||||
min_content_length: Optional[int] = Field(None, description="Minimum content length")
|
||||
max_content_length: Optional[int] = Field(None, description="Maximum content length")
|
||||
|
||||
def to_query_params(self) -> Dict[str, Any]:
|
||||
"""Convert to query parameters for API requests."""
|
||||
params = {}
|
||||
|
||||
if self.path_prefix:
|
||||
params["path_prefix"] = self.path_prefix
|
||||
if self.author:
|
||||
params["author"] = self.author
|
||||
if self.status:
|
||||
params["status"] = self.status.value
|
||||
if self.created_after:
|
||||
params["created_after"] = self.created_after.isoformat()
|
||||
if self.created_before:
|
||||
params["created_before"] = self.created_before.isoformat()
|
||||
if self.content_contains:
|
||||
params["content_contains"] = self.content_contains
|
||||
if self.min_content_length is not None:
|
||||
params["min_content_length"] = self.min_content_length
|
||||
if self.max_content_length is not None:
|
||||
params["max_content_length"] = self.max_content_length
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class PaginationOptions(BaseModel):
|
||||
"""Pagination configuration."""
|
||||
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
page_size: int = Field(default=20, ge=1, le=1000, description="Items per page")
|
||||
sort_by: Optional[str] = Field(None, description="Sort field")
|
||||
sort_order: str = Field(default="desc", description="Sort order (asc/desc)")
|
||||
|
||||
@validator('sort_order')
|
||||
def validate_sort_order(cls, v):
|
||||
if v not in ['asc', 'desc']:
|
||||
raise ValueError('Sort order must be "asc" or "desc"')
|
||||
return v
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
"""Calculate offset for database queries."""
|
||||
return (self.page - 1) * self.page_size
|
||||
|
||||
def to_query_params(self) -> Dict[str, Any]:
|
||||
"""Convert to query parameters."""
|
||||
params = {
|
||||
"page": self.page,
|
||||
"page_size": self.page_size,
|
||||
"sort_order": self.sort_order
|
||||
}
|
||||
if self.sort_by:
|
||||
params["sort_by"] = self.sort_by
|
||||
return params
|
||||
|
||||
|
||||
class SearchOptions(BaseModel):
|
||||
"""Search configuration options."""
|
||||
|
||||
search_type: SearchType = Field(default=SearchType.SEMANTIC, description="Type of search")
|
||||
top_k: int = Field(default=10, ge=1, le=1000, description="Maximum results to return")
|
||||
similarity_threshold: float = Field(default=0.0, ge=0.0, le=1.0, description="Minimum similarity score")
|
||||
path_prefix: Optional[str] = Field(None, description="Search within path prefix")
|
||||
semantic_weight: float = Field(default=0.7, ge=0.0, le=1.0, description="Weight for semantic search in hybrid mode")
|
||||
include_content: bool = Field(default=True, description="Include full content in results")
|
||||
include_highlights: bool = Field(default=True, description="Include text highlights")
|
||||
max_highlights: int = Field(default=3, ge=0, le=10, description="Maximum highlight snippets")
|
||||
|
||||
def to_request_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to API request dictionary."""
|
||||
return {
|
||||
"search_type": self.search_type.value,
|
||||
"top_k": self.top_k,
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"path_prefix": self.path_prefix,
|
||||
"semantic_weight": self.semantic_weight,
|
||||
"include_content": self.include_content,
|
||||
"include_highlights": self.include_highlights
|
||||
}
|
||||
|
||||
|
||||
class CacheConfig(BaseModel):
|
||||
"""Cache configuration."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Enable caching")
|
||||
strategy: CacheStrategy = Field(default=CacheStrategy.LRU, description="Cache eviction strategy")
|
||||
max_size: int = Field(default=1000, ge=1, description="Maximum cache entries")
|
||||
ttl_seconds: Optional[int] = Field(default=3600, ge=1, description="Time-to-live in seconds")
|
||||
memory_limit_mb: Optional[int] = Field(default=100, ge=1, description="Memory limit in MB")
|
||||
persist_to_disk: bool = Field(default=False, description="Persist cache to disk")
|
||||
disk_cache_path: Optional[str] = Field(None, description="Disk cache directory")
|
||||
|
||||
@validator('ttl_seconds')
|
||||
def validate_ttl(cls, v, values):
|
||||
if values.get('strategy') == CacheStrategy.TTL and v is None:
|
||||
raise ValueError('TTL must be specified for TTL cache strategy')
|
||||
return v
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""Retry configuration for failed requests."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Enable retry logic")
|
||||
max_attempts: int = Field(default=3, ge=1, le=10, description="Maximum retry attempts")
|
||||
strategy: RetryStrategy = Field(default=RetryStrategy.EXPONENTIAL_BACKOFF, description="Retry strategy")
|
||||
base_delay: float = Field(default=1.0, ge=0.1, description="Base delay in seconds")
|
||||
max_delay: float = Field(default=60.0, ge=1.0, description="Maximum delay in seconds")
|
||||
backoff_multiplier: float = Field(default=2.0, ge=1.0, description="Backoff multiplier")
|
||||
jitter: bool = Field(default=True, description="Add random jitter to delays")
|
||||
retry_on_status: List[int] = Field(
|
||||
default_factory=lambda: [429, 500, 502, 503, 504],
|
||||
description="HTTP status codes to retry on"
|
||||
)
|
||||
retry_on_timeout: bool = Field(default=True, description="Retry on timeout errors")
|
||||
|
||||
|
||||
class WebSocketConfig(BaseModel):
|
||||
"""WebSocket connection configuration."""
|
||||
|
||||
auto_reconnect: bool = Field(default=True, description="Automatically reconnect on disconnect")
|
||||
reconnect_interval: float = Field(default=5.0, ge=1.0, description="Reconnect interval in seconds")
|
||||
max_reconnect_attempts: int = Field(default=10, ge=1, description="Maximum reconnection attempts")
|
||||
ping_interval: float = Field(default=30.0, ge=1.0, description="Ping interval in seconds")
|
||||
ping_timeout: float = Field(default=10.0, ge=1.0, description="Ping timeout in seconds")
|
||||
message_queue_size: int = Field(default=1000, ge=1, description="Maximum queued messages")
|
||||
|
||||
|
||||
class ClientConfig(BaseModel):
|
||||
"""Main client configuration."""
|
||||
|
||||
base_url: str = Field(..., description="HCFS API base URL")
|
||||
api_key: Optional[str] = Field(None, description="API key for authentication")
|
||||
jwt_token: Optional[str] = Field(None, description="JWT token for authentication")
|
||||
timeout: float = Field(default=30.0, ge=1.0, description="Request timeout in seconds")
|
||||
user_agent: str = Field(default="HCFS-SDK/2.0.0", description="User agent string")
|
||||
|
||||
# Advanced configurations
|
||||
cache: CacheConfig = Field(default_factory=CacheConfig)
|
||||
retry: RetryConfig = Field(default_factory=RetryConfig)
|
||||
websocket: WebSocketConfig = Field(default_factory=WebSocketConfig)
|
||||
|
||||
# Connection pooling
|
||||
max_connections: int = Field(default=100, ge=1, description="Maximum connection pool size")
|
||||
max_keepalive_connections: int = Field(default=20, ge=1, description="Maximum keep-alive connections")
|
||||
|
||||
@validator('base_url')
|
||||
def validate_base_url(cls, v):
|
||||
if not v.startswith(('http://', 'https://')):
|
||||
raise ValueError('Base URL must start with http:// or https://')
|
||||
return v.rstrip('/')
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchResult:
|
||||
"""Result of a batch operation."""
|
||||
|
||||
success_count: int
|
||||
error_count: int
|
||||
total_items: int
|
||||
successful_items: List[Any]
|
||||
failed_items: List[Dict[str, Any]]
|
||||
execution_time: float
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate success rate."""
|
||||
return self.success_count / self.total_items if self.total_items > 0 else 0.0
|
||||
|
||||
@property
|
||||
def has_errors(self) -> bool:
|
||||
"""Check if there were any errors."""
|
||||
return self.error_count > 0
|
||||
|
||||
|
||||
class StreamEvent(BaseModel):
|
||||
"""WebSocket stream event."""
|
||||
|
||||
event_type: str = Field(..., description="Event type (created/updated/deleted)")
|
||||
data: Dict[str, Any] = Field(..., description="Event data")
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow, description="Event timestamp")
|
||||
context_id: Optional[int] = Field(None, description="Related context ID")
|
||||
path: Optional[str] = Field(None, description="Related context path")
|
||||
|
||||
def is_context_event(self) -> bool:
|
||||
"""Check if this is a context-related event."""
|
||||
return self.event_type in ['context_created', 'context_updated', 'context_deleted']
|
||||
|
||||
|
||||
class AnalyticsData(BaseModel):
|
||||
"""Analytics and usage data."""
|
||||
|
||||
operation_count: Dict[str, int] = Field(default_factory=dict, description="Operation counts")
|
||||
cache_stats: Dict[str, Any] = Field(default_factory=dict, description="Cache statistics")
|
||||
error_stats: Dict[str, int] = Field(default_factory=dict, description="Error statistics")
|
||||
performance_stats: Dict[str, float] = Field(default_factory=dict, description="Performance metrics")
|
||||
session_start: datetime = Field(default_factory=datetime.utcnow, description="Session start time")
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
hits = self.cache_stats.get('hits', 0)
|
||||
misses = self.cache_stats.get('misses', 0)
|
||||
total = hits + misses
|
||||
return hits / total if total > 0 else 0.0
|
||||
|
||||
def get_error_rate(self) -> float:
|
||||
"""Calculate overall error rate."""
|
||||
total_operations = sum(self.operation_count.values())
|
||||
total_errors = sum(self.error_stats.values())
|
||||
return total_errors / total_operations if total_operations > 0 else 0.0
|
||||
564
hcfs-python/hcfs/sdk/utils.py
Normal file
564
hcfs-python/hcfs/sdk/utils.py
Normal file
@@ -0,0 +1,564 @@
|
||||
"""
|
||||
HCFS SDK Utility Functions
|
||||
|
||||
Common utilities for text processing, caching, and data manipulation.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Tuple, Iterator, Callable, Union
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, OrderedDict
|
||||
from threading import Lock
|
||||
import asyncio
|
||||
from functools import lru_cache, wraps
|
||||
|
||||
from .models import Context, SearchResult, CacheStrategy
|
||||
from .exceptions import HCFSError, HCFSCacheError
|
||||
|
||||
|
||||
def context_similarity(context1: Context, context2: Context, method: str = "jaccard") -> float:
|
||||
"""
|
||||
Calculate similarity between two contexts.
|
||||
|
||||
Args:
|
||||
context1: First context
|
||||
context2: Second context
|
||||
method: Similarity method ("jaccard", "cosine", "levenshtein")
|
||||
|
||||
Returns:
|
||||
Similarity score between 0.0 and 1.0
|
||||
"""
|
||||
if method == "jaccard":
|
||||
return _jaccard_similarity(context1.content, context2.content)
|
||||
elif method == "cosine":
|
||||
return _cosine_similarity(context1.content, context2.content)
|
||||
elif method == "levenshtein":
|
||||
return _levenshtein_similarity(context1.content, context2.content)
|
||||
else:
|
||||
raise ValueError(f"Unknown similarity method: {method}")
|
||||
|
||||
|
||||
def _jaccard_similarity(text1: str, text2: str) -> float:
|
||||
"""Calculate Jaccard similarity between two texts."""
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
intersection = words1.intersection(words2)
|
||||
union = words1.union(words2)
|
||||
|
||||
return len(intersection) / len(union) if union else 0.0
|
||||
|
||||
|
||||
def _cosine_similarity(text1: str, text2: str) -> float:
|
||||
"""Calculate cosine similarity between two texts."""
|
||||
words1 = text1.lower().split()
|
||||
words2 = text2.lower().split()
|
||||
|
||||
# Create word frequency vectors
|
||||
all_words = set(words1 + words2)
|
||||
vector1 = [words1.count(word) for word in all_words]
|
||||
vector2 = [words2.count(word) for word in all_words]
|
||||
|
||||
# Calculate dot product and magnitudes
|
||||
dot_product = sum(a * b for a, b in zip(vector1, vector2))
|
||||
magnitude1 = math.sqrt(sum(a * a for a in vector1))
|
||||
magnitude2 = math.sqrt(sum(a * a for a in vector2))
|
||||
|
||||
if magnitude1 == 0 or magnitude2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (magnitude1 * magnitude2)
|
||||
|
||||
|
||||
def _levenshtein_similarity(text1: str, text2: str) -> float:
|
||||
"""Calculate normalized Levenshtein similarity."""
|
||||
def levenshtein_distance(s1: str, s2: str) -> int:
|
||||
if len(s1) < len(s2):
|
||||
return levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = list(range(len(s2) + 1))
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
max_len = max(len(text1), len(text2))
|
||||
if max_len == 0:
|
||||
return 1.0
|
||||
|
||||
distance = levenshtein_distance(text1.lower(), text2.lower())
|
||||
return 1.0 - (distance / max_len)
|
||||
|
||||
|
||||
def text_chunker(text: str, chunk_size: int = 512, overlap: int = 50, preserve_sentences: bool = True) -> List[str]:
|
||||
"""
|
||||
Split text into overlapping chunks.
|
||||
|
||||
Args:
|
||||
text: Text to chunk
|
||||
chunk_size: Maximum chunk size in characters
|
||||
overlap: Overlap between chunks
|
||||
preserve_sentences: Try to preserve sentence boundaries
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
if len(text) <= chunk_size:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
|
||||
if end >= len(text):
|
||||
chunks.append(text[start:])
|
||||
break
|
||||
|
||||
# Try to find a good break point
|
||||
chunk = text[start:end]
|
||||
|
||||
if preserve_sentences and '.' in chunk:
|
||||
# Find the last sentence boundary
|
||||
last_period = chunk.rfind('.')
|
||||
if last_period > chunk_size // 2: # Don't make chunks too small
|
||||
end = start + last_period + 1
|
||||
chunk = text[start:end]
|
||||
|
||||
chunks.append(chunk.strip())
|
||||
start = end - overlap
|
||||
|
||||
return [chunk for chunk in chunks if chunk.strip()]
|
||||
|
||||
|
||||
def extract_keywords(text: str, max_keywords: int = 10, min_length: int = 3) -> List[str]:
|
||||
"""
|
||||
Extract keywords from text using simple frequency analysis.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
max_keywords: Maximum number of keywords
|
||||
min_length: Minimum keyword length
|
||||
|
||||
Returns:
|
||||
List of keywords ordered by frequency
|
||||
"""
|
||||
# Simple stopwords
|
||||
stopwords = {
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
||||
'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be',
|
||||
'been', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would',
|
||||
'could', 'should', 'may', 'might', 'can', 'this', 'that', 'these',
|
||||
'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him',
|
||||
'her', 'us', 'them', 'my', 'your', 'his', 'its', 'our', 'their'
|
||||
}
|
||||
|
||||
# Extract words and count frequencies
|
||||
words = re.findall(r'\b[a-zA-Z]+\b', text.lower())
|
||||
word_freq = defaultdict(int)
|
||||
|
||||
for word in words:
|
||||
if len(word) >= min_length and word not in stopwords:
|
||||
word_freq[word] += 1
|
||||
|
||||
# Sort by frequency and return top keywords
|
||||
return sorted(word_freq.keys(), key=lambda x: word_freq[x], reverse=True)[:max_keywords]
|
||||
|
||||
|
||||
def format_content_preview(content: str, max_length: int = 200) -> str:
|
||||
"""
|
||||
Format content for preview display.
|
||||
|
||||
Args:
|
||||
content: Full content
|
||||
max_length: Maximum preview length
|
||||
|
||||
Returns:
|
||||
Formatted preview string
|
||||
"""
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
|
||||
# Try to cut at word boundary
|
||||
preview = content[:max_length]
|
||||
last_space = preview.rfind(' ')
|
||||
|
||||
if last_space > max_length * 0.8: # Don't cut too much
|
||||
preview = preview[:last_space]
|
||||
|
||||
return preview + "..."
|
||||
|
||||
|
||||
def validate_path(path: str) -> bool:
|
||||
"""
|
||||
Validate context path format.
|
||||
|
||||
Args:
|
||||
path: Path to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
if not path or not isinstance(path, str):
|
||||
return False
|
||||
|
||||
if not path.startswith('/'):
|
||||
return False
|
||||
|
||||
# Check for invalid characters
|
||||
invalid_chars = set('<>"|?*')
|
||||
if any(char in path for char in invalid_chars):
|
||||
return False
|
||||
|
||||
# Check path components
|
||||
components = path.split('/')
|
||||
for component in components[1:]: # Skip empty first component
|
||||
if not component or component in ['.', '..']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def normalize_path(path: str) -> str:
|
||||
"""
|
||||
Normalize context path.
|
||||
|
||||
Args:
|
||||
path: Path to normalize
|
||||
|
||||
Returns:
|
||||
Normalized path
|
||||
"""
|
||||
if not path.startswith('/'):
|
||||
path = '/' + path
|
||||
|
||||
# Remove duplicate slashes and normalize
|
||||
components = [c for c in path.split('/') if c]
|
||||
return '/' + '/'.join(components) if components else '/'
|
||||
|
||||
|
||||
def hash_content(content: str, algorithm: str = "sha256") -> str:
|
||||
"""
|
||||
Generate hash of content for deduplication.
|
||||
|
||||
Args:
|
||||
content: Content to hash
|
||||
algorithm: Hash algorithm
|
||||
|
||||
Returns:
|
||||
Hex digest of content hash
|
||||
"""
|
||||
if algorithm == "md5":
|
||||
hasher = hashlib.md5()
|
||||
elif algorithm == "sha1":
|
||||
hasher = hashlib.sha1()
|
||||
elif algorithm == "sha256":
|
||||
hasher = hashlib.sha256()
|
||||
else:
|
||||
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
||||
|
||||
hasher.update(content.encode('utf-8'))
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def merge_contexts(contexts: List[Context], strategy: str = "latest") -> Context:
|
||||
"""
|
||||
Merge multiple contexts into one.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to merge
|
||||
strategy: Merge strategy ("latest", "longest", "combined")
|
||||
|
||||
Returns:
|
||||
Merged context
|
||||
"""
|
||||
if not contexts:
|
||||
raise ValueError("No contexts to merge")
|
||||
|
||||
if len(contexts) == 1:
|
||||
return contexts[0]
|
||||
|
||||
if strategy == "latest":
|
||||
return max(contexts, key=lambda c: c.updated_at or c.created_at or datetime.min)
|
||||
elif strategy == "longest":
|
||||
return max(contexts, key=lambda c: len(c.content))
|
||||
elif strategy == "combined":
|
||||
# Combine content and metadata
|
||||
merged = contexts[0].copy()
|
||||
merged.content = "\n\n".join(c.content for c in contexts)
|
||||
merged.tags = list(set(tag for c in contexts for tag in c.tags))
|
||||
|
||||
# Merge metadata
|
||||
merged_metadata = {}
|
||||
for context in contexts:
|
||||
merged_metadata.update(context.metadata)
|
||||
merged.metadata = merged_metadata
|
||||
|
||||
return merged
|
||||
else:
|
||||
raise ValueError(f"Unknown merge strategy: {strategy}")
|
||||
|
||||
|
||||
class MemoryCache:
|
||||
"""Thread-safe in-memory cache with configurable eviction strategies."""
|
||||
|
||||
def __init__(self, max_size: int = 1000, strategy: CacheStrategy = CacheStrategy.LRU, ttl_seconds: Optional[int] = None):
|
||||
self.max_size = max_size
|
||||
self.strategy = strategy
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self._cache = OrderedDict()
|
||||
self._access_counts = defaultdict(int)
|
||||
self._timestamps = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache."""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
|
||||
# Check TTL
|
||||
if self.ttl_seconds and key in self._timestamps:
|
||||
if time.time() - self._timestamps[key] > self.ttl_seconds:
|
||||
self._remove(key)
|
||||
return None
|
||||
|
||||
# Update access patterns
|
||||
if self.strategy == CacheStrategy.LRU:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
elif self.strategy == CacheStrategy.LFU:
|
||||
self._access_counts[key] += 1
|
||||
|
||||
return self._cache[key]
|
||||
|
||||
def put(self, key: str, value: Any) -> None:
|
||||
"""Put value in cache."""
|
||||
with self._lock:
|
||||
# Remove if already exists
|
||||
if key in self._cache:
|
||||
self._remove(key)
|
||||
|
||||
# Evict if necessary
|
||||
while len(self._cache) >= self.max_size:
|
||||
self._evict_one()
|
||||
|
||||
# Add new entry
|
||||
self._cache[key] = value
|
||||
self._timestamps[key] = time.time()
|
||||
if self.strategy == CacheStrategy.LFU:
|
||||
self._access_counts[key] = 1
|
||||
|
||||
def remove(self, key: str) -> bool:
|
||||
"""Remove key from cache."""
|
||||
with self._lock:
|
||||
return self._remove(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache entries."""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._access_counts.clear()
|
||||
self._timestamps.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_size": self.max_size,
|
||||
"strategy": self.strategy.value,
|
||||
"ttl_seconds": self.ttl_seconds,
|
||||
"keys": list(self._cache.keys())
|
||||
}
|
||||
|
||||
def _remove(self, key: str) -> bool:
|
||||
"""Remove key without lock (internal use)."""
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._access_counts.pop(key, None)
|
||||
self._timestamps.pop(key, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _evict_one(self) -> None:
|
||||
"""Evict one item based on strategy."""
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
if self.strategy == CacheStrategy.LRU:
|
||||
# Remove least recently used (first item)
|
||||
key = next(iter(self._cache))
|
||||
self._remove(key)
|
||||
elif self.strategy == CacheStrategy.LFU:
|
||||
# Remove least frequently used
|
||||
if self._access_counts:
|
||||
key = min(self._access_counts.keys(), key=lambda k: self._access_counts[k])
|
||||
self._remove(key)
|
||||
elif self.strategy == CacheStrategy.FIFO:
|
||||
# Remove first in, first out
|
||||
key = next(iter(self._cache))
|
||||
self._remove(key)
|
||||
elif self.strategy == CacheStrategy.TTL:
|
||||
# Remove expired items first, then oldest
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, timestamp in self._timestamps.items()
|
||||
if current_time - timestamp > (self.ttl_seconds or 0)
|
||||
]
|
||||
|
||||
if expired_keys:
|
||||
self._remove(expired_keys[0])
|
||||
else:
|
||||
# Remove oldest
|
||||
key = min(self._timestamps.keys(), key=lambda k: self._timestamps[k])
|
||||
self._remove(key)
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""Utility for processing items in batches with error handling."""
|
||||
|
||||
def __init__(self, batch_size: int = 10, max_workers: Optional[int] = None):
|
||||
self.batch_size = batch_size
|
||||
self.max_workers = max_workers or min(32, (len(os.sched_getaffinity(0)) or 1) + 4)
|
||||
|
||||
async def process_async(self,
|
||||
items: List[Any],
|
||||
processor: Callable[[Any], Any],
|
||||
on_success: Optional[Callable[[Any, Any], None]] = None,
|
||||
on_error: Optional[Callable[[Any, Exception], None]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Process items asynchronously in batches.
|
||||
|
||||
Args:
|
||||
items: Items to process
|
||||
processor: Async function to process each item
|
||||
on_success: Callback for successful processing
|
||||
on_error: Callback for processing errors
|
||||
|
||||
Returns:
|
||||
Processing results summary
|
||||
"""
|
||||
results = {
|
||||
"success_count": 0,
|
||||
"error_count": 0,
|
||||
"total_items": len(items),
|
||||
"successful_items": [],
|
||||
"failed_items": [],
|
||||
"execution_time": 0
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(items), self.batch_size):
|
||||
batch = items[i:i + self.batch_size]
|
||||
|
||||
# Create tasks for this batch
|
||||
tasks = []
|
||||
for item in batch:
|
||||
task = asyncio.create_task(self._process_item_async(item, processor))
|
||||
tasks.append((item, task))
|
||||
|
||||
# Wait for batch completion
|
||||
for item, task in tasks:
|
||||
try:
|
||||
result = await task
|
||||
results["success_count"] += 1
|
||||
results["successful_items"].append(result)
|
||||
|
||||
if on_success:
|
||||
on_success(item, result)
|
||||
|
||||
except Exception as e:
|
||||
results["error_count"] += 1
|
||||
results["failed_items"].append({"item": item, "error": str(e)})
|
||||
|
||||
if on_error:
|
||||
on_error(item, e)
|
||||
|
||||
results["execution_time"] = time.time() - start_time
|
||||
return results
|
||||
|
||||
async def _process_item_async(self, item: Any, processor: Callable) -> Any:
|
||||
"""Process a single item asynchronously."""
|
||||
if asyncio.iscoroutinefunction(processor):
|
||||
return await processor(item)
|
||||
else:
|
||||
# Run synchronous processor in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, processor, item)
|
||||
|
||||
|
||||
# Global instances
|
||||
embedding_cache = MemoryCache(max_size=2000, strategy=CacheStrategy.LRU, ttl_seconds=3600)
|
||||
batch_processor = BatchProcessor(batch_size=10)
|
||||
|
||||
|
||||
def cache_key(*args, **kwargs) -> str:
|
||||
"""Generate cache key from arguments."""
|
||||
key_parts = []
|
||||
|
||||
# Add positional arguments
|
||||
for arg in args:
|
||||
if isinstance(arg, (str, int, float, bool)):
|
||||
key_parts.append(str(arg))
|
||||
else:
|
||||
key_parts.append(str(hash(str(arg))))
|
||||
|
||||
# Add keyword arguments
|
||||
for k, v in sorted(kwargs.items()):
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
key_parts.append(f"{k}={v}")
|
||||
else:
|
||||
key_parts.append(f"{k}={hash(str(v))}")
|
||||
|
||||
return ":".join(key_parts)
|
||||
|
||||
|
||||
def timing_decorator(func):
|
||||
"""Decorator to measure function execution time."""
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
execution_time = time.time() - start_time
|
||||
# Could log or store timing data here
|
||||
pass
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
execution_time = time.time() - start_time
|
||||
# Could log or store timing data here
|
||||
pass
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
Reference in New Issue
Block a user