Files
HCFS/hcfs-python/hcfs/core/embeddings.py
2025-07-30 09:34:16 +10:00

188 lines
7.4 KiB
Python

"""
Embedding Manager - Generate and manage context embeddings.
"""
import json
import numpy as np
from typing import List, Dict, Optional, Tuple
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from .context_db import Context, ContextDatabase
class EmbeddingManager:
"""
Manages embeddings for context blobs and semantic similarity search.
"""
def __init__(self, context_db: ContextDatabase, model_name: str = "all-MiniLM-L6-v2"):
self.context_db = context_db
self.model_name = model_name
self.model = SentenceTransformer(model_name)
self.tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=1000)
self._tfidf_fitted = False
def generate_embedding(self, text: str) -> np.ndarray:
"""Generate embedding for a text."""
return self.model.encode(text, normalize_embeddings=True)
def store_context_with_embedding(self, context: Context) -> int:
"""Store context and generate its embedding."""
# Generate embedding
embedding = self.generate_embedding(context.content)
# Store in database
context_id = self.context_db.store_context(context)
# Update with embedding (you'd extend ContextBlob model for this)
self._store_embedding(context_id, embedding)
return context_id
def _store_embedding(self, context_id: int, embedding: np.ndarray) -> None:
"""Store embedding vector in database."""
embedding_json = json.dumps(embedding.tolist())
with self.context_db.get_session() as session:
from .context_db import ContextBlob
blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
if blob:
blob.embedding_model = self.model_name
blob.embedding_vector = embedding_json
session.commit()
def semantic_search(self, query: str, path_prefix: str = None, top_k: int = 5) -> List[Tuple[Context, float]]:
"""
Perform semantic search for contexts similar to query.
Args:
query: Search query text
path_prefix: Optional path prefix to limit search scope
top_k: Number of results to return
Returns:
List of (Context, similarity_score) tuples
"""
query_embedding = self.generate_embedding(query)
with self.context_db.get_session() as session:
from .context_db import ContextBlob
query_filter = session.query(ContextBlob).filter(
ContextBlob.embedding_vector.isnot(None)
)
if path_prefix:
query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix))
blobs = query_filter.all()
if not blobs:
return []
# Calculate similarities
similarities = []
for blob in blobs:
if blob.embedding_vector:
stored_embedding = np.array(json.loads(blob.embedding_vector))
similarity = cosine_similarity(
query_embedding.reshape(1, -1),
stored_embedding.reshape(1, -1)
)[0][0]
context = Context(
id=blob.id,
path=blob.path,
content=blob.content,
summary=blob.summary,
author=blob.author,
created_at=blob.created_at,
updated_at=blob.updated_at,
version=blob.version
)
similarities.append((context, float(similarity)))
# Sort by similarity and return top_k
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
def hybrid_search(self, query: str, path_prefix: str = None, top_k: int = 5,
semantic_weight: float = 0.7) -> List[Tuple[Context, float]]:
"""
Hybrid search combining semantic similarity and BM25.
Args:
query: Search query
path_prefix: Optional path filter
top_k: Number of results
semantic_weight: Weight for semantic vs BM25 (0.0-1.0)
"""
# Get contexts for BM25
with self.context_db.get_session() as session:
from .context_db import ContextBlob
query_filter = session.query(ContextBlob)
if path_prefix:
query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix))
blobs = query_filter.all()
if not blobs:
return []
# Prepare documents for BM25
documents = [blob.content for blob in blobs]
# Fit TF-IDF if not already fitted or refitting needed
if not self._tfidf_fitted or len(documents) > 100: # Refit periodically
self.tfidf_vectorizer.fit(documents)
self._tfidf_fitted = True
# BM25 scoring (using TF-IDF as approximation)
doc_vectors = self.tfidf_vectorizer.transform(documents)
query_vector = self.tfidf_vectorizer.transform([query])
bm25_scores = cosine_similarity(query_vector, doc_vectors)[0]
# Semantic scoring
semantic_results = self.semantic_search(query, path_prefix, len(blobs))
semantic_scores = {ctx.id: score for ctx, score in semantic_results}
# Combine scores
combined_results = []
for i, blob in enumerate(blobs):
bm25_score = bm25_scores[i]
semantic_score = semantic_scores.get(blob.id, 0.0)
combined_score = (semantic_weight * semantic_score +
(1 - semantic_weight) * bm25_score)
context = Context(
id=blob.id,
path=blob.path,
content=blob.content,
summary=blob.summary,
author=blob.author,
created_at=blob.created_at,
updated_at=blob.updated_at,
version=blob.version
)
combined_results.append((context, combined_score))
# Sort and return top results
combined_results.sort(key=lambda x: x[1], reverse=True)
return combined_results[:top_k]
def get_similar_contexts(self, context_id: int, top_k: int = 5) -> List[Tuple[Context, float]]:
"""Find contexts similar to a given context."""
with self.context_db.get_session() as session:
from .context_db import ContextBlob
reference_blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
if not reference_blob or not reference_blob.content:
return []
return self.semantic_search(reference_blob.content, top_k=top_k)