""" 公司管理路由 —— /api/companies """ from __future__ import annotations import uuid from datetime import datetime from fastapi import APIRouter, Body, Depends from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user, get_current_company_id from app.db.database import get_db from app.models.sys import SysCompany, SysUserCompany from app.schemas.auth import CurrentUserPayload from app.schemas.response import ok router = APIRouter(prefix="/companies", tags=["公司管理"]) @router.get("", summary="获取当前用户可访问的公司列表") async def list_companies( db: AsyncSession = Depends(get_db), current_user: CurrentUserPayload = Depends(get_current_user), ) -> dict: """返回当前登录用户所关联的所有激活公司""" stmt = ( select(SysCompany) .join(SysUserCompany, SysUserCompany.company_id == SysCompany.id) .where( SysUserCompany.user_id == current_user.user_id, SysCompany.is_active.is_(True), ) .order_by(SysCompany.created_at) ) companies = (await db.execute(stmt)).scalars().all() # 查该用户的默认公司 default_stmt = ( select(SysUserCompany.company_id) .where( SysUserCompany.user_id == current_user.user_id, SysUserCompany.is_default.is_(True), ) ) default_id = (await db.execute(default_stmt)).scalar_one_or_none() return ok(data={ "companies": [ { "id": str(c.id), "name": c.name, "code": c.code, "is_active": c.is_active, } for c in companies ], "default_company_id": str(default_id) if default_id else ( str(companies[0].id) if companies else None ), }) @router.get("/current", summary="获取当前公司详情(含 full_info)") async def get_current_company( db: AsyncSession = Depends(get_db), current_user: CurrentUserPayload = Depends(get_current_user), company_id: uuid.UUID = Depends(get_current_company_id), ) -> dict: company = (await db.execute( select(SysCompany).where(SysCompany.id == company_id) )).scalar_one_or_none() if company is None: return ok(data=None) return ok(data={ "id": str(company.id), "name": company.name, "code": company.code, "full_info": company.full_info or {}, "is_active": company.is_active, }) @router.put("/current", summary="更新当前公司信息(含 full_info)") async def update_current_company( body: dict = Body(...), db: AsyncSession = Depends(get_db), current_user: CurrentUserPayload = Depends(get_current_user), company_id: uuid.UUID = Depends(get_current_company_id), ) -> dict: # 仅管理员可编辑 if current_user.data_scope != "all": from app.core.exceptions import ForbiddenException raise ForbiddenException("仅管理员可编辑公司信息") values: dict = {} if "name" in body: values["name"] = body["name"] if "full_info" in body: values["full_info"] = body["full_info"] if values: values["updated_at"] = datetime.utcnow() await db.execute( update(SysCompany).where(SysCompany.id == company_id).values(**values) ) await db.commit() # 返回更新后的数据 company = (await db.execute( select(SysCompany).where(SysCompany.id == company_id) )).scalar_one() return ok(data={ "id": str(company.id), "name": company.name, "code": company.code, "full_info": company.full_info or {}, "is_active": company.is_active, }, message="公司信息已更新")