""" Optimized Embedding Manager - High-performance vector operations and storage. This module provides enhanced embedding capabilities including: - Vector database integration with SQLite-Vec - Optimized batch processing and caching - Multiple embedding model support - Efficient similarity search with indexing - Memory-efficient embedding storage """ import json import time import numpy as np import sqlite3 from typing import List, Dict, Optional, Tuple, Union, Any from dataclasses import dataclass, asdict from pathlib import Path from sentence_transformers import SentenceTransformer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import threading from contextlib import contextmanager from functools import lru_cache import logging from .context_db import Context, ContextDatabase logger = logging.getLogger(__name__) @dataclass class EmbeddingModel: """Configuration for embedding models.""" name: str model_path: str dimension: int max_tokens: int = 512 normalize: bool = True @dataclass class VectorSearchResult: """Result from vector search operations.""" context_id: int score: float context: Optional[Context] = None metadata: Dict[str, Any] = None class VectorCache: """High-performance LRU cache for embeddings.""" def __init__(self, max_size: int = 5000, ttl_seconds: int = 3600): self.max_size = max_size self.ttl_seconds = ttl_seconds self.cache: Dict[str, Tuple[np.ndarray, float]] = {} self.access_times: Dict[str, float] = {} self.lock = threading.RLock() def get(self, key: str) -> Optional[np.ndarray]: """Get embedding from cache.""" with self.lock: current_time = time.time() if key in self.cache: embedding, created_time = self.cache[key] # Check TTL if current_time - created_time < self.ttl_seconds: self.access_times[key] = current_time return embedding.copy() else: # Expired del self.cache[key] del self.access_times[key] return None def put(self, key: str, embedding: np.ndarray) -> None: """Store embedding in cache.""" with self.lock: current_time = time.time() # Evict if cache is full if len(self.cache) >= self.max_size: self._evict_lru() self.cache[key] = (embedding.copy(), current_time) self.access_times[key] = current_time def _evict_lru(self) -> None: """Evict least recently used item.""" if not self.access_times: return lru_key = min(self.access_times.items(), key=lambda x: x[1])[0] del self.cache[lru_key] del self.access_times[lru_key] def clear(self) -> None: """Clear cache.""" with self.lock: self.cache.clear() self.access_times.clear() def stats(self) -> Dict[str, Any]: """Get cache statistics.""" with self.lock: return { "size": len(self.cache), "max_size": self.max_size, "hit_rate": getattr(self, '_hits', 0) / max(getattr(self, '_requests', 1), 1), "ttl_seconds": self.ttl_seconds } class OptimizedEmbeddingManager: """ High-performance embedding manager with vector database capabilities. """ # Predefined embedding models MODELS = { "mini": EmbeddingModel("all-MiniLM-L6-v2", "all-MiniLM-L6-v2", 384), "base": EmbeddingModel("all-MiniLM-L12-v2", "all-MiniLM-L12-v2", 384), "large": EmbeddingModel("all-mpnet-base-v2", "all-mpnet-base-v2", 768), "multilingual": EmbeddingModel("paraphrase-multilingual-MiniLM-L12-v2", "paraphrase-multilingual-MiniLM-L12-v2", 384) } def __init__(self, context_db: ContextDatabase, model_name: str = "mini", vector_db_path: Optional[str] = None, cache_size: int = 5000, batch_size: int = 32): self.context_db = context_db self.model_config = self.MODELS.get(model_name, self.MODELS["mini"]) self.model = None # Lazy loading self.vector_cache = VectorCache(cache_size) self.batch_size = batch_size # Vector database setup self.vector_db_path = vector_db_path or "hcfs_vectors.db" self._init_vector_db() # TF-IDF for hybrid search self.tfidf_vectorizer = TfidfVectorizer( stop_words='english', max_features=5000, ngram_range=(1, 2), min_df=2 ) self._tfidf_fitted = False self._model_lock = threading.RLock() logger.info(f"Initialized OptimizedEmbeddingManager with model: {self.model_config.name}") def _get_model(self) -> SentenceTransformer: """Lazy load the embedding model.""" if self.model is None: with self._model_lock: if self.model is None: logger.info(f"Loading embedding model: {self.model_config.model_path}") self.model = SentenceTransformer(self.model_config.model_path) return self.model def _init_vector_db(self) -> None: """Initialize SQLite vector database for fast similarity search.""" conn = sqlite3.connect(self.vector_db_path) cursor = conn.cursor() # Create vectors table cursor.execute(''' CREATE TABLE IF NOT EXISTS context_vectors ( context_id INTEGER PRIMARY KEY, model_name TEXT NOT NULL, embedding_dimension INTEGER NOT NULL, vector_data BLOB NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') # Create index for fast lookups cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_context_vectors_model ON context_vectors(model_name, context_id) ''') conn.commit() conn.close() logger.info(f"Vector database initialized: {self.vector_db_path}") @contextmanager def _get_vector_db(self): """Get vector database connection with proper cleanup.""" conn = sqlite3.connect(self.vector_db_path) try: yield conn finally: conn.close() def generate_embedding(self, text: str, use_cache: bool = True) -> np.ndarray: """Generate embedding for text with caching.""" cache_key = f"{self.model_config.name}:{hash(text)}" if use_cache: cached = self.vector_cache.get(cache_key) if cached is not None: return cached model = self._get_model() embedding = model.encode( text, normalize_embeddings=self.model_config.normalize, show_progress_bar=False ) if use_cache: self.vector_cache.put(cache_key, embedding) return embedding def generate_embeddings_batch(self, texts: List[str], use_cache: bool = True) -> List[np.ndarray]: """Generate embeddings for multiple texts efficiently.""" if not texts: return [] # Check cache first cache_results = [] uncached_indices = [] uncached_texts = [] if use_cache: for i, text in enumerate(texts): cache_key = f"{self.model_config.name}:{hash(text)}" cached = self.vector_cache.get(cache_key) if cached is not None: cache_results.append((i, cached)) else: uncached_indices.append(i) uncached_texts.append(text) else: uncached_indices = list(range(len(texts))) uncached_texts = texts # Generate embeddings for uncached texts embeddings = [None] * len(texts) # Place cached results for i, embedding in cache_results: embeddings[i] = embedding if uncached_texts: model = self._get_model() # Process in batches for batch_start in range(0, len(uncached_texts), self.batch_size): batch_end = min(batch_start + self.batch_size, len(uncached_texts)) batch_texts = uncached_texts[batch_start:batch_end] batch_indices = uncached_indices[batch_start:batch_end] batch_embeddings = model.encode( batch_texts, normalize_embeddings=self.model_config.normalize, show_progress_bar=False, batch_size=self.batch_size ) # Store results and cache for i, (orig_idx, embedding) in enumerate(zip(batch_indices, batch_embeddings)): embeddings[orig_idx] = embedding if use_cache: cache_key = f"{self.model_config.name}:{hash(batch_texts[i])}" self.vector_cache.put(cache_key, embedding) return embeddings def store_embedding(self, context_id: int, embedding: np.ndarray) -> None: """Store embedding in vector database.""" with self._get_vector_db() as conn: cursor = conn.cursor() # Convert to bytes for storage vector_bytes = embedding.astype(np.float32).tobytes() cursor.execute(''' INSERT OR REPLACE INTO context_vectors (context_id, model_name, embedding_dimension, vector_data, updated_at) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP) ''', (context_id, self.model_config.name, embedding.shape[0], vector_bytes)) conn.commit() def store_embeddings_batch(self, context_embeddings: List[Tuple[int, np.ndarray]]) -> None: """Store multiple embeddings efficiently.""" if not context_embeddings: return with self._get_vector_db() as conn: cursor = conn.cursor() data = [ (context_id, self.model_config.name, embedding.shape[0], embedding.astype(np.float32).tobytes()) for context_id, embedding in context_embeddings ] cursor.executemany(''' INSERT OR REPLACE INTO context_vectors (context_id, model_name, embedding_dimension, vector_data, updated_at) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP) ''', data) conn.commit() logger.info(f"Stored {len(context_embeddings)} embeddings in batch") def get_embedding(self, context_id: int) -> Optional[np.ndarray]: """Retrieve embedding for a context.""" with self._get_vector_db() as conn: cursor = conn.cursor() cursor.execute(''' SELECT vector_data, embedding_dimension FROM context_vectors WHERE context_id = ? AND model_name = ? ''', (context_id, self.model_config.name)) result = cursor.fetchone() if result: vector_bytes, dimension = result return np.frombuffer(vector_bytes, dtype=np.float32).reshape(dimension) return None def vector_similarity_search(self, query_embedding: np.ndarray, context_ids: Optional[List[int]] = None, top_k: int = 10, min_similarity: float = 0.0) -> List[VectorSearchResult]: """Efficient vector similarity search.""" with self._get_vector_db() as conn: cursor = conn.cursor() # Build query if context_ids: placeholders = ','.join(['?'] * len(context_ids)) query = f''' SELECT context_id, vector_data, embedding_dimension FROM context_vectors WHERE model_name = ? AND context_id IN ({placeholders}) ''' params = [self.model_config.name] + context_ids else: query = ''' SELECT context_id, vector_data, embedding_dimension FROM context_vectors WHERE model_name = ? ''' params = [self.model_config.name] cursor.execute(query, params) results = cursor.fetchall() if not results: return [] # Calculate similarities similarities = [] query_embedding = query_embedding.reshape(1, -1) for context_id, vector_bytes, dimension in results: stored_embedding = np.frombuffer(vector_bytes, dtype=np.float32).reshape(1, dimension) similarity = cosine_similarity(query_embedding, stored_embedding)[0][0] if similarity >= min_similarity: similarities.append(VectorSearchResult( context_id=context_id, score=float(similarity) )) # Sort by similarity and return top_k similarities.sort(key=lambda x: x.score, reverse=True) return similarities[:top_k] def semantic_search_optimized(self, query: str, path_prefix: str = None, top_k: int = 5, include_contexts: bool = True) -> List[VectorSearchResult]: """High-performance semantic search.""" # Generate query embedding query_embedding = self.generate_embedding(query) # Get relevant context IDs based on path filter context_ids = None if path_prefix: with self.context_db.get_session() as session: from .context_db import ContextBlob blobs = session.query(ContextBlob.id).filter( ContextBlob.path.startswith(path_prefix) ).all() context_ids = [blob.id for blob in blobs] if not context_ids: return [] # Perform vector search results = self.vector_similarity_search( query_embedding, context_ids=context_ids, top_k=top_k ) # Populate with context data if requested if include_contexts and results: context_map = {} with self.context_db.get_session() as session: from .context_db import ContextBlob result_ids = [r.context_id for r in results] blobs = session.query(ContextBlob).filter( ContextBlob.id.in_(result_ids) ).all() for blob in blobs: context_map[blob.id] = Context( id=blob.id, path=blob.path, content=blob.content, summary=blob.summary, author=blob.author, created_at=blob.created_at, updated_at=blob.updated_at, version=blob.version ) # Add contexts to results for result in results: result.context = context_map.get(result.context_id) return results def hybrid_search_optimized(self, query: str, path_prefix: str = None, top_k: int = 5, semantic_weight: float = 0.7, rerank_top_n: int = 50) -> List[VectorSearchResult]: """Optimized hybrid search with two-stage ranking.""" # Stage 1: Fast semantic search to get candidate set semantic_results = self.semantic_search_optimized( query, path_prefix, rerank_top_n, include_contexts=True ) if not semantic_results or len(semantic_results) < 2: return semantic_results[:top_k] # Stage 2: Re-rank with BM25 scores contexts = [r.context for r in semantic_results if r.context] if not contexts: return semantic_results[:top_k] documents = [ctx.content for ctx in contexts] # Compute BM25 scores try: if not self._tfidf_fitted: self.tfidf_vectorizer.fit(documents) self._tfidf_fitted = True doc_vectors = self.tfidf_vectorizer.transform(documents) query_vector = self.tfidf_vectorizer.transform([query]) bm25_scores = cosine_similarity(query_vector, doc_vectors)[0] except Exception as e: logger.warning(f"BM25 scoring failed: {e}, using semantic only") return semantic_results[:top_k] # Combine scores for i, result in enumerate(semantic_results[:len(bm25_scores)]): semantic_score = result.score bm25_score = bm25_scores[i] combined_score = (semantic_weight * semantic_score + (1 - semantic_weight) * bm25_score) result.score = float(combined_score) result.metadata = { "semantic_score": float(semantic_score), "bm25_score": float(bm25_score), "semantic_weight": semantic_weight } # Re-sort by combined score semantic_results.sort(key=lambda x: x.score, reverse=True) return semantic_results[:top_k] def build_embeddings_index(self, batch_size: int = 100) -> Dict[str, Any]: """Build embeddings for all contexts without embeddings.""" start_time = time.time() # Get contexts without embeddings with self.context_db.get_session() as session: from .context_db import ContextBlob # Find contexts missing embeddings with self._get_vector_db() as vector_conn: vector_cursor = vector_conn.cursor() vector_cursor.execute(''' SELECT context_id FROM context_vectors WHERE model_name = ? ''', (self.model_config.name,)) existing_ids = {row[0] for row in vector_cursor.fetchall()} # Get contexts that need embeddings all_blobs = session.query(ContextBlob).all() missing_blobs = [blob for blob in all_blobs if blob.id not in existing_ids] if not missing_blobs: return { "total_processed": 0, "processing_time": 0, "embeddings_per_second": 0, "message": "All contexts already have embeddings" } logger.info(f"Building embeddings for {len(missing_blobs)} contexts") # Process in batches total_processed = 0 for batch_start in range(0, len(missing_blobs), batch_size): batch_end = min(batch_start + batch_size, len(missing_blobs)) batch_blobs = missing_blobs[batch_start:batch_end] # Generate embeddings for batch texts = [blob.content for blob in batch_blobs] embeddings = self.generate_embeddings_batch(texts, use_cache=False) # Store embeddings context_embeddings = [ (blob.id, embedding) for blob, embedding in zip(batch_blobs, embeddings) ] self.store_embeddings_batch(context_embeddings) total_processed += len(batch_blobs) logger.info(f"Processed {total_processed}/{len(missing_blobs)} contexts") processing_time = time.time() - start_time embeddings_per_second = total_processed / processing_time if processing_time > 0 else 0 return { "total_processed": total_processed, "processing_time": processing_time, "embeddings_per_second": embeddings_per_second, "model_used": self.model_config.name, "embedding_dimension": self.model_config.dimension } def get_statistics(self) -> Dict[str, Any]: """Get embedding manager statistics.""" with self._get_vector_db() as conn: cursor = conn.cursor() cursor.execute(''' SELECT COUNT(*) as total_embeddings, COUNT(DISTINCT model_name) as unique_models, AVG(embedding_dimension) as avg_dimension FROM context_vectors ''') db_stats = cursor.fetchone() cursor.execute(''' SELECT model_name, COUNT(*) as count FROM context_vectors GROUP BY model_name ''') model_counts = dict(cursor.fetchall()) return { "database_stats": { "total_embeddings": db_stats[0] if db_stats else 0, "unique_models": db_stats[1] if db_stats else 0, "average_dimension": db_stats[2] if db_stats else 0, "model_counts": model_counts }, "cache_stats": self.vector_cache.stats(), "current_model": asdict(self.model_config), "vector_db_path": self.vector_db_path, "batch_size": self.batch_size } def cleanup_old_embeddings(self, days_old: int = 30) -> int: """Remove old unused embeddings.""" with self._get_vector_db() as conn: cursor = conn.cursor() cursor.execute(''' DELETE FROM context_vectors WHERE updated_at < datetime('now', '-{} days') AND context_id NOT IN ( SELECT id FROM context_blobs ) '''.format(days_old)) deleted_count = cursor.rowcount conn.commit() logger.info(f"Cleaned up {deleted_count} old embeddings") return deleted_count