464 lines
16 KiB
Python
464 lines
16 KiB
Python
"""
|
|
Test suite for Context Database functionality.
|
|
|
|
Tests covering:
|
|
- Basic CRUD operations
|
|
- Context versioning
|
|
- Database integrity
|
|
- Performance characteristics
|
|
- Error handling
|
|
"""
|
|
|
|
import pytest
|
|
import tempfile
|
|
import shutil
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import sqlite3
|
|
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from hcfs.core.context_db import Context, ContextDatabase
|
|
from hcfs.core.context_db_optimized_fixed import OptimizedContextDatabase
|
|
from hcfs.core.context_versioning import VersioningSystem
|
|
|
|
|
|
class TestContextDatabase:
|
|
"""Test basic context database operations."""
|
|
|
|
@pytest.fixture
|
|
def temp_db(self):
|
|
"""Create temporary database for testing."""
|
|
temp_dir = Path(tempfile.mkdtemp())
|
|
db_path = temp_dir / "test.db"
|
|
db = ContextDatabase(str(db_path))
|
|
yield db
|
|
shutil.rmtree(temp_dir)
|
|
|
|
@pytest.fixture
|
|
def sample_context(self):
|
|
"""Create sample context for testing."""
|
|
return Context(
|
|
id=None,
|
|
path="/test/path",
|
|
content="Test content for context",
|
|
summary="Test summary",
|
|
author="test_user",
|
|
version=1
|
|
)
|
|
|
|
def test_store_context(self, temp_db, sample_context):
|
|
"""Test storing a context."""
|
|
context_id = temp_db.store_context(sample_context)
|
|
assert context_id is not None
|
|
assert isinstance(context_id, int)
|
|
assert context_id > 0
|
|
|
|
def test_get_context(self, temp_db, sample_context):
|
|
"""Test retrieving a context."""
|
|
context_id = temp_db.store_context(sample_context)
|
|
retrieved = temp_db.get_context(context_id)
|
|
|
|
assert retrieved is not None
|
|
assert retrieved.path == sample_context.path
|
|
assert retrieved.content == sample_context.content
|
|
assert retrieved.summary == sample_context.summary
|
|
assert retrieved.author == sample_context.author
|
|
|
|
def test_get_contexts_by_path(self, temp_db):
|
|
"""Test path-based context retrieval."""
|
|
contexts = [
|
|
Context(None, "/test/path1", "Content 1", "Summary 1", "user1", 1),
|
|
Context(None, "/test/path2", "Content 2", "Summary 2", "user2", 1),
|
|
Context(None, "/other/path", "Content 3", "Summary 3", "user3", 1)
|
|
]
|
|
|
|
for context in contexts:
|
|
temp_db.store_context(context)
|
|
|
|
test_contexts = temp_db.get_contexts_by_path("/test")
|
|
assert len(test_contexts) == 2
|
|
|
|
exact_context = temp_db.get_contexts_by_path("/test/path1", exact_match=True)
|
|
assert len(exact_context) == 1
|
|
|
|
def test_update_context(self, temp_db, sample_context):
|
|
"""Test updating a context."""
|
|
context_id = temp_db.store_context(sample_context)
|
|
|
|
# Update the context
|
|
updated_content = "Updated content"
|
|
temp_db.update_context(context_id, content=updated_content)
|
|
|
|
retrieved = temp_db.get_context(context_id)
|
|
assert retrieved.content == updated_content
|
|
|
|
def test_delete_context(self, temp_db, sample_context):
|
|
"""Test deleting a context."""
|
|
context_id = temp_db.store_context(sample_context)
|
|
|
|
# Verify it exists
|
|
assert temp_db.get_context(context_id) is not None
|
|
|
|
# Delete it
|
|
success = temp_db.delete_context(context_id)
|
|
assert success
|
|
|
|
# Verify it's gone
|
|
assert temp_db.get_context(context_id) is None
|
|
|
|
def test_search_contexts(self, temp_db):
|
|
"""Test context search functionality."""
|
|
contexts = [
|
|
Context(None, "/ml/algorithms", "Machine learning algorithms", "ML summary", "user1", 1),
|
|
Context(None, "/web/api", "RESTful API development", "API summary", "user2", 1),
|
|
Context(None, "/db/optimization", "Database query optimization", "DB summary", "user3", 1)
|
|
]
|
|
|
|
for context in contexts:
|
|
temp_db.store_context(context)
|
|
|
|
# Search by content
|
|
results = temp_db.search_contexts("machine learning")
|
|
assert len(results) == 1
|
|
assert "algorithms" in results[0].path
|
|
|
|
# Search by path
|
|
results = temp_db.search_contexts("api")
|
|
assert len(results) == 1
|
|
assert "web" in results[0].path
|
|
|
|
|
|
class TestOptimizedContextDatabase:
|
|
"""Test optimized context database operations."""
|
|
|
|
@pytest.fixture
|
|
def temp_optimized_db(self):
|
|
"""Create temporary optimized database."""
|
|
temp_dir = Path(tempfile.mkdtemp())
|
|
db_path = temp_dir / "optimized_test.db"
|
|
db = OptimizedContextDatabase(str(db_path))
|
|
yield db
|
|
shutil.rmtree(temp_dir)
|
|
|
|
def test_batch_operations(self, temp_optimized_db):
|
|
"""Test batch context operations."""
|
|
contexts = [
|
|
Context(None, f"/batch/test{i}", f"Content {i}", f"Summary {i}", f"user{i}", 1)
|
|
for i in range(10)
|
|
]
|
|
|
|
# Batch store
|
|
context_ids = temp_optimized_db.store_contexts_batch(contexts)
|
|
assert len(context_ids) == 10
|
|
assert all(isinstance(cid, int) for cid in context_ids)
|
|
|
|
# Batch retrieve
|
|
retrieved = temp_optimized_db.get_contexts_batch(context_ids)
|
|
assert len(retrieved) == 10
|
|
|
|
for i, context in enumerate(retrieved):
|
|
assert context.path == f"/batch/test{i}"
|
|
assert context.content == f"Content {i}"
|
|
|
|
def test_caching_performance(self, temp_optimized_db):
|
|
"""Test caching functionality."""
|
|
context = Context(None, "/cache/test", "Cached content", "Cache summary", "user", 1)
|
|
context_id = temp_optimized_db.store_context(context)
|
|
|
|
# First access (cache miss)
|
|
import time
|
|
start = time.time()
|
|
result1 = temp_optimized_db.get_context(context_id)
|
|
first_time = time.time() - start
|
|
|
|
# Second access (cache hit)
|
|
start = time.time()
|
|
result2 = temp_optimized_db.get_context(context_id)
|
|
second_time = time.time() - start
|
|
|
|
assert result1.content == result2.content
|
|
assert second_time < first_time # Should be faster due to caching
|
|
|
|
def test_connection_pooling(self, temp_optimized_db):
|
|
"""Test database connection pooling."""
|
|
import threading
|
|
import concurrent.futures
|
|
|
|
def worker(worker_id):
|
|
context = Context(
|
|
None, f"/worker/{worker_id}",
|
|
f"Worker {worker_id} content",
|
|
f"Summary {worker_id}",
|
|
f"worker{worker_id}", 1
|
|
)
|
|
return temp_optimized_db.store_context(context)
|
|
|
|
# Test concurrent operations
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
|
futures = [executor.submit(worker, i) for i in range(10)]
|
|
results = [future.result() for future in futures]
|
|
|
|
assert len(results) == 10
|
|
assert all(isinstance(result, int) for result in results)
|
|
assert len(set(results)) == 10 # All IDs should be unique
|
|
|
|
|
|
class TestVersioningSystem:
|
|
"""Test context versioning functionality."""
|
|
|
|
@pytest.fixture
|
|
def temp_versioning_db(self):
|
|
"""Create temporary database with versioning."""
|
|
temp_dir = Path(tempfile.mkdtemp())
|
|
db_path = temp_dir / "versioning_test.db"
|
|
context_db = OptimizedContextDatabase(str(db_path))
|
|
versioning = VersioningSystem(str(db_path))
|
|
yield context_db, versioning
|
|
shutil.rmtree(temp_dir)
|
|
|
|
def test_create_version(self, temp_versioning_db):
|
|
"""Test creating context versions."""
|
|
context_db, versioning = temp_versioning_db
|
|
|
|
# Create initial context
|
|
context = Context(None, "/version/test", "Original content", "Original summary", "user", 1)
|
|
context_id = context_db.store_context(context)
|
|
|
|
# Create version
|
|
version = versioning.create_version(
|
|
context_id, "user", "Initial version", {"tag": "v1.0"}
|
|
)
|
|
|
|
assert version is not None
|
|
assert version.context_id == context_id
|
|
assert version.author == "user"
|
|
assert version.message == "Initial version"
|
|
|
|
def test_version_history(self, temp_versioning_db):
|
|
"""Test version history retrieval."""
|
|
context_db, versioning = temp_versioning_db
|
|
|
|
# Create context with multiple versions
|
|
context = Context(None, "/history/test", "Content v1", "Summary v1", "user", 1)
|
|
context_id = context_db.store_context(context)
|
|
|
|
# Create multiple versions
|
|
for i in range(3):
|
|
versioning.create_version(
|
|
context_id, f"user{i}", f"Version {i+1}", {"iteration": i+1}
|
|
)
|
|
|
|
# Update context
|
|
context_db.update_context(context_id, content=f"Content v{i+2}")
|
|
|
|
# Get history
|
|
history = versioning.get_version_history(context_id)
|
|
assert len(history) == 3
|
|
|
|
# Verify order (newest first)
|
|
for i, version in enumerate(history):
|
|
assert version.message == f"Version {3-i}"
|
|
|
|
def test_rollback_version(self, temp_versioning_db):
|
|
"""Test version rollback functionality."""
|
|
context_db, versioning = temp_versioning_db
|
|
|
|
# Create context
|
|
original_content = "Original content"
|
|
context = Context(None, "/rollback/test", original_content, "Summary", "user", 1)
|
|
context_id = context_db.store_context(context)
|
|
|
|
# Create version before modification
|
|
version1 = versioning.create_version(context_id, "user", "Before changes")
|
|
|
|
# Modify context
|
|
modified_content = "Modified content"
|
|
context_db.update_context(context_id, content=modified_content)
|
|
|
|
# Verify modification
|
|
current = context_db.get_context(context_id)
|
|
assert current.content == modified_content
|
|
|
|
# Rollback
|
|
rollback_version = versioning.rollback_to_version(
|
|
context_id, version1.version_number, "user", "Rolling back changes"
|
|
)
|
|
|
|
assert rollback_version is not None
|
|
|
|
# Verify rollback (content should be back to original)
|
|
rolled_back = context_db.get_context(context_id)
|
|
assert rolled_back.content == original_content
|
|
|
|
def test_version_comparison(self, temp_versioning_db):
|
|
"""Test version comparison."""
|
|
context_db, versioning = temp_versioning_db
|
|
|
|
# Create context with versions
|
|
context = Context(None, "/compare/test", "Content A", "Summary A", "user", 1)
|
|
context_id = context_db.store_context(context)
|
|
|
|
version1 = versioning.create_version(context_id, "user", "Version A")
|
|
|
|
context_db.update_context(context_id, content="Content B", summary="Summary B")
|
|
version2 = versioning.create_version(context_id, "user", "Version B")
|
|
|
|
# Compare versions
|
|
diff = versioning.compare_versions(context_id, version1.version_number, version2.version_number)
|
|
|
|
assert diff is not None
|
|
assert "Content A" in str(diff)
|
|
assert "Content B" in str(diff)
|
|
assert "Summary A" in str(diff)
|
|
assert "Summary B" in str(diff)
|
|
|
|
|
|
class TestDatabaseIntegrity:
|
|
"""Test database integrity and error handling."""
|
|
|
|
@pytest.fixture
|
|
def temp_db(self):
|
|
"""Create temporary database."""
|
|
temp_dir = Path(tempfile.mkdtemp())
|
|
db_path = temp_dir / "integrity_test.db"
|
|
db = OptimizedContextDatabase(str(db_path))
|
|
yield db, db_path
|
|
shutil.rmtree(temp_dir)
|
|
|
|
def test_database_schema(self, temp_db):
|
|
"""Test database schema integrity."""
|
|
db, db_path = temp_db
|
|
|
|
# Connect directly to check schema
|
|
conn = sqlite3.connect(str(db_path))
|
|
cursor = conn.cursor()
|
|
|
|
# Check tables exist
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = [row[0] for row in cursor.fetchall()]
|
|
|
|
assert "context_blobs" in tables
|
|
|
|
# Check context_blobs schema
|
|
cursor.execute("PRAGMA table_info(context_blobs)")
|
|
columns = {row[1]: row[2] for row in cursor.fetchall()} # name: type
|
|
|
|
expected_columns = {
|
|
"id": "INTEGER",
|
|
"path": "TEXT",
|
|
"content": "TEXT",
|
|
"summary": "TEXT",
|
|
"author": "TEXT",
|
|
"created_at": "TIMESTAMP",
|
|
"updated_at": "TIMESTAMP",
|
|
"version": "INTEGER"
|
|
}
|
|
|
|
for col_name, col_type in expected_columns.items():
|
|
assert col_name in columns
|
|
|
|
conn.close()
|
|
|
|
def test_constraint_violations(self, temp_db):
|
|
"""Test handling of constraint violations."""
|
|
db, _ = temp_db
|
|
|
|
# Test invalid context (missing required fields)
|
|
with pytest.raises((ValueError, TypeError, AttributeError)):
|
|
invalid_context = Context(None, "", "", None, None, 0) # Empty required fields
|
|
db.store_context(invalid_context)
|
|
|
|
def test_transaction_rollback(self, temp_db):
|
|
"""Test transaction rollback on errors."""
|
|
db, db_path = temp_db
|
|
|
|
# Create a valid context first
|
|
context = Context(None, "/transaction/test", "Content", "Summary", "user", 1)
|
|
context_id = db.store_context(context)
|
|
|
|
# Verify it exists
|
|
assert db.get_context(context_id) is not None
|
|
|
|
# Now test that failed operations don't affect existing data
|
|
try:
|
|
# This should fail but not corrupt the database
|
|
db.update_context(999999, content="Should fail") # Non-existent ID
|
|
except:
|
|
pass # Expected to fail
|
|
|
|
# Verify original context still exists and is unchanged
|
|
retrieved = db.get_context(context_id)
|
|
assert retrieved is not None
|
|
assert retrieved.content == "Content"
|
|
|
|
def test_concurrent_access(self, temp_db):
|
|
"""Test concurrent database access."""
|
|
db, _ = temp_db
|
|
|
|
import threading
|
|
import time
|
|
|
|
results = []
|
|
errors = []
|
|
|
|
def worker(worker_id):
|
|
try:
|
|
for i in range(5):
|
|
context = Context(
|
|
None, f"/concurrent/{worker_id}/{i}",
|
|
f"Content {worker_id}-{i}",
|
|
f"Summary {worker_id}-{i}",
|
|
f"worker{worker_id}", 1
|
|
)
|
|
context_id = db.store_context(context)
|
|
results.append(context_id)
|
|
time.sleep(0.001) # Small delay to increase contention
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
# Run multiple workers concurrently
|
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
# Check results
|
|
assert len(errors) == 0, f"Concurrent access errors: {errors}"
|
|
assert len(results) == 15 # 3 workers * 5 contexts each
|
|
assert len(set(results)) == 15 # All IDs should be unique
|
|
|
|
|
|
def run_context_db_tests():
|
|
"""Run all context database tests."""
|
|
import subprocess
|
|
import sys
|
|
|
|
try:
|
|
# Run pytest on this module
|
|
result = subprocess.run([
|
|
sys.executable, "-m", "pytest", __file__, "-v", "--tb=short"
|
|
], capture_output=True, text=True, cwd=Path(__file__).parent.parent)
|
|
|
|
print("CONTEXT DATABASE TEST RESULTS")
|
|
print("=" * 50)
|
|
print(result.stdout)
|
|
|
|
if result.stderr:
|
|
print("ERRORS:")
|
|
print(result.stderr)
|
|
|
|
return result.returncode == 0
|
|
|
|
except Exception as e:
|
|
print(f"Failed to run tests: {e}")
|
|
return False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = run_context_db_tests()
|
|
exit(0 if success else 1) |