""" Test suite for Context Database functionality. Tests covering: - Basic CRUD operations - Context versioning - Database integrity - Performance characteristics - Error handling """ import pytest import tempfile import shutil from pathlib import Path from datetime import datetime import sqlite3 import sys sys.path.insert(0, str(Path(__file__).parent.parent)) from hcfs.core.context_db import Context, ContextDatabase from hcfs.core.context_db_optimized_fixed import OptimizedContextDatabase from hcfs.core.context_versioning import VersioningSystem class TestContextDatabase: """Test basic context database operations.""" @pytest.fixture def temp_db(self): """Create temporary database for testing.""" temp_dir = Path(tempfile.mkdtemp()) db_path = temp_dir / "test.db" db = ContextDatabase(str(db_path)) yield db shutil.rmtree(temp_dir) @pytest.fixture def sample_context(self): """Create sample context for testing.""" return Context( id=None, path="/test/path", content="Test content for context", summary="Test summary", author="test_user", version=1 ) def test_store_context(self, temp_db, sample_context): """Test storing a context.""" context_id = temp_db.store_context(sample_context) assert context_id is not None assert isinstance(context_id, int) assert context_id > 0 def test_get_context(self, temp_db, sample_context): """Test retrieving a context.""" context_id = temp_db.store_context(sample_context) retrieved = temp_db.get_context(context_id) assert retrieved is not None assert retrieved.path == sample_context.path assert retrieved.content == sample_context.content assert retrieved.summary == sample_context.summary assert retrieved.author == sample_context.author def test_get_contexts_by_path(self, temp_db): """Test path-based context retrieval.""" contexts = [ Context(None, "/test/path1", "Content 1", "Summary 1", "user1", 1), Context(None, "/test/path2", "Content 2", "Summary 2", "user2", 1), Context(None, "/other/path", "Content 3", "Summary 3", "user3", 1) ] for context in contexts: temp_db.store_context(context) test_contexts = temp_db.get_contexts_by_path("/test") assert len(test_contexts) == 2 exact_context = temp_db.get_contexts_by_path("/test/path1", exact_match=True) assert len(exact_context) == 1 def test_update_context(self, temp_db, sample_context): """Test updating a context.""" context_id = temp_db.store_context(sample_context) # Update the context updated_content = "Updated content" temp_db.update_context(context_id, content=updated_content) retrieved = temp_db.get_context(context_id) assert retrieved.content == updated_content def test_delete_context(self, temp_db, sample_context): """Test deleting a context.""" context_id = temp_db.store_context(sample_context) # Verify it exists assert temp_db.get_context(context_id) is not None # Delete it success = temp_db.delete_context(context_id) assert success # Verify it's gone assert temp_db.get_context(context_id) is None def test_search_contexts(self, temp_db): """Test context search functionality.""" contexts = [ Context(None, "/ml/algorithms", "Machine learning algorithms", "ML summary", "user1", 1), Context(None, "/web/api", "RESTful API development", "API summary", "user2", 1), Context(None, "/db/optimization", "Database query optimization", "DB summary", "user3", 1) ] for context in contexts: temp_db.store_context(context) # Search by content results = temp_db.search_contexts("machine learning") assert len(results) == 1 assert "algorithms" in results[0].path # Search by path results = temp_db.search_contexts("api") assert len(results) == 1 assert "web" in results[0].path class TestOptimizedContextDatabase: """Test optimized context database operations.""" @pytest.fixture def temp_optimized_db(self): """Create temporary optimized database.""" temp_dir = Path(tempfile.mkdtemp()) db_path = temp_dir / "optimized_test.db" db = OptimizedContextDatabase(str(db_path)) yield db shutil.rmtree(temp_dir) def test_batch_operations(self, temp_optimized_db): """Test batch context operations.""" contexts = [ Context(None, f"/batch/test{i}", f"Content {i}", f"Summary {i}", f"user{i}", 1) for i in range(10) ] # Batch store context_ids = temp_optimized_db.store_contexts_batch(contexts) assert len(context_ids) == 10 assert all(isinstance(cid, int) for cid in context_ids) # Batch retrieve retrieved = temp_optimized_db.get_contexts_batch(context_ids) assert len(retrieved) == 10 for i, context in enumerate(retrieved): assert context.path == f"/batch/test{i}" assert context.content == f"Content {i}" def test_caching_performance(self, temp_optimized_db): """Test caching functionality.""" context = Context(None, "/cache/test", "Cached content", "Cache summary", "user", 1) context_id = temp_optimized_db.store_context(context) # First access (cache miss) import time start = time.time() result1 = temp_optimized_db.get_context(context_id) first_time = time.time() - start # Second access (cache hit) start = time.time() result2 = temp_optimized_db.get_context(context_id) second_time = time.time() - start assert result1.content == result2.content assert second_time < first_time # Should be faster due to caching def test_connection_pooling(self, temp_optimized_db): """Test database connection pooling.""" import threading import concurrent.futures def worker(worker_id): context = Context( None, f"/worker/{worker_id}", f"Worker {worker_id} content", f"Summary {worker_id}", f"worker{worker_id}", 1 ) return temp_optimized_db.store_context(context) # Test concurrent operations with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: futures = [executor.submit(worker, i) for i in range(10)] results = [future.result() for future in futures] assert len(results) == 10 assert all(isinstance(result, int) for result in results) assert len(set(results)) == 10 # All IDs should be unique class TestVersioningSystem: """Test context versioning functionality.""" @pytest.fixture def temp_versioning_db(self): """Create temporary database with versioning.""" temp_dir = Path(tempfile.mkdtemp()) db_path = temp_dir / "versioning_test.db" context_db = OptimizedContextDatabase(str(db_path)) versioning = VersioningSystem(str(db_path)) yield context_db, versioning shutil.rmtree(temp_dir) def test_create_version(self, temp_versioning_db): """Test creating context versions.""" context_db, versioning = temp_versioning_db # Create initial context context = Context(None, "/version/test", "Original content", "Original summary", "user", 1) context_id = context_db.store_context(context) # Create version version = versioning.create_version( context_id, "user", "Initial version", {"tag": "v1.0"} ) assert version is not None assert version.context_id == context_id assert version.author == "user" assert version.message == "Initial version" def test_version_history(self, temp_versioning_db): """Test version history retrieval.""" context_db, versioning = temp_versioning_db # Create context with multiple versions context = Context(None, "/history/test", "Content v1", "Summary v1", "user", 1) context_id = context_db.store_context(context) # Create multiple versions for i in range(3): versioning.create_version( context_id, f"user{i}", f"Version {i+1}", {"iteration": i+1} ) # Update context context_db.update_context(context_id, content=f"Content v{i+2}") # Get history history = versioning.get_version_history(context_id) assert len(history) == 3 # Verify order (newest first) for i, version in enumerate(history): assert version.message == f"Version {3-i}" def test_rollback_version(self, temp_versioning_db): """Test version rollback functionality.""" context_db, versioning = temp_versioning_db # Create context original_content = "Original content" context = Context(None, "/rollback/test", original_content, "Summary", "user", 1) context_id = context_db.store_context(context) # Create version before modification version1 = versioning.create_version(context_id, "user", "Before changes") # Modify context modified_content = "Modified content" context_db.update_context(context_id, content=modified_content) # Verify modification current = context_db.get_context(context_id) assert current.content == modified_content # Rollback rollback_version = versioning.rollback_to_version( context_id, version1.version_number, "user", "Rolling back changes" ) assert rollback_version is not None # Verify rollback (content should be back to original) rolled_back = context_db.get_context(context_id) assert rolled_back.content == original_content def test_version_comparison(self, temp_versioning_db): """Test version comparison.""" context_db, versioning = temp_versioning_db # Create context with versions context = Context(None, "/compare/test", "Content A", "Summary A", "user", 1) context_id = context_db.store_context(context) version1 = versioning.create_version(context_id, "user", "Version A") context_db.update_context(context_id, content="Content B", summary="Summary B") version2 = versioning.create_version(context_id, "user", "Version B") # Compare versions diff = versioning.compare_versions(context_id, version1.version_number, version2.version_number) assert diff is not None assert "Content A" in str(diff) assert "Content B" in str(diff) assert "Summary A" in str(diff) assert "Summary B" in str(diff) class TestDatabaseIntegrity: """Test database integrity and error handling.""" @pytest.fixture def temp_db(self): """Create temporary database.""" temp_dir = Path(tempfile.mkdtemp()) db_path = temp_dir / "integrity_test.db" db = OptimizedContextDatabase(str(db_path)) yield db, db_path shutil.rmtree(temp_dir) def test_database_schema(self, temp_db): """Test database schema integrity.""" db, db_path = temp_db # Connect directly to check schema conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() # Check tables exist cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = [row[0] for row in cursor.fetchall()] assert "context_blobs" in tables # Check context_blobs schema cursor.execute("PRAGMA table_info(context_blobs)") columns = {row[1]: row[2] for row in cursor.fetchall()} # name: type expected_columns = { "id": "INTEGER", "path": "TEXT", "content": "TEXT", "summary": "TEXT", "author": "TEXT", "created_at": "TIMESTAMP", "updated_at": "TIMESTAMP", "version": "INTEGER" } for col_name, col_type in expected_columns.items(): assert col_name in columns conn.close() def test_constraint_violations(self, temp_db): """Test handling of constraint violations.""" db, _ = temp_db # Test invalid context (missing required fields) with pytest.raises((ValueError, TypeError, AttributeError)): invalid_context = Context(None, "", "", None, None, 0) # Empty required fields db.store_context(invalid_context) def test_transaction_rollback(self, temp_db): """Test transaction rollback on errors.""" db, db_path = temp_db # Create a valid context first context = Context(None, "/transaction/test", "Content", "Summary", "user", 1) context_id = db.store_context(context) # Verify it exists assert db.get_context(context_id) is not None # Now test that failed operations don't affect existing data try: # This should fail but not corrupt the database db.update_context(999999, content="Should fail") # Non-existent ID except: pass # Expected to fail # Verify original context still exists and is unchanged retrieved = db.get_context(context_id) assert retrieved is not None assert retrieved.content == "Content" def test_concurrent_access(self, temp_db): """Test concurrent database access.""" db, _ = temp_db import threading import time results = [] errors = [] def worker(worker_id): try: for i in range(5): context = Context( None, f"/concurrent/{worker_id}/{i}", f"Content {worker_id}-{i}", f"Summary {worker_id}-{i}", f"worker{worker_id}", 1 ) context_id = db.store_context(context) results.append(context_id) time.sleep(0.001) # Small delay to increase contention except Exception as e: errors.append(e) # Run multiple workers concurrently threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)] for thread in threads: thread.start() for thread in threads: thread.join() # Check results assert len(errors) == 0, f"Concurrent access errors: {errors}" assert len(results) == 15 # 3 workers * 5 contexts each assert len(set(results)) == 15 # All IDs should be unique def run_context_db_tests(): """Run all context database tests.""" import subprocess import sys try: # Run pytest on this module result = subprocess.run([ sys.executable, "-m", "pytest", __file__, "-v", "--tb=short" ], capture_output=True, text=True, cwd=Path(__file__).parent.parent) print("CONTEXT DATABASE TEST RESULTS") print("=" * 50) print(result.stdout) if result.stderr: print("ERRORS:") print(result.stderr) return result.returncode == 0 except Exception as e: print(f"Failed to run tests: {e}") return False if __name__ == "__main__": success = run_context_db_tests() exit(0 if success else 1)