""" FastAPI 依赖注入 —— 权限拦截核心 get_current_user: 解析 JWT → 查表获取完整权限上下文 get_current_company_id: 从 X-Company-Id Header 提取公司 ID + IDOR 校验 """ from __future__ import annotations import uuid from fastapi import Depends, Header from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.exceptions import ForbiddenException, UnauthorizedException from app.core.security import decode_access_token from app.db.database import get_db from app.models.sys import SysCompany, SysUser, SysUserCompany from app.schemas.auth import CurrentUserPayload async def get_current_user( authorization: str = Header(..., description="Bearer "), db: AsyncSession = Depends(get_db), ) -> CurrentUserPayload: """ 核心鉴权依赖: 1. 从 Header 提取 Bearer Token 2. 解码 JWT 拿到 user_id 3. 查 sys_users + 联表 role/dept 拿到 data_scope 等完整上下文 4. 返回 CurrentUserPayload 供业务层使用 """ # ── 解析 Bearer Token ── if not authorization.startswith("Bearer "): raise UnauthorizedException("Authorization 格式错误,需为 Bearer ") token = authorization.removeprefix("Bearer ").strip() payload = decode_access_token(token) if payload is None: raise UnauthorizedException("Token 无效或已过期") user_id: str | None = payload.get("sub") if user_id is None: raise UnauthorizedException("Token 载荷缺少 sub 字段") # ── 查库获取用户及关联角色 ── stmt = ( select(SysUser) .where(SysUser.id == user_id, SysUser.is_deleted.is_(False)) ) result = await db.execute(stmt) user = result.scalar_one_or_none() if user is None: raise UnauthorizedException("用户不存在或已被停用") if user.status != 1: raise UnauthorizedException("账号已被禁用") # ── 组装权限上下文 ── return CurrentUserPayload( user_id=user.id, username=user.username, real_name=user.real_name, dept_id=user.dept_id, dept_name=user.department.name if user.department else None, role_id=user.role_id, role_name=user.role.role_name if user.role else None, data_scope=user.role.data_scope if user.role else "self", menu_keys=user.role.menu_keys if user.role else [], ) async def get_current_company_id( x_company_id: str = Header(..., alias="X-Company-Id", description="当前工作台的公司 ID"), current_user: CurrentUserPayload = Depends(get_current_user), db: AsyncSession = Depends(get_db), ) -> uuid.UUID: """ 公司视角依赖(IDOR 防护核心): 1. 从 X-Company-Id Header 提取公司 UUID 2. 校验当前用户是否归属于该公司(查 sys_user_companies) 3. 校验公司是否启用 """ # ── 解析 company_id ── try: company_uuid = uuid.UUID(x_company_id) except ValueError: raise UnauthorizedException("X-Company-Id 格式错误,需为合法 UUID") # ── IDOR 防护:校验用户-公司归属 ── assoc = (await db.execute( select(SysUserCompany).where( SysUserCompany.user_id == current_user.user_id, SysUserCompany.company_id == company_uuid, ) )).scalar_one_or_none() if assoc is None: raise ForbiddenException("您无权访问该公司数据") # ── 校验公司是否启用 ── company = (await db.execute( select(SysCompany).where( SysCompany.id == company_uuid, SysCompany.is_active.is_(True), ) )).scalar_one_or_none() if company is None: raise ForbiddenException("公司不存在或已停用") return company_uuid