v0.2.0: CRM/ERP 系统升级 - 清理 .gitignore 并移除误提交的 venv/env/db 文件
- 更新 .gitignore:全面覆盖环境变量、数据库、日志、缓存、上传文件 - 移除误跟踪的 server/venv/、crm_data.db、.env 文件 - 新增 server/.env.example 模板 - 新增合同管理、利润核算、AI教练等功能模块 - 新增 Playwright e2e 测试套件 - 前后端多项功能升级和 bug 修复
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
AI 教练引擎路由 —— /api/ai-coaching
|
||||
Dify 回调 + SSE 通知流
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
from app.services import ai_coaching_service as svc
|
||||
|
||||
router = APIRouter(prefix="/ai-coaching", tags=["AI教练引擎"])
|
||||
|
||||
|
||||
@router.post("/dify-callback/{sales_log_id}", summary="Dify Workflow 回调端点")
|
||||
async def dify_coaching_callback(
|
||||
sales_log_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""接收 Dify Workflow 的异步回调,写回教练反馈"""
|
||||
import json
|
||||
body = await request.json()
|
||||
await svc.handle_dify_coaching_callback(db, sales_log_id, body)
|
||||
return ok(message="教练反馈已回写")
|
||||
|
||||
|
||||
@router.get("/notifications/stream", summary="SSE 通知流")
|
||||
async def sse_notifications(
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
):
|
||||
"""Server-Sent Events 推送通知(AI 教练反馈、系统通知等)"""
|
||||
return StreamingResponse(
|
||||
svc.sse_notification_generator(current_user.user_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
公司管理路由 —— /api/companies
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.models.sys import SysCompany, SysUserCompany
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
|
||||
router = APIRouter(prefix="/companies", tags=["公司管理"])
|
||||
|
||||
|
||||
@router.get("", summary="获取当前用户可访问的公司列表")
|
||||
async def list_companies(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""返回当前登录用户所关联的所有激活公司"""
|
||||
stmt = (
|
||||
select(SysCompany)
|
||||
.join(SysUserCompany, SysUserCompany.company_id == SysCompany.id)
|
||||
.where(
|
||||
SysUserCompany.user_id == current_user.user_id,
|
||||
SysCompany.is_active.is_(True),
|
||||
)
|
||||
.order_by(SysCompany.created_at)
|
||||
)
|
||||
companies = (await db.execute(stmt)).scalars().all()
|
||||
|
||||
# 查该用户的默认公司
|
||||
default_stmt = (
|
||||
select(SysUserCompany.company_id)
|
||||
.where(
|
||||
SysUserCompany.user_id == current_user.user_id,
|
||||
SysUserCompany.is_default.is_(True),
|
||||
)
|
||||
)
|
||||
default_id = (await db.execute(default_stmt)).scalar_one_or_none()
|
||||
|
||||
return ok(data={
|
||||
"companies": [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"name": c.name,
|
||||
"code": c.code,
|
||||
"is_active": c.is_active,
|
||||
}
|
||||
for c in companies
|
||||
],
|
||||
"default_company_id": str(default_id) if default_id else (
|
||||
str(companies[0].id) if companies else None
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@router.get("/current", summary="获取当前公司详情(含 full_info)")
|
||||
async def get_current_company(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
company = (await db.execute(
|
||||
select(SysCompany).where(SysCompany.id == company_id)
|
||||
)).scalar_one_or_none()
|
||||
if company is None:
|
||||
return ok(data=None)
|
||||
return ok(data={
|
||||
"id": str(company.id),
|
||||
"name": company.name,
|
||||
"code": company.code,
|
||||
"full_info": company.full_info or {},
|
||||
"is_active": company.is_active,
|
||||
})
|
||||
|
||||
|
||||
@router.put("/current", summary="更新当前公司信息(含 full_info)")
|
||||
async def update_current_company(
|
||||
body: dict = Body(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
# 仅管理员可编辑
|
||||
if current_user.data_scope != "all":
|
||||
from app.core.exceptions import ForbiddenException
|
||||
raise ForbiddenException("仅管理员可编辑公司信息")
|
||||
|
||||
values: dict = {}
|
||||
if "name" in body:
|
||||
values["name"] = body["name"]
|
||||
if "full_info" in body:
|
||||
values["full_info"] = body["full_info"]
|
||||
if values:
|
||||
values["updated_at"] = datetime.utcnow()
|
||||
await db.execute(
|
||||
update(SysCompany).where(SysCompany.id == company_id).values(**values)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# 返回更新后的数据
|
||||
company = (await db.execute(
|
||||
select(SysCompany).where(SysCompany.id == company_id)
|
||||
)).scalar_one()
|
||||
return ok(data={
|
||||
"id": str(company.id),
|
||||
"name": company.name,
|
||||
"code": company.code,
|
||||
"full_info": company.full_info or {},
|
||||
"is_active": company.is_active,
|
||||
}, message="公司信息已更新")
|
||||
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
合同管理路由 —— /api/contracts
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Body, Depends, Query, UploadFile, File
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.contract import ContractCreate, ContractUpdate
|
||||
from app.schemas.response import ok
|
||||
from app.services import contract_service as svc
|
||||
|
||||
router = APIRouter(prefix="/contracts", tags=["合同管理"])
|
||||
|
||||
|
||||
@router.post("", summary="新增合同")
|
||||
async def create_contract(
|
||||
body: ContractCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_contract(db, current_user, company_id, body)
|
||||
return ok(data=result.model_dump(mode="json"), message="合同创建成功")
|
||||
|
||||
|
||||
@router.get("", summary="合同列表(分页)")
|
||||
async def list_contracts(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
keyword: str | None = Query(None, description="合同编号搜索"),
|
||||
status: str | None = Query(None, description="状态筛选"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_contracts(db, company_id, page, size, keyword, status)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.get("/{contract_id}", summary="合同详情(含执行进度)")
|
||||
async def get_contract(
|
||||
contract_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_contract(db, contract_id, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.put("/{contract_id}", summary="编辑合同")
|
||||
async def update_contract(
|
||||
contract_id: uuid.UUID,
|
||||
body: ContractUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.update_contract(db, contract_id, company_id, body)
|
||||
return ok(data=result.model_dump(mode="json"), message="合同已更新")
|
||||
|
||||
|
||||
@router.delete("/{contract_id}", summary="删除合同")
|
||||
async def delete_contract(
|
||||
contract_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
await svc.delete_contract(db, contract_id, company_id)
|
||||
return ok(message="合同已删除")
|
||||
|
||||
|
||||
@router.post("/{contract_id}/generate-order", summary="一键从合同生成订单")
|
||||
async def generate_order_from_contract(
|
||||
contract_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.generate_order_from_contract(db, current_user, contract_id, company_id)
|
||||
return ok(data=result, message="订单生成成功")
|
||||
|
||||
|
||||
@router.get("/{contract_id}/generate", summary="生成合同 Word 文档下载")
|
||||
async def generate_contract_document(
|
||||
contract_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
from fastapi.responses import Response
|
||||
docx_bytes = await svc.generate_contract_docx(db, contract_id, company_id)
|
||||
return Response(
|
||||
content=docx_bytes,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f"attachment; filename=contract_{contract_id}.docx"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{contract_id}/upload-signed", summary="上传双签盖章版")
|
||||
async def upload_signed_copy(
|
||||
contract_id: uuid.UUID,
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
import os
|
||||
from app.models.contract import ErpContract, ErpContractAttachment
|
||||
from sqlalchemy import update as sa_update
|
||||
|
||||
# 验证合同存在
|
||||
from sqlalchemy import select as sa_select
|
||||
contract = (await db.execute(
|
||||
sa_select(ErpContract).where(
|
||||
ErpContract.id == contract_id,
|
||||
ErpContract.company_id == company_id,
|
||||
ErpContract.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if contract is None:
|
||||
raise Exception("合同不存在")
|
||||
|
||||
# 保存文件
|
||||
upload_dir = f"uploads/contracts/{contract_id}"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
file_path = f"{upload_dir}/{file.filename}"
|
||||
with open(file_path, "wb") as f:
|
||||
content = await file.read()
|
||||
f.write(content)
|
||||
|
||||
file_url = f"/{file_path}"
|
||||
|
||||
# 记录附件
|
||||
attachment = ErpContractAttachment(
|
||||
contract_id=contract_id,
|
||||
file_name=file.filename or "signed_copy",
|
||||
file_url=file_url,
|
||||
file_type="signed_copy",
|
||||
uploader_id=current_user.user_id,
|
||||
)
|
||||
db.add(attachment)
|
||||
|
||||
# 更新合同签署状态
|
||||
await db.execute(
|
||||
sa_update(ErpContract)
|
||||
.where(ErpContract.id == contract_id)
|
||||
.values(is_signed=True, signed_file_url=file_url)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return ok(message="双签盖章版上传成功", data={"file_url": file_url})
|
||||
@@ -4,7 +4,7 @@ CRM 客户模块路由 —— /api/customers
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.db.database import get_db
|
||||
@@ -91,6 +91,20 @@ async def restore_customer(
|
||||
return ok(message="客户已恢复")
|
||||
|
||||
|
||||
@router.put("/{customer_id}/transfer", summary="转移客户负责人(仅管理员)")
|
||||
async def transfer_customer(
|
||||
customer_id: uuid.UUID,
|
||||
body: dict = Body(..., examples=[{"new_owner_id": "uuid-here"}]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
new_owner_id = body.get("new_owner_id")
|
||||
if not new_owner_id:
|
||||
raise Exception("缺少 new_owner_id 参数")
|
||||
result = await svc.transfer_customer(db, current_user, customer_id, uuid.UUID(str(new_owner_id)))
|
||||
return ok(data=result.model_dump(mode="json"), message="客户转移成功")
|
||||
|
||||
|
||||
@router.get("/{customer_id}/products", summary="获取客户关联产品(通过订单反查)")
|
||||
async def get_customer_products(
|
||||
customer_id: uuid.UUID,
|
||||
|
||||
+15
-10
@@ -3,20 +3,21 @@ Dashboard 统计 API — /api/dashboard
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func, select, and_, extract
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
|
||||
from app.models.order import ErpOrder
|
||||
from app.models.shipping import ErpShippingRecord
|
||||
from app.models.erp import ProductSku
|
||||
from app.models.erp import ErpSkuInventory
|
||||
|
||||
router = APIRouter(prefix="/dashboard", tags=["Dashboard"])
|
||||
|
||||
@@ -25,42 +26,46 @@ router = APIRouter(prefix="/dashboard", tags=["Dashboard"])
|
||||
async def get_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
today = date.today()
|
||||
month_start = today.replace(day=1)
|
||||
|
||||
# 本月新增订单数
|
||||
# 本月新增订单数(按公司隔离)
|
||||
orders_count_q = select(func.count()).select_from(ErpOrder).where(
|
||||
and_(
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.order_date >= month_start,
|
||||
)
|
||||
)
|
||||
orders_count = (await db.execute(orders_count_q)).scalar() or 0
|
||||
|
||||
# 待出库发货数(状态为 pending)
|
||||
# 待出库发货数(按公司隔离)
|
||||
pending_shipping_q = select(func.count()).select_from(ErpOrder).where(
|
||||
and_(
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.shipping_state == "pending",
|
||||
)
|
||||
)
|
||||
pending_shipping = (await db.execute(pending_shipping_q)).scalar() or 0
|
||||
|
||||
# 库存预警 SKU 数(stock_qty <= warning_threshold 且 warning_threshold > 0)
|
||||
warning_skus_q = select(func.count()).select_from(ProductSku).where(
|
||||
# 库存预警 SKU 数(从 erp_sku_inventory 查,按公司隔离)
|
||||
warning_skus_q = select(func.count()).select_from(ErpSkuInventory).where(
|
||||
and_(
|
||||
ProductSku.is_deleted.is_(False),
|
||||
ProductSku.warning_threshold > 0,
|
||||
ProductSku.stock_qty <= ProductSku.warning_threshold,
|
||||
ErpSkuInventory.company_id == company_id,
|
||||
ErpSkuInventory.warning_threshold > 0,
|
||||
ErpSkuInventory.stock_qty <= ErpSkuInventory.warning_threshold,
|
||||
)
|
||||
)
|
||||
warning_skus = (await db.execute(warning_skus_q)).scalar() or 0
|
||||
|
||||
# 本月预计营收(本月订单总金额)
|
||||
# 本月预计营收(按公司隔离)
|
||||
revenue_q = select(func.coalesce(func.sum(ErpOrder.total_amount), 0)).where(
|
||||
and_(
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.order_date >= month_start,
|
||||
)
|
||||
)
|
||||
|
||||
+47
-2
@@ -1,18 +1,21 @@
|
||||
"""
|
||||
FastAPI 依赖注入 —— 权限拦截核心
|
||||
get_current_user: 解析 JWT → 查表获取完整权限上下文
|
||||
get_current_company_id: 从 X-Company-Id Header 提取公司 ID + IDOR 校验
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends, Header
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import UnauthorizedException
|
||||
from app.core.exceptions import ForbiddenException, UnauthorizedException
|
||||
from app.core.security import decode_access_token
|
||||
from app.db.database import get_db
|
||||
from app.models.sys import SysUser
|
||||
from app.models.sys import SysCompany, SysUser, SysUserCompany
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
|
||||
|
||||
@@ -65,3 +68,45 @@ async def get_current_user(
|
||||
data_scope=user.role.data_scope if user.role else "self",
|
||||
menu_keys=user.role.menu_keys if user.role else [],
|
||||
)
|
||||
|
||||
|
||||
async def get_current_company_id(
|
||||
x_company_id: str = Header(..., alias="X-Company-Id", description="当前工作台的公司 ID"),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> uuid.UUID:
|
||||
"""
|
||||
公司视角依赖(IDOR 防护核心):
|
||||
1. 从 X-Company-Id Header 提取公司 UUID
|
||||
2. 校验当前用户是否归属于该公司(查 sys_user_companies)
|
||||
3. 校验公司是否启用
|
||||
"""
|
||||
# ── 解析 company_id ──
|
||||
try:
|
||||
company_uuid = uuid.UUID(x_company_id)
|
||||
except ValueError:
|
||||
raise UnauthorizedException("X-Company-Id 格式错误,需为合法 UUID")
|
||||
|
||||
# ── IDOR 防护:校验用户-公司归属 ──
|
||||
assoc = (await db.execute(
|
||||
select(SysUserCompany).where(
|
||||
SysUserCompany.user_id == current_user.user_id,
|
||||
SysUserCompany.company_id == company_uuid,
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if assoc is None:
|
||||
raise ForbiddenException("您无权访问该公司数据")
|
||||
|
||||
# ── 校验公司是否启用 ──
|
||||
company = (await db.execute(
|
||||
select(SysCompany).where(
|
||||
SysCompany.id == company_uuid,
|
||||
SysCompany.is_active.is_(True),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if company is None:
|
||||
raise ForbiddenException("公司不存在或已停用")
|
||||
|
||||
return company_uuid
|
||||
|
||||
+407
-21
@@ -9,7 +9,7 @@ import time
|
||||
import base64
|
||||
from fastapi import APIRouter, Depends, Query, Body, File, UploadFile, Form
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.finance import ExpenseCreate, ExpenseStatusUpdate, InvoiceCreate
|
||||
@@ -43,34 +43,96 @@ async def ocr_recognize(
|
||||
|
||||
file_url = f"/uploads/finance/{safe_filename}"
|
||||
|
||||
# 仅支持图片(png/jpg/jpeg)和 PDF,不再支持 MD/TXT
|
||||
supported = {".png", ".jpg", ".jpeg", ".pdf"}
|
||||
# 支持的格式:结构化零算力 > 文本 LLM > 图片 Vision
|
||||
supported = {".png", ".jpg", ".jpeg", ".pdf", ".md", ".ofd", ".xml", ".zip"}
|
||||
if ext not in supported:
|
||||
raise BizException(message=f"不支持的文件格式 {ext},仅支持: {', '.join(supported)}")
|
||||
|
||||
# 如果是 PDF,转成 PNG 再做 OCR
|
||||
ocr_bytes = file_bytes
|
||||
raise BizException(message=f"不支持的文件格式 {ext},仅支持: {', '.join(sorted(supported))}")
|
||||
|
||||
# ── 策略 A0: ZIP → 解包所有 XML 并逐个解析 ──
|
||||
if ext == ".zip":
|
||||
from app.services.invoice_parser import parse_zip_invoices
|
||||
results = parse_zip_invoices(file_bytes)
|
||||
return ok(data={"zip_results": [
|
||||
{"filename": r.get("filename", ""), "success": r.get("success", False),
|
||||
"ocr_data": r.get("data", {}), "needs_llm": r.get("needs_llm", False),
|
||||
"error": r.get("error")}
|
||||
for r in results
|
||||
], "file_url": file_url}, message=f"ZIP 解析完成:{sum(1 for r in results if r.get('success'))}/{len(results)} 成功")
|
||||
|
||||
# ── 策略 A: OFD / XML → 结构化零算力提取(最快最准)──
|
||||
if ext in (".ofd", ".xml"):
|
||||
from app.services.invoice_parser import parse_ofd_invoice, parse_xml_invoice
|
||||
parser = parse_ofd_invoice if ext == ".ofd" else parse_xml_invoice
|
||||
result = parser(file_bytes)
|
||||
print(f"[OCR] {ext.upper()} 解析: success={result.get('success')}")
|
||||
|
||||
if result.get("success"):
|
||||
# 如果解析器提取到 raw_text 且标记 needs_llm,交给 LLM 做字段提取
|
||||
if result.get("needs_llm") and result["data"].get("raw_text"):
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
llm_result = await extract_invoice_from_text(result["data"]["raw_text"], scene)
|
||||
if llm_result.get("success"):
|
||||
return ok(data={"ocr_data": llm_result["data"], "file_url": file_url}, message=f"AI 发票识别成功({ext.upper()} → LLM)")
|
||||
return ok(data={"ocr_data": llm_result.get("data", {}), "file_url": file_url}, message=llm_result.get("error", "LLM 解析失败"))
|
||||
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message=f"发票识别成功({ext.upper()} 结构化提取)")
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message=result.get("error", f"{ext.upper()} 解析失败"))
|
||||
|
||||
# ── 策略 B: MD → 纯文本 LLM 理解(零 GPU Vision)──
|
||||
if ext == ".md":
|
||||
text = file_bytes.decode("utf-8", errors="replace").strip()
|
||||
print(f"[OCR] MD 文本: {len(text)} 字符")
|
||||
if len(text) < 20:
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message="MD 文件内容过少,无法识别")
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
result = await extract_invoice_from_text(text, scene)
|
||||
if result.get("success"):
|
||||
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message="AI 发票识别成功(MD 文本解析)")
|
||||
return ok(data={"ocr_data": result.get("data", {}), "file_url": file_url}, message=result.get("error", "MD 文本解析失败"))
|
||||
|
||||
# ── 策略 C: PDF → PyMuPDF 提取文本 → LLM(零 GPU Vision)──
|
||||
if ext == ".pdf":
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
doc = fitz.open(stream=file_bytes, filetype="pdf")
|
||||
page = doc[0] # 取第一页
|
||||
# 中等分辨率渲染(150 DPI,平衡质量与大小)
|
||||
pix = page.get_pixmap(dpi=150)
|
||||
ocr_bytes = pix.tobytes("png")
|
||||
text = ""
|
||||
for page in doc:
|
||||
text += page.get_text() + "\n"
|
||||
doc.close()
|
||||
print(f"[OCR] PDF 转 PNG 成功: {len(ocr_bytes)} bytes")
|
||||
text = text.strip()
|
||||
print(f"[OCR] PDF 文本提取: {len(text)} 字符")
|
||||
|
||||
if len(text) > 50: # 有足够文本内容
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
result = await extract_invoice_from_text(text, scene)
|
||||
if result.get("success"):
|
||||
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message="AI 发票识别成功(PDF 文本解析)")
|
||||
return ok(data={"ocr_data": result.get("data", {}), "file_url": file_url}, message=result.get("error", "PDF 文本提取失败"))
|
||||
else:
|
||||
# PDF 是扫描件(无文字层),降级到图片 OCR
|
||||
print(f"[OCR] PDF 无文本层(仅 {len(text)} 字符),降级到图片 OCR")
|
||||
page = fitz.open(stream=file_bytes, filetype="pdf")[0]
|
||||
pix = page.get_pixmap(dpi=150)
|
||||
ocr_bytes = pix.tobytes("png")
|
||||
print(f"[OCR] PDF 转 PNG 成功: {len(ocr_bytes)} bytes")
|
||||
except Exception as e:
|
||||
print(f"[OCR] PDF 转换失败: {e}")
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message=f"PDF 转换失败: {e}")
|
||||
|
||||
# 转换为纯 base64 传给 OCR
|
||||
print(f"[OCR] PDF 处理失败: {e}")
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message=f"PDF 处理失败: {e}")
|
||||
else:
|
||||
ocr_bytes = file_bytes
|
||||
|
||||
# ── 策略 D: 图片/扫描PDF → Vision OCR(需要视觉模型)──
|
||||
from app.services.ocr_service import ocr_image
|
||||
image_base64 = base64.b64encode(ocr_bytes).decode("utf-8")
|
||||
result = await ocr_image(image_base64, scene)
|
||||
|
||||
if result.get("success"):
|
||||
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message="AI OCR 识别成功")
|
||||
return ok(data={"ocr_data": result.get("data", {}), "file_url": file_url}, message=result.get("error", "OCR 识别失败"))
|
||||
|
||||
# Vision 失败时友好提示
|
||||
error_msg = result.get("error", "OCR 识别失败")
|
||||
if "模型进程崩溃" in error_msg or "unexpectedly stopped" in error_msg or "服务异常" in error_msg:
|
||||
error_msg += "。建议:请上传电子版 PDF/OFD/XML 发票,系统可零算力直接提取数据"
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message=error_msg)
|
||||
|
||||
|
||||
@router.post("/invoices", summary="上传票据入池(含 AI/OCR JSONB 数据)")
|
||||
@@ -78,8 +140,9 @@ async def create_invoice(
|
||||
body: InvoiceCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_invoice(db, current_user, body)
|
||||
result = await svc.create_invoice(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message="票据入池成功")
|
||||
|
||||
|
||||
@@ -91,8 +154,9 @@ async def list_invoices(
|
||||
is_used: bool | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_invoices(db, current_user, page, size, type, is_used)
|
||||
result = await svc.list_invoices(db, current_user, page, size, type, is_used, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -111,8 +175,9 @@ async def create_expense(
|
||||
body: ExpenseCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_expense(db, current_user, body)
|
||||
result = await svc.create_expense(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message=f"报销单 {result.system_no} 提交成功")
|
||||
|
||||
|
||||
@@ -124,8 +189,9 @@ async def list_expenses(
|
||||
applicant_id: uuid.UUID | None = Query(None, description="按申请人过滤(管理员用)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_expenses(db, current_user, page, size, status, applicant_id)
|
||||
result = await svc.list_expenses(db, current_user, page, size, status, applicant_id, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -148,3 +214,323 @@ async def update_expense_status(
|
||||
) -> dict:
|
||||
msg = await svc.update_expense_status(db, current_user, expense_id, body)
|
||||
return ok(message=msg)
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# 批量上传 + OCR 任务队列 API
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
@router.post("/upload-batch", summary="批量上传发票(ZIP/XML 即时入池,图片PDF 入队列)")
|
||||
async def upload_batch(
|
||||
files: list[UploadFile] = File(...),
|
||||
scene: str = Form("invoice"),
|
||||
inv_type: str = Form("expense"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from app.services.invoice_parser import parse_xml_invoice, parse_ofd_invoice, parse_zip_invoices
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
from app.models.finance import FinInvoicePool, FinOcrTask
|
||||
|
||||
upload_dir = "uploads/finance"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
results = [] # 返回给前端
|
||||
|
||||
for file in files:
|
||||
file_bytes = await file.read()
|
||||
ext = os.path.splitext(file.filename or "")[1].lower() or ".bin"
|
||||
ts = int(time.time())
|
||||
safe_fn = f"{ts}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
file_path = os.path.join(upload_dir, safe_fn)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_bytes)
|
||||
file_url = f"/uploads/finance/{safe_fn}"
|
||||
|
||||
# ── ZIP: 解压内部 XML,逐个即时入池 ──
|
||||
if ext == ".zip":
|
||||
zip_results = parse_zip_invoices(file_bytes)
|
||||
for zr in zip_results:
|
||||
if zr.get("success") and not zr.get("needs_llm"):
|
||||
ai_data = zr.get("data", {})
|
||||
# 需要 LLM 的 zip 中的 xml 也立刻处理
|
||||
merchant = ai_data.get("merchant") or ai_data.get("merchant_name") or "(ZIP)"
|
||||
amount = float(ai_data.get("amount", 0) or 0)
|
||||
inv_date_str = ai_data.get("date")
|
||||
inv_date = None
|
||||
if inv_date_str:
|
||||
try:
|
||||
from datetime import date as d
|
||||
inv_date = d.fromisoformat(inv_date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=current_user.user_id, company_id=company_id,
|
||||
file_url=file_url, merchant_name=merchant, amount=amount,
|
||||
invoice_date=inv_date, type=inv_type, ai_extracted_data=ai_data,
|
||||
)
|
||||
db.add(inv)
|
||||
results.append({"filename": zr.get("filename", file.filename), "action": "pooled",
|
||||
"status": "success", "message": f"✅ {merchant} ¥{amount}"})
|
||||
elif zr.get("needs_llm") and zr.get("data", {}).get("raw_text"):
|
||||
# LLM 文本理解(即时,<5s)
|
||||
try:
|
||||
llm_r = await extract_invoice_from_text(zr["data"]["raw_text"], scene)
|
||||
if llm_r.get("success"):
|
||||
ai_data = llm_r["data"]
|
||||
merchant = ai_data.get("merchant") or "(LLM)"
|
||||
amount = float(ai_data.get("amount", 0) or 0)
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=current_user.user_id, company_id=company_id,
|
||||
file_url=file_url, merchant_name=merchant, amount=amount,
|
||||
type=inv_type, ai_extracted_data=ai_data,
|
||||
)
|
||||
db.add(inv)
|
||||
results.append({"filename": zr.get("filename"), "action": "pooled",
|
||||
"status": "success", "message": f"✅ {merchant} ¥{amount} (LLM)"})
|
||||
else:
|
||||
results.append({"filename": zr.get("filename"), "action": "failed",
|
||||
"status": "error", "message": llm_r.get("error", "LLM 解析失败")})
|
||||
except Exception as e:
|
||||
results.append({"filename": zr.get("filename"), "action": "failed",
|
||||
"status": "error", "message": str(e)})
|
||||
else:
|
||||
results.append({"filename": zr.get("filename", file.filename), "action": "failed",
|
||||
"status": "error", "message": zr.get("error", "解析失败")})
|
||||
continue
|
||||
|
||||
# ── XML / OFD: 零算力即时入池 ──
|
||||
if ext in (".xml", ".ofd"):
|
||||
parser = parse_xml_invoice if ext == ".xml" else parse_ofd_invoice
|
||||
r = parser(file_bytes)
|
||||
if r.get("success") and not r.get("needs_llm"):
|
||||
ai_data = r.get("data", {})
|
||||
merchant = ai_data.get("merchant") or ai_data.get("merchant_name") or "(解析)"
|
||||
amount = float(ai_data.get("amount", 0) or 0)
|
||||
inv_date_str = ai_data.get("date")
|
||||
inv_date = None
|
||||
if inv_date_str:
|
||||
try:
|
||||
from datetime import date as d
|
||||
inv_date = d.fromisoformat(inv_date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=current_user.user_id, company_id=company_id,
|
||||
file_url=file_url, merchant_name=merchant, amount=amount,
|
||||
invoice_date=inv_date, type=inv_type, ai_extracted_data=ai_data,
|
||||
)
|
||||
db.add(inv)
|
||||
results.append({"filename": file.filename, "action": "pooled",
|
||||
"status": "success", "message": f"✅ {merchant} ¥{amount}"})
|
||||
elif r.get("needs_llm") and r.get("data", {}).get("raw_text"):
|
||||
try:
|
||||
llm_r = await extract_invoice_from_text(r["data"]["raw_text"], scene)
|
||||
if llm_r.get("success"):
|
||||
ai_data = llm_r["data"]
|
||||
merchant = ai_data.get("merchant") or "(LLM)"
|
||||
amount = float(ai_data.get("amount", 0) or 0)
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=current_user.user_id, company_id=company_id,
|
||||
file_url=file_url, merchant_name=merchant, amount=amount,
|
||||
type=inv_type, ai_extracted_data=ai_data,
|
||||
)
|
||||
db.add(inv)
|
||||
results.append({"filename": file.filename, "action": "pooled",
|
||||
"status": "success", "message": f"✅ {merchant} ¥{amount} (LLM)"})
|
||||
else:
|
||||
results.append({"filename": file.filename, "action": "failed",
|
||||
"status": "error", "message": llm_r.get("error", "LLM 失败")})
|
||||
except Exception as e:
|
||||
results.append({"filename": file.filename, "action": "failed",
|
||||
"status": "error", "message": str(e)})
|
||||
else:
|
||||
results.append({"filename": file.filename, "action": "failed",
|
||||
"status": "error", "message": r.get("error", "解析失败")})
|
||||
continue
|
||||
|
||||
# ── 图片 / PDF : 写入 DB 任务队列 ──
|
||||
task = FinOcrTask(
|
||||
file_url=file_url, file_ext=ext,
|
||||
original_name=file.filename or "unknown",
|
||||
uploader_id=current_user.user_id,
|
||||
company_id=company_id,
|
||||
inv_type=inv_type,
|
||||
priority=50 if ext == ".pdf" else 100, # PDF 优先(可能有文字层)
|
||||
)
|
||||
db.add(task)
|
||||
await db.flush()
|
||||
results.append({"filename": file.filename, "action": "queued",
|
||||
"status": "pending", "task_id": str(task.id),
|
||||
"message": "🕐 已加入 OCR 处理队列"})
|
||||
|
||||
await db.commit()
|
||||
|
||||
pooled = sum(1 for r in results if r["action"] == "pooled")
|
||||
queued = sum(1 for r in results if r["action"] == "queued")
|
||||
failed = sum(1 for r in results if r["action"] == "failed")
|
||||
return ok(data={"results": results},
|
||||
message=f"批量处理完成:{pooled} 即时入池,{queued} 排队中,{failed} 失败")
|
||||
|
||||
|
||||
@router.get("/ocr-tasks", summary="OCR 任务队列列表")
|
||||
async def list_ocr_tasks(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
status: str | None = Query(None, description="pending/processing/success/failed/manual"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from sqlalchemy import func, select
|
||||
from app.models.finance import FinOcrTask
|
||||
|
||||
where = [FinOcrTask.company_id == company_id, FinOcrTask.is_deleted.is_(False)]
|
||||
if current_user.data_scope == "self":
|
||||
where.append(FinOcrTask.uploader_id == current_user.user_id)
|
||||
if status:
|
||||
where.append(FinOcrTask.status == status)
|
||||
|
||||
total = (await db.execute(select(func.count()).select_from(FinOcrTask).where(*where))).scalar() or 0
|
||||
stmt = (
|
||||
select(FinOcrTask).where(*where)
|
||||
.order_by(FinOcrTask.priority, FinOcrTask.created_at.desc())
|
||||
.offset((page - 1) * size).limit(size)
|
||||
)
|
||||
tasks = (await db.execute(stmt)).scalars().all()
|
||||
|
||||
return ok(data={
|
||||
"total": total, "page": page, "size": size,
|
||||
"items": [{
|
||||
"id": str(t.id),
|
||||
"original_name": t.original_name,
|
||||
"file_ext": t.file_ext,
|
||||
"file_url": t.file_url,
|
||||
"status": t.status,
|
||||
"priority": t.priority,
|
||||
"retry_count": t.retry_count,
|
||||
"max_retries": t.max_retries,
|
||||
"error_message": t.error_message,
|
||||
"ocr_result": t.ocr_result,
|
||||
"invoice_pool_id": str(t.invoice_pool_id) if t.invoice_pool_id else None,
|
||||
"uploader_name": t.uploader.real_name if t.uploader else None,
|
||||
"inv_type": t.inv_type,
|
||||
"created_at": str(t.created_at),
|
||||
"updated_at": str(t.updated_at),
|
||||
} for t in tasks],
|
||||
})
|
||||
|
||||
|
||||
@router.post("/ocr-tasks/{task_id}/retry", summary="重试失败的 OCR 任务")
|
||||
async def retry_ocr_task(
|
||||
task_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
from sqlalchemy import select, update
|
||||
from app.models.finance import FinOcrTask
|
||||
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
raise BizException(message="任务不存在")
|
||||
if task.status not in ("failed", "manual"):
|
||||
raise BizException(message=f"当前状态 [{task.status}] 不允许重试")
|
||||
|
||||
task.status = "pending"
|
||||
task.retry_count = 0
|
||||
task.error_message = None
|
||||
await db.commit()
|
||||
return ok(message="任务已重新入队")
|
||||
|
||||
|
||||
@router.post("/ocr-tasks/{task_id}/manual", summary="手动录入 OCR 结果并入池")
|
||||
async def manual_ocr_task(
|
||||
task_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinOcrTask, FinInvoicePool
|
||||
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
raise BizException(message="任务不存在")
|
||||
|
||||
merchant = body.get("merchant_name", "手动录入")
|
||||
amount = float(body.get("amount", 0))
|
||||
inv_date_str = body.get("invoice_date")
|
||||
inv_date = None
|
||||
if inv_date_str:
|
||||
try:
|
||||
from datetime import date as d
|
||||
inv_date = d.fromisoformat(inv_date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=task.uploader_id, company_id=task.company_id,
|
||||
file_url=task.file_url, merchant_name=merchant, amount=amount,
|
||||
invoice_date=inv_date, type=task.inv_type, ai_extracted_data=body,
|
||||
)
|
||||
db.add(inv)
|
||||
await db.flush()
|
||||
|
||||
task.status = "manual"
|
||||
task.invoice_pool_id = inv.id
|
||||
task.ocr_result = body
|
||||
task.error_message = None
|
||||
await db.commit()
|
||||
|
||||
return ok(data={"invoice_pool_id": str(inv.id)}, message="手动录入成功,发票已入池")
|
||||
|
||||
|
||||
@router.put("/ocr-tasks/{task_id}/priority", summary="调整 OCR 任务优先级")
|
||||
async def update_ocr_task_priority(
|
||||
task_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinOcrTask
|
||||
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
raise BizException(message="任务不存在")
|
||||
if task.status not in ("pending",):
|
||||
raise BizException(message="仅待处理任务可调整优先级")
|
||||
|
||||
new_priority = body.get("priority", task.priority)
|
||||
task.priority = int(new_priority)
|
||||
await db.commit()
|
||||
return ok(message=f"优先级已调整为 {task.priority}")
|
||||
|
||||
|
||||
@router.delete("/ocr-tasks/{task_id}", summary="取消/删除 OCR 任务")
|
||||
async def delete_ocr_task(
|
||||
task_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinOcrTask
|
||||
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
raise BizException(message="任务不存在")
|
||||
if task.status == "processing":
|
||||
raise BizException(message="正在处理中的任务无法取消")
|
||||
|
||||
task.is_deleted = True
|
||||
await db.commit()
|
||||
return ok(message="任务已取消")
|
||||
|
||||
@@ -56,7 +56,7 @@ async def import_products(
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
):
|
||||
from openpyxl import load_workbook
|
||||
from app.models.erp import ErpProductSku
|
||||
from app.models.erp import ProductSku
|
||||
|
||||
content = await file.read()
|
||||
wb = load_workbook(io.BytesIO(content))
|
||||
@@ -79,7 +79,6 @@ async def import_products(
|
||||
spec = str(row[2] or "").strip() or None
|
||||
standard_price = float(row[3] or 0)
|
||||
unit = str(row[4] or "桶").strip()
|
||||
warning_threshold = float(row[5] or 0)
|
||||
|
||||
if not sku_code or not name:
|
||||
skipped += 1
|
||||
@@ -87,22 +86,21 @@ async def import_products(
|
||||
|
||||
# 检查 sku_code 是否已存在
|
||||
exists = (await db.execute(
|
||||
select(func.count()).select_from(ErpProductSku).where(
|
||||
ErpProductSku.sku_code == sku_code,
|
||||
ErpProductSku.is_deleted.is_(False),
|
||||
select(func.count()).select_from(ProductSku).where(
|
||||
ProductSku.sku_code == sku_code,
|
||||
ProductSku.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar()
|
||||
if exists:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
sku = ErpProductSku(
|
||||
sku = ProductSku(
|
||||
sku_code=sku_code,
|
||||
name=name,
|
||||
spec=spec,
|
||||
standard_price=standard_price,
|
||||
unit=unit,
|
||||
warning_threshold=warning_threshold,
|
||||
)
|
||||
db.add(sku)
|
||||
created += 1
|
||||
|
||||
+338
-4
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.order import OrderCreate
|
||||
@@ -32,8 +32,9 @@ async def create_order(
|
||||
body: OrderCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_order(db, current_user, body)
|
||||
result = await svc.create_order(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message=f"订单 {result.order_no} 创建成功")
|
||||
|
||||
|
||||
@@ -47,16 +48,349 @@ async def list_orders(
|
||||
keyword: str | None = Query(None, description="模糊搜索订单号"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_orders(db, current_user, page, size, customer_id, shipping_state, payment_state, keyword)
|
||||
result = await svc.list_orders(db, current_user, page, size, customer_id, shipping_state, payment_state, keyword, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.get("/unlinked-invoices", summary="查询未关联订单的发票列表")
|
||||
async def list_unlinked_invoices(
|
||||
keyword: str | None = Query(None, description="发票号模糊搜索"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinSalesInvoice
|
||||
conditions = [
|
||||
FinSalesInvoice.company_id == company_id,
|
||||
FinSalesInvoice.is_deleted.is_(False),
|
||||
FinSalesInvoice.order_id.is_(None),
|
||||
]
|
||||
if keyword:
|
||||
conditions.append(FinSalesInvoice.invoice_number.ilike(f"%{keyword}%"))
|
||||
|
||||
stmt = (
|
||||
select(FinSalesInvoice)
|
||||
.where(*conditions)
|
||||
.order_by(FinSalesInvoice.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
invoices = (await db.execute(stmt)).scalars().all()
|
||||
return ok(data=[
|
||||
{
|
||||
"id": str(inv.id),
|
||||
"invoice_number": inv.invoice_number,
|
||||
"issuer": inv.issuer,
|
||||
"receiver_name": inv.receiver_customer.name if inv.receiver_customer else None,
|
||||
"amount": float(inv.amount),
|
||||
"billing_date": str(inv.billing_date),
|
||||
}
|
||||
for inv in invoices
|
||||
])
|
||||
|
||||
|
||||
@router.get("/{order_id}", summary="订单全景详情(关系预加载 items + customer)")
|
||||
async def get_order(
|
||||
order_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_order(db, current_user, order_id)
|
||||
result = await svc.get_order(db, current_user, order_id, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.get("/{order_id}/invoices", summary="获取订单关联的销项发票")
|
||||
async def get_order_invoices(
|
||||
order_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinSalesInvoice
|
||||
stmt = (
|
||||
select(FinSalesInvoice)
|
||||
.where(
|
||||
FinSalesInvoice.order_id == order_id,
|
||||
FinSalesInvoice.company_id == company_id,
|
||||
FinSalesInvoice.is_deleted.is_(False),
|
||||
)
|
||||
.order_by(FinSalesInvoice.created_at.desc())
|
||||
)
|
||||
invoices = (await db.execute(stmt)).scalars().all()
|
||||
return ok(data=[
|
||||
{
|
||||
"id": str(inv.id),
|
||||
"invoice_number": inv.invoice_number,
|
||||
"issuer": inv.issuer,
|
||||
"receiver_name": inv.receiver_customer.name if inv.receiver_customer else None,
|
||||
"amount": float(inv.amount),
|
||||
"billing_date": str(inv.billing_date),
|
||||
"payment_status": inv.payment_status,
|
||||
"payment_date": str(inv.payment_date) if inv.payment_date else None,
|
||||
"payment_amount": float(inv.payment_amount or 0),
|
||||
"payment_due_date": str(inv.payment_due_date) if inv.payment_due_date else None,
|
||||
}
|
||||
for inv in invoices
|
||||
])
|
||||
|
||||
|
||||
@router.put("/{order_id}/payment", summary="更新订单收款状态")
|
||||
async def update_order_payment(
|
||||
order_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from sqlalchemy import select, update as sa_update
|
||||
from app.models.order import ErpOrder
|
||||
from datetime import datetime
|
||||
|
||||
order = (await db.execute(
|
||||
select(ErpOrder).where(
|
||||
ErpOrder.id == order_id,
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if order is None:
|
||||
from app.core.exceptions import NotFoundException
|
||||
raise NotFoundException("订单不存在")
|
||||
|
||||
values = {}
|
||||
if "paid_amount" in body:
|
||||
paid = float(body["paid_amount"])
|
||||
values["paid_amount"] = paid
|
||||
total = float(order.total_amount)
|
||||
if paid >= total:
|
||||
values["payment_state"] = "cleared"
|
||||
elif paid > 0:
|
||||
values["payment_state"] = "partial"
|
||||
else:
|
||||
values["payment_state"] = "unpaid"
|
||||
if "payment_state" in body:
|
||||
values["payment_state"] = body["payment_state"]
|
||||
if values:
|
||||
values["updated_at"] = datetime.utcnow()
|
||||
await db.execute(
|
||||
sa_update(ErpOrder).where(ErpOrder.id == order_id).values(**values)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return ok(message="收款状态已更新")
|
||||
|
||||
|
||||
@router.get("/{order_id}/invoice-detail-preview", summary="生成开票明细预览")
|
||||
async def invoice_detail_preview(
|
||||
order_id: uuid.UUID,
|
||||
mode: str = Query("full", pattern=r"^(full|batch)$", description="full=整体开票, batch=按发货批次"),
|
||||
shipping_id: uuid.UUID | None = Query(None, description="batch模式下必传发货单ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
"""根据模式生成开票明细: 整体=订单全部商品, 批次=指定发货单商品"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from app.models.order import ErpOrder, ErpOrderItem
|
||||
from app.models.shipping import ErpShippingRecord, ErpShippingItem
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.models.sys import SysCompany
|
||||
from app.core.exceptions import NotFoundException, BizException
|
||||
|
||||
# 查订单
|
||||
order = (await db.execute(
|
||||
select(ErpOrder)
|
||||
.where(ErpOrder.id == order_id, ErpOrder.company_id == company_id, ErpOrder.is_deleted.is_(False))
|
||||
.options(
|
||||
selectinload(ErpOrder.items),
|
||||
selectinload(ErpOrder.customer),
|
||||
selectinload(ErpOrder.salesperson),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if not order:
|
||||
raise NotFoundException("订单不存在")
|
||||
|
||||
# 买方名称
|
||||
buyer_name = order.customer.name if order.customer else ""
|
||||
# 卖方名称
|
||||
company = (await db.execute(
|
||||
select(SysCompany).where(SysCompany.id == company_id)
|
||||
)).scalar_one_or_none()
|
||||
seller_name = company.name if company else ""
|
||||
|
||||
items_data = []
|
||||
total_amount = 0.0
|
||||
|
||||
if mode == "full":
|
||||
# 整体开票: 聚合全部订单明细
|
||||
for oi in (order.items or []):
|
||||
sub = float(oi.sub_total or 0)
|
||||
items_data.append({
|
||||
"sku_code": oi.sku.sku_code if oi.sku else "",
|
||||
"sku_name": oi.sku.name if oi.sku else "",
|
||||
"spec": oi.sku.spec if oi.sku else "",
|
||||
"unit": oi.sku.unit if oi.sku else "",
|
||||
"qty": float(oi.qty),
|
||||
"unit_price": float(oi.unit_price),
|
||||
"sub_total": sub,
|
||||
})
|
||||
total_amount += sub
|
||||
else:
|
||||
# 按发货批次
|
||||
if not shipping_id:
|
||||
raise BizException(message="batch模式需指定shipping_id")
|
||||
ship = (await db.execute(
|
||||
select(ErpShippingRecord)
|
||||
.where(
|
||||
ErpShippingRecord.id == shipping_id,
|
||||
ErpShippingRecord.order_id == order_id,
|
||||
ErpShippingRecord.is_deleted.is_(False),
|
||||
)
|
||||
.options(selectinload(ErpShippingRecord.items).selectinload(ErpShippingItem.sku))
|
||||
)).scalar_one_or_none()
|
||||
if not ship:
|
||||
raise NotFoundException("发货单不存在")
|
||||
|
||||
# 查对应的订单明细来获取单价
|
||||
order_item_map = {str(oi.id): oi for oi in (order.items or [])}
|
||||
for si in (ship.items or []):
|
||||
oi = order_item_map.get(str(si.order_item_id))
|
||||
unit_price = float(oi.unit_price) if oi else 0
|
||||
qty = float(si.shipped_qty)
|
||||
sub = round(qty * unit_price, 2)
|
||||
items_data.append({
|
||||
"sku_code": si.sku.sku_code if si.sku else "",
|
||||
"sku_name": si.sku.name if si.sku else "",
|
||||
"spec": si.sku.spec if si.sku else "",
|
||||
"unit": si.sku.unit if si.sku else "",
|
||||
"qty": qty,
|
||||
"unit_price": unit_price,
|
||||
"sub_total": sub,
|
||||
})
|
||||
total_amount += sub
|
||||
|
||||
return ok(data={
|
||||
"order_no": order.order_no,
|
||||
"buyer_name": buyer_name,
|
||||
"seller_name": seller_name,
|
||||
"customer_id": str(order.customer_id),
|
||||
"items": items_data,
|
||||
"total_amount": round(total_amount, 2),
|
||||
"shipping_id": str(shipping_id) if shipping_id else None,
|
||||
})
|
||||
|
||||
|
||||
@router.post("/{order_id}/invoices/link", summary="关联已有发票到订单")
|
||||
async def link_existing_invoice(
|
||||
order_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
"""将已存在的销项发票关联到该订单"""
|
||||
from sqlalchemy import select, update as sa_update
|
||||
from app.models.finance import FinSalesInvoice
|
||||
from app.core.exceptions import NotFoundException, BizException
|
||||
from datetime import datetime
|
||||
|
||||
invoice_id = body.get("invoice_id")
|
||||
shipping_record_id = body.get("shipping_record_id")
|
||||
if not invoice_id:
|
||||
raise BizException(message="请提供 invoice_id")
|
||||
|
||||
inv = (await db.execute(
|
||||
select(FinSalesInvoice).where(
|
||||
FinSalesInvoice.id == uuid.UUID(invoice_id),
|
||||
FinSalesInvoice.company_id == company_id,
|
||||
FinSalesInvoice.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if not inv:
|
||||
raise NotFoundException("发票不存在")
|
||||
|
||||
values = {"order_id": order_id, "updated_at": datetime.utcnow()}
|
||||
if shipping_record_id:
|
||||
values["shipping_record_id"] = uuid.UUID(shipping_record_id)
|
||||
|
||||
await db.execute(
|
||||
sa_update(FinSalesInvoice)
|
||||
.where(FinSalesInvoice.id == uuid.UUID(invoice_id))
|
||||
.values(**values)
|
||||
)
|
||||
await db.commit()
|
||||
return ok(message="发票已关联到订单")
|
||||
|
||||
|
||||
@router.post("/{order_id}/invoices/create", summary="直接创建发票并关联到订单")
|
||||
async def create_and_link_invoice(
|
||||
order_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
"""创建新的销项发票,同时关联到当前订单"""
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinSalesInvoice
|
||||
from app.models.order import ErpOrder
|
||||
from app.core.exceptions import NotFoundException, BizException
|
||||
from datetime import date as dt_date
|
||||
|
||||
order = (await db.execute(
|
||||
select(ErpOrder).where(
|
||||
ErpOrder.id == order_id,
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if not order:
|
||||
raise NotFoundException("订单不存在")
|
||||
|
||||
invoice_number = body.get("invoice_number", "").strip()
|
||||
amount = float(body.get("amount", 0))
|
||||
issuer = body.get("issuer", "").strip()
|
||||
receiver_customer_id = body.get("receiver_customer_id") or str(order.customer_id)
|
||||
billing_date_str = body.get("billing_date")
|
||||
shipping_record_id = body.get("shipping_record_id")
|
||||
remark = body.get("remark")
|
||||
|
||||
if not invoice_number:
|
||||
raise BizException(message="请填写发票号")
|
||||
if amount <= 0:
|
||||
raise BizException(message="开票金额需大于0")
|
||||
if not issuer:
|
||||
raise BizException(message="请填写开票方名称")
|
||||
|
||||
# 检查唯一性
|
||||
from sqlalchemy import func as sa_func
|
||||
existing = (await db.execute(
|
||||
select(sa_func.count()).select_from(FinSalesInvoice).where(
|
||||
FinSalesInvoice.invoice_number == invoice_number,
|
||||
FinSalesInvoice.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar()
|
||||
if existing:
|
||||
raise BizException(message=f"发票号 {invoice_number} 已存在")
|
||||
|
||||
inv = FinSalesInvoice(
|
||||
issuer=issuer,
|
||||
receiver_customer_id=uuid.UUID(receiver_customer_id),
|
||||
invoice_number=invoice_number,
|
||||
amount=amount,
|
||||
billing_date=dt_date.fromisoformat(billing_date_str) if billing_date_str else dt_date.today(),
|
||||
remark=remark,
|
||||
order_id=order_id,
|
||||
shipping_record_id=uuid.UUID(shipping_record_id) if shipping_record_id else None,
|
||||
created_by=current_user.user_id,
|
||||
company_id=company_id,
|
||||
)
|
||||
db.add(inv)
|
||||
await db.commit()
|
||||
return ok(data={"id": str(inv.id), "invoice_number": invoice_number}, message="发票创建并关联成功")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.erp import CategoryCreate, CategoryUpdate, InventoryFlowCreate, SkuCreate, SkuUpdate
|
||||
@@ -64,8 +64,9 @@ async def list_skus(
|
||||
keyword: str | None = Query(None, description="模糊搜索 SKU 编码或名称"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_skus(db, page, size, category_id, keyword)
|
||||
result = await svc.list_skus(db, company_id, page, size, category_id, keyword)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -95,8 +96,9 @@ async def create_inventory_flow(
|
||||
body: InventoryFlowCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_inventory_flow(db, current_user, body)
|
||||
result = await svc.create_inventory_flow(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message="库存变更成功")
|
||||
|
||||
|
||||
@@ -107,6 +109,7 @@ async def get_inventory_flows(
|
||||
size: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_inventory_flows(db, sku_id, page, size)
|
||||
result = await svc.get_inventory_flows(db, sku_id, company_id, page, size)
|
||||
return ok(data=result)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
利润核算路由 —— /api/profit
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
from app.services import profit_service as svc
|
||||
|
||||
router = APIRouter(prefix="/profit", tags=["利润核算"])
|
||||
|
||||
|
||||
@router.get("/report", summary="利润报表(订单维度)")
|
||||
async def profit_report(
|
||||
start_date: str | None = Query(None, description="起始日期 YYYY-MM-DD"),
|
||||
end_date: str | None = Query(None, description="结束日期 YYYY-MM-DD"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_profit_report(db, company_id, start_date, end_date)
|
||||
return ok(data=result)
|
||||
|
||||
|
||||
@router.post("/snapshot/{order_id}", summary="为订单锚定成本快照")
|
||||
async def snapshot_costs(
|
||||
order_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.snapshot_order_item_costs(db, order_id, company_id)
|
||||
return ok(data=result, message=f"已为 {len(result)} 项明细锚定成本")
|
||||
+16
-10
@@ -14,7 +14,7 @@ from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
@@ -28,15 +28,16 @@ async def generate_report(
|
||||
end_date: date = Body(..., embed=True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
authorization: str | None = Header(None),
|
||||
):
|
||||
"""
|
||||
1. 聚合该用户在时间范围内的 sales_logs 内容
|
||||
1. 聚合该用户在时间范围内、涉及当前公司的 sales_logs 内容
|
||||
2. 调用 Dify Workflow (streaming) 生成复盘报告
|
||||
3. SSE 流式返回给前端
|
||||
"""
|
||||
return StreamingResponse(
|
||||
_report_sse_generator(db, current_user, start_date, end_date, authorization or ""),
|
||||
_report_sse_generator(db, current_user, start_date, end_date, authorization or "", company_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -47,20 +48,25 @@ async def _report_sse_generator(
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
authorization: str = "",
|
||||
company_id: uuid.UUID | None = None,
|
||||
):
|
||||
import httpx
|
||||
from app.core.config import settings
|
||||
from app.models.ai import SalesLog
|
||||
|
||||
# 1. 聚合日志
|
||||
# 1. 聚合日志 — 仅提取涉及当前公司的日志
|
||||
conditions = [
|
||||
SalesLog.salesperson_id == user.user_id,
|
||||
SalesLog.log_date >= start_date,
|
||||
SalesLog.log_date <= end_date,
|
||||
SalesLog.is_deleted.is_(False),
|
||||
]
|
||||
if company_id:
|
||||
conditions.append(SalesLog.involved_company_ids.any(company_id))
|
||||
|
||||
stmt = (
|
||||
select(SalesLog)
|
||||
.where(
|
||||
SalesLog.salesperson_id == user.user_id,
|
||||
SalesLog.log_date >= start_date,
|
||||
SalesLog.log_date <= end_date,
|
||||
SalesLog.is_deleted.is_(False),
|
||||
)
|
||||
.where(*conditions)
|
||||
.order_by(SalesLog.log_date)
|
||||
)
|
||||
logs = (await db.execute(stmt)).scalars().all()
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.sales_invoice import SalesInvoiceCreate, SalesInvoiceUpdate
|
||||
@@ -26,8 +26,9 @@ async def create_invoice(
|
||||
body: SalesInvoiceCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_invoice(db, current_user, body)
|
||||
result = await svc.create_invoice(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message="销项发票创建成功")
|
||||
|
||||
|
||||
@@ -42,10 +43,11 @@ async def list_invoices(
|
||||
end_date: date | None = Query(None, description="开票结束日期"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_invoices(
|
||||
db, page, size, customer_name, invoice_number,
|
||||
payment_status, start_date, end_date,
|
||||
payment_status, start_date, end_date, company_id,
|
||||
)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
@@ -26,6 +28,7 @@ async def list_logs(
|
||||
end_date: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
result = await sales_log_service.list_logs(
|
||||
db, current_user,
|
||||
@@ -34,18 +37,56 @@ async def list_logs(
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
company_id=company_id,
|
||||
)
|
||||
return ok(data=result)
|
||||
|
||||
|
||||
async def _resolve_company_ids(
|
||||
db: AsyncSession,
|
||||
company_id: uuid.UUID,
|
||||
customer_id: str | None,
|
||||
company_ids: list[str] | None,
|
||||
) -> list[uuid.UUID]:
|
||||
"""
|
||||
智能解析 involved_company_ids:
|
||||
1. 如果前端显式传了 company_ids,使用它
|
||||
2. 否则以当前视角公司为基础
|
||||
3. 如果选了客户,自动查客户 owner 所属的公司,合并进来
|
||||
"""
|
||||
if company_ids:
|
||||
resolved = set(uuid.UUID(cid) for cid in company_ids)
|
||||
else:
|
||||
resolved = {company_id}
|
||||
|
||||
# 自动关联客户 owner 所在公司
|
||||
if customer_id:
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.models.sys import SysUserCompany
|
||||
cust = await db.get(CrmCustomer, uuid.UUID(customer_id))
|
||||
if cust and cust.owner_id:
|
||||
stmt = select(SysUserCompany.company_id).where(
|
||||
SysUserCompany.user_id == cust.owner_id
|
||||
)
|
||||
rows = (await db.execute(stmt)).scalars().all()
|
||||
for cid in rows:
|
||||
resolved.add(cid)
|
||||
|
||||
# 确保当前公司始终在内
|
||||
resolved.add(company_id)
|
||||
return list(resolved)
|
||||
|
||||
|
||||
@router.post("", summary="创建销售日志")
|
||||
async def create_log(
|
||||
content: str = Body(..., embed=True),
|
||||
customer_id: str | None = Body(None, embed=True),
|
||||
contact_ids: list[str] | None = Body(None, embed=True),
|
||||
log_date: str | None = Body(None, embed=True),
|
||||
company_ids: list[str] | None = Body(None, embed=True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
from datetime import date as date_type
|
||||
|
||||
@@ -53,17 +94,20 @@ async def create_log(
|
||||
if log_date:
|
||||
parsed_date = date_type.fromisoformat(log_date)
|
||||
|
||||
# 智能解析公司关联
|
||||
resolved_company_ids = await _resolve_company_ids(db, company_id, customer_id, company_ids)
|
||||
|
||||
result = await sales_log_service.create_log(
|
||||
db, current_user,
|
||||
content=content,
|
||||
customer_id=customer_id,
|
||||
contact_ids=contact_ids,
|
||||
log_date=parsed_date,
|
||||
company_ids=resolved_company_ids,
|
||||
)
|
||||
|
||||
# 异步触发 Dify 画像提取工作流(仅当关联了客户时)
|
||||
if customer_id:
|
||||
import uuid
|
||||
asyncio.create_task(
|
||||
sales_log_service.trigger_persona_workflow(
|
||||
log_id=uuid.UUID(result["id"]),
|
||||
@@ -75,3 +119,35 @@ async def create_log(
|
||||
)
|
||||
|
||||
return ok(data=result, message="日志创建成功")
|
||||
|
||||
|
||||
@router.put("/{log_id}", summary="编辑销售日志")
|
||||
async def update_log(
|
||||
log_id: uuid.UUID,
|
||||
content: str | None = Body(None, embed=True),
|
||||
customer_id: str | None = Body(None, embed=True),
|
||||
contact_ids: list[str] | None = Body(None, embed=True),
|
||||
log_date: str | None = Body(None, embed=True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
result = await sales_log_service.update_log(
|
||||
db, current_user, log_id,
|
||||
content=content,
|
||||
customer_id=customer_id,
|
||||
contact_ids=contact_ids,
|
||||
log_date=log_date,
|
||||
company_id=company_id,
|
||||
)
|
||||
return ok(data=result, message="日志更新成功")
|
||||
|
||||
|
||||
@router.delete("/{log_id}", summary="删除销售日志(软删除)")
|
||||
async def delete_log(
|
||||
log_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
):
|
||||
await sales_log_service.delete_log(db, current_user, log_id)
|
||||
return ok(message="日志已删除")
|
||||
|
||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.shipping import ShippingCreate
|
||||
@@ -21,8 +21,9 @@ async def create_shipping(
|
||||
body: ShippingCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
resp, new_state = await svc.create_shipping(db, current_user, body)
|
||||
resp, new_state = await svc.create_shipping(db, current_user, body, company_id)
|
||||
return ok(data=resp.model_dump(mode="json"), message=f"发货单 {resp.shipping_no} 创建成功,订单状态已更新为 {new_state}")
|
||||
|
||||
|
||||
@@ -34,8 +35,9 @@ async def list_shipping(
|
||||
tracking_no: str | None = Query(None, description="按物流单号搜索"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_shipping(db, current_user, page, size, order_no, tracking_no)
|
||||
result = await svc.list_shipping(db, current_user, page, size, order_no, tracking_no, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -44,6 +46,7 @@ async def get_shipping_by_order(
|
||||
order_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_shipping_by_order(db, current_user, order_id)
|
||||
result = await svc.get_shipping_by_order(db, current_user, order_id, company_id)
|
||||
return ok(data=result)
|
||||
|
||||
Reference in New Issue
Block a user