Files
HCFS/hcfs-python/hcfs/core/embeddings_optimized.py
2025-07-30 09:34:16 +10:00

616 lines
23 KiB
Python

"""
Optimized Embedding Manager - High-performance vector operations and storage.
This module provides enhanced embedding capabilities including:
- Vector database integration with SQLite-Vec
- Optimized batch processing and caching
- Multiple embedding model support
- Efficient similarity search with indexing
- Memory-efficient embedding storage
"""
import json
import time
import numpy as np
import sqlite3
from typing import List, Dict, Optional, Tuple, Union, Any
from dataclasses import dataclass, asdict
from pathlib import Path
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import threading
from contextlib import contextmanager
from functools import lru_cache
import logging
from .context_db import Context, ContextDatabase
logger = logging.getLogger(__name__)
@dataclass
class EmbeddingModel:
"""Configuration for embedding models."""
name: str
model_path: str
dimension: int
max_tokens: int = 512
normalize: bool = True
@dataclass
class VectorSearchResult:
"""Result from vector search operations."""
context_id: int
score: float
context: Optional[Context] = None
metadata: Dict[str, Any] = None
class VectorCache:
"""High-performance LRU cache for embeddings."""
def __init__(self, max_size: int = 5000, ttl_seconds: int = 3600):
self.max_size = max_size
self.ttl_seconds = ttl_seconds
self.cache: Dict[str, Tuple[np.ndarray, float]] = {}
self.access_times: Dict[str, float] = {}
self.lock = threading.RLock()
def get(self, key: str) -> Optional[np.ndarray]:
"""Get embedding from cache."""
with self.lock:
current_time = time.time()
if key in self.cache:
embedding, created_time = self.cache[key]
# Check TTL
if current_time - created_time < self.ttl_seconds:
self.access_times[key] = current_time
return embedding.copy()
else:
# Expired
del self.cache[key]
del self.access_times[key]
return None
def put(self, key: str, embedding: np.ndarray) -> None:
"""Store embedding in cache."""
with self.lock:
current_time = time.time()
# Evict if cache is full
if len(self.cache) >= self.max_size:
self._evict_lru()
self.cache[key] = (embedding.copy(), current_time)
self.access_times[key] = current_time
def _evict_lru(self) -> None:
"""Evict least recently used item."""
if not self.access_times:
return
lru_key = min(self.access_times.items(), key=lambda x: x[1])[0]
del self.cache[lru_key]
del self.access_times[lru_key]
def clear(self) -> None:
"""Clear cache."""
with self.lock:
self.cache.clear()
self.access_times.clear()
def stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
with self.lock:
return {
"size": len(self.cache),
"max_size": self.max_size,
"hit_rate": getattr(self, '_hits', 0) / max(getattr(self, '_requests', 1), 1),
"ttl_seconds": self.ttl_seconds
}
class OptimizedEmbeddingManager:
"""
High-performance embedding manager with vector database capabilities.
"""
# Predefined embedding models
MODELS = {
"mini": EmbeddingModel("all-MiniLM-L6-v2", "all-MiniLM-L6-v2", 384),
"base": EmbeddingModel("all-MiniLM-L12-v2", "all-MiniLM-L12-v2", 384),
"large": EmbeddingModel("all-mpnet-base-v2", "all-mpnet-base-v2", 768),
"multilingual": EmbeddingModel("paraphrase-multilingual-MiniLM-L12-v2",
"paraphrase-multilingual-MiniLM-L12-v2", 384)
}
def __init__(self,
context_db: ContextDatabase,
model_name: str = "mini",
vector_db_path: Optional[str] = None,
cache_size: int = 5000,
batch_size: int = 32):
self.context_db = context_db
self.model_config = self.MODELS.get(model_name, self.MODELS["mini"])
self.model = None # Lazy loading
self.vector_cache = VectorCache(cache_size)
self.batch_size = batch_size
# Vector database setup
self.vector_db_path = vector_db_path or "hcfs_vectors.db"
self._init_vector_db()
# TF-IDF for hybrid search
self.tfidf_vectorizer = TfidfVectorizer(
stop_words='english',
max_features=5000,
ngram_range=(1, 2),
min_df=2
)
self._tfidf_fitted = False
self._model_lock = threading.RLock()
logger.info(f"Initialized OptimizedEmbeddingManager with model: {self.model_config.name}")
def _get_model(self) -> SentenceTransformer:
"""Lazy load the embedding model."""
if self.model is None:
with self._model_lock:
if self.model is None:
logger.info(f"Loading embedding model: {self.model_config.model_path}")
self.model = SentenceTransformer(self.model_config.model_path)
return self.model
def _init_vector_db(self) -> None:
"""Initialize SQLite vector database for fast similarity search."""
conn = sqlite3.connect(self.vector_db_path)
cursor = conn.cursor()
# Create vectors table
cursor.execute('''
CREATE TABLE IF NOT EXISTS context_vectors (
context_id INTEGER PRIMARY KEY,
model_name TEXT NOT NULL,
embedding_dimension INTEGER NOT NULL,
vector_data BLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Create index for fast lookups
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_context_vectors_model
ON context_vectors(model_name, context_id)
''')
conn.commit()
conn.close()
logger.info(f"Vector database initialized: {self.vector_db_path}")
@contextmanager
def _get_vector_db(self):
"""Get vector database connection with proper cleanup."""
conn = sqlite3.connect(self.vector_db_path)
try:
yield conn
finally:
conn.close()
def generate_embedding(self, text: str, use_cache: bool = True) -> np.ndarray:
"""Generate embedding for text with caching."""
cache_key = f"{self.model_config.name}:{hash(text)}"
if use_cache:
cached = self.vector_cache.get(cache_key)
if cached is not None:
return cached
model = self._get_model()
embedding = model.encode(
text,
normalize_embeddings=self.model_config.normalize,
show_progress_bar=False
)
if use_cache:
self.vector_cache.put(cache_key, embedding)
return embedding
def generate_embeddings_batch(self, texts: List[str], use_cache: bool = True) -> List[np.ndarray]:
"""Generate embeddings for multiple texts efficiently."""
if not texts:
return []
# Check cache first
cache_results = []
uncached_indices = []
uncached_texts = []
if use_cache:
for i, text in enumerate(texts):
cache_key = f"{self.model_config.name}:{hash(text)}"
cached = self.vector_cache.get(cache_key)
if cached is not None:
cache_results.append((i, cached))
else:
uncached_indices.append(i)
uncached_texts.append(text)
else:
uncached_indices = list(range(len(texts)))
uncached_texts = texts
# Generate embeddings for uncached texts
embeddings = [None] * len(texts)
# Place cached results
for i, embedding in cache_results:
embeddings[i] = embedding
if uncached_texts:
model = self._get_model()
# Process in batches
for batch_start in range(0, len(uncached_texts), self.batch_size):
batch_end = min(batch_start + self.batch_size, len(uncached_texts))
batch_texts = uncached_texts[batch_start:batch_end]
batch_indices = uncached_indices[batch_start:batch_end]
batch_embeddings = model.encode(
batch_texts,
normalize_embeddings=self.model_config.normalize,
show_progress_bar=False,
batch_size=self.batch_size
)
# Store results and cache
for i, (orig_idx, embedding) in enumerate(zip(batch_indices, batch_embeddings)):
embeddings[orig_idx] = embedding
if use_cache:
cache_key = f"{self.model_config.name}:{hash(batch_texts[i])}"
self.vector_cache.put(cache_key, embedding)
return embeddings
def store_embedding(self, context_id: int, embedding: np.ndarray) -> None:
"""Store embedding in vector database."""
with self._get_vector_db() as conn:
cursor = conn.cursor()
# Convert to bytes for storage
vector_bytes = embedding.astype(np.float32).tobytes()
cursor.execute('''
INSERT OR REPLACE INTO context_vectors
(context_id, model_name, embedding_dimension, vector_data, updated_at)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
''', (context_id, self.model_config.name, embedding.shape[0], vector_bytes))
conn.commit()
def store_embeddings_batch(self, context_embeddings: List[Tuple[int, np.ndarray]]) -> None:
"""Store multiple embeddings efficiently."""
if not context_embeddings:
return
with self._get_vector_db() as conn:
cursor = conn.cursor()
data = [
(context_id, self.model_config.name, embedding.shape[0],
embedding.astype(np.float32).tobytes())
for context_id, embedding in context_embeddings
]
cursor.executemany('''
INSERT OR REPLACE INTO context_vectors
(context_id, model_name, embedding_dimension, vector_data, updated_at)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
''', data)
conn.commit()
logger.info(f"Stored {len(context_embeddings)} embeddings in batch")
def get_embedding(self, context_id: int) -> Optional[np.ndarray]:
"""Retrieve embedding for a context."""
with self._get_vector_db() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT vector_data, embedding_dimension FROM context_vectors
WHERE context_id = ? AND model_name = ?
''', (context_id, self.model_config.name))
result = cursor.fetchone()
if result:
vector_bytes, dimension = result
return np.frombuffer(vector_bytes, dtype=np.float32).reshape(dimension)
return None
def vector_similarity_search(self,
query_embedding: np.ndarray,
context_ids: Optional[List[int]] = None,
top_k: int = 10,
min_similarity: float = 0.0) -> List[VectorSearchResult]:
"""Efficient vector similarity search."""
with self._get_vector_db() as conn:
cursor = conn.cursor()
# Build query
if context_ids:
placeholders = ','.join(['?'] * len(context_ids))
query = f'''
SELECT context_id, vector_data, embedding_dimension
FROM context_vectors
WHERE model_name = ? AND context_id IN ({placeholders})
'''
params = [self.model_config.name] + context_ids
else:
query = '''
SELECT context_id, vector_data, embedding_dimension
FROM context_vectors
WHERE model_name = ?
'''
params = [self.model_config.name]
cursor.execute(query, params)
results = cursor.fetchall()
if not results:
return []
# Calculate similarities
similarities = []
query_embedding = query_embedding.reshape(1, -1)
for context_id, vector_bytes, dimension in results:
stored_embedding = np.frombuffer(vector_bytes, dtype=np.float32).reshape(1, dimension)
similarity = cosine_similarity(query_embedding, stored_embedding)[0][0]
if similarity >= min_similarity:
similarities.append(VectorSearchResult(
context_id=context_id,
score=float(similarity)
))
# Sort by similarity and return top_k
similarities.sort(key=lambda x: x.score, reverse=True)
return similarities[:top_k]
def semantic_search_optimized(self,
query: str,
path_prefix: str = None,
top_k: int = 5,
include_contexts: bool = True) -> List[VectorSearchResult]:
"""High-performance semantic search."""
# Generate query embedding
query_embedding = self.generate_embedding(query)
# Get relevant context IDs based on path filter
context_ids = None
if path_prefix:
with self.context_db.get_session() as session:
from .context_db import ContextBlob
blobs = session.query(ContextBlob.id).filter(
ContextBlob.path.startswith(path_prefix)
).all()
context_ids = [blob.id for blob in blobs]
if not context_ids:
return []
# Perform vector search
results = self.vector_similarity_search(
query_embedding,
context_ids=context_ids,
top_k=top_k
)
# Populate with context data if requested
if include_contexts and results:
context_map = {}
with self.context_db.get_session() as session:
from .context_db import ContextBlob
result_ids = [r.context_id for r in results]
blobs = session.query(ContextBlob).filter(
ContextBlob.id.in_(result_ids)
).all()
for blob in blobs:
context_map[blob.id] = Context(
id=blob.id,
path=blob.path,
content=blob.content,
summary=blob.summary,
author=blob.author,
created_at=blob.created_at,
updated_at=blob.updated_at,
version=blob.version
)
# Add contexts to results
for result in results:
result.context = context_map.get(result.context_id)
return results
def hybrid_search_optimized(self,
query: str,
path_prefix: str = None,
top_k: int = 5,
semantic_weight: float = 0.7,
rerank_top_n: int = 50) -> List[VectorSearchResult]:
"""Optimized hybrid search with two-stage ranking."""
# Stage 1: Fast semantic search to get candidate set
semantic_results = self.semantic_search_optimized(
query, path_prefix, rerank_top_n, include_contexts=True
)
if not semantic_results or len(semantic_results) < 2:
return semantic_results[:top_k]
# Stage 2: Re-rank with BM25 scores
contexts = [r.context for r in semantic_results if r.context]
if not contexts:
return semantic_results[:top_k]
documents = [ctx.content for ctx in contexts]
# Compute BM25 scores
try:
if not self._tfidf_fitted:
self.tfidf_vectorizer.fit(documents)
self._tfidf_fitted = True
doc_vectors = self.tfidf_vectorizer.transform(documents)
query_vector = self.tfidf_vectorizer.transform([query])
bm25_scores = cosine_similarity(query_vector, doc_vectors)[0]
except Exception as e:
logger.warning(f"BM25 scoring failed: {e}, using semantic only")
return semantic_results[:top_k]
# Combine scores
for i, result in enumerate(semantic_results[:len(bm25_scores)]):
semantic_score = result.score
bm25_score = bm25_scores[i]
combined_score = (semantic_weight * semantic_score +
(1 - semantic_weight) * bm25_score)
result.score = float(combined_score)
result.metadata = {
"semantic_score": float(semantic_score),
"bm25_score": float(bm25_score),
"semantic_weight": semantic_weight
}
# Re-sort by combined score
semantic_results.sort(key=lambda x: x.score, reverse=True)
return semantic_results[:top_k]
def build_embeddings_index(self, batch_size: int = 100) -> Dict[str, Any]:
"""Build embeddings for all contexts without embeddings."""
start_time = time.time()
# Get contexts without embeddings
with self.context_db.get_session() as session:
from .context_db import ContextBlob
# Find contexts missing embeddings
with self._get_vector_db() as vector_conn:
vector_cursor = vector_conn.cursor()
vector_cursor.execute('''
SELECT context_id FROM context_vectors
WHERE model_name = ?
''', (self.model_config.name,))
existing_ids = {row[0] for row in vector_cursor.fetchall()}
# Get contexts that need embeddings
all_blobs = session.query(ContextBlob).all()
missing_blobs = [blob for blob in all_blobs if blob.id not in existing_ids]
if not missing_blobs:
return {
"total_processed": 0,
"processing_time": 0,
"embeddings_per_second": 0,
"message": "All contexts already have embeddings"
}
logger.info(f"Building embeddings for {len(missing_blobs)} contexts")
# Process in batches
total_processed = 0
for batch_start in range(0, len(missing_blobs), batch_size):
batch_end = min(batch_start + batch_size, len(missing_blobs))
batch_blobs = missing_blobs[batch_start:batch_end]
# Generate embeddings for batch
texts = [blob.content for blob in batch_blobs]
embeddings = self.generate_embeddings_batch(texts, use_cache=False)
# Store embeddings
context_embeddings = [
(blob.id, embedding)
for blob, embedding in zip(batch_blobs, embeddings)
]
self.store_embeddings_batch(context_embeddings)
total_processed += len(batch_blobs)
logger.info(f"Processed {total_processed}/{len(missing_blobs)} contexts")
processing_time = time.time() - start_time
embeddings_per_second = total_processed / processing_time if processing_time > 0 else 0
return {
"total_processed": total_processed,
"processing_time": processing_time,
"embeddings_per_second": embeddings_per_second,
"model_used": self.model_config.name,
"embedding_dimension": self.model_config.dimension
}
def get_statistics(self) -> Dict[str, Any]:
"""Get embedding manager statistics."""
with self._get_vector_db() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT
COUNT(*) as total_embeddings,
COUNT(DISTINCT model_name) as unique_models,
AVG(embedding_dimension) as avg_dimension
FROM context_vectors
''')
db_stats = cursor.fetchone()
cursor.execute('''
SELECT model_name, COUNT(*) as count
FROM context_vectors
GROUP BY model_name
''')
model_counts = dict(cursor.fetchall())
return {
"database_stats": {
"total_embeddings": db_stats[0] if db_stats else 0,
"unique_models": db_stats[1] if db_stats else 0,
"average_dimension": db_stats[2] if db_stats else 0,
"model_counts": model_counts
},
"cache_stats": self.vector_cache.stats(),
"current_model": asdict(self.model_config),
"vector_db_path": self.vector_db_path,
"batch_size": self.batch_size
}
def cleanup_old_embeddings(self, days_old: int = 30) -> int:
"""Remove old unused embeddings."""
with self._get_vector_db() as conn:
cursor = conn.cursor()
cursor.execute('''
DELETE FROM context_vectors
WHERE updated_at < datetime('now', '-{} days')
AND context_id NOT IN (
SELECT id FROM context_blobs
)
'''.format(days_old))
deleted_count = cursor.rowcount
conn.commit()
logger.info(f"Cleaned up {deleted_count} old embeddings")
return deleted_count