Files
hive/backend/app/services/ai_model_service.py
anthonyrawlins 268214d971 Major WHOOSH system refactoring and feature enhancements
- Migrated from HIVE branding to WHOOSH across all components
- Enhanced backend API with new services: AI models, BZZZ integration, templates, members
- Added comprehensive testing suite with security, performance, and integration tests
- Improved frontend with new components for project setup, AI models, and team management
- Updated MCP server implementation with WHOOSH-specific tools and resources
- Enhanced deployment configurations with production-ready Docker setups
- Added comprehensive documentation and setup guides
- Implemented age encryption service and UCXL integration

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-27 08:34:48 +10:00

411 lines
15 KiB
Python

"""
WHOOSH AI Model Service - Phase 6.1
Advanced AI model integration with distributed Ollama cluster
"""
import asyncio
import aiohttp
import json
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import logging
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class ModelCapability(Enum):
"""AI Model capabilities"""
CODE_GENERATION = "code_generation"
CODE_REVIEW = "code_review"
DOCUMENTATION = "documentation"
TESTING = "testing"
ARCHITECTURE = "architecture"
DEBUGGING = "debugging"
REFACTORING = "refactoring"
GENERAL_CHAT = "general_chat"
SPECIALIZED_DOMAIN = "specialized_domain"
@dataclass
class AIModel:
"""AI Model information"""
name: str
node_url: str
capabilities: List[ModelCapability]
context_length: int
parameter_count: str
specialization: Optional[str] = None
performance_score: float = 0.0
availability: bool = True
last_used: Optional[datetime] = None
usage_count: int = 0
avg_response_time: float = 0.0
@dataclass
class ClusterNode:
"""Ollama cluster node information"""
host: str
port: int
status: str = "unknown"
models: List[str] = None
load: float = 0.0
last_ping: Optional[datetime] = None
class AIModelService:
"""Advanced AI Model Service for WHOOSH"""
def __init__(self):
# Distributed Ollama cluster nodes from CLAUDE.md
self.cluster_nodes = [
ClusterNode("192.168.1.27", 11434), # Node 1
ClusterNode("192.168.1.72", 11434), # Node 2
ClusterNode("192.168.1.113", 11434), # Node 3
ClusterNode("192.168.1.106", 11434), # Node 4
]
self.models: Dict[str, AIModel] = {}
self.model_cache = {}
self.load_balancer_state = {}
self.session: Optional[aiohttp.ClientSession] = None
async def initialize(self):
"""Initialize the AI model service"""
logger.info("Initializing AI Model Service...")
# Create aiohttp session
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=30)
)
# Discover all available models across the cluster
await self.discover_cluster_models()
# Set up load balancing
await self.initialize_load_balancer()
logger.info(f"AI Model Service initialized with {len(self.models)} models across {len(self.cluster_nodes)} nodes")
async def discover_cluster_models(self):
"""Discover all available models across the Ollama cluster"""
logger.info("Discovering models across Ollama cluster...")
discovered_models = {}
for node in self.cluster_nodes:
try:
node_url = f"http://{node.host}:{node.port}"
# Check node health
async with self.session.get(f"{node_url}/api/tags", timeout=5) as response:
if response.status == 200:
data = await response.json()
node.status = "healthy"
node.models = [model["name"] for model in data.get("models", [])]
node.last_ping = datetime.now()
# Process each model
for model_info in data.get("models", []):
model_name = model_info["name"]
# Determine model capabilities based on name patterns
capabilities = self._determine_model_capabilities(model_name)
# Create or update model entry
if model_name not in discovered_models:
discovered_models[model_name] = AIModel(
name=model_name,
node_url=node_url,
capabilities=capabilities,
context_length=self._estimate_context_length(model_name),
parameter_count=self._estimate_parameters(model_name),
specialization=self._determine_specialization(model_name)
)
logger.info(f"Node {node.host}: {len(node.models)} models available")
except Exception as e:
logger.warning(f"Failed to connect to node {node.host}:{node.port}: {e}")
node.status = "unavailable"
node.models = []
self.models = discovered_models
logger.info(f"Discovered {len(self.models)} total models across cluster")
def _determine_model_capabilities(self, model_name: str) -> List[ModelCapability]:
"""Determine model capabilities based on name patterns"""
capabilities = []
name_lower = model_name.lower()
# Code-focused models
if any(keyword in name_lower for keyword in ["code", "codellama", "deepseek", "starcoder", "wizard"]):
capabilities.extend([
ModelCapability.CODE_GENERATION,
ModelCapability.CODE_REVIEW,
ModelCapability.DEBUGGING,
ModelCapability.REFACTORING
])
# Documentation models
if any(keyword in name_lower for keyword in ["llama", "mistral", "gemma"]):
capabilities.append(ModelCapability.DOCUMENTATION)
# Testing models
if "test" in name_lower or "wizard" in name_lower:
capabilities.append(ModelCapability.TESTING)
# Architecture models (larger models)
if any(keyword in name_lower for keyword in ["70b", "34b", "33b"]):
capabilities.append(ModelCapability.ARCHITECTURE)
# General chat (most models)
capabilities.append(ModelCapability.GENERAL_CHAT)
# Default if no specific capabilities found
if len(capabilities) == 1: # Only GENERAL_CHAT
capabilities.append(ModelCapability.CODE_GENERATION)
return capabilities
def _estimate_context_length(self, model_name: str) -> int:
"""Estimate context length based on model name"""
name_lower = model_name.lower()
if "32k" in name_lower:
return 32768
elif "16k" in name_lower:
return 16384
elif "8k" in name_lower:
return 8192
elif any(size in name_lower for size in ["70b", "65b"]):
return 4096
elif any(size in name_lower for size in ["34b", "33b"]):
return 4096
else:
return 2048 # Default
def _estimate_parameters(self, model_name: str) -> str:
"""Estimate parameter count based on model name"""
name_lower = model_name.lower()
if "70b" in name_lower:
return "70B"
elif "34b" in name_lower or "33b" in name_lower:
return "34B"
elif "13b" in name_lower:
return "13B"
elif "7b" in name_lower:
return "7B"
elif "3b" in name_lower:
return "3B"
elif "1b" in name_lower:
return "1B"
else:
return "Unknown"
def _determine_specialization(self, model_name: str) -> Optional[str]:
"""Determine model specialization"""
name_lower = model_name.lower()
if "code" in name_lower:
return "Programming"
elif "math" in name_lower:
return "Mathematics"
elif "sql" in name_lower:
return "Database"
elif "medical" in name_lower:
return "Healthcare"
else:
return None
async def get_best_model_for_task(self,
task_type: ModelCapability,
context_requirements: int = 2048,
prefer_specialized: bool = True) -> Optional[AIModel]:
"""Select the best model for a specific task"""
# Filter models by capability
suitable_models = [
model for model in self.models.values()
if task_type in model.capabilities and
model.availability and
model.context_length >= context_requirements
]
if not suitable_models:
logger.warning(f"No suitable models found for task {task_type}")
return None
# Scoring algorithm
def score_model(model: AIModel) -> float:
score = 0.0
# Base score from performance
score += model.performance_score * 0.3
# Capability match bonus
if task_type in model.capabilities:
score += 0.2
# Specialization bonus
if prefer_specialized and model.specialization:
score += 0.2
# Context length bonus (more is better up to a point)
context_ratio = min(model.context_length / context_requirements, 2.0)
score += context_ratio * 0.1
# Load balancing - prefer less used models
if model.usage_count > 0:
usage_penalty = min(model.usage_count / 100.0, 0.1)
score -= usage_penalty
# Response time bonus (faster is better)
if model.avg_response_time > 0:
time_bonus = max(0.1 - (model.avg_response_time / 10.0), 0)
score += time_bonus
return score
# Sort by score and return best
best_model = max(suitable_models, key=score_model)
logger.info(f"Selected model {best_model.name} for task {task_type}")
return best_model
async def generate_completion(self,
model_name: str,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: int = 1000,
temperature: float = 0.7) -> Dict[str, Any]:
"""Generate completion using specified model"""
if model_name not in self.models:
raise ValueError(f"Model {model_name} not available")
model = self.models[model_name]
start_time = time.time()
try:
# Prepare request
request_data = {
"model": model_name,
"prompt": prompt,
"stream": False,
"options": {
"num_predict": max_tokens,
"temperature": temperature
}
}
if system_prompt:
request_data["system"] = system_prompt
# Make request to Ollama
async with self.session.post(
f"{model.node_url}/api/generate",
json=request_data
) as response:
if response.status == 200:
result = await response.json()
# Update model statistics
end_time = time.time()
response_time = end_time - start_time
model.usage_count += 1
model.last_used = datetime.now()
# Update average response time
if model.avg_response_time == 0:
model.avg_response_time = response_time
else:
model.avg_response_time = (model.avg_response_time * 0.8) + (response_time * 0.2)
return {
"success": True,
"content": result.get("response", ""),
"model": model_name,
"response_time": response_time,
"usage_stats": {
"total_duration": result.get("total_duration", 0),
"load_duration": result.get("load_duration", 0),
"prompt_eval_count": result.get("prompt_eval_count", 0),
"eval_count": result.get("eval_count", 0)
}
}
else:
error_text = await response.text()
raise Exception(f"API error {response.status}: {error_text}")
except Exception as e:
logger.error(f"Error generating completion with {model_name}: {e}")
model.availability = False
return {
"success": False,
"error": str(e),
"model": model_name
}
async def initialize_load_balancer(self):
"""Initialize load balancing for the cluster"""
logger.info("Initializing load balancer...")
for node in self.cluster_nodes:
if node.status == "healthy":
self.load_balancer_state[f"{node.host}:{node.port}"] = {
"active_requests": 0,
"total_requests": 0,
"last_request": None,
"average_response_time": 0.0
}
async def get_cluster_status(self) -> Dict[str, Any]:
"""Get comprehensive cluster status"""
return {
"total_nodes": len(self.cluster_nodes),
"healthy_nodes": len([n for n in self.cluster_nodes if n.status == "healthy"]),
"total_models": len(self.models),
"models_by_capability": {
capability.value: len([
m for m in self.models.values()
if capability in m.capabilities
])
for capability in ModelCapability
},
"cluster_load": self._calculate_cluster_load(),
"model_usage_stats": {
name: {
"usage_count": model.usage_count,
"avg_response_time": model.avg_response_time,
"last_used": model.last_used.isoformat() if model.last_used else None
}
for name, model in self.models.items()
}
}
def _calculate_cluster_load(self) -> float:
"""Calculate overall cluster load"""
if not self.load_balancer_state:
return 0.0
total_load = sum(
state["active_requests"]
for state in self.load_balancer_state.values()
)
healthy_nodes = len([n for n in self.cluster_nodes if n.status == "healthy"])
if healthy_nodes == 0:
return 0.0
return total_load / healthy_nodes
async def cleanup(self):
"""Cleanup resources"""
if self.session:
await self.session.close()
# Global instance
ai_model_service = AIModelService()