Implement complete Bearer Token and API key authentication system
- Create comprehensive authentication backend with JWT and API key support - Add database models for users, API keys, and tokens with proper security - Implement authentication middleware and API endpoints - Build complete frontend authentication UI with: - LoginForm component with JWT authentication - APIKeyManager for creating and managing API keys - AuthDashboard for comprehensive auth management - AuthContext for state management and authenticated requests - Initialize database with default admin user (admin/admin123) - Add proper token refresh, validation, and blacklisting - Implement scope-based API key authorization system 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
449
backend/app/api/auth.py
Normal file
449
backend/app/api/auth.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
Authentication API endpoints for Hive platform.
|
||||
Handles user registration, login, token refresh, and API key management.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.security import TokenManager, APIKeyManager, create_token_response, verify_password
|
||||
from app.core.auth_deps import (
|
||||
get_current_user_context,
|
||||
get_current_active_user,
|
||||
get_current_superuser,
|
||||
require_admin
|
||||
)
|
||||
from app.models.auth import User, APIKey, RefreshToken, TokenBlacklist, API_SCOPES, DEFAULT_API_SCOPES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Pydantic models for request/response
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
email: EmailStr
|
||||
password: str
|
||||
full_name: Optional[str] = None
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
email: str
|
||||
full_name: Optional[str]
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
is_verified: bool
|
||||
created_at: str
|
||||
last_login: Optional[str]
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
expires_in: int
|
||||
user: UserResponse
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class APIKeyCreate(BaseModel):
|
||||
name: str
|
||||
scopes: Optional[List[str]] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class APIKeyResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
key_prefix: str
|
||||
scopes: List[str]
|
||||
is_active: bool
|
||||
last_used: Optional[str]
|
||||
usage_count: int
|
||||
expires_at: Optional[str]
|
||||
created_at: str
|
||||
|
||||
|
||||
class APIKeyCreateResponse(APIKeyResponse):
|
||||
api_key: str # Only returned once during creation
|
||||
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class ScopeInfo(BaseModel):
|
||||
scope: str
|
||||
description: str
|
||||
|
||||
|
||||
# Authentication endpoints
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register_user(
|
||||
user_data: UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_superuser) # Only admins can create users
|
||||
):
|
||||
"""Register a new user (admin only)."""
|
||||
# Check if username or email already exists
|
||||
existing_user = db.query(User).filter(
|
||||
(User.username == user_data.username) | (User.email == user_data.email)
|
||||
).first()
|
||||
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username or email already registered"
|
||||
)
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
full_name=user_data.full_name,
|
||||
hashed_password=User.hash_password(user_data.password),
|
||||
is_active=True,
|
||||
is_verified=True # Auto-verify admin-created users
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return UserResponse(**user.to_dict())
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Authenticate user and return JWT tokens."""
|
||||
# Find user by username
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
|
||||
if not user or not user.verify_password(form_data.password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
# Update last login
|
||||
user.update_last_login()
|
||||
db.commit()
|
||||
|
||||
# Create token response
|
||||
user_data = user.to_dict()
|
||||
user_data["scopes"] = ["admin"] if user.is_superuser else []
|
||||
|
||||
token_response = create_token_response(user.id, user_data)
|
||||
|
||||
# Store refresh token in database
|
||||
refresh_token_plain = token_response["refresh_token"]
|
||||
refresh_token_hash = User.hash_password(refresh_token_plain)
|
||||
|
||||
# Get device info
|
||||
device_info = {
|
||||
"user_agent": request.headers.get("user-agent", ""),
|
||||
"ip": request.client.host if request.client else None,
|
||||
}
|
||||
|
||||
# Create refresh token record
|
||||
refresh_token_record = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=refresh_token_hash,
|
||||
jti=TokenManager.get_token_claims(refresh_token_plain).get("jti"),
|
||||
device_info=str(device_info),
|
||||
expires_at=datetime.utcnow() + timedelta(days=30)
|
||||
)
|
||||
|
||||
db.add(refresh_token_record)
|
||||
db.commit()
|
||||
|
||||
return TokenResponse(**token_response)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_token(
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Refresh access token using refresh token."""
|
||||
try:
|
||||
# Verify refresh token
|
||||
payload = TokenManager.verify_token(refresh_request.refresh_token)
|
||||
|
||||
if payload.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type"
|
||||
)
|
||||
|
||||
user_id = int(payload.get("sub"))
|
||||
jti = payload.get("jti")
|
||||
|
||||
# Check if refresh token exists and is valid
|
||||
refresh_token_record = db.query(RefreshToken).filter(
|
||||
RefreshToken.jti == jti,
|
||||
RefreshToken.user_id == user_id,
|
||||
RefreshToken.is_active == True
|
||||
).first()
|
||||
|
||||
if not refresh_token_record or not refresh_token_record.is_valid():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token"
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive"
|
||||
)
|
||||
|
||||
# Update refresh token usage
|
||||
refresh_token_record.record_usage()
|
||||
db.commit()
|
||||
|
||||
# Create new token response
|
||||
user_data = user.to_dict()
|
||||
user_data["scopes"] = ["admin"] if user.is_superuser else []
|
||||
|
||||
return TokenResponse(**create_token_response(user.id, user_data))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate refresh token"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_context),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Logout user and revoke current tokens."""
|
||||
# Blacklist the current access token
|
||||
if current_user.get("token_jti"):
|
||||
TokenBlacklist.blacklist_token(
|
||||
db,
|
||||
current_user["token_jti"],
|
||||
"access",
|
||||
datetime.utcnow() + timedelta(hours=1) # Token would expire anyway
|
||||
)
|
||||
|
||||
# Revoke all user's refresh tokens
|
||||
refresh_tokens = db.query(RefreshToken).filter(
|
||||
RefreshToken.user_id == current_user["user_id"],
|
||||
RefreshToken.is_active == True
|
||||
).all()
|
||||
|
||||
for token in refresh_tokens:
|
||||
token.revoke()
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"message": "Successfully logged out"}
|
||||
|
||||
|
||||
# User management endpoints
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user_info(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get current user information."""
|
||||
user = db.query(User).filter(User.id == current_user["user_id"]).first()
|
||||
return UserResponse(**user.to_dict())
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
async def change_password(
|
||||
password_data: PasswordChange,
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Change current user's password."""
|
||||
user = db.query(User).filter(User.id == current_user["user_id"]).first()
|
||||
|
||||
if not user.verify_password(password_data.current_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect"
|
||||
)
|
||||
|
||||
user.set_password(password_data.new_password)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Password changed successfully"}
|
||||
|
||||
|
||||
# API Key management endpoints
|
||||
@router.get("/api-keys", response_model=List[APIKeyResponse])
|
||||
async def list_api_keys(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""List user's API keys."""
|
||||
api_keys = db.query(APIKey).filter(APIKey.user_id == current_user["user_id"]).all()
|
||||
return [APIKeyResponse(**key.to_dict()) for key in api_keys]
|
||||
|
||||
|
||||
@router.post("/api-keys", response_model=APIKeyCreateResponse)
|
||||
async def create_api_key(
|
||||
key_data: APIKeyCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create a new API key."""
|
||||
# Generate API key
|
||||
plain_key, hashed_key, prefix = APIKeyManager.generate_api_key()
|
||||
|
||||
# Set default scopes if none provided
|
||||
scopes = key_data.scopes if key_data.scopes else DEFAULT_API_SCOPES
|
||||
|
||||
# Validate scopes
|
||||
invalid_scopes = [scope for scope in scopes if scope not in API_SCOPES]
|
||||
if invalid_scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid scopes: {', '.join(invalid_scopes)}"
|
||||
)
|
||||
|
||||
# Create API key record
|
||||
api_key = APIKey(
|
||||
user_id=current_user["user_id"],
|
||||
name=key_data.name,
|
||||
key_hash=hashed_key,
|
||||
key_prefix=prefix,
|
||||
expires_at=key_data.expires_at
|
||||
)
|
||||
api_key.set_scopes(scopes)
|
||||
|
||||
db.add(api_key)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
# Return API key with the plain key (only time it's shown)
|
||||
response_data = api_key.to_dict()
|
||||
response_data["api_key"] = plain_key
|
||||
|
||||
return APIKeyCreateResponse(**response_data)
|
||||
|
||||
|
||||
@router.delete("/api-keys/{key_id}")
|
||||
async def delete_api_key(
|
||||
key_id: int,
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Delete an API key."""
|
||||
api_key = db.query(APIKey).filter(
|
||||
APIKey.id == key_id,
|
||||
APIKey.user_id == current_user["user_id"]
|
||||
).first()
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
)
|
||||
|
||||
db.delete(api_key)
|
||||
db.commit()
|
||||
|
||||
return {"message": "API key deleted successfully"}
|
||||
|
||||
|
||||
@router.patch("/api-keys/{key_id}")
|
||||
async def update_api_key(
|
||||
key_id: int,
|
||||
key_data: dict,
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update an API key (name, scopes, active status)."""
|
||||
api_key = db.query(APIKey).filter(
|
||||
APIKey.id == key_id,
|
||||
APIKey.user_id == current_user["user_id"]
|
||||
).first()
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
)
|
||||
|
||||
# Update allowed fields
|
||||
if "name" in key_data:
|
||||
api_key.name = key_data["name"]
|
||||
|
||||
if "scopes" in key_data:
|
||||
scopes = key_data["scopes"]
|
||||
invalid_scopes = [scope for scope in scopes if scope not in API_SCOPES]
|
||||
if invalid_scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid scopes: {', '.join(invalid_scopes)}"
|
||||
)
|
||||
api_key.set_scopes(scopes)
|
||||
|
||||
if "is_active" in key_data:
|
||||
api_key.is_active = key_data["is_active"]
|
||||
|
||||
db.commit()
|
||||
|
||||
return APIKeyResponse(**api_key.to_dict())
|
||||
|
||||
|
||||
# Admin endpoints
|
||||
@router.get("/users", response_model=List[UserResponse])
|
||||
async def list_users(
|
||||
current_user: Dict[str, Any] = Depends(require_admin),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""List all users (admin only)."""
|
||||
users = db.query(User).all()
|
||||
return [UserResponse(**user.to_dict()) for user in users]
|
||||
|
||||
|
||||
@router.get("/scopes", response_model=List[ScopeInfo])
|
||||
async def list_available_scopes():
|
||||
"""List all available API scopes."""
|
||||
return [
|
||||
ScopeInfo(scope=scope, description=description)
|
||||
for scope, description in API_SCOPES.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/cleanup-tokens")
|
||||
async def cleanup_expired_tokens(
|
||||
current_user: Dict[str, Any] = Depends(require_admin),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Cleanup expired tokens from blacklist (admin only)."""
|
||||
count = TokenBlacklist.cleanup_expired_tokens(db)
|
||||
return {"message": f"Cleaned up {count} expired tokens"}
|
||||
199
backend/app/core/auth_deps.py
Normal file
199
backend/app/core/auth_deps.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Authentication dependencies for FastAPI endpoints.
|
||||
Provides dependency injection for authentication and authorization.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any, List
|
||||
from fastapi import Depends, HTTPException, status, Request, Header
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.database import get_db
|
||||
from app.core.security import AuthManager, AuthenticationError
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def get_api_key_from_header(x_api_key: Optional[str] = Header(None)) -> Optional[str]:
|
||||
"""Extract API key from X-API-Key header."""
|
||||
return x_api_key
|
||||
|
||||
|
||||
def get_current_user_context(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
authorization: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
api_key: Optional[str] = Depends(get_api_key_from_header),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current authenticated user context.
|
||||
Supports both JWT Bearer tokens and API key authentication.
|
||||
"""
|
||||
try:
|
||||
user_context = AuthManager.authenticate_request(
|
||||
session=db,
|
||||
authorization=authorization,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
# Add request metadata
|
||||
user_context["request_ip"] = request.client.host if request.client else None
|
||||
user_context["user_agent"] = request.headers.get("user-agent")
|
||||
|
||||
return user_context
|
||||
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(
|
||||
status_code=e.status_code,
|
||||
detail=e.message,
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def get_current_user(
|
||||
user_context: Dict[str, Any] = Depends(get_current_user_context),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current authenticated user (alias for get_current_user_context)."""
|
||||
return user_context
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
user_context: Dict[str, Any] = Depends(get_current_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current authenticated and active user."""
|
||||
from app.models.auth import User
|
||||
|
||||
user = db.query(User).filter(User.id == user_context["user_id"]).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return user_context
|
||||
|
||||
|
||||
def get_current_superuser(
|
||||
user_context: Dict[str, Any] = Depends(get_current_active_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current authenticated superuser."""
|
||||
if not user_context.get("is_superuser", False):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
)
|
||||
return user_context
|
||||
|
||||
|
||||
def require_scope(required_scope: str):
|
||||
"""
|
||||
Dependency factory to require specific scope for an endpoint.
|
||||
|
||||
Usage:
|
||||
@app.get("/admin/users", dependencies=[Depends(require_scope("admin"))])
|
||||
async def get_users():
|
||||
...
|
||||
"""
|
||||
def scope_dependency(
|
||||
user_context: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
from app.core.security import APIKeyManager
|
||||
|
||||
user_scopes = user_context.get("scopes", [])
|
||||
|
||||
if not APIKeyManager.check_scope_permission(user_scopes, required_scope):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Insufficient permissions. Required scope: {required_scope}"
|
||||
)
|
||||
|
||||
return user_context
|
||||
|
||||
return scope_dependency
|
||||
|
||||
|
||||
def require_scopes(required_scopes: List[str]):
|
||||
"""
|
||||
Dependency factory to require multiple scopes for an endpoint.
|
||||
User must have ALL specified scopes.
|
||||
"""
|
||||
def scopes_dependency(
|
||||
user_context: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
from app.core.security import APIKeyManager
|
||||
|
||||
user_scopes = user_context.get("scopes", [])
|
||||
|
||||
for scope in required_scopes:
|
||||
if not APIKeyManager.check_scope_permission(user_scopes, scope):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Insufficient permissions. Required scopes: {', '.join(required_scopes)}"
|
||||
)
|
||||
|
||||
return user_context
|
||||
|
||||
return scopes_dependency
|
||||
|
||||
|
||||
def require_any_scope(required_scopes: List[str]):
|
||||
"""
|
||||
Dependency factory to require at least one of the specified scopes.
|
||||
User must have ANY of the specified scopes.
|
||||
"""
|
||||
def any_scope_dependency(
|
||||
user_context: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
from app.core.security import APIKeyManager
|
||||
|
||||
user_scopes = user_context.get("scopes", [])
|
||||
|
||||
for scope in required_scopes:
|
||||
if APIKeyManager.check_scope_permission(user_scopes, scope):
|
||||
return user_context
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Insufficient permissions. Required one of: {', '.join(required_scopes)}"
|
||||
)
|
||||
|
||||
return any_scope_dependency
|
||||
|
||||
|
||||
# Optional authentication (won't raise error if not authenticated)
|
||||
def get_optional_user_context(
|
||||
db: Session = Depends(get_db),
|
||||
authorization: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
api_key: Optional[str] = Depends(get_api_key_from_header),
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get current user context if authenticated, None otherwise.
|
||||
Useful for endpoints that work with or without authentication.
|
||||
"""
|
||||
try:
|
||||
return AuthManager.authenticate_request(
|
||||
session=db,
|
||||
authorization=authorization,
|
||||
api_key=api_key
|
||||
)
|
||||
except AuthenticationError:
|
||||
return None
|
||||
|
||||
|
||||
# Common scope dependencies for convenience
|
||||
require_admin = require_scope("admin")
|
||||
require_agents_read = require_scope("agents:read")
|
||||
require_agents_write = require_scope("agents:write")
|
||||
require_workflows_read = require_scope("workflows:read")
|
||||
require_workflows_write = require_scope("workflows:write")
|
||||
require_tasks_read = require_scope("tasks:read")
|
||||
require_tasks_write = require_scope("tasks:write")
|
||||
require_metrics_read = require_scope("metrics:read")
|
||||
require_system_read = require_scope("system:read")
|
||||
require_system_write = require_scope("system:write")
|
||||
127
backend/app/core/init_db.py
Normal file
127
backend/app/core/init_db.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Database initialization script for Hive platform.
|
||||
Creates all tables and sets up initial data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.database import engine, SessionLocal
|
||||
from app.models.auth import Base as AuthBase, User, API_SCOPES
|
||||
from app.models.auth import APIKey
|
||||
|
||||
# Import other model bases here as they're created
|
||||
# from app.models.workflows import Base as WorkflowsBase
|
||||
# from app.models.agents import Base as AgentsBase
|
||||
|
||||
def create_tables():
|
||||
"""Create all database tables."""
|
||||
try:
|
||||
# Create auth tables
|
||||
AuthBase.metadata.create_all(bind=engine)
|
||||
|
||||
# Add other model bases here
|
||||
# WorkflowsBase.metadata.create_all(bind=engine)
|
||||
# AgentsBase.metadata.create_all(bind=engine)
|
||||
|
||||
logging.info("Database tables created successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to create database tables: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def create_initial_user(db: Session):
|
||||
"""Create initial admin user if none exists."""
|
||||
try:
|
||||
# Check if any users exist
|
||||
user_count = db.query(User).count()
|
||||
if user_count > 0:
|
||||
logging.info("Users already exist, skipping initial user creation")
|
||||
return True
|
||||
|
||||
# Create initial admin user
|
||||
admin_user = User(
|
||||
username="admin",
|
||||
email="admin@hive.local",
|
||||
full_name="Hive Administrator",
|
||||
hashed_password=User.hash_password("admin123"), # Change this!
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
is_verified=True
|
||||
)
|
||||
|
||||
db.add(admin_user)
|
||||
db.commit()
|
||||
db.refresh(admin_user)
|
||||
|
||||
logging.info("Initial admin user created: admin/admin123")
|
||||
logging.warning("SECURITY: Please change the default admin password!")
|
||||
|
||||
# Create initial API key for the admin user
|
||||
from app.core.security import APIKeyManager
|
||||
plain_key, hashed_key, prefix = APIKeyManager.generate_api_key()
|
||||
|
||||
admin_api_key = APIKey(
|
||||
user_id=admin_user.id,
|
||||
name="Default Admin API Key",
|
||||
key_hash=hashed_key,
|
||||
key_prefix=prefix,
|
||||
is_active=True
|
||||
)
|
||||
admin_api_key.set_scopes(["admin"])
|
||||
|
||||
db.add(admin_api_key)
|
||||
db.commit()
|
||||
|
||||
logging.info(f"Initial admin API key created: {plain_key}")
|
||||
logging.warning("SECURITY: Save this API key securely, it won't be shown again!")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to create initial user: {e}")
|
||||
db.rollback()
|
||||
return False
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""Initialize the complete database."""
|
||||
logging.info("Starting database initialization...")
|
||||
|
||||
# Create tables
|
||||
if not create_tables():
|
||||
return False
|
||||
|
||||
# Create initial data
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Create initial admin user
|
||||
if not create_initial_user(db):
|
||||
return False
|
||||
|
||||
logging.info("Database initialization completed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Database initialization failed: {e}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
# Initialize database
|
||||
success = initialize_database()
|
||||
if success:
|
||||
print("✅ Database initialization completed successfully")
|
||||
print("🔑 Default admin credentials: admin/admin123")
|
||||
print("⚠️ SECURITY: Please change the default password immediately!")
|
||||
else:
|
||||
print("❌ Database initialization failed")
|
||||
exit(1)
|
||||
289
backend/app/core/security.py
Normal file
289
backend/app/core/security.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Security utilities for JWT token generation, validation, and API key management.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
import jwt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# JWT Configuration
|
||||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-this-in-production")
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "30"))
|
||||
|
||||
# Security scheme
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Manages JWT token creation, validation, and refresh."""
|
||||
|
||||
@staticmethod
|
||||
def create_access_token(
|
||||
data: Dict[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""Create a JWT access token."""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
# Add standard claims
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.utcnow(),
|
||||
"type": "access",
|
||||
"jti": str(uuid.uuid4()), # JWT ID for blacklisting
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
@staticmethod
|
||||
def create_refresh_token(
|
||||
user_id: int,
|
||||
expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""Create a JWT refresh token."""
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"sub": str(user_id),
|
||||
"exp": expire,
|
||||
"iat": datetime.utcnow(),
|
||||
"type": "refresh",
|
||||
"jti": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
@staticmethod
|
||||
def verify_token(token: str) -> Dict[str, Any]:
|
||||
"""Verify and decode a JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired"
|
||||
)
|
||||
except jwt.JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_user_id(token: str) -> int:
|
||||
"""Extract user ID from a valid token."""
|
||||
payload = TokenManager.verify_token(token)
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token missing user information"
|
||||
)
|
||||
return int(user_id)
|
||||
|
||||
@staticmethod
|
||||
def get_token_claims(token: str) -> Dict[str, Any]:
|
||||
"""Get all claims from a token without verification (for expired tokens)."""
|
||||
try:
|
||||
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_exp": False})
|
||||
except jwt.JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token format"
|
||||
)
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
"""Manages API key generation, validation, and permissions."""
|
||||
|
||||
@staticmethod
|
||||
def generate_api_key() -> tuple[str, str, str]:
|
||||
"""
|
||||
Generate a new API key.
|
||||
Returns: (plain_key, hashed_key, prefix)
|
||||
"""
|
||||
from app.models.auth import APIKey
|
||||
plain_key, hashed_key = APIKey.generate_api_key()
|
||||
prefix = plain_key[:8] # First 8 characters for identification
|
||||
return plain_key, hashed_key, prefix
|
||||
|
||||
@staticmethod
|
||||
def validate_api_key(session: Session, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Validate an API key and return user/key information.
|
||||
Returns None if invalid.
|
||||
"""
|
||||
from app.models.auth import APIKey, User
|
||||
|
||||
# Find API key by trying to match the hash
|
||||
api_keys = session.query(APIKey).filter(APIKey.is_active == True).all()
|
||||
|
||||
for key_record in api_keys:
|
||||
if APIKey.verify_api_key(api_key, key_record.key_hash):
|
||||
if not key_record.is_valid():
|
||||
return None
|
||||
|
||||
# Get user information
|
||||
user = session.query(User).filter(User.id == key_record.user_id).first()
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
|
||||
# Record usage
|
||||
key_record.record_usage()
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"api_key_id": key_record.id,
|
||||
"scopes": key_record.get_scopes(),
|
||||
"is_superuser": user.is_superuser,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def check_scope_permission(user_scopes: List[str], required_scope: str) -> bool:
|
||||
"""Check if user has required scope permission."""
|
||||
# Admin users have all permissions
|
||||
if "admin" in user_scopes:
|
||||
return True
|
||||
|
||||
# Check for specific scope
|
||||
if required_scope in user_scopes:
|
||||
return True
|
||||
|
||||
# Check for wildcard permissions (e.g., "workflows:*" covers "workflows:read")
|
||||
scope_parts = required_scope.split(":")
|
||||
if len(scope_parts) >= 2:
|
||||
wildcard_scope = f"{scope_parts[0]}:*"
|
||||
if wildcard_scope in user_scopes:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Custom exception for authentication errors."""
|
||||
def __init__(self, message: str, status_code: int = status.HTTP_401_UNAUTHORIZED):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""Main authentication manager combining JWT and API key auth."""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_request(
|
||||
session: Session,
|
||||
authorization: Optional[HTTPAuthorizationCredentials] = None,
|
||||
api_key: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Authenticate a request using either Bearer token or API key.
|
||||
Returns user context information.
|
||||
"""
|
||||
# Try API key authentication first
|
||||
if api_key:
|
||||
user_context = APIKeyManager.validate_api_key(session, api_key)
|
||||
if user_context:
|
||||
user_context["auth_type"] = "api_key"
|
||||
return user_context
|
||||
else:
|
||||
raise AuthenticationError("Invalid API key")
|
||||
|
||||
# Try JWT Bearer token authentication
|
||||
if authorization and authorization.scheme.lower() == "bearer":
|
||||
try:
|
||||
payload = TokenManager.verify_token(authorization.credentials)
|
||||
|
||||
# Check if token is blacklisted
|
||||
from app.models.auth import TokenBlacklist
|
||||
jti = payload.get("jti")
|
||||
if jti and TokenBlacklist.is_token_blacklisted(session, jti):
|
||||
raise AuthenticationError("Token has been revoked")
|
||||
|
||||
# Get user information
|
||||
user_id = int(payload.get("sub"))
|
||||
from app.models.auth import User
|
||||
user = session.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise AuthenticationError("User not found or inactive")
|
||||
|
||||
return {
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"scopes": ["admin"] if user.is_superuser else [],
|
||||
"is_superuser": user.is_superuser,
|
||||
"auth_type": "jwt",
|
||||
"token_jti": jti,
|
||||
}
|
||||
|
||||
except HTTPException as e:
|
||||
raise AuthenticationError(e.detail, e.status_code)
|
||||
|
||||
raise AuthenticationError("No valid authentication provided")
|
||||
|
||||
@staticmethod
|
||||
def require_scope(required_scope: str):
|
||||
"""Decorator to require specific scope for an endpoint."""
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# This will be implemented in the dependency injection system
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def create_token_response(user_id: int, user_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create a complete token response with access and refresh tokens."""
|
||||
# Create access token with user data
|
||||
access_token_data = {
|
||||
"sub": str(user_id),
|
||||
"username": user_data.get("username"),
|
||||
"scopes": user_data.get("scopes", []),
|
||||
}
|
||||
|
||||
access_token = TokenManager.create_access_token(access_token_data)
|
||||
refresh_token = TokenManager.create_refresh_token(user_id)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60, # seconds
|
||||
"user": user_data,
|
||||
}
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Hash a password for storage."""
|
||||
from app.models.auth import User
|
||||
return User.hash_password(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
from app.models.auth import pwd_context
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
@@ -12,8 +12,7 @@ import socketio
|
||||
from .core.hive_coordinator import HiveCoordinator
|
||||
from .core.distributed_coordinator import DistributedCoordinator
|
||||
from .core.database import engine, get_db, init_database_with_retry, test_database_connection
|
||||
from .core.auth import get_current_user
|
||||
from .api import agents, workflows, executions, monitoring, projects, tasks, cluster, distributed_workflows, cli_agents
|
||||
from .api import agents, workflows, executions, monitoring, projects, tasks, cluster, distributed_workflows, cli_agents, auth
|
||||
# from .mcp.distributed_mcp_server import get_mcp_server
|
||||
from .models.user import Base
|
||||
from .models import agent, project # Import the new agent and project models
|
||||
@@ -35,6 +34,11 @@ async def lifespan(app: FastAPI):
|
||||
print("📊 Initializing database...")
|
||||
init_database_with_retry()
|
||||
|
||||
# Initialize auth database tables and initial data
|
||||
print("🔐 Initializing authentication system...")
|
||||
from .core.init_db import initialize_database
|
||||
initialize_database()
|
||||
|
||||
# Test database connection
|
||||
if not test_database_connection():
|
||||
raise Exception("Database connection test failed")
|
||||
@@ -100,6 +104,7 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# Include API routes
|
||||
app.include_router(auth.router, prefix="/api/auth", tags=["authentication"])
|
||||
app.include_router(agents.router, prefix="/api", tags=["agents"])
|
||||
app.include_router(workflows.router, prefix="/api", tags=["workflows"])
|
||||
app.include_router(executions.router, prefix="/api", tags=["executions"])
|
||||
|
||||
297
backend/app/models/auth.py
Normal file
297
backend/app/models/auth.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Authentication and authorization models for Hive platform.
|
||||
Includes users, API keys, and JWT token management.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
import secrets
|
||||
import string
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from passlib.context import CryptContext
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""User model for authentication and authorization."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String(50), unique=True, index=True, nullable=False)
|
||||
email = Column(String(255), unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
full_name = Column(String(255), nullable=True)
|
||||
|
||||
# User status and permissions
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_superuser = Column(Boolean, default=False)
|
||||
is_verified = Column(Boolean, default=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
last_login = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
|
||||
refresh_tokens = relationship("RefreshToken", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
def verify_password(self, password: str) -> bool:
|
||||
"""Verify a password against the hashed password."""
|
||||
return pwd_context.verify(password, self.hashed_password)
|
||||
|
||||
@classmethod
|
||||
def hash_password(cls, password: str) -> str:
|
||||
"""Hash a password for storage."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
def set_password(self, password: str) -> None:
|
||||
"""Set a new password for the user."""
|
||||
self.hashed_password = self.hash_password(password)
|
||||
|
||||
def update_last_login(self) -> None:
|
||||
"""Update the last login timestamp."""
|
||||
self.last_login = datetime.utcnow()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert user to dictionary (excluding sensitive data)."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"username": self.username,
|
||||
"email": self.email,
|
||||
"full_name": self.full_name,
|
||||
"is_active": self.is_active,
|
||||
"is_superuser": self.is_superuser,
|
||||
"is_verified": self.is_verified,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"last_login": self.last_login.isoformat() if self.last_login else None,
|
||||
}
|
||||
|
||||
|
||||
class APIKey(Base):
|
||||
"""API Key model for programmatic access to Hive API."""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# API Key details
|
||||
name = Column(String(255), nullable=False) # Human-readable name
|
||||
key_hash = Column(String(255), unique=True, index=True, nullable=False) # Hashed API key
|
||||
key_prefix = Column(String(10), nullable=False) # First 8 chars for identification
|
||||
|
||||
# Permissions and scope
|
||||
scopes = Column(Text, nullable=True) # JSON list of permissions
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Usage tracking
|
||||
last_used = Column(DateTime, nullable=True)
|
||||
usage_count = Column(Integer, default=0)
|
||||
|
||||
# Expiration
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="api_keys")
|
||||
|
||||
@classmethod
|
||||
def generate_api_key(cls) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a new API key.
|
||||
Returns: (plain_key, hashed_key)
|
||||
"""
|
||||
# Generate a random API key: hive_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
key_suffix = ''.join(secrets.choice(alphabet) for _ in range(32))
|
||||
plain_key = f"hive_{key_suffix}"
|
||||
|
||||
# Hash the key for storage
|
||||
hashed_key = pwd_context.hash(plain_key)
|
||||
|
||||
return plain_key, hashed_key
|
||||
|
||||
@classmethod
|
||||
def verify_api_key(cls, plain_key: str, hashed_key: str) -> bool:
|
||||
"""Verify an API key against the hashed version."""
|
||||
return pwd_context.verify(plain_key, hashed_key)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the API key is valid (active and not expired)."""
|
||||
if not self.is_active:
|
||||
return False
|
||||
|
||||
if self.expires_at and self.expires_at < datetime.utcnow():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def record_usage(self) -> None:
|
||||
"""Record API key usage."""
|
||||
self.last_used = datetime.utcnow()
|
||||
self.usage_count += 1
|
||||
|
||||
def get_scopes(self) -> List[str]:
|
||||
"""Get list of scopes/permissions for this API key."""
|
||||
if not self.scopes:
|
||||
return []
|
||||
try:
|
||||
import json
|
||||
return json.loads(self.scopes)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
def set_scopes(self, scopes: List[str]) -> None:
|
||||
"""Set scopes/permissions for this API key."""
|
||||
import json
|
||||
self.scopes = json.dumps(scopes)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert API key to dictionary (excluding sensitive data)."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"key_prefix": self.key_prefix,
|
||||
"scopes": self.get_scopes(),
|
||||
"is_active": self.is_active,
|
||||
"last_used": self.last_used.isoformat() if self.last_used else None,
|
||||
"usage_count": self.usage_count,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
"""Refresh token model for JWT token management."""
|
||||
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# Token details
|
||||
token_hash = Column(String(255), unique=True, index=True, nullable=False)
|
||||
jti = Column(String(36), unique=True, index=True, nullable=False) # JWT ID
|
||||
|
||||
# Token metadata
|
||||
device_info = Column(String(512), nullable=True) # User agent, IP, etc.
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Expiration
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_used = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="refresh_tokens")
|
||||
|
||||
@classmethod
|
||||
def generate_refresh_token(cls, length: int = 64) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a new refresh token.
|
||||
Returns: (plain_token, hashed_token)
|
||||
"""
|
||||
alphabet = string.ascii_letters + string.digits + "-_"
|
||||
plain_token = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
hashed_token = pwd_context.hash(plain_token)
|
||||
|
||||
return plain_token, hashed_token
|
||||
|
||||
@classmethod
|
||||
def verify_refresh_token(cls, plain_token: str, hashed_token: str) -> bool:
|
||||
"""Verify a refresh token against the hashed version."""
|
||||
return pwd_context.verify(plain_token, hashed_token)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the refresh token is valid (active and not expired)."""
|
||||
if not self.is_active:
|
||||
return False
|
||||
|
||||
if self.expires_at < datetime.utcnow():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def revoke(self) -> None:
|
||||
"""Revoke the refresh token."""
|
||||
self.is_active = False
|
||||
|
||||
def record_usage(self) -> None:
|
||||
"""Record refresh token usage."""
|
||||
self.last_used = datetime.utcnow()
|
||||
|
||||
|
||||
class TokenBlacklist(Base):
|
||||
"""Blacklist for revoked JWT tokens."""
|
||||
|
||||
__tablename__ = "token_blacklist"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
jti = Column(String(36), unique=True, index=True, nullable=False) # JWT ID
|
||||
token_type = Column(String(20), nullable=False) # "access" or "refresh"
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def is_token_blacklisted(cls, session, jti: str) -> bool:
|
||||
"""Check if a token is blacklisted."""
|
||||
token = session.query(cls).filter(cls.jti == jti).first()
|
||||
return token is not None
|
||||
|
||||
@classmethod
|
||||
def blacklist_token(cls, session, jti: str, token_type: str, expires_at: datetime) -> None:
|
||||
"""Add a token to the blacklist."""
|
||||
blacklisted_token = cls(
|
||||
jti=jti,
|
||||
token_type=token_type,
|
||||
expires_at=expires_at
|
||||
)
|
||||
session.add(blacklisted_token)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired_tokens(cls, session) -> int:
|
||||
"""Remove expired tokens from blacklist and return count removed."""
|
||||
now = datetime.utcnow()
|
||||
expired_tokens = session.query(cls).filter(cls.expires_at < now)
|
||||
count = expired_tokens.count()
|
||||
expired_tokens.delete()
|
||||
session.commit()
|
||||
return count
|
||||
|
||||
|
||||
# Available scopes for API keys
|
||||
API_SCOPES = {
|
||||
"agents:read": "View agent information and status",
|
||||
"agents:write": "Manage agents (start, stop, configure)",
|
||||
"workflows:read": "View workflow information and executions",
|
||||
"workflows:write": "Create, modify, and execute workflows",
|
||||
"tasks:read": "View task information and results",
|
||||
"tasks:write": "Create and manage tasks",
|
||||
"metrics:read": "View system metrics and performance data",
|
||||
"system:read": "View system status and configuration",
|
||||
"system:write": "Modify system configuration",
|
||||
"admin": "Full administrative access",
|
||||
}
|
||||
|
||||
# Default scopes for new API keys
|
||||
DEFAULT_API_SCOPES = [
|
||||
"agents:read",
|
||||
"workflows:read",
|
||||
"tasks:read",
|
||||
"metrics:read",
|
||||
"system:read"
|
||||
]
|
||||
Reference in New Issue
Block a user