Fix critical in-memory task storage with database persistence
Major architectural improvement to replace in-memory task storage with database-backed persistence while maintaining backward compatibility. Changes: - Created Task SQLAlchemy model matching database schema - Added Workflow and Execution SQLAlchemy models - Created TaskService for database CRUD operations - Updated UnifiedCoordinator to use database persistence - Modified task APIs to leverage database storage - Added task loading from database on coordinator initialization - Implemented status change persistence during task execution - Enhanced task cleanup with database support - Added comprehensive task statistics from database Benefits: - Tasks persist across application restarts - Better scalability and reliability - Historical task data retention - Comprehensive task filtering and querying - Maintains in-memory cache for performance 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -62,12 +62,54 @@ async def get_task(task_id: str, current_user: dict = Depends(get_current_user))
|
||||
async def get_tasks(
|
||||
status: Optional[str] = Query(None, description="Filter by task status"),
|
||||
agent: Optional[str] = Query(None, description="Filter by assigned agent"),
|
||||
limit: int = Query(20, description="Maximum number of tasks to return"),
|
||||
workflow_id: Optional[str] = Query(None, description="Filter by workflow ID"),
|
||||
limit: int = Query(50, description="Maximum number of tasks to return"),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get list of tasks with optional filtering"""
|
||||
"""Get list of tasks with optional filtering (includes database tasks)"""
|
||||
|
||||
# Get all tasks from coordinator
|
||||
try:
|
||||
# Get tasks from database (more comprehensive than in-memory only)
|
||||
db_tasks = coordinator.task_service.get_tasks(
|
||||
status=status,
|
||||
agent_id=agent,
|
||||
workflow_id=workflow_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Convert ORM tasks to coordinator tasks for consistent response format
|
||||
tasks = []
|
||||
for orm_task in db_tasks:
|
||||
coordinator_task = coordinator.task_service.coordinator_task_from_orm(orm_task)
|
||||
tasks.append({
|
||||
"id": coordinator_task.id,
|
||||
"type": coordinator_task.type.value,
|
||||
"priority": coordinator_task.priority,
|
||||
"status": coordinator_task.status.value,
|
||||
"context": coordinator_task.context,
|
||||
"assigned_agent": coordinator_task.assigned_agent,
|
||||
"result": coordinator_task.result,
|
||||
"created_at": coordinator_task.created_at,
|
||||
"completed_at": coordinator_task.completed_at,
|
||||
"workflow_id": coordinator_task.workflow_id,
|
||||
})
|
||||
|
||||
# Get total count for the response
|
||||
total_count = len(db_tasks)
|
||||
|
||||
return {
|
||||
"tasks": tasks,
|
||||
"total": total_count,
|
||||
"source": "database",
|
||||
"filters_applied": {
|
||||
"status": status,
|
||||
"agent": agent,
|
||||
"workflow_id": workflow_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to in-memory tasks if database fails
|
||||
all_tasks = list(coordinator.tasks.values())
|
||||
|
||||
# Apply filters
|
||||
@@ -83,6 +125,9 @@ async def get_tasks(
|
||||
if agent:
|
||||
filtered_tasks = [t for t in filtered_tasks if t.assigned_agent == agent]
|
||||
|
||||
if workflow_id:
|
||||
filtered_tasks = [t for t in filtered_tasks if t.workflow_id == workflow_id]
|
||||
|
||||
# Sort by creation time (newest first) and limit
|
||||
filtered_tasks.sort(key=lambda t: t.created_at or 0, reverse=True)
|
||||
filtered_tasks = filtered_tasks[:limit]
|
||||
@@ -100,10 +145,57 @@ async def get_tasks(
|
||||
"result": task.result,
|
||||
"created_at": task.created_at,
|
||||
"completed_at": task.completed_at,
|
||||
"workflow_id": task.workflow_id,
|
||||
})
|
||||
|
||||
return {
|
||||
"tasks": tasks,
|
||||
"total": len(tasks),
|
||||
"source": "memory_fallback",
|
||||
"database_error": str(e),
|
||||
"filtered": len(all_tasks) != len(tasks),
|
||||
}
|
||||
|
||||
@router.get("/tasks/statistics")
|
||||
async def get_task_statistics(current_user: dict = Depends(get_current_user)):
|
||||
"""Get comprehensive task statistics"""
|
||||
try:
|
||||
db_stats = coordinator.task_service.get_task_statistics()
|
||||
|
||||
# Get in-memory statistics
|
||||
memory_stats = {
|
||||
"in_memory_active": len([t for t in coordinator.tasks.values() if t.status == TaskStatus.IN_PROGRESS]),
|
||||
"in_memory_pending": len(coordinator.task_queue),
|
||||
"in_memory_total": len(coordinator.tasks)
|
||||
}
|
||||
|
||||
return {
|
||||
"database_statistics": db_stats,
|
||||
"memory_statistics": memory_stats,
|
||||
"coordinator_status": "operational" if coordinator.is_initialized else "initializing"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get task statistics: {str(e)}")
|
||||
|
||||
@router.delete("/tasks/{task_id}")
|
||||
async def delete_task(task_id: str, current_user: dict = Depends(get_current_user)):
|
||||
"""Delete a specific task"""
|
||||
try:
|
||||
# Remove from in-memory cache if present
|
||||
if task_id in coordinator.tasks:
|
||||
del coordinator.tasks[task_id]
|
||||
|
||||
# Remove from task queue if present
|
||||
coordinator.task_queue = [t for t in coordinator.task_queue if t.id != task_id]
|
||||
|
||||
# Delete from database
|
||||
success = coordinator.task_service.delete_task(task_id)
|
||||
|
||||
if success:
|
||||
return {"message": f"Task {task_id} deleted successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete task: {str(e)}")
|
||||
@@ -21,6 +21,7 @@ from prometheus_client import Counter, Histogram, Gauge
|
||||
from ..models.agent import Agent as ORMAgent
|
||||
from ..core.database import SessionLocal
|
||||
from ..cli_agents.cli_agent_manager import get_cli_agent_manager
|
||||
from ..services.task_service import TaskService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -120,10 +121,13 @@ class UnifiedCoordinator:
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379"):
|
||||
# Core state
|
||||
self.agents: Dict[str, Agent] = {}
|
||||
self.tasks: Dict[str, Task] = {}
|
||||
self.tasks: Dict[str, Task] = {} # In-memory cache for active tasks
|
||||
self.task_queue: List[Task] = []
|
||||
self.is_initialized = False
|
||||
|
||||
# Database persistence
|
||||
self.task_service = TaskService()
|
||||
|
||||
# CLI agent support
|
||||
self.cli_agent_manager = None
|
||||
|
||||
@@ -163,6 +167,9 @@ class UnifiedCoordinator:
|
||||
# Load agents from database
|
||||
await self._load_database_agents()
|
||||
|
||||
# Load existing tasks from database
|
||||
await self._load_database_tasks()
|
||||
|
||||
# Initialize cluster agents
|
||||
self._initialize_cluster_agents()
|
||||
|
||||
@@ -249,6 +256,31 @@ class UnifiedCoordinator:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to load agents from database: {e}")
|
||||
|
||||
async def _load_database_tasks(self):
|
||||
"""Load pending and in-progress tasks from database"""
|
||||
try:
|
||||
# Load pending tasks
|
||||
pending_orm_tasks = self.task_service.get_tasks(status='pending', limit=100)
|
||||
for orm_task in pending_orm_tasks:
|
||||
coordinator_task = self.task_service.coordinator_task_from_orm(orm_task)
|
||||
self.tasks[coordinator_task.id] = coordinator_task
|
||||
self.task_queue.append(coordinator_task)
|
||||
|
||||
# Load in-progress tasks
|
||||
in_progress_orm_tasks = self.task_service.get_tasks(status='in_progress', limit=100)
|
||||
for orm_task in in_progress_orm_tasks:
|
||||
coordinator_task = self.task_service.coordinator_task_from_orm(orm_task)
|
||||
self.tasks[coordinator_task.id] = coordinator_task
|
||||
# In-progress tasks are not added to task_queue as they're already being processed
|
||||
|
||||
# Sort task queue by priority
|
||||
self.task_queue.sort(key=lambda t: t.priority)
|
||||
|
||||
logger.info(f"📊 Loaded {len(pending_orm_tasks)} pending and {len(in_progress_orm_tasks)} in-progress tasks from database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to load tasks from database: {e}")
|
||||
|
||||
def _initialize_cluster_agents(self):
|
||||
"""Initialize predefined cluster agents"""
|
||||
# This maintains compatibility with the original HiveCoordinator
|
||||
@@ -292,6 +324,14 @@ class UnifiedCoordinator:
|
||||
payload=context # For compatibility
|
||||
)
|
||||
|
||||
# Persist to database
|
||||
try:
|
||||
self.task_service.create_task(task)
|
||||
logger.info(f"💾 Task {task_id} persisted to database")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to persist task {task_id} to database: {e}")
|
||||
|
||||
# Add to in-memory structures
|
||||
self.tasks[task_id] = task
|
||||
self.task_queue.append(task)
|
||||
|
||||
@@ -416,6 +456,13 @@ class UnifiedCoordinator:
|
||||
task.assigned_agent = agent.id
|
||||
agent.current_tasks += 1
|
||||
|
||||
# Persist status change to database
|
||||
try:
|
||||
self.task_service.update_task(task.id, task)
|
||||
logger.debug(f"💾 Updated task {task.id} status to IN_PROGRESS in database")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to update task {task.id} status in database: {e}")
|
||||
|
||||
ACTIVE_TASKS.labels(agent=agent.id).inc()
|
||||
start_time = time.time()
|
||||
|
||||
@@ -435,6 +482,13 @@ class UnifiedCoordinator:
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.completed_at = time.time()
|
||||
|
||||
# Persist completion to database
|
||||
try:
|
||||
self.task_service.update_task(task.id, task)
|
||||
logger.debug(f"💾 Updated task {task.id} status to COMPLETED in database")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to update completed task {task.id} in database: {e}")
|
||||
|
||||
# Update agent
|
||||
agent.current_tasks -= 1
|
||||
self.load_balancer.update_weight(agent.id, execution_time)
|
||||
@@ -450,6 +504,14 @@ class UnifiedCoordinator:
|
||||
except Exception as e:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.result = {"error": str(e)}
|
||||
|
||||
# Persist failure to database
|
||||
try:
|
||||
self.task_service.update_task(task.id, task)
|
||||
logger.debug(f"💾 Updated task {task.id} status to FAILED in database")
|
||||
except Exception as db_e:
|
||||
logger.error(f"❌ Failed to update failed task {task.id} in database: {db_e}")
|
||||
|
||||
agent.current_tasks -= 1
|
||||
ACTIVE_TASKS.labels(agent=agent.id).dec()
|
||||
logger.error(f"❌ Task {task.id} failed: {e}")
|
||||
@@ -622,6 +684,8 @@ Please complete this task based on the provided context and requirements.
|
||||
|
||||
async def _cleanup_completed_tasks(self):
|
||||
"""Clean up old completed tasks"""
|
||||
try:
|
||||
# Clean up in-memory tasks (keep only active ones)
|
||||
cutoff_time = time.time() - 3600 # 1 hour ago
|
||||
|
||||
completed_tasks = [
|
||||
@@ -632,8 +696,19 @@ Please complete this task based on the provided context and requirements.
|
||||
for task_id in completed_tasks:
|
||||
del self.tasks[task_id]
|
||||
|
||||
# Clean up database tasks (older ones)
|
||||
try:
|
||||
db_cleaned_count = self.task_service.cleanup_completed_tasks(max_age_hours=24)
|
||||
if db_cleaned_count > 0:
|
||||
logger.info(f"🧹 Cleaned up {db_cleaned_count} old tasks from database")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to cleanup database tasks: {e}")
|
||||
|
||||
if completed_tasks:
|
||||
logger.info(f"🧹 Cleaned up {len(completed_tasks)} old completed tasks")
|
||||
logger.info(f"🧹 Cleaned up {len(completed_tasks)} old completed tasks from memory")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to cleanup completed tasks: {e}")
|
||||
|
||||
# =========================================================================
|
||||
# STATUS & METRICS
|
||||
@@ -641,11 +716,39 @@ Please complete this task based on the provided context and requirements.
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[Task]:
|
||||
"""Get status of a specific task"""
|
||||
return self.tasks.get(task_id)
|
||||
# First check in-memory cache
|
||||
task = self.tasks.get(task_id)
|
||||
if task:
|
||||
return task
|
||||
|
||||
def get_completed_tasks(self) -> List[Task]:
|
||||
# If not in memory, check database
|
||||
try:
|
||||
orm_task = self.task_service.get_task(task_id)
|
||||
if orm_task:
|
||||
return self.task_service.coordinator_task_from_orm(orm_task)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to get task {task_id} from database: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_completed_tasks(self, limit: int = 50) -> List[Task]:
|
||||
"""Get all completed tasks"""
|
||||
return [task for task in self.tasks.values() if task.status == TaskStatus.COMPLETED]
|
||||
# Get from in-memory cache first
|
||||
memory_completed = [task for task in self.tasks.values() if task.status == TaskStatus.COMPLETED]
|
||||
|
||||
# Get additional from database if needed
|
||||
try:
|
||||
if len(memory_completed) < limit:
|
||||
db_completed = self.task_service.get_tasks(status='completed', limit=limit)
|
||||
db_tasks = [self.task_service.coordinator_task_from_orm(orm_task) for orm_task in db_completed]
|
||||
|
||||
# Combine and deduplicate
|
||||
all_tasks = {task.id: task for task in memory_completed + db_tasks}
|
||||
return list(all_tasks.values())[:limit]
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to get completed tasks from database: {e}")
|
||||
|
||||
return memory_completed[:limit]
|
||||
|
||||
async def get_health_status(self):
|
||||
"""Get coordinator health status"""
|
||||
@@ -660,13 +763,21 @@ Please complete this task based on the provided context and requirements.
|
||||
"last_heartbeat": agent.last_heartbeat
|
||||
}
|
||||
|
||||
# Get comprehensive task statistics from database
|
||||
try:
|
||||
db_stats = self.task_service.get_task_statistics()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to get task statistics from database: {e}")
|
||||
db_stats = {}
|
||||
|
||||
return {
|
||||
"status": "operational" if self.is_initialized else "initializing",
|
||||
"agents": agent_status,
|
||||
"total_agents": len(self.agents),
|
||||
"active_tasks": len([t for t in self.tasks.values() if t.status == TaskStatus.IN_PROGRESS]),
|
||||
"pending_tasks": len(self.task_queue),
|
||||
"completed_tasks": len([t for t in self.tasks.values() if t.status == TaskStatus.COMPLETED])
|
||||
"completed_tasks": len([t for t in self.tasks.values() if t.status == TaskStatus.COMPLETED]),
|
||||
"database_statistics": db_stats
|
||||
}
|
||||
|
||||
async def get_comprehensive_status(self):
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from . import agent
|
||||
from . import project
|
||||
from . import task
|
||||
from . import sqlalchemy_models
|
||||
@@ -1,5 +1,6 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, JSON
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from ..core.database import Base
|
||||
|
||||
class Agent(Base):
|
||||
@@ -23,6 +24,9 @@ class Agent(Base):
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
last_seen = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
tasks = relationship("Task", back_populates="assigned_agent")
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
|
||||
63
backend/app/models/sqlalchemy_models.py
Normal file
63
backend/app/models/sqlalchemy_models.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
SQLAlchemy models for workflows and executions
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Text, Integer, Boolean, DateTime, ForeignKey, UUID as SqlUUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from ..core.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Workflow(Base):
|
||||
__tablename__ = "workflows"
|
||||
|
||||
# Primary identification
|
||||
id = Column(SqlUUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4)
|
||||
|
||||
# Workflow details
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text)
|
||||
n8n_data = Column(JSONB, nullable=False)
|
||||
mcp_tools = Column(JSONB)
|
||||
|
||||
# Relationships
|
||||
created_by = Column(SqlUUID(as_uuid=True), ForeignKey("users.id"))
|
||||
|
||||
# Metadata
|
||||
version = Column(Integer, default=1)
|
||||
active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
# Relationships
|
||||
creator = relationship("User", back_populates="workflows")
|
||||
executions = relationship("Execution", back_populates="workflow")
|
||||
tasks = relationship("Task", back_populates="workflow")
|
||||
|
||||
|
||||
class Execution(Base):
|
||||
__tablename__ = "executions"
|
||||
|
||||
# Primary identification
|
||||
id = Column(SqlUUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4)
|
||||
|
||||
# Execution details
|
||||
workflow_id = Column(SqlUUID(as_uuid=True), ForeignKey("workflows.id"), nullable=True)
|
||||
status = Column(String(50), default='pending')
|
||||
input_data = Column(JSONB)
|
||||
output_data = Column(JSONB)
|
||||
error_message = Column(Text)
|
||||
progress = Column(Integer, default=0)
|
||||
|
||||
# Timestamps
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# Relationships
|
||||
workflow = relationship("Workflow", back_populates="executions")
|
||||
tasks = relationship("Task", back_populates="execution")
|
||||
41
backend/app/models/task.py
Normal file
41
backend/app/models/task.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Task model for SQLAlchemy ORM
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, UUID as SqlUUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from ..core.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Task(Base):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
# Primary identification
|
||||
id = Column(SqlUUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4)
|
||||
|
||||
# Task details
|
||||
title = Column(String(255), nullable=False)
|
||||
description = Column(Text)
|
||||
priority = Column(Integer, default=5)
|
||||
status = Column(String(50), default='pending')
|
||||
|
||||
# Relationships
|
||||
assigned_agent_id = Column(String(255), ForeignKey("agents.id"), nullable=True)
|
||||
workflow_id = Column(SqlUUID(as_uuid=True), ForeignKey("workflows.id"), nullable=True)
|
||||
execution_id = Column(SqlUUID(as_uuid=True), ForeignKey("executions.id"), nullable=True)
|
||||
|
||||
# Metadata and context
|
||||
metadata = Column(JSONB, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
assigned_agent = relationship("Agent", back_populates="tasks")
|
||||
workflow = relationship("Workflow", back_populates="tasks")
|
||||
execution = relationship("Execution", back_populates="tasks")
|
||||
@@ -44,6 +44,7 @@ class User(Base):
|
||||
# Relationships for authentication features
|
||||
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
|
||||
refresh_tokens = relationship("RefreshToken", back_populates="user", cascade="all, delete-orphan")
|
||||
workflows = relationship("Workflow", back_populates="creator")
|
||||
|
||||
def verify_password(self, password: str) -> bool:
|
||||
"""Verify a password against the hashed password."""
|
||||
|
||||
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .task_service import TaskService
|
||||
220
backend/app/services/task_service.py
Normal file
220
backend/app/services/task_service.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Task service for database operations
|
||||
Handles CRUD operations for tasks and integrates with the UnifiedCoordinator
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, func
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
|
||||
from ..models.task import Task as ORMTask
|
||||
from ..core.unified_coordinator import Task as CoordinatorTask, TaskStatus, AgentType
|
||||
from ..core.database import SessionLocal
|
||||
|
||||
|
||||
class TaskService:
|
||||
"""Service for managing task persistence and database operations"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create_task(self, coordinator_task: CoordinatorTask) -> ORMTask:
|
||||
"""Create a task in the database from a coordinator task"""
|
||||
with SessionLocal() as db:
|
||||
try:
|
||||
# Convert coordinator task to database task
|
||||
db_task = ORMTask(
|
||||
id=uuid.UUID(coordinator_task.id) if isinstance(coordinator_task.id, str) else coordinator_task.id,
|
||||
title=coordinator_task.context.get('title', f"Task {coordinator_task.type.value}"),
|
||||
description=coordinator_task.context.get('description', ''),
|
||||
priority=coordinator_task.priority,
|
||||
status=coordinator_task.status.value,
|
||||
assigned_agent_id=coordinator_task.assigned_agent,
|
||||
workflow_id=uuid.UUID(coordinator_task.workflow_id) if coordinator_task.workflow_id else None,
|
||||
metadata={
|
||||
'type': coordinator_task.type.value,
|
||||
'context': coordinator_task.context,
|
||||
'payload': coordinator_task.payload,
|
||||
'dependencies': coordinator_task.dependencies,
|
||||
'created_at': coordinator_task.created_at,
|
||||
'completed_at': coordinator_task.completed_at,
|
||||
'result': coordinator_task.result
|
||||
}
|
||||
)
|
||||
|
||||
if coordinator_task.status == TaskStatus.IN_PROGRESS and coordinator_task.created_at:
|
||||
db_task.started_at = datetime.fromtimestamp(coordinator_task.created_at)
|
||||
|
||||
if coordinator_task.status == TaskStatus.COMPLETED and coordinator_task.completed_at:
|
||||
db_task.completed_at = datetime.fromtimestamp(coordinator_task.completed_at)
|
||||
|
||||
db.add(db_task)
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
|
||||
return db_task
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
def update_task(self, task_id: str, coordinator_task: CoordinatorTask) -> Optional[ORMTask]:
|
||||
"""Update a task in the database"""
|
||||
with SessionLocal() as db:
|
||||
try:
|
||||
# Convert string ID to UUID if needed
|
||||
uuid_id = uuid.UUID(task_id) if isinstance(task_id, str) else task_id
|
||||
|
||||
db_task = db.query(ORMTask).filter(ORMTask.id == uuid_id).first()
|
||||
if not db_task:
|
||||
return None
|
||||
|
||||
# Update fields from coordinator task
|
||||
db_task.title = coordinator_task.context.get('title', db_task.title)
|
||||
db_task.description = coordinator_task.context.get('description', db_task.description)
|
||||
db_task.priority = coordinator_task.priority
|
||||
db_task.status = coordinator_task.status.value
|
||||
db_task.assigned_agent_id = coordinator_task.assigned_agent
|
||||
|
||||
# Update metadata
|
||||
db_task.metadata = {
|
||||
'type': coordinator_task.type.value,
|
||||
'context': coordinator_task.context,
|
||||
'payload': coordinator_task.payload,
|
||||
'dependencies': coordinator_task.dependencies,
|
||||
'created_at': coordinator_task.created_at,
|
||||
'completed_at': coordinator_task.completed_at,
|
||||
'result': coordinator_task.result
|
||||
}
|
||||
|
||||
# Update timestamps based on status
|
||||
if coordinator_task.status == TaskStatus.IN_PROGRESS and not db_task.started_at:
|
||||
db_task.started_at = datetime.utcnow()
|
||||
|
||||
if coordinator_task.status == TaskStatus.COMPLETED and not db_task.completed_at:
|
||||
db_task.completed_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
|
||||
return db_task
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[ORMTask]:
|
||||
"""Get a task by ID"""
|
||||
with SessionLocal() as db:
|
||||
uuid_id = uuid.UUID(task_id) if isinstance(task_id, str) else task_id
|
||||
return db.query(ORMTask).filter(ORMTask.id == uuid_id).first()
|
||||
|
||||
def get_tasks(self, status: Optional[str] = None, agent_id: Optional[str] = None,
|
||||
workflow_id: Optional[str] = None, limit: int = 100) -> List[ORMTask]:
|
||||
"""Get tasks with optional filtering"""
|
||||
with SessionLocal() as db:
|
||||
query = db.query(ORMTask)
|
||||
|
||||
if status:
|
||||
query = query.filter(ORMTask.status == status)
|
||||
if agent_id:
|
||||
query = query.filter(ORMTask.assigned_agent_id == agent_id)
|
||||
if workflow_id:
|
||||
uuid_workflow_id = uuid.UUID(workflow_id) if isinstance(workflow_id, str) else workflow_id
|
||||
query = query.filter(ORMTask.workflow_id == uuid_workflow_id)
|
||||
|
||||
return query.order_by(desc(ORMTask.created_at)).limit(limit).all()
|
||||
|
||||
def get_pending_tasks(self, limit: int = 50) -> List[ORMTask]:
|
||||
"""Get pending tasks ordered by priority"""
|
||||
with SessionLocal() as db:
|
||||
return db.query(ORMTask).filter(
|
||||
ORMTask.status == 'pending'
|
||||
).order_by(
|
||||
ORMTask.priority.asc(), # Lower number = higher priority
|
||||
ORMTask.created_at.asc()
|
||||
).limit(limit).all()
|
||||
|
||||
def delete_task(self, task_id: str) -> bool:
|
||||
"""Delete a task"""
|
||||
with SessionLocal() as db:
|
||||
try:
|
||||
uuid_id = uuid.UUID(task_id) if isinstance(task_id, str) else task_id
|
||||
task = db.query(ORMTask).filter(ORMTask.id == uuid_id).first()
|
||||
if task:
|
||||
db.delete(task)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
def cleanup_completed_tasks(self, max_age_hours: int = 24) -> int:
|
||||
"""Clean up old completed tasks"""
|
||||
with SessionLocal() as db:
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=max_age_hours)
|
||||
|
||||
deleted_count = db.query(ORMTask).filter(
|
||||
ORMTask.status.in_(['completed', 'failed']),
|
||||
ORMTask.completed_at < cutoff_time
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db.commit()
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
def coordinator_task_from_orm(self, orm_task: ORMTask) -> CoordinatorTask:
|
||||
"""Convert ORM task back to coordinator task"""
|
||||
metadata = orm_task.metadata or {}
|
||||
|
||||
# Extract fields from metadata
|
||||
task_type = AgentType(metadata.get('type', 'general_ai'))
|
||||
context = metadata.get('context', {})
|
||||
payload = metadata.get('payload', {})
|
||||
dependencies = metadata.get('dependencies', [])
|
||||
result = metadata.get('result')
|
||||
created_at = metadata.get('created_at', orm_task.created_at.timestamp() if orm_task.created_at else None)
|
||||
completed_at = metadata.get('completed_at')
|
||||
|
||||
# Convert status
|
||||
status = TaskStatus(orm_task.status) if orm_task.status in [s.value for s in TaskStatus] else TaskStatus.PENDING
|
||||
|
||||
return CoordinatorTask(
|
||||
id=str(orm_task.id),
|
||||
type=task_type,
|
||||
priority=orm_task.priority,
|
||||
status=status,
|
||||
context=context,
|
||||
payload=payload,
|
||||
assigned_agent=orm_task.assigned_agent_id,
|
||||
result=result,
|
||||
created_at=created_at,
|
||||
completed_at=completed_at,
|
||||
workflow_id=str(orm_task.workflow_id) if orm_task.workflow_id else None,
|
||||
dependencies=dependencies
|
||||
)
|
||||
|
||||
def get_task_statistics(self) -> Dict[str, Any]:
|
||||
"""Get task statistics"""
|
||||
with SessionLocal() as db:
|
||||
total_tasks = db.query(ORMTask).count()
|
||||
pending_tasks = db.query(ORMTask).filter(ORMTask.status == 'pending').count()
|
||||
in_progress_tasks = db.query(ORMTask).filter(ORMTask.status == 'in_progress').count()
|
||||
completed_tasks = db.query(ORMTask).filter(ORMTask.status == 'completed').count()
|
||||
failed_tasks = db.query(ORMTask).filter(ORMTask.status == 'failed').count()
|
||||
|
||||
return {
|
||||
'total_tasks': total_tasks,
|
||||
'pending_tasks': pending_tasks,
|
||||
'in_progress_tasks': in_progress_tasks,
|
||||
'completed_tasks': completed_tasks,
|
||||
'failed_tasks': failed_tasks,
|
||||
'success_rate': completed_tasks / total_tasks if total_tasks > 0 else 0
|
||||
}
|
||||
Reference in New Issue
Block a user