188 lines
7.4 KiB
Python
188 lines
7.4 KiB
Python
"""
|
|
Embedding Manager - Generate and manage context embeddings.
|
|
"""
|
|
|
|
import json
|
|
import numpy as np
|
|
from typing import List, Dict, Optional, Tuple
|
|
from sentence_transformers import SentenceTransformer
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
from .context_db import Context, ContextDatabase
|
|
|
|
|
|
class EmbeddingManager:
|
|
"""
|
|
Manages embeddings for context blobs and semantic similarity search.
|
|
"""
|
|
|
|
def __init__(self, context_db: ContextDatabase, model_name: str = "all-MiniLM-L6-v2"):
|
|
self.context_db = context_db
|
|
self.model_name = model_name
|
|
self.model = SentenceTransformer(model_name)
|
|
self.tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=1000)
|
|
self._tfidf_fitted = False
|
|
|
|
def generate_embedding(self, text: str) -> np.ndarray:
|
|
"""Generate embedding for a text."""
|
|
return self.model.encode(text, normalize_embeddings=True)
|
|
|
|
def store_context_with_embedding(self, context: Context) -> int:
|
|
"""Store context and generate its embedding."""
|
|
# Generate embedding
|
|
embedding = self.generate_embedding(context.content)
|
|
|
|
# Store in database
|
|
context_id = self.context_db.store_context(context)
|
|
|
|
# Update with embedding (you'd extend ContextBlob model for this)
|
|
self._store_embedding(context_id, embedding)
|
|
|
|
return context_id
|
|
|
|
def _store_embedding(self, context_id: int, embedding: np.ndarray) -> None:
|
|
"""Store embedding vector in database."""
|
|
embedding_json = json.dumps(embedding.tolist())
|
|
|
|
with self.context_db.get_session() as session:
|
|
from .context_db import ContextBlob
|
|
blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
|
|
if blob:
|
|
blob.embedding_model = self.model_name
|
|
blob.embedding_vector = embedding_json
|
|
session.commit()
|
|
|
|
def semantic_search(self, query: str, path_prefix: str = None, top_k: int = 5) -> List[Tuple[Context, float]]:
|
|
"""
|
|
Perform semantic search for contexts similar to query.
|
|
|
|
Args:
|
|
query: Search query text
|
|
path_prefix: Optional path prefix to limit search scope
|
|
top_k: Number of results to return
|
|
|
|
Returns:
|
|
List of (Context, similarity_score) tuples
|
|
"""
|
|
query_embedding = self.generate_embedding(query)
|
|
|
|
with self.context_db.get_session() as session:
|
|
from .context_db import ContextBlob
|
|
|
|
query_filter = session.query(ContextBlob).filter(
|
|
ContextBlob.embedding_vector.isnot(None)
|
|
)
|
|
|
|
if path_prefix:
|
|
query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix))
|
|
|
|
blobs = query_filter.all()
|
|
|
|
if not blobs:
|
|
return []
|
|
|
|
# Calculate similarities
|
|
similarities = []
|
|
for blob in blobs:
|
|
if blob.embedding_vector:
|
|
stored_embedding = np.array(json.loads(blob.embedding_vector))
|
|
similarity = cosine_similarity(
|
|
query_embedding.reshape(1, -1),
|
|
stored_embedding.reshape(1, -1)
|
|
)[0][0]
|
|
|
|
context = Context(
|
|
id=blob.id,
|
|
path=blob.path,
|
|
content=blob.content,
|
|
summary=blob.summary,
|
|
author=blob.author,
|
|
created_at=blob.created_at,
|
|
updated_at=blob.updated_at,
|
|
version=blob.version
|
|
)
|
|
|
|
similarities.append((context, float(similarity)))
|
|
|
|
# Sort by similarity and return top_k
|
|
similarities.sort(key=lambda x: x[1], reverse=True)
|
|
return similarities[:top_k]
|
|
|
|
def hybrid_search(self, query: str, path_prefix: str = None, top_k: int = 5,
|
|
semantic_weight: float = 0.7) -> List[Tuple[Context, float]]:
|
|
"""
|
|
Hybrid search combining semantic similarity and BM25.
|
|
|
|
Args:
|
|
query: Search query
|
|
path_prefix: Optional path filter
|
|
top_k: Number of results
|
|
semantic_weight: Weight for semantic vs BM25 (0.0-1.0)
|
|
"""
|
|
# Get contexts for BM25
|
|
with self.context_db.get_session() as session:
|
|
from .context_db import ContextBlob
|
|
|
|
query_filter = session.query(ContextBlob)
|
|
if path_prefix:
|
|
query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix))
|
|
|
|
blobs = query_filter.all()
|
|
|
|
if not blobs:
|
|
return []
|
|
|
|
# Prepare documents for BM25
|
|
documents = [blob.content for blob in blobs]
|
|
|
|
# Fit TF-IDF if not already fitted or refitting needed
|
|
if not self._tfidf_fitted or len(documents) > 100: # Refit periodically
|
|
self.tfidf_vectorizer.fit(documents)
|
|
self._tfidf_fitted = True
|
|
|
|
# BM25 scoring (using TF-IDF as approximation)
|
|
doc_vectors = self.tfidf_vectorizer.transform(documents)
|
|
query_vector = self.tfidf_vectorizer.transform([query])
|
|
bm25_scores = cosine_similarity(query_vector, doc_vectors)[0]
|
|
|
|
# Semantic scoring
|
|
semantic_results = self.semantic_search(query, path_prefix, len(blobs))
|
|
semantic_scores = {ctx.id: score for ctx, score in semantic_results}
|
|
|
|
# Combine scores
|
|
combined_results = []
|
|
for i, blob in enumerate(blobs):
|
|
bm25_score = bm25_scores[i]
|
|
semantic_score = semantic_scores.get(blob.id, 0.0)
|
|
|
|
combined_score = (semantic_weight * semantic_score +
|
|
(1 - semantic_weight) * bm25_score)
|
|
|
|
context = Context(
|
|
id=blob.id,
|
|
path=blob.path,
|
|
content=blob.content,
|
|
summary=blob.summary,
|
|
author=blob.author,
|
|
created_at=blob.created_at,
|
|
updated_at=blob.updated_at,
|
|
version=blob.version
|
|
)
|
|
|
|
combined_results.append((context, combined_score))
|
|
|
|
# Sort and return top results
|
|
combined_results.sort(key=lambda x: x[1], reverse=True)
|
|
return combined_results[:top_k]
|
|
|
|
def get_similar_contexts(self, context_id: int, top_k: int = 5) -> List[Tuple[Context, float]]:
|
|
"""Find contexts similar to a given context."""
|
|
with self.context_db.get_session() as session:
|
|
from .context_db import ContextBlob
|
|
reference_blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
|
|
|
|
if not reference_blob or not reference_blob.content:
|
|
return []
|
|
|
|
return self.semantic_search(reference_blob.content, top_k=top_k) |