Phase 2 build initial
This commit is contained in:
188
hcfs-python/hcfs/core/embeddings.py
Normal file
188
hcfs-python/hcfs/core/embeddings.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Embedding Manager - Generate and manage context embeddings.
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
from .context_db import Context, ContextDatabase
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
"""
|
||||
Manages embeddings for context blobs and semantic similarity search.
|
||||
"""
|
||||
|
||||
def __init__(self, context_db: ContextDatabase, model_name: str = "all-MiniLM-L6-v2"):
|
||||
self.context_db = context_db
|
||||
self.model_name = model_name
|
||||
self.model = SentenceTransformer(model_name)
|
||||
self.tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=1000)
|
||||
self._tfidf_fitted = False
|
||||
|
||||
def generate_embedding(self, text: str) -> np.ndarray:
|
||||
"""Generate embedding for a text."""
|
||||
return self.model.encode(text, normalize_embeddings=True)
|
||||
|
||||
def store_context_with_embedding(self, context: Context) -> int:
|
||||
"""Store context and generate its embedding."""
|
||||
# Generate embedding
|
||||
embedding = self.generate_embedding(context.content)
|
||||
|
||||
# Store in database
|
||||
context_id = self.context_db.store_context(context)
|
||||
|
||||
# Update with embedding (you'd extend ContextBlob model for this)
|
||||
self._store_embedding(context_id, embedding)
|
||||
|
||||
return context_id
|
||||
|
||||
def _store_embedding(self, context_id: int, embedding: np.ndarray) -> None:
|
||||
"""Store embedding vector in database."""
|
||||
embedding_json = json.dumps(embedding.tolist())
|
||||
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
|
||||
if blob:
|
||||
blob.embedding_model = self.model_name
|
||||
blob.embedding_vector = embedding_json
|
||||
session.commit()
|
||||
|
||||
def semantic_search(self, query: str, path_prefix: str = None, top_k: int = 5) -> List[Tuple[Context, float]]:
|
||||
"""
|
||||
Perform semantic search for contexts similar to query.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
path_prefix: Optional path prefix to limit search scope
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of (Context, similarity_score) tuples
|
||||
"""
|
||||
query_embedding = self.generate_embedding(query)
|
||||
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
|
||||
query_filter = session.query(ContextBlob).filter(
|
||||
ContextBlob.embedding_vector.isnot(None)
|
||||
)
|
||||
|
||||
if path_prefix:
|
||||
query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix))
|
||||
|
||||
blobs = query_filter.all()
|
||||
|
||||
if not blobs:
|
||||
return []
|
||||
|
||||
# Calculate similarities
|
||||
similarities = []
|
||||
for blob in blobs:
|
||||
if blob.embedding_vector:
|
||||
stored_embedding = np.array(json.loads(blob.embedding_vector))
|
||||
similarity = cosine_similarity(
|
||||
query_embedding.reshape(1, -1),
|
||||
stored_embedding.reshape(1, -1)
|
||||
)[0][0]
|
||||
|
||||
context = Context(
|
||||
id=blob.id,
|
||||
path=blob.path,
|
||||
content=blob.content,
|
||||
summary=blob.summary,
|
||||
author=blob.author,
|
||||
created_at=blob.created_at,
|
||||
updated_at=blob.updated_at,
|
||||
version=blob.version
|
||||
)
|
||||
|
||||
similarities.append((context, float(similarity)))
|
||||
|
||||
# Sort by similarity and return top_k
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
return similarities[:top_k]
|
||||
|
||||
def hybrid_search(self, query: str, path_prefix: str = None, top_k: int = 5,
|
||||
semantic_weight: float = 0.7) -> List[Tuple[Context, float]]:
|
||||
"""
|
||||
Hybrid search combining semantic similarity and BM25.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
path_prefix: Optional path filter
|
||||
top_k: Number of results
|
||||
semantic_weight: Weight for semantic vs BM25 (0.0-1.0)
|
||||
"""
|
||||
# Get contexts for BM25
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
|
||||
query_filter = session.query(ContextBlob)
|
||||
if path_prefix:
|
||||
query_filter = query_filter.filter(ContextBlob.path.startswith(path_prefix))
|
||||
|
||||
blobs = query_filter.all()
|
||||
|
||||
if not blobs:
|
||||
return []
|
||||
|
||||
# Prepare documents for BM25
|
||||
documents = [blob.content for blob in blobs]
|
||||
|
||||
# Fit TF-IDF if not already fitted or refitting needed
|
||||
if not self._tfidf_fitted or len(documents) > 100: # Refit periodically
|
||||
self.tfidf_vectorizer.fit(documents)
|
||||
self._tfidf_fitted = True
|
||||
|
||||
# BM25 scoring (using TF-IDF as approximation)
|
||||
doc_vectors = self.tfidf_vectorizer.transform(documents)
|
||||
query_vector = self.tfidf_vectorizer.transform([query])
|
||||
bm25_scores = cosine_similarity(query_vector, doc_vectors)[0]
|
||||
|
||||
# Semantic scoring
|
||||
semantic_results = self.semantic_search(query, path_prefix, len(blobs))
|
||||
semantic_scores = {ctx.id: score for ctx, score in semantic_results}
|
||||
|
||||
# Combine scores
|
||||
combined_results = []
|
||||
for i, blob in enumerate(blobs):
|
||||
bm25_score = bm25_scores[i]
|
||||
semantic_score = semantic_scores.get(blob.id, 0.0)
|
||||
|
||||
combined_score = (semantic_weight * semantic_score +
|
||||
(1 - semantic_weight) * bm25_score)
|
||||
|
||||
context = Context(
|
||||
id=blob.id,
|
||||
path=blob.path,
|
||||
content=blob.content,
|
||||
summary=blob.summary,
|
||||
author=blob.author,
|
||||
created_at=blob.created_at,
|
||||
updated_at=blob.updated_at,
|
||||
version=blob.version
|
||||
)
|
||||
|
||||
combined_results.append((context, combined_score))
|
||||
|
||||
# Sort and return top results
|
||||
combined_results.sort(key=lambda x: x[1], reverse=True)
|
||||
return combined_results[:top_k]
|
||||
|
||||
def get_similar_contexts(self, context_id: int, top_k: int = 5) -> List[Tuple[Context, float]]:
|
||||
"""Find contexts similar to a given context."""
|
||||
with self.context_db.get_session() as session:
|
||||
from .context_db import ContextBlob
|
||||
reference_blob = session.query(ContextBlob).filter(ContextBlob.id == context_id).first()
|
||||
|
||||
if not reference_blob or not reference_blob.content:
|
||||
return []
|
||||
|
||||
return self.semantic_search(reference_blob.content, top_k=top_k)
|
||||
Reference in New Issue
Block a user