815cbf9d8c
- 更新 .gitignore:全面覆盖环境变量、数据库、日志、缓存、上传文件 - 移除误跟踪的 server/venv/、crm_data.db、.env 文件 - 新增 server/.env.example 模板 - 新增合同管理、利润核算、AI教练等功能模块 - 新增 Playwright e2e 测试套件 - 前后端多项功能升级和 bug 修复
120 lines
3.8 KiB
Python
120 lines
3.8 KiB
Python
"""
|
|
公司管理路由 —— /api/companies
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from fastapi import APIRouter, Body, Depends
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user, get_current_company_id
|
|
from app.db.database import get_db
|
|
from app.models.sys import SysCompany, SysUserCompany
|
|
from app.schemas.auth import CurrentUserPayload
|
|
from app.schemas.response import ok
|
|
|
|
router = APIRouter(prefix="/companies", tags=["公司管理"])
|
|
|
|
|
|
@router.get("", summary="获取当前用户可访问的公司列表")
|
|
async def list_companies(
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: CurrentUserPayload = Depends(get_current_user),
|
|
) -> dict:
|
|
"""返回当前登录用户所关联的所有激活公司"""
|
|
stmt = (
|
|
select(SysCompany)
|
|
.join(SysUserCompany, SysUserCompany.company_id == SysCompany.id)
|
|
.where(
|
|
SysUserCompany.user_id == current_user.user_id,
|
|
SysCompany.is_active.is_(True),
|
|
)
|
|
.order_by(SysCompany.created_at)
|
|
)
|
|
companies = (await db.execute(stmt)).scalars().all()
|
|
|
|
# 查该用户的默认公司
|
|
default_stmt = (
|
|
select(SysUserCompany.company_id)
|
|
.where(
|
|
SysUserCompany.user_id == current_user.user_id,
|
|
SysUserCompany.is_default.is_(True),
|
|
)
|
|
)
|
|
default_id = (await db.execute(default_stmt)).scalar_one_or_none()
|
|
|
|
return ok(data={
|
|
"companies": [
|
|
{
|
|
"id": str(c.id),
|
|
"name": c.name,
|
|
"code": c.code,
|
|
"is_active": c.is_active,
|
|
}
|
|
for c in companies
|
|
],
|
|
"default_company_id": str(default_id) if default_id else (
|
|
str(companies[0].id) if companies else None
|
|
),
|
|
})
|
|
|
|
|
|
@router.get("/current", summary="获取当前公司详情(含 full_info)")
|
|
async def get_current_company(
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: CurrentUserPayload = Depends(get_current_user),
|
|
company_id: uuid.UUID = Depends(get_current_company_id),
|
|
) -> dict:
|
|
company = (await db.execute(
|
|
select(SysCompany).where(SysCompany.id == company_id)
|
|
)).scalar_one_or_none()
|
|
if company is None:
|
|
return ok(data=None)
|
|
return ok(data={
|
|
"id": str(company.id),
|
|
"name": company.name,
|
|
"code": company.code,
|
|
"full_info": company.full_info or {},
|
|
"is_active": company.is_active,
|
|
})
|
|
|
|
|
|
@router.put("/current", summary="更新当前公司信息(含 full_info)")
|
|
async def update_current_company(
|
|
body: dict = Body(...),
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: CurrentUserPayload = Depends(get_current_user),
|
|
company_id: uuid.UUID = Depends(get_current_company_id),
|
|
) -> dict:
|
|
# 仅管理员可编辑
|
|
if current_user.data_scope != "all":
|
|
from app.core.exceptions import ForbiddenException
|
|
raise ForbiddenException("仅管理员可编辑公司信息")
|
|
|
|
values: dict = {}
|
|
if "name" in body:
|
|
values["name"] = body["name"]
|
|
if "full_info" in body:
|
|
values["full_info"] = body["full_info"]
|
|
if values:
|
|
values["updated_at"] = datetime.utcnow()
|
|
await db.execute(
|
|
update(SysCompany).where(SysCompany.id == company_id).values(**values)
|
|
)
|
|
await db.commit()
|
|
|
|
# 返回更新后的数据
|
|
company = (await db.execute(
|
|
select(SysCompany).where(SysCompany.id == company_id)
|
|
)).scalar_one()
|
|
return ok(data={
|
|
"id": str(company.id),
|
|
"name": company.name,
|
|
"code": company.code,
|
|
"full_info": company.full_info or {},
|
|
"is_active": company.is_active,
|
|
}, message="公司信息已更新")
|