Files
HCFS/hcfs-python/tests/test_context_db.py
2025-07-30 09:34:16 +10:00

464 lines
16 KiB
Python

"""
Test suite for Context Database functionality.
Tests covering:
- Basic CRUD operations
- Context versioning
- Database integrity
- Performance characteristics
- Error handling
"""
import pytest
import tempfile
import shutil
from pathlib import Path
from datetime import datetime
import sqlite3
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from hcfs.core.context_db import Context, ContextDatabase
from hcfs.core.context_db_optimized_fixed import OptimizedContextDatabase
from hcfs.core.context_versioning import VersioningSystem
class TestContextDatabase:
"""Test basic context database operations."""
@pytest.fixture
def temp_db(self):
"""Create temporary database for testing."""
temp_dir = Path(tempfile.mkdtemp())
db_path = temp_dir / "test.db"
db = ContextDatabase(str(db_path))
yield db
shutil.rmtree(temp_dir)
@pytest.fixture
def sample_context(self):
"""Create sample context for testing."""
return Context(
id=None,
path="/test/path",
content="Test content for context",
summary="Test summary",
author="test_user",
version=1
)
def test_store_context(self, temp_db, sample_context):
"""Test storing a context."""
context_id = temp_db.store_context(sample_context)
assert context_id is not None
assert isinstance(context_id, int)
assert context_id > 0
def test_get_context(self, temp_db, sample_context):
"""Test retrieving a context."""
context_id = temp_db.store_context(sample_context)
retrieved = temp_db.get_context(context_id)
assert retrieved is not None
assert retrieved.path == sample_context.path
assert retrieved.content == sample_context.content
assert retrieved.summary == sample_context.summary
assert retrieved.author == sample_context.author
def test_get_contexts_by_path(self, temp_db):
"""Test path-based context retrieval."""
contexts = [
Context(None, "/test/path1", "Content 1", "Summary 1", "user1", 1),
Context(None, "/test/path2", "Content 2", "Summary 2", "user2", 1),
Context(None, "/other/path", "Content 3", "Summary 3", "user3", 1)
]
for context in contexts:
temp_db.store_context(context)
test_contexts = temp_db.get_contexts_by_path("/test")
assert len(test_contexts) == 2
exact_context = temp_db.get_contexts_by_path("/test/path1", exact_match=True)
assert len(exact_context) == 1
def test_update_context(self, temp_db, sample_context):
"""Test updating a context."""
context_id = temp_db.store_context(sample_context)
# Update the context
updated_content = "Updated content"
temp_db.update_context(context_id, content=updated_content)
retrieved = temp_db.get_context(context_id)
assert retrieved.content == updated_content
def test_delete_context(self, temp_db, sample_context):
"""Test deleting a context."""
context_id = temp_db.store_context(sample_context)
# Verify it exists
assert temp_db.get_context(context_id) is not None
# Delete it
success = temp_db.delete_context(context_id)
assert success
# Verify it's gone
assert temp_db.get_context(context_id) is None
def test_search_contexts(self, temp_db):
"""Test context search functionality."""
contexts = [
Context(None, "/ml/algorithms", "Machine learning algorithms", "ML summary", "user1", 1),
Context(None, "/web/api", "RESTful API development", "API summary", "user2", 1),
Context(None, "/db/optimization", "Database query optimization", "DB summary", "user3", 1)
]
for context in contexts:
temp_db.store_context(context)
# Search by content
results = temp_db.search_contexts("machine learning")
assert len(results) == 1
assert "algorithms" in results[0].path
# Search by path
results = temp_db.search_contexts("api")
assert len(results) == 1
assert "web" in results[0].path
class TestOptimizedContextDatabase:
"""Test optimized context database operations."""
@pytest.fixture
def temp_optimized_db(self):
"""Create temporary optimized database."""
temp_dir = Path(tempfile.mkdtemp())
db_path = temp_dir / "optimized_test.db"
db = OptimizedContextDatabase(str(db_path))
yield db
shutil.rmtree(temp_dir)
def test_batch_operations(self, temp_optimized_db):
"""Test batch context operations."""
contexts = [
Context(None, f"/batch/test{i}", f"Content {i}", f"Summary {i}", f"user{i}", 1)
for i in range(10)
]
# Batch store
context_ids = temp_optimized_db.store_contexts_batch(contexts)
assert len(context_ids) == 10
assert all(isinstance(cid, int) for cid in context_ids)
# Batch retrieve
retrieved = temp_optimized_db.get_contexts_batch(context_ids)
assert len(retrieved) == 10
for i, context in enumerate(retrieved):
assert context.path == f"/batch/test{i}"
assert context.content == f"Content {i}"
def test_caching_performance(self, temp_optimized_db):
"""Test caching functionality."""
context = Context(None, "/cache/test", "Cached content", "Cache summary", "user", 1)
context_id = temp_optimized_db.store_context(context)
# First access (cache miss)
import time
start = time.time()
result1 = temp_optimized_db.get_context(context_id)
first_time = time.time() - start
# Second access (cache hit)
start = time.time()
result2 = temp_optimized_db.get_context(context_id)
second_time = time.time() - start
assert result1.content == result2.content
assert second_time < first_time # Should be faster due to caching
def test_connection_pooling(self, temp_optimized_db):
"""Test database connection pooling."""
import threading
import concurrent.futures
def worker(worker_id):
context = Context(
None, f"/worker/{worker_id}",
f"Worker {worker_id} content",
f"Summary {worker_id}",
f"worker{worker_id}", 1
)
return temp_optimized_db.store_context(context)
# Test concurrent operations
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(worker, i) for i in range(10)]
results = [future.result() for future in futures]
assert len(results) == 10
assert all(isinstance(result, int) for result in results)
assert len(set(results)) == 10 # All IDs should be unique
class TestVersioningSystem:
"""Test context versioning functionality."""
@pytest.fixture
def temp_versioning_db(self):
"""Create temporary database with versioning."""
temp_dir = Path(tempfile.mkdtemp())
db_path = temp_dir / "versioning_test.db"
context_db = OptimizedContextDatabase(str(db_path))
versioning = VersioningSystem(str(db_path))
yield context_db, versioning
shutil.rmtree(temp_dir)
def test_create_version(self, temp_versioning_db):
"""Test creating context versions."""
context_db, versioning = temp_versioning_db
# Create initial context
context = Context(None, "/version/test", "Original content", "Original summary", "user", 1)
context_id = context_db.store_context(context)
# Create version
version = versioning.create_version(
context_id, "user", "Initial version", {"tag": "v1.0"}
)
assert version is not None
assert version.context_id == context_id
assert version.author == "user"
assert version.message == "Initial version"
def test_version_history(self, temp_versioning_db):
"""Test version history retrieval."""
context_db, versioning = temp_versioning_db
# Create context with multiple versions
context = Context(None, "/history/test", "Content v1", "Summary v1", "user", 1)
context_id = context_db.store_context(context)
# Create multiple versions
for i in range(3):
versioning.create_version(
context_id, f"user{i}", f"Version {i+1}", {"iteration": i+1}
)
# Update context
context_db.update_context(context_id, content=f"Content v{i+2}")
# Get history
history = versioning.get_version_history(context_id)
assert len(history) == 3
# Verify order (newest first)
for i, version in enumerate(history):
assert version.message == f"Version {3-i}"
def test_rollback_version(self, temp_versioning_db):
"""Test version rollback functionality."""
context_db, versioning = temp_versioning_db
# Create context
original_content = "Original content"
context = Context(None, "/rollback/test", original_content, "Summary", "user", 1)
context_id = context_db.store_context(context)
# Create version before modification
version1 = versioning.create_version(context_id, "user", "Before changes")
# Modify context
modified_content = "Modified content"
context_db.update_context(context_id, content=modified_content)
# Verify modification
current = context_db.get_context(context_id)
assert current.content == modified_content
# Rollback
rollback_version = versioning.rollback_to_version(
context_id, version1.version_number, "user", "Rolling back changes"
)
assert rollback_version is not None
# Verify rollback (content should be back to original)
rolled_back = context_db.get_context(context_id)
assert rolled_back.content == original_content
def test_version_comparison(self, temp_versioning_db):
"""Test version comparison."""
context_db, versioning = temp_versioning_db
# Create context with versions
context = Context(None, "/compare/test", "Content A", "Summary A", "user", 1)
context_id = context_db.store_context(context)
version1 = versioning.create_version(context_id, "user", "Version A")
context_db.update_context(context_id, content="Content B", summary="Summary B")
version2 = versioning.create_version(context_id, "user", "Version B")
# Compare versions
diff = versioning.compare_versions(context_id, version1.version_number, version2.version_number)
assert diff is not None
assert "Content A" in str(diff)
assert "Content B" in str(diff)
assert "Summary A" in str(diff)
assert "Summary B" in str(diff)
class TestDatabaseIntegrity:
"""Test database integrity and error handling."""
@pytest.fixture
def temp_db(self):
"""Create temporary database."""
temp_dir = Path(tempfile.mkdtemp())
db_path = temp_dir / "integrity_test.db"
db = OptimizedContextDatabase(str(db_path))
yield db, db_path
shutil.rmtree(temp_dir)
def test_database_schema(self, temp_db):
"""Test database schema integrity."""
db, db_path = temp_db
# Connect directly to check schema
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Check tables exist
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
assert "context_blobs" in tables
# Check context_blobs schema
cursor.execute("PRAGMA table_info(context_blobs)")
columns = {row[1]: row[2] for row in cursor.fetchall()} # name: type
expected_columns = {
"id": "INTEGER",
"path": "TEXT",
"content": "TEXT",
"summary": "TEXT",
"author": "TEXT",
"created_at": "TIMESTAMP",
"updated_at": "TIMESTAMP",
"version": "INTEGER"
}
for col_name, col_type in expected_columns.items():
assert col_name in columns
conn.close()
def test_constraint_violations(self, temp_db):
"""Test handling of constraint violations."""
db, _ = temp_db
# Test invalid context (missing required fields)
with pytest.raises((ValueError, TypeError, AttributeError)):
invalid_context = Context(None, "", "", None, None, 0) # Empty required fields
db.store_context(invalid_context)
def test_transaction_rollback(self, temp_db):
"""Test transaction rollback on errors."""
db, db_path = temp_db
# Create a valid context first
context = Context(None, "/transaction/test", "Content", "Summary", "user", 1)
context_id = db.store_context(context)
# Verify it exists
assert db.get_context(context_id) is not None
# Now test that failed operations don't affect existing data
try:
# This should fail but not corrupt the database
db.update_context(999999, content="Should fail") # Non-existent ID
except:
pass # Expected to fail
# Verify original context still exists and is unchanged
retrieved = db.get_context(context_id)
assert retrieved is not None
assert retrieved.content == "Content"
def test_concurrent_access(self, temp_db):
"""Test concurrent database access."""
db, _ = temp_db
import threading
import time
results = []
errors = []
def worker(worker_id):
try:
for i in range(5):
context = Context(
None, f"/concurrent/{worker_id}/{i}",
f"Content {worker_id}-{i}",
f"Summary {worker_id}-{i}",
f"worker{worker_id}", 1
)
context_id = db.store_context(context)
results.append(context_id)
time.sleep(0.001) # Small delay to increase contention
except Exception as e:
errors.append(e)
# Run multiple workers concurrently
threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# Check results
assert len(errors) == 0, f"Concurrent access errors: {errors}"
assert len(results) == 15 # 3 workers * 5 contexts each
assert len(set(results)) == 15 # All IDs should be unique
def run_context_db_tests():
"""Run all context database tests."""
import subprocess
import sys
try:
# Run pytest on this module
result = subprocess.run([
sys.executable, "-m", "pytest", __file__, "-v", "--tb=short"
], capture_output=True, text=True, cwd=Path(__file__).parent.parent)
print("CONTEXT DATABASE TEST RESULTS")
print("=" * 50)
print(result.stdout)
if result.stderr:
print("ERRORS:")
print(result.stderr)
return result.returncode == 0
except Exception as e:
print(f"Failed to run tests: {e}")
return False
if __name__ == "__main__":
success = run_context_db_tests()
exit(0 if success else 1)