250 lines
8.0 KiB
Python
250 lines
8.0 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional, List
|
|
from jose import JWTError, jwt
|
|
from passlib.context import CryptContext
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from sqlalchemy.orm import Session
|
|
import os
|
|
import secrets
|
|
from database import get_db
|
|
from models import User, UserRole, Permission, RolePermission, Role
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
security = HTTPBearer()
|
|
|
|
JWT_SECRET = os.environ.get('JWT_SECRET', 'your-secret-key')
|
|
JWT_ALGORITHM = os.environ.get('JWT_ALGORITHM', 'HS256')
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get('ACCESS_TOKEN_EXPIRE_MINUTES', 30))
|
|
|
|
def verify_password(plain_password, hashed_password):
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
def get_password_hash(password):
|
|
return pwd_context.hash(password)
|
|
|
|
def generate_reset_token():
|
|
"""Generate secure random token for password reset"""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
def create_password_reset_token(user, db):
|
|
"""Create reset token with 1-hour expiration"""
|
|
token = generate_reset_token()
|
|
expires = datetime.now(timezone.utc) + timedelta(hours=1)
|
|
|
|
user.password_reset_token = token
|
|
user.password_reset_expires = expires
|
|
db.commit()
|
|
|
|
return token
|
|
|
|
def verify_reset_token(token, db):
|
|
"""Verify token is valid and not expired"""
|
|
user = db.query(User).filter(User.password_reset_token == token).first()
|
|
|
|
if not user:
|
|
return None
|
|
|
|
if user.password_reset_expires < datetime.now(timezone.utc):
|
|
return None # Token expired
|
|
|
|
return user
|
|
|
|
def get_user_role_code(user: User) -> str:
|
|
"""
|
|
Get user's role code from either dynamic role system or legacy enum.
|
|
Supports backward compatibility during migration (Phase 3).
|
|
|
|
Args:
|
|
user: User object
|
|
|
|
Returns:
|
|
Role code string (e.g., "superadmin", "admin", "member", "guest")
|
|
"""
|
|
# Prefer dynamic role if set (Phase 3+)
|
|
if user.role_id is not None and user.role_obj is not None:
|
|
return user.role_obj.code
|
|
|
|
# Fallback to legacy enum (Phase 1-2)
|
|
return user.role.value
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
def decode_token(token: str):
|
|
try:
|
|
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
db: Session = Depends(get_db)
|
|
) -> User:
|
|
token = credentials.credentials
|
|
payload = decode_token(token)
|
|
|
|
if payload is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid authentication credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
user_id: str = payload.get("sub")
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid authentication credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
user = db.query(User).filter(User.id == user_id).first()
|
|
if user is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return user
|
|
|
|
async def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
|
|
"""Require user to be admin or superadmin"""
|
|
role_code = get_user_role_code(current_user)
|
|
if role_code not in ["admin", "superadmin"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Not enough permissions"
|
|
)
|
|
return current_user
|
|
|
|
async def get_active_member(current_user: User = Depends(get_current_user)) -> User:
|
|
"""Require user to be active member with valid payment"""
|
|
from models import UserStatus
|
|
|
|
if current_user.status != UserStatus.active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Active membership required. Please complete payment."
|
|
)
|
|
|
|
role_code = get_user_role_code(current_user)
|
|
if role_code not in ["member", "admin", "superadmin"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Member access only"
|
|
)
|
|
|
|
return current_user
|
|
|
|
|
|
# ============================================================
|
|
# RBAC Permission System
|
|
# ============================================================
|
|
|
|
async def get_user_permissions(user: User, db: Session) -> List[str]:
|
|
"""
|
|
Get all permission codes for user's role.
|
|
Superadmin automatically gets all permissions.
|
|
Uses request-level caching to avoid repeated DB queries.
|
|
Supports both dynamic roles (role_id) and legacy enum (role).
|
|
|
|
Args:
|
|
user: Current authenticated user
|
|
db: Database session
|
|
|
|
Returns:
|
|
List of permission code strings (e.g., ["users.view", "events.create"])
|
|
"""
|
|
# Check if permissions are already cached for this request
|
|
if hasattr(user, '_permission_cache'):
|
|
return user._permission_cache
|
|
|
|
# Get role code using helper
|
|
role_code = get_user_role_code(user)
|
|
|
|
# Superadmin gets all permissions automatically
|
|
if role_code == "superadmin":
|
|
all_perms = db.query(Permission.code).all()
|
|
permissions = [p[0] for p in all_perms]
|
|
else:
|
|
# Fetch permissions assigned to this role
|
|
# Prefer dynamic role_id, fallback to enum
|
|
if user.role_id is not None:
|
|
# Use role_id for dynamic roles
|
|
permissions = db.query(Permission.code)\
|
|
.join(RolePermission)\
|
|
.filter(RolePermission.role_id == user.role_id)\
|
|
.all()
|
|
else:
|
|
# Fallback to legacy enum
|
|
permissions = db.query(Permission.code)\
|
|
.join(RolePermission)\
|
|
.filter(RolePermission.role == user.role)\
|
|
.all()
|
|
permissions = [p[0] for p in permissions]
|
|
|
|
# Cache permissions on user object for this request
|
|
user._permission_cache = permissions
|
|
return permissions
|
|
|
|
|
|
def require_permission(permission_code: str):
|
|
"""
|
|
Dependency injection for permission-based access control.
|
|
|
|
Usage:
|
|
@app.get("/admin/users", dependencies=[Depends(require_permission("users.view"))])
|
|
async def get_users():
|
|
...
|
|
|
|
Args:
|
|
permission_code: Permission code to check (e.g., "users.create")
|
|
|
|
Returns:
|
|
Async function that checks if current user has the permission
|
|
|
|
Raises:
|
|
HTTPException 403 if user lacks the required permission
|
|
"""
|
|
async def permission_checker(
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
) -> User:
|
|
# Get user's permissions
|
|
user_perms = await get_user_permissions(current_user, db)
|
|
|
|
# Check if user has the required permission
|
|
if permission_code not in user_perms:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Permission required: {permission_code}"
|
|
)
|
|
|
|
return current_user
|
|
|
|
return permission_checker
|
|
|
|
|
|
async def get_current_superadmin(current_user: User = Depends(get_current_user)) -> User:
|
|
"""
|
|
Require user to be superadmin.
|
|
Used for endpoints that should only be accessible to superadmins.
|
|
"""
|
|
role_code = get_user_role_code(current_user)
|
|
if role_code != "superadmin":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Superadmin access required"
|
|
)
|
|
return current_user
|