""" Embedding Manager - Generate and manage context embeddings. """ import json import numpy as np from typing import List, Dict, Optional, Tuple from sentence_transformers import SentenceTransformer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from .context_db import Context, ContextDatabase class EmbeddingManager: """ Manages embeddings for context blobs and semantic similarity search. """ def __init__(self, context_db: ContextDatabase, model_name: str = "all-MiniLM-L6-v2"): self.context_db = context_db self.model_name = model_name self.model = SentenceTransformer(model_name) self.tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=1000) self._tfidf_fitted = False def generate_embedding(self, text: str) -> np.ndarray: """Generate embedding for a text.""" return self.model.encode(text, normalize_embeddings=True) def store_context_with_embedding(self, context: Context) -> int: """Store context and generate its embedding.""" # Generate embedding embedding = self.generate_embedding(context.content) # Store in database context_id = self.context_db.store_context(context) # Update with embedding (you'd extend ContextBlob model for this) self._store_embedding(context_id, embedding) return context_id def _store_embedding(self, context_id: int, embedding: np.ndarray) -> None: """Store embedding vector in database.""" embedding_json = json.dumps(embedding.tolist()) with self.context_db.get_session() as session: from .context_db import ContextBlob blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first() if blob: blob.embedding_model = self.model_name blob.embedding_vector = embedding_json session.commit() def semantic_search(self, query: str, path_prefix: str = None, top_k: int = 5) -> List[Tuple[Context, float]]: """ Perform semantic search for contexts similar to query. Args: query: Search query text path_prefix: Optional path prefix to limit search scope top_k: Number of results to return Returns: List of (Context, similarity_score) tuples """ query_embedding = self.generate_embedding(query) with self.context_db.get_session() as session: from .context_db import ContextBlob query_filter = session.query(ContextBlob).filter( ContextBlob.embedding_vector.isnot(None) ) if path_prefix: query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix)) blobs = query_filter.all() if not blobs: return [] # Calculate similarities similarities = [] for blob in blobs: if blob.embedding_vector: stored_embedding = np.array(json.loads(blob.embedding_vector)) similarity = cosine_similarity( query_embedding.reshape(1, -1), stored_embedding.reshape(1, -1) )[0][0] context = Context( id=blob.id, path=blob.path, content=blob.content, summary=blob.summary, author=blob.author, created_at=blob.created_at, updated_at=blob.updated_at, version=blob.version ) similarities.append((context, float(similarity))) # Sort by similarity and return top_k similarities.sort(key=lambda x: x[1], reverse=True) return similarities[:top_k] def hybrid_search(self, query: str, path_prefix: str = None, top_k: int = 5, semantic_weight: float = 0.7) -> List[Tuple[Context, float]]: """ Hybrid search combining semantic similarity and BM25. Args: query: Search query path_prefix: Optional path filter top_k: Number of results semantic_weight: Weight for semantic vs BM25 (0.0-1.0) """ # Get contexts for BM25 with self.context_db.get_session() as session: from .context_db import ContextBlob query_filter = session.query(ContextBlob) if path_prefix: query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix)) blobs = query_filter.all() if not blobs: return [] # Prepare documents for BM25 documents = [blob.content for blob in blobs] # Fit TF-IDF if not already fitted or refitting needed if not self._tfidf_fitted or len(documents) > 100: # Refit periodically self.tfidf_vectorizer.fit(documents) self._tfidf_fitted = True # BM25 scoring (using TF-IDF as approximation) doc_vectors = self.tfidf_vectorizer.transform(documents) query_vector = self.tfidf_vectorizer.transform([query]) bm25_scores = cosine_similarity(query_vector, doc_vectors)[0] # Semantic scoring semantic_results = self.semantic_search(query, path_prefix, len(blobs)) semantic_scores = {ctx.id: score for ctx, score in semantic_results} # Combine scores combined_results = [] for i, blob in enumerate(blobs): bm25_score = bm25_scores[i] semantic_score = semantic_scores.get(blob.id, 0.0) combined_score = (semantic_weight * semantic_score + (1 - semantic_weight) * bm25_score) context = Context( id=blob.id, path=blob.path, content=blob.content, summary=blob.summary, author=blob.author, created_at=blob.created_at, updated_at=blob.updated_at, version=blob.version ) combined_results.append((context, combined_score)) # Sort and return top results combined_results.sort(key=lambda x: x[1], reverse=True) return combined_results[:top_k] def get_similar_contexts(self, context_id: int, top_k: int = 5) -> List[Tuple[Context, float]]: """Find contexts similar to a given context.""" with self.context_db.get_session() as session: from .context_db import ContextBlob reference_blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first() if not reference_blob or not reference_blob.content: return [] return self.semantic_search(reference_blob.content, top_k=top_k)