v0.2.0: CRM/ERP 系统升级 - 清理 .gitignore 并移除误提交的 venv/env/db 文件
- 更新 .gitignore:全面覆盖环境变量、数据库、日志、缓存、上传文件 - 移除误跟踪的 server/venv/、crm_data.db、.env 文件 - 新增 server/.env.example 模板 - 新增合同管理、利润核算、AI教练等功能模块 - 新增 Playwright e2e 测试套件 - 前后端多项功能升级和 bug 修复
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
AI 教练引擎路由 —— /api/ai-coaching
|
||||
Dify 回调 + SSE 通知流
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
from app.services import ai_coaching_service as svc
|
||||
|
||||
router = APIRouter(prefix="/ai-coaching", tags=["AI教练引擎"])
|
||||
|
||||
|
||||
@router.post("/dify-callback/{sales_log_id}", summary="Dify Workflow 回调端点")
|
||||
async def dify_coaching_callback(
|
||||
sales_log_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""接收 Dify Workflow 的异步回调,写回教练反馈"""
|
||||
import json
|
||||
body = await request.json()
|
||||
await svc.handle_dify_coaching_callback(db, sales_log_id, body)
|
||||
return ok(message="教练反馈已回写")
|
||||
|
||||
|
||||
@router.get("/notifications/stream", summary="SSE 通知流")
|
||||
async def sse_notifications(
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
):
|
||||
"""Server-Sent Events 推送通知(AI 教练反馈、系统通知等)"""
|
||||
return StreamingResponse(
|
||||
svc.sse_notification_generator(current_user.user_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
公司管理路由 —— /api/companies
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.models.sys import SysCompany, SysUserCompany
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
|
||||
router = APIRouter(prefix="/companies", tags=["公司管理"])
|
||||
|
||||
|
||||
@router.get("", summary="获取当前用户可访问的公司列表")
|
||||
async def list_companies(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""返回当前登录用户所关联的所有激活公司"""
|
||||
stmt = (
|
||||
select(SysCompany)
|
||||
.join(SysUserCompany, SysUserCompany.company_id == SysCompany.id)
|
||||
.where(
|
||||
SysUserCompany.user_id == current_user.user_id,
|
||||
SysCompany.is_active.is_(True),
|
||||
)
|
||||
.order_by(SysCompany.created_at)
|
||||
)
|
||||
companies = (await db.execute(stmt)).scalars().all()
|
||||
|
||||
# 查该用户的默认公司
|
||||
default_stmt = (
|
||||
select(SysUserCompany.company_id)
|
||||
.where(
|
||||
SysUserCompany.user_id == current_user.user_id,
|
||||
SysUserCompany.is_default.is_(True),
|
||||
)
|
||||
)
|
||||
default_id = (await db.execute(default_stmt)).scalar_one_or_none()
|
||||
|
||||
return ok(data={
|
||||
"companies": [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"name": c.name,
|
||||
"code": c.code,
|
||||
"is_active": c.is_active,
|
||||
}
|
||||
for c in companies
|
||||
],
|
||||
"default_company_id": str(default_id) if default_id else (
|
||||
str(companies[0].id) if companies else None
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@router.get("/current", summary="获取当前公司详情(含 full_info)")
|
||||
async def get_current_company(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
company = (await db.execute(
|
||||
select(SysCompany).where(SysCompany.id == company_id)
|
||||
)).scalar_one_or_none()
|
||||
if company is None:
|
||||
return ok(data=None)
|
||||
return ok(data={
|
||||
"id": str(company.id),
|
||||
"name": company.name,
|
||||
"code": company.code,
|
||||
"full_info": company.full_info or {},
|
||||
"is_active": company.is_active,
|
||||
})
|
||||
|
||||
|
||||
@router.put("/current", summary="更新当前公司信息(含 full_info)")
|
||||
async def update_current_company(
|
||||
body: dict = Body(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
# 仅管理员可编辑
|
||||
if current_user.data_scope != "all":
|
||||
from app.core.exceptions import ForbiddenException
|
||||
raise ForbiddenException("仅管理员可编辑公司信息")
|
||||
|
||||
values: dict = {}
|
||||
if "name" in body:
|
||||
values["name"] = body["name"]
|
||||
if "full_info" in body:
|
||||
values["full_info"] = body["full_info"]
|
||||
if values:
|
||||
values["updated_at"] = datetime.utcnow()
|
||||
await db.execute(
|
||||
update(SysCompany).where(SysCompany.id == company_id).values(**values)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# 返回更新后的数据
|
||||
company = (await db.execute(
|
||||
select(SysCompany).where(SysCompany.id == company_id)
|
||||
)).scalar_one()
|
||||
return ok(data={
|
||||
"id": str(company.id),
|
||||
"name": company.name,
|
||||
"code": company.code,
|
||||
"full_info": company.full_info or {},
|
||||
"is_active": company.is_active,
|
||||
}, message="公司信息已更新")
|
||||
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
合同管理路由 —— /api/contracts
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Body, Depends, Query, UploadFile, File
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.contract import ContractCreate, ContractUpdate
|
||||
from app.schemas.response import ok
|
||||
from app.services import contract_service as svc
|
||||
|
||||
router = APIRouter(prefix="/contracts", tags=["合同管理"])
|
||||
|
||||
|
||||
@router.post("", summary="新增合同")
|
||||
async def create_contract(
|
||||
body: ContractCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_contract(db, current_user, company_id, body)
|
||||
return ok(data=result.model_dump(mode="json"), message="合同创建成功")
|
||||
|
||||
|
||||
@router.get("", summary="合同列表(分页)")
|
||||
async def list_contracts(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
keyword: str | None = Query(None, description="合同编号搜索"),
|
||||
status: str | None = Query(None, description="状态筛选"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_contracts(db, company_id, page, size, keyword, status)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.get("/{contract_id}", summary="合同详情(含执行进度)")
|
||||
async def get_contract(
|
||||
contract_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_contract(db, contract_id, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.put("/{contract_id}", summary="编辑合同")
|
||||
async def update_contract(
|
||||
contract_id: uuid.UUID,
|
||||
body: ContractUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.update_contract(db, contract_id, company_id, body)
|
||||
return ok(data=result.model_dump(mode="json"), message="合同已更新")
|
||||
|
||||
|
||||
@router.delete("/{contract_id}", summary="删除合同")
|
||||
async def delete_contract(
|
||||
contract_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
await svc.delete_contract(db, contract_id, company_id)
|
||||
return ok(message="合同已删除")
|
||||
|
||||
|
||||
@router.post("/{contract_id}/generate-order", summary="一键从合同生成订单")
|
||||
async def generate_order_from_contract(
|
||||
contract_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.generate_order_from_contract(db, current_user, contract_id, company_id)
|
||||
return ok(data=result, message="订单生成成功")
|
||||
|
||||
|
||||
@router.get("/{contract_id}/generate", summary="生成合同 Word 文档下载")
|
||||
async def generate_contract_document(
|
||||
contract_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
from fastapi.responses import Response
|
||||
docx_bytes = await svc.generate_contract_docx(db, contract_id, company_id)
|
||||
return Response(
|
||||
content=docx_bytes,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f"attachment; filename=contract_{contract_id}.docx"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{contract_id}/upload-signed", summary="上传双签盖章版")
|
||||
async def upload_signed_copy(
|
||||
contract_id: uuid.UUID,
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
import os
|
||||
from app.models.contract import ErpContract, ErpContractAttachment
|
||||
from sqlalchemy import update as sa_update
|
||||
|
||||
# 验证合同存在
|
||||
from sqlalchemy import select as sa_select
|
||||
contract = (await db.execute(
|
||||
sa_select(ErpContract).where(
|
||||
ErpContract.id == contract_id,
|
||||
ErpContract.company_id == company_id,
|
||||
ErpContract.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if contract is None:
|
||||
raise Exception("合同不存在")
|
||||
|
||||
# 保存文件
|
||||
upload_dir = f"uploads/contracts/{contract_id}"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
file_path = f"{upload_dir}/{file.filename}"
|
||||
with open(file_path, "wb") as f:
|
||||
content = await file.read()
|
||||
f.write(content)
|
||||
|
||||
file_url = f"/{file_path}"
|
||||
|
||||
# 记录附件
|
||||
attachment = ErpContractAttachment(
|
||||
contract_id=contract_id,
|
||||
file_name=file.filename or "signed_copy",
|
||||
file_url=file_url,
|
||||
file_type="signed_copy",
|
||||
uploader_id=current_user.user_id,
|
||||
)
|
||||
db.add(attachment)
|
||||
|
||||
# 更新合同签署状态
|
||||
await db.execute(
|
||||
sa_update(ErpContract)
|
||||
.where(ErpContract.id == contract_id)
|
||||
.values(is_signed=True, signed_file_url=file_url)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return ok(message="双签盖章版上传成功", data={"file_url": file_url})
|
||||
@@ -4,7 +4,7 @@ CRM 客户模块路由 —— /api/customers
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.db.database import get_db
|
||||
@@ -91,6 +91,20 @@ async def restore_customer(
|
||||
return ok(message="客户已恢复")
|
||||
|
||||
|
||||
@router.put("/{customer_id}/transfer", summary="转移客户负责人(仅管理员)")
|
||||
async def transfer_customer(
|
||||
customer_id: uuid.UUID,
|
||||
body: dict = Body(..., examples=[{"new_owner_id": "uuid-here"}]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
new_owner_id = body.get("new_owner_id")
|
||||
if not new_owner_id:
|
||||
raise Exception("缺少 new_owner_id 参数")
|
||||
result = await svc.transfer_customer(db, current_user, customer_id, uuid.UUID(str(new_owner_id)))
|
||||
return ok(data=result.model_dump(mode="json"), message="客户转移成功")
|
||||
|
||||
|
||||
@router.get("/{customer_id}/products", summary="获取客户关联产品(通过订单反查)")
|
||||
async def get_customer_products(
|
||||
customer_id: uuid.UUID,
|
||||
|
||||
+15
-10
@@ -3,20 +3,21 @@ Dashboard 统计 API — /api/dashboard
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func, select, and_, extract
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
|
||||
from app.models.order import ErpOrder
|
||||
from app.models.shipping import ErpShippingRecord
|
||||
from app.models.erp import ProductSku
|
||||
from app.models.erp import ErpSkuInventory
|
||||
|
||||
router = APIRouter(prefix="/dashboard", tags=["Dashboard"])
|
||||
|
||||
@@ -25,42 +26,46 @@ router = APIRouter(prefix="/dashboard", tags=["Dashboard"])
|
||||
async def get_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
today = date.today()
|
||||
month_start = today.replace(day=1)
|
||||
|
||||
# 本月新增订单数
|
||||
# 本月新增订单数(按公司隔离)
|
||||
orders_count_q = select(func.count()).select_from(ErpOrder).where(
|
||||
and_(
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.order_date >= month_start,
|
||||
)
|
||||
)
|
||||
orders_count = (await db.execute(orders_count_q)).scalar() or 0
|
||||
|
||||
# 待出库发货数(状态为 pending)
|
||||
# 待出库发货数(按公司隔离)
|
||||
pending_shipping_q = select(func.count()).select_from(ErpOrder).where(
|
||||
and_(
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.shipping_state == "pending",
|
||||
)
|
||||
)
|
||||
pending_shipping = (await db.execute(pending_shipping_q)).scalar() or 0
|
||||
|
||||
# 库存预警 SKU 数(stock_qty <= warning_threshold 且 warning_threshold > 0)
|
||||
warning_skus_q = select(func.count()).select_from(ProductSku).where(
|
||||
# 库存预警 SKU 数(从 erp_sku_inventory 查,按公司隔离)
|
||||
warning_skus_q = select(func.count()).select_from(ErpSkuInventory).where(
|
||||
and_(
|
||||
ProductSku.is_deleted.is_(False),
|
||||
ProductSku.warning_threshold > 0,
|
||||
ProductSku.stock_qty <= ProductSku.warning_threshold,
|
||||
ErpSkuInventory.company_id == company_id,
|
||||
ErpSkuInventory.warning_threshold > 0,
|
||||
ErpSkuInventory.stock_qty <= ErpSkuInventory.warning_threshold,
|
||||
)
|
||||
)
|
||||
warning_skus = (await db.execute(warning_skus_q)).scalar() or 0
|
||||
|
||||
# 本月预计营收(本月订单总金额)
|
||||
# 本月预计营收(按公司隔离)
|
||||
revenue_q = select(func.coalesce(func.sum(ErpOrder.total_amount), 0)).where(
|
||||
and_(
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.order_date >= month_start,
|
||||
)
|
||||
)
|
||||
|
||||
+47
-2
@@ -1,18 +1,21 @@
|
||||
"""
|
||||
FastAPI 依赖注入 —— 权限拦截核心
|
||||
get_current_user: 解析 JWT → 查表获取完整权限上下文
|
||||
get_current_company_id: 从 X-Company-Id Header 提取公司 ID + IDOR 校验
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends, Header
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import UnauthorizedException
|
||||
from app.core.exceptions import ForbiddenException, UnauthorizedException
|
||||
from app.core.security import decode_access_token
|
||||
from app.db.database import get_db
|
||||
from app.models.sys import SysUser
|
||||
from app.models.sys import SysCompany, SysUser, SysUserCompany
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
|
||||
|
||||
@@ -65,3 +68,45 @@ async def get_current_user(
|
||||
data_scope=user.role.data_scope if user.role else "self",
|
||||
menu_keys=user.role.menu_keys if user.role else [],
|
||||
)
|
||||
|
||||
|
||||
async def get_current_company_id(
|
||||
x_company_id: str = Header(..., alias="X-Company-Id", description="当前工作台的公司 ID"),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> uuid.UUID:
|
||||
"""
|
||||
公司视角依赖(IDOR 防护核心):
|
||||
1. 从 X-Company-Id Header 提取公司 UUID
|
||||
2. 校验当前用户是否归属于该公司(查 sys_user_companies)
|
||||
3. 校验公司是否启用
|
||||
"""
|
||||
# ── 解析 company_id ──
|
||||
try:
|
||||
company_uuid = uuid.UUID(x_company_id)
|
||||
except ValueError:
|
||||
raise UnauthorizedException("X-Company-Id 格式错误,需为合法 UUID")
|
||||
|
||||
# ── IDOR 防护:校验用户-公司归属 ──
|
||||
assoc = (await db.execute(
|
||||
select(SysUserCompany).where(
|
||||
SysUserCompany.user_id == current_user.user_id,
|
||||
SysUserCompany.company_id == company_uuid,
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if assoc is None:
|
||||
raise ForbiddenException("您无权访问该公司数据")
|
||||
|
||||
# ── 校验公司是否启用 ──
|
||||
company = (await db.execute(
|
||||
select(SysCompany).where(
|
||||
SysCompany.id == company_uuid,
|
||||
SysCompany.is_active.is_(True),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if company is None:
|
||||
raise ForbiddenException("公司不存在或已停用")
|
||||
|
||||
return company_uuid
|
||||
|
||||
+407
-21
@@ -9,7 +9,7 @@ import time
|
||||
import base64
|
||||
from fastapi import APIRouter, Depends, Query, Body, File, UploadFile, Form
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.finance import ExpenseCreate, ExpenseStatusUpdate, InvoiceCreate
|
||||
@@ -43,34 +43,96 @@ async def ocr_recognize(
|
||||
|
||||
file_url = f"/uploads/finance/{safe_filename}"
|
||||
|
||||
# 仅支持图片(png/jpg/jpeg)和 PDF,不再支持 MD/TXT
|
||||
supported = {".png", ".jpg", ".jpeg", ".pdf"}
|
||||
# 支持的格式:结构化零算力 > 文本 LLM > 图片 Vision
|
||||
supported = {".png", ".jpg", ".jpeg", ".pdf", ".md", ".ofd", ".xml", ".zip"}
|
||||
if ext not in supported:
|
||||
raise BizException(message=f"不支持的文件格式 {ext},仅支持: {', '.join(supported)}")
|
||||
|
||||
# 如果是 PDF,转成 PNG 再做 OCR
|
||||
ocr_bytes = file_bytes
|
||||
raise BizException(message=f"不支持的文件格式 {ext},仅支持: {', '.join(sorted(supported))}")
|
||||
|
||||
# ── 策略 A0: ZIP → 解包所有 XML 并逐个解析 ──
|
||||
if ext == ".zip":
|
||||
from app.services.invoice_parser import parse_zip_invoices
|
||||
results = parse_zip_invoices(file_bytes)
|
||||
return ok(data={"zip_results": [
|
||||
{"filename": r.get("filename", ""), "success": r.get("success", False),
|
||||
"ocr_data": r.get("data", {}), "needs_llm": r.get("needs_llm", False),
|
||||
"error": r.get("error")}
|
||||
for r in results
|
||||
], "file_url": file_url}, message=f"ZIP 解析完成:{sum(1 for r in results if r.get('success'))}/{len(results)} 成功")
|
||||
|
||||
# ── 策略 A: OFD / XML → 结构化零算力提取(最快最准)──
|
||||
if ext in (".ofd", ".xml"):
|
||||
from app.services.invoice_parser import parse_ofd_invoice, parse_xml_invoice
|
||||
parser = parse_ofd_invoice if ext == ".ofd" else parse_xml_invoice
|
||||
result = parser(file_bytes)
|
||||
print(f"[OCR] {ext.upper()} 解析: success={result.get('success')}")
|
||||
|
||||
if result.get("success"):
|
||||
# 如果解析器提取到 raw_text 且标记 needs_llm,交给 LLM 做字段提取
|
||||
if result.get("needs_llm") and result["data"].get("raw_text"):
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
llm_result = await extract_invoice_from_text(result["data"]["raw_text"], scene)
|
||||
if llm_result.get("success"):
|
||||
return ok(data={"ocr_data": llm_result["data"], "file_url": file_url}, message=f"AI 发票识别成功({ext.upper()} → LLM)")
|
||||
return ok(data={"ocr_data": llm_result.get("data", {}), "file_url": file_url}, message=llm_result.get("error", "LLM 解析失败"))
|
||||
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message=f"发票识别成功({ext.upper()} 结构化提取)")
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message=result.get("error", f"{ext.upper()} 解析失败"))
|
||||
|
||||
# ── 策略 B: MD → 纯文本 LLM 理解(零 GPU Vision)──
|
||||
if ext == ".md":
|
||||
text = file_bytes.decode("utf-8", errors="replace").strip()
|
||||
print(f"[OCR] MD 文本: {len(text)} 字符")
|
||||
if len(text) < 20:
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message="MD 文件内容过少,无法识别")
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
result = await extract_invoice_from_text(text, scene)
|
||||
if result.get("success"):
|
||||
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message="AI 发票识别成功(MD 文本解析)")
|
||||
return ok(data={"ocr_data": result.get("data", {}), "file_url": file_url}, message=result.get("error", "MD 文本解析失败"))
|
||||
|
||||
# ── 策略 C: PDF → PyMuPDF 提取文本 → LLM(零 GPU Vision)──
|
||||
if ext == ".pdf":
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
doc = fitz.open(stream=file_bytes, filetype="pdf")
|
||||
page = doc[0] # 取第一页
|
||||
# 中等分辨率渲染(150 DPI,平衡质量与大小)
|
||||
pix = page.get_pixmap(dpi=150)
|
||||
ocr_bytes = pix.tobytes("png")
|
||||
text = ""
|
||||
for page in doc:
|
||||
text += page.get_text() + "\n"
|
||||
doc.close()
|
||||
print(f"[OCR] PDF 转 PNG 成功: {len(ocr_bytes)} bytes")
|
||||
text = text.strip()
|
||||
print(f"[OCR] PDF 文本提取: {len(text)} 字符")
|
||||
|
||||
if len(text) > 50: # 有足够文本内容
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
result = await extract_invoice_from_text(text, scene)
|
||||
if result.get("success"):
|
||||
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message="AI 发票识别成功(PDF 文本解析)")
|
||||
return ok(data={"ocr_data": result.get("data", {}), "file_url": file_url}, message=result.get("error", "PDF 文本提取失败"))
|
||||
else:
|
||||
# PDF 是扫描件(无文字层),降级到图片 OCR
|
||||
print(f"[OCR] PDF 无文本层(仅 {len(text)} 字符),降级到图片 OCR")
|
||||
page = fitz.open(stream=file_bytes, filetype="pdf")[0]
|
||||
pix = page.get_pixmap(dpi=150)
|
||||
ocr_bytes = pix.tobytes("png")
|
||||
print(f"[OCR] PDF 转 PNG 成功: {len(ocr_bytes)} bytes")
|
||||
except Exception as e:
|
||||
print(f"[OCR] PDF 转换失败: {e}")
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message=f"PDF 转换失败: {e}")
|
||||
|
||||
# 转换为纯 base64 传给 OCR
|
||||
print(f"[OCR] PDF 处理失败: {e}")
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message=f"PDF 处理失败: {e}")
|
||||
else:
|
||||
ocr_bytes = file_bytes
|
||||
|
||||
# ── 策略 D: 图片/扫描PDF → Vision OCR(需要视觉模型)──
|
||||
from app.services.ocr_service import ocr_image
|
||||
image_base64 = base64.b64encode(ocr_bytes).decode("utf-8")
|
||||
result = await ocr_image(image_base64, scene)
|
||||
|
||||
if result.get("success"):
|
||||
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message="AI OCR 识别成功")
|
||||
return ok(data={"ocr_data": result.get("data", {}), "file_url": file_url}, message=result.get("error", "OCR 识别失败"))
|
||||
|
||||
# Vision 失败时友好提示
|
||||
error_msg = result.get("error", "OCR 识别失败")
|
||||
if "模型进程崩溃" in error_msg or "unexpectedly stopped" in error_msg or "服务异常" in error_msg:
|
||||
error_msg += "。建议:请上传电子版 PDF/OFD/XML 发票,系统可零算力直接提取数据"
|
||||
return ok(data={"ocr_data": {}, "file_url": file_url}, message=error_msg)
|
||||
|
||||
|
||||
@router.post("/invoices", summary="上传票据入池(含 AI/OCR JSONB 数据)")
|
||||
@@ -78,8 +140,9 @@ async def create_invoice(
|
||||
body: InvoiceCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_invoice(db, current_user, body)
|
||||
result = await svc.create_invoice(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message="票据入池成功")
|
||||
|
||||
|
||||
@@ -91,8 +154,9 @@ async def list_invoices(
|
||||
is_used: bool | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_invoices(db, current_user, page, size, type, is_used)
|
||||
result = await svc.list_invoices(db, current_user, page, size, type, is_used, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -111,8 +175,9 @@ async def create_expense(
|
||||
body: ExpenseCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_expense(db, current_user, body)
|
||||
result = await svc.create_expense(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message=f"报销单 {result.system_no} 提交成功")
|
||||
|
||||
|
||||
@@ -124,8 +189,9 @@ async def list_expenses(
|
||||
applicant_id: uuid.UUID | None = Query(None, description="按申请人过滤(管理员用)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_expenses(db, current_user, page, size, status, applicant_id)
|
||||
result = await svc.list_expenses(db, current_user, page, size, status, applicant_id, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -148,3 +214,323 @@ async def update_expense_status(
|
||||
) -> dict:
|
||||
msg = await svc.update_expense_status(db, current_user, expense_id, body)
|
||||
return ok(message=msg)
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# 批量上传 + OCR 任务队列 API
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
@router.post("/upload-batch", summary="批量上传发票(ZIP/XML 即时入池,图片PDF 入队列)")
|
||||
async def upload_batch(
|
||||
files: list[UploadFile] = File(...),
|
||||
scene: str = Form("invoice"),
|
||||
inv_type: str = Form("expense"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from app.services.invoice_parser import parse_xml_invoice, parse_ofd_invoice, parse_zip_invoices
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
from app.models.finance import FinInvoicePool, FinOcrTask
|
||||
|
||||
upload_dir = "uploads/finance"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
results = [] # 返回给前端
|
||||
|
||||
for file in files:
|
||||
file_bytes = await file.read()
|
||||
ext = os.path.splitext(file.filename or "")[1].lower() or ".bin"
|
||||
ts = int(time.time())
|
||||
safe_fn = f"{ts}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
file_path = os.path.join(upload_dir, safe_fn)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_bytes)
|
||||
file_url = f"/uploads/finance/{safe_fn}"
|
||||
|
||||
# ── ZIP: 解压内部 XML,逐个即时入池 ──
|
||||
if ext == ".zip":
|
||||
zip_results = parse_zip_invoices(file_bytes)
|
||||
for zr in zip_results:
|
||||
if zr.get("success") and not zr.get("needs_llm"):
|
||||
ai_data = zr.get("data", {})
|
||||
# 需要 LLM 的 zip 中的 xml 也立刻处理
|
||||
merchant = ai_data.get("merchant") or ai_data.get("merchant_name") or "(ZIP)"
|
||||
amount = float(ai_data.get("amount", 0) or 0)
|
||||
inv_date_str = ai_data.get("date")
|
||||
inv_date = None
|
||||
if inv_date_str:
|
||||
try:
|
||||
from datetime import date as d
|
||||
inv_date = d.fromisoformat(inv_date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=current_user.user_id, company_id=company_id,
|
||||
file_url=file_url, merchant_name=merchant, amount=amount,
|
||||
invoice_date=inv_date, type=inv_type, ai_extracted_data=ai_data,
|
||||
)
|
||||
db.add(inv)
|
||||
results.append({"filename": zr.get("filename", file.filename), "action": "pooled",
|
||||
"status": "success", "message": f"✅ {merchant} ¥{amount}"})
|
||||
elif zr.get("needs_llm") and zr.get("data", {}).get("raw_text"):
|
||||
# LLM 文本理解(即时,<5s)
|
||||
try:
|
||||
llm_r = await extract_invoice_from_text(zr["data"]["raw_text"], scene)
|
||||
if llm_r.get("success"):
|
||||
ai_data = llm_r["data"]
|
||||
merchant = ai_data.get("merchant") or "(LLM)"
|
||||
amount = float(ai_data.get("amount", 0) or 0)
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=current_user.user_id, company_id=company_id,
|
||||
file_url=file_url, merchant_name=merchant, amount=amount,
|
||||
type=inv_type, ai_extracted_data=ai_data,
|
||||
)
|
||||
db.add(inv)
|
||||
results.append({"filename": zr.get("filename"), "action": "pooled",
|
||||
"status": "success", "message": f"✅ {merchant} ¥{amount} (LLM)"})
|
||||
else:
|
||||
results.append({"filename": zr.get("filename"), "action": "failed",
|
||||
"status": "error", "message": llm_r.get("error", "LLM 解析失败")})
|
||||
except Exception as e:
|
||||
results.append({"filename": zr.get("filename"), "action": "failed",
|
||||
"status": "error", "message": str(e)})
|
||||
else:
|
||||
results.append({"filename": zr.get("filename", file.filename), "action": "failed",
|
||||
"status": "error", "message": zr.get("error", "解析失败")})
|
||||
continue
|
||||
|
||||
# ── XML / OFD: 零算力即时入池 ──
|
||||
if ext in (".xml", ".ofd"):
|
||||
parser = parse_xml_invoice if ext == ".xml" else parse_ofd_invoice
|
||||
r = parser(file_bytes)
|
||||
if r.get("success") and not r.get("needs_llm"):
|
||||
ai_data = r.get("data", {})
|
||||
merchant = ai_data.get("merchant") or ai_data.get("merchant_name") or "(解析)"
|
||||
amount = float(ai_data.get("amount", 0) or 0)
|
||||
inv_date_str = ai_data.get("date")
|
||||
inv_date = None
|
||||
if inv_date_str:
|
||||
try:
|
||||
from datetime import date as d
|
||||
inv_date = d.fromisoformat(inv_date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=current_user.user_id, company_id=company_id,
|
||||
file_url=file_url, merchant_name=merchant, amount=amount,
|
||||
invoice_date=inv_date, type=inv_type, ai_extracted_data=ai_data,
|
||||
)
|
||||
db.add(inv)
|
||||
results.append({"filename": file.filename, "action": "pooled",
|
||||
"status": "success", "message": f"✅ {merchant} ¥{amount}"})
|
||||
elif r.get("needs_llm") and r.get("data", {}).get("raw_text"):
|
||||
try:
|
||||
llm_r = await extract_invoice_from_text(r["data"]["raw_text"], scene)
|
||||
if llm_r.get("success"):
|
||||
ai_data = llm_r["data"]
|
||||
merchant = ai_data.get("merchant") or "(LLM)"
|
||||
amount = float(ai_data.get("amount", 0) or 0)
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=current_user.user_id, company_id=company_id,
|
||||
file_url=file_url, merchant_name=merchant, amount=amount,
|
||||
type=inv_type, ai_extracted_data=ai_data,
|
||||
)
|
||||
db.add(inv)
|
||||
results.append({"filename": file.filename, "action": "pooled",
|
||||
"status": "success", "message": f"✅ {merchant} ¥{amount} (LLM)"})
|
||||
else:
|
||||
results.append({"filename": file.filename, "action": "failed",
|
||||
"status": "error", "message": llm_r.get("error", "LLM 失败")})
|
||||
except Exception as e:
|
||||
results.append({"filename": file.filename, "action": "failed",
|
||||
"status": "error", "message": str(e)})
|
||||
else:
|
||||
results.append({"filename": file.filename, "action": "failed",
|
||||
"status": "error", "message": r.get("error", "解析失败")})
|
||||
continue
|
||||
|
||||
# ── 图片 / PDF : 写入 DB 任务队列 ──
|
||||
task = FinOcrTask(
|
||||
file_url=file_url, file_ext=ext,
|
||||
original_name=file.filename or "unknown",
|
||||
uploader_id=current_user.user_id,
|
||||
company_id=company_id,
|
||||
inv_type=inv_type,
|
||||
priority=50 if ext == ".pdf" else 100, # PDF 优先(可能有文字层)
|
||||
)
|
||||
db.add(task)
|
||||
await db.flush()
|
||||
results.append({"filename": file.filename, "action": "queued",
|
||||
"status": "pending", "task_id": str(task.id),
|
||||
"message": "🕐 已加入 OCR 处理队列"})
|
||||
|
||||
await db.commit()
|
||||
|
||||
pooled = sum(1 for r in results if r["action"] == "pooled")
|
||||
queued = sum(1 for r in results if r["action"] == "queued")
|
||||
failed = sum(1 for r in results if r["action"] == "failed")
|
||||
return ok(data={"results": results},
|
||||
message=f"批量处理完成:{pooled} 即时入池,{queued} 排队中,{failed} 失败")
|
||||
|
||||
|
||||
@router.get("/ocr-tasks", summary="OCR 任务队列列表")
|
||||
async def list_ocr_tasks(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
status: str | None = Query(None, description="pending/processing/success/failed/manual"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from sqlalchemy import func, select
|
||||
from app.models.finance import FinOcrTask
|
||||
|
||||
where = [FinOcrTask.company_id == company_id, FinOcrTask.is_deleted.is_(False)]
|
||||
if current_user.data_scope == "self":
|
||||
where.append(FinOcrTask.uploader_id == current_user.user_id)
|
||||
if status:
|
||||
where.append(FinOcrTask.status == status)
|
||||
|
||||
total = (await db.execute(select(func.count()).select_from(FinOcrTask).where(*where))).scalar() or 0
|
||||
stmt = (
|
||||
select(FinOcrTask).where(*where)
|
||||
.order_by(FinOcrTask.priority, FinOcrTask.created_at.desc())
|
||||
.offset((page - 1) * size).limit(size)
|
||||
)
|
||||
tasks = (await db.execute(stmt)).scalars().all()
|
||||
|
||||
return ok(data={
|
||||
"total": total, "page": page, "size": size,
|
||||
"items": [{
|
||||
"id": str(t.id),
|
||||
"original_name": t.original_name,
|
||||
"file_ext": t.file_ext,
|
||||
"file_url": t.file_url,
|
||||
"status": t.status,
|
||||
"priority": t.priority,
|
||||
"retry_count": t.retry_count,
|
||||
"max_retries": t.max_retries,
|
||||
"error_message": t.error_message,
|
||||
"ocr_result": t.ocr_result,
|
||||
"invoice_pool_id": str(t.invoice_pool_id) if t.invoice_pool_id else None,
|
||||
"uploader_name": t.uploader.real_name if t.uploader else None,
|
||||
"inv_type": t.inv_type,
|
||||
"created_at": str(t.created_at),
|
||||
"updated_at": str(t.updated_at),
|
||||
} for t in tasks],
|
||||
})
|
||||
|
||||
|
||||
@router.post("/ocr-tasks/{task_id}/retry", summary="重试失败的 OCR 任务")
|
||||
async def retry_ocr_task(
|
||||
task_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
from sqlalchemy import select, update
|
||||
from app.models.finance import FinOcrTask
|
||||
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
raise BizException(message="任务不存在")
|
||||
if task.status not in ("failed", "manual"):
|
||||
raise BizException(message=f"当前状态 [{task.status}] 不允许重试")
|
||||
|
||||
task.status = "pending"
|
||||
task.retry_count = 0
|
||||
task.error_message = None
|
||||
await db.commit()
|
||||
return ok(message="任务已重新入队")
|
||||
|
||||
|
||||
@router.post("/ocr-tasks/{task_id}/manual", summary="手动录入 OCR 结果并入池")
|
||||
async def manual_ocr_task(
|
||||
task_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinOcrTask, FinInvoicePool
|
||||
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
raise BizException(message="任务不存在")
|
||||
|
||||
merchant = body.get("merchant_name", "手动录入")
|
||||
amount = float(body.get("amount", 0))
|
||||
inv_date_str = body.get("invoice_date")
|
||||
inv_date = None
|
||||
if inv_date_str:
|
||||
try:
|
||||
from datetime import date as d
|
||||
inv_date = d.fromisoformat(inv_date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=task.uploader_id, company_id=task.company_id,
|
||||
file_url=task.file_url, merchant_name=merchant, amount=amount,
|
||||
invoice_date=inv_date, type=task.inv_type, ai_extracted_data=body,
|
||||
)
|
||||
db.add(inv)
|
||||
await db.flush()
|
||||
|
||||
task.status = "manual"
|
||||
task.invoice_pool_id = inv.id
|
||||
task.ocr_result = body
|
||||
task.error_message = None
|
||||
await db.commit()
|
||||
|
||||
return ok(data={"invoice_pool_id": str(inv.id)}, message="手动录入成功,发票已入池")
|
||||
|
||||
|
||||
@router.put("/ocr-tasks/{task_id}/priority", summary="调整 OCR 任务优先级")
|
||||
async def update_ocr_task_priority(
|
||||
task_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinOcrTask
|
||||
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
raise BizException(message="任务不存在")
|
||||
if task.status not in ("pending",):
|
||||
raise BizException(message="仅待处理任务可调整优先级")
|
||||
|
||||
new_priority = body.get("priority", task.priority)
|
||||
task.priority = int(new_priority)
|
||||
await db.commit()
|
||||
return ok(message=f"优先级已调整为 {task.priority}")
|
||||
|
||||
|
||||
@router.delete("/ocr-tasks/{task_id}", summary="取消/删除 OCR 任务")
|
||||
async def delete_ocr_task(
|
||||
task_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinOcrTask
|
||||
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
raise BizException(message="任务不存在")
|
||||
if task.status == "processing":
|
||||
raise BizException(message="正在处理中的任务无法取消")
|
||||
|
||||
task.is_deleted = True
|
||||
await db.commit()
|
||||
return ok(message="任务已取消")
|
||||
|
||||
@@ -56,7 +56,7 @@ async def import_products(
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
):
|
||||
from openpyxl import load_workbook
|
||||
from app.models.erp import ErpProductSku
|
||||
from app.models.erp import ProductSku
|
||||
|
||||
content = await file.read()
|
||||
wb = load_workbook(io.BytesIO(content))
|
||||
@@ -79,7 +79,6 @@ async def import_products(
|
||||
spec = str(row[2] or "").strip() or None
|
||||
standard_price = float(row[3] or 0)
|
||||
unit = str(row[4] or "桶").strip()
|
||||
warning_threshold = float(row[5] or 0)
|
||||
|
||||
if not sku_code or not name:
|
||||
skipped += 1
|
||||
@@ -87,22 +86,21 @@ async def import_products(
|
||||
|
||||
# 检查 sku_code 是否已存在
|
||||
exists = (await db.execute(
|
||||
select(func.count()).select_from(ErpProductSku).where(
|
||||
ErpProductSku.sku_code == sku_code,
|
||||
ErpProductSku.is_deleted.is_(False),
|
||||
select(func.count()).select_from(ProductSku).where(
|
||||
ProductSku.sku_code == sku_code,
|
||||
ProductSku.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar()
|
||||
if exists:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
sku = ErpProductSku(
|
||||
sku = ProductSku(
|
||||
sku_code=sku_code,
|
||||
name=name,
|
||||
spec=spec,
|
||||
standard_price=standard_price,
|
||||
unit=unit,
|
||||
warning_threshold=warning_threshold,
|
||||
)
|
||||
db.add(sku)
|
||||
created += 1
|
||||
|
||||
+338
-4
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.order import OrderCreate
|
||||
@@ -32,8 +32,9 @@ async def create_order(
|
||||
body: OrderCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_order(db, current_user, body)
|
||||
result = await svc.create_order(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message=f"订单 {result.order_no} 创建成功")
|
||||
|
||||
|
||||
@@ -47,16 +48,349 @@ async def list_orders(
|
||||
keyword: str | None = Query(None, description="模糊搜索订单号"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_orders(db, current_user, page, size, customer_id, shipping_state, payment_state, keyword)
|
||||
result = await svc.list_orders(db, current_user, page, size, customer_id, shipping_state, payment_state, keyword, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.get("/unlinked-invoices", summary="查询未关联订单的发票列表")
|
||||
async def list_unlinked_invoices(
|
||||
keyword: str | None = Query(None, description="发票号模糊搜索"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinSalesInvoice
|
||||
conditions = [
|
||||
FinSalesInvoice.company_id == company_id,
|
||||
FinSalesInvoice.is_deleted.is_(False),
|
||||
FinSalesInvoice.order_id.is_(None),
|
||||
]
|
||||
if keyword:
|
||||
conditions.append(FinSalesInvoice.invoice_number.ilike(f"%{keyword}%"))
|
||||
|
||||
stmt = (
|
||||
select(FinSalesInvoice)
|
||||
.where(*conditions)
|
||||
.order_by(FinSalesInvoice.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
invoices = (await db.execute(stmt)).scalars().all()
|
||||
return ok(data=[
|
||||
{
|
||||
"id": str(inv.id),
|
||||
"invoice_number": inv.invoice_number,
|
||||
"issuer": inv.issuer,
|
||||
"receiver_name": inv.receiver_customer.name if inv.receiver_customer else None,
|
||||
"amount": float(inv.amount),
|
||||
"billing_date": str(inv.billing_date),
|
||||
}
|
||||
for inv in invoices
|
||||
])
|
||||
|
||||
|
||||
@router.get("/{order_id}", summary="订单全景详情(关系预加载 items + customer)")
|
||||
async def get_order(
|
||||
order_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_order(db, current_user, order_id)
|
||||
result = await svc.get_order(db, current_user, order_id, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.get("/{order_id}/invoices", summary="获取订单关联的销项发票")
|
||||
async def get_order_invoices(
|
||||
order_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinSalesInvoice
|
||||
stmt = (
|
||||
select(FinSalesInvoice)
|
||||
.where(
|
||||
FinSalesInvoice.order_id == order_id,
|
||||
FinSalesInvoice.company_id == company_id,
|
||||
FinSalesInvoice.is_deleted.is_(False),
|
||||
)
|
||||
.order_by(FinSalesInvoice.created_at.desc())
|
||||
)
|
||||
invoices = (await db.execute(stmt)).scalars().all()
|
||||
return ok(data=[
|
||||
{
|
||||
"id": str(inv.id),
|
||||
"invoice_number": inv.invoice_number,
|
||||
"issuer": inv.issuer,
|
||||
"receiver_name": inv.receiver_customer.name if inv.receiver_customer else None,
|
||||
"amount": float(inv.amount),
|
||||
"billing_date": str(inv.billing_date),
|
||||
"payment_status": inv.payment_status,
|
||||
"payment_date": str(inv.payment_date) if inv.payment_date else None,
|
||||
"payment_amount": float(inv.payment_amount or 0),
|
||||
"payment_due_date": str(inv.payment_due_date) if inv.payment_due_date else None,
|
||||
}
|
||||
for inv in invoices
|
||||
])
|
||||
|
||||
|
||||
@router.put("/{order_id}/payment", summary="更新订单收款状态")
|
||||
async def update_order_payment(
|
||||
order_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
from sqlalchemy import select, update as sa_update
|
||||
from app.models.order import ErpOrder
|
||||
from datetime import datetime
|
||||
|
||||
order = (await db.execute(
|
||||
select(ErpOrder).where(
|
||||
ErpOrder.id == order_id,
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if order is None:
|
||||
from app.core.exceptions import NotFoundException
|
||||
raise NotFoundException("订单不存在")
|
||||
|
||||
values = {}
|
||||
if "paid_amount" in body:
|
||||
paid = float(body["paid_amount"])
|
||||
values["paid_amount"] = paid
|
||||
total = float(order.total_amount)
|
||||
if paid >= total:
|
||||
values["payment_state"] = "cleared"
|
||||
elif paid > 0:
|
||||
values["payment_state"] = "partial"
|
||||
else:
|
||||
values["payment_state"] = "unpaid"
|
||||
if "payment_state" in body:
|
||||
values["payment_state"] = body["payment_state"]
|
||||
if values:
|
||||
values["updated_at"] = datetime.utcnow()
|
||||
await db.execute(
|
||||
sa_update(ErpOrder).where(ErpOrder.id == order_id).values(**values)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return ok(message="收款状态已更新")
|
||||
|
||||
|
||||
@router.get("/{order_id}/invoice-detail-preview", summary="生成开票明细预览")
|
||||
async def invoice_detail_preview(
|
||||
order_id: uuid.UUID,
|
||||
mode: str = Query("full", pattern=r"^(full|batch)$", description="full=整体开票, batch=按发货批次"),
|
||||
shipping_id: uuid.UUID | None = Query(None, description="batch模式下必传发货单ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
"""根据模式生成开票明细: 整体=订单全部商品, 批次=指定发货单商品"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from app.models.order import ErpOrder, ErpOrderItem
|
||||
from app.models.shipping import ErpShippingRecord, ErpShippingItem
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.models.sys import SysCompany
|
||||
from app.core.exceptions import NotFoundException, BizException
|
||||
|
||||
# 查订单
|
||||
order = (await db.execute(
|
||||
select(ErpOrder)
|
||||
.where(ErpOrder.id == order_id, ErpOrder.company_id == company_id, ErpOrder.is_deleted.is_(False))
|
||||
.options(
|
||||
selectinload(ErpOrder.items),
|
||||
selectinload(ErpOrder.customer),
|
||||
selectinload(ErpOrder.salesperson),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if not order:
|
||||
raise NotFoundException("订单不存在")
|
||||
|
||||
# 买方名称
|
||||
buyer_name = order.customer.name if order.customer else ""
|
||||
# 卖方名称
|
||||
company = (await db.execute(
|
||||
select(SysCompany).where(SysCompany.id == company_id)
|
||||
)).scalar_one_or_none()
|
||||
seller_name = company.name if company else ""
|
||||
|
||||
items_data = []
|
||||
total_amount = 0.0
|
||||
|
||||
if mode == "full":
|
||||
# 整体开票: 聚合全部订单明细
|
||||
for oi in (order.items or []):
|
||||
sub = float(oi.sub_total or 0)
|
||||
items_data.append({
|
||||
"sku_code": oi.sku.sku_code if oi.sku else "",
|
||||
"sku_name": oi.sku.name if oi.sku else "",
|
||||
"spec": oi.sku.spec if oi.sku else "",
|
||||
"unit": oi.sku.unit if oi.sku else "",
|
||||
"qty": float(oi.qty),
|
||||
"unit_price": float(oi.unit_price),
|
||||
"sub_total": sub,
|
||||
})
|
||||
total_amount += sub
|
||||
else:
|
||||
# 按发货批次
|
||||
if not shipping_id:
|
||||
raise BizException(message="batch模式需指定shipping_id")
|
||||
ship = (await db.execute(
|
||||
select(ErpShippingRecord)
|
||||
.where(
|
||||
ErpShippingRecord.id == shipping_id,
|
||||
ErpShippingRecord.order_id == order_id,
|
||||
ErpShippingRecord.is_deleted.is_(False),
|
||||
)
|
||||
.options(selectinload(ErpShippingRecord.items).selectinload(ErpShippingItem.sku))
|
||||
)).scalar_one_or_none()
|
||||
if not ship:
|
||||
raise NotFoundException("发货单不存在")
|
||||
|
||||
# 查对应的订单明细来获取单价
|
||||
order_item_map = {str(oi.id): oi for oi in (order.items or [])}
|
||||
for si in (ship.items or []):
|
||||
oi = order_item_map.get(str(si.order_item_id))
|
||||
unit_price = float(oi.unit_price) if oi else 0
|
||||
qty = float(si.shipped_qty)
|
||||
sub = round(qty * unit_price, 2)
|
||||
items_data.append({
|
||||
"sku_code": si.sku.sku_code if si.sku else "",
|
||||
"sku_name": si.sku.name if si.sku else "",
|
||||
"spec": si.sku.spec if si.sku else "",
|
||||
"unit": si.sku.unit if si.sku else "",
|
||||
"qty": qty,
|
||||
"unit_price": unit_price,
|
||||
"sub_total": sub,
|
||||
})
|
||||
total_amount += sub
|
||||
|
||||
return ok(data={
|
||||
"order_no": order.order_no,
|
||||
"buyer_name": buyer_name,
|
||||
"seller_name": seller_name,
|
||||
"customer_id": str(order.customer_id),
|
||||
"items": items_data,
|
||||
"total_amount": round(total_amount, 2),
|
||||
"shipping_id": str(shipping_id) if shipping_id else None,
|
||||
})
|
||||
|
||||
|
||||
@router.post("/{order_id}/invoices/link", summary="关联已有发票到订单")
|
||||
async def link_existing_invoice(
|
||||
order_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
"""将已存在的销项发票关联到该订单"""
|
||||
from sqlalchemy import select, update as sa_update
|
||||
from app.models.finance import FinSalesInvoice
|
||||
from app.core.exceptions import NotFoundException, BizException
|
||||
from datetime import datetime
|
||||
|
||||
invoice_id = body.get("invoice_id")
|
||||
shipping_record_id = body.get("shipping_record_id")
|
||||
if not invoice_id:
|
||||
raise BizException(message="请提供 invoice_id")
|
||||
|
||||
inv = (await db.execute(
|
||||
select(FinSalesInvoice).where(
|
||||
FinSalesInvoice.id == uuid.UUID(invoice_id),
|
||||
FinSalesInvoice.company_id == company_id,
|
||||
FinSalesInvoice.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if not inv:
|
||||
raise NotFoundException("发票不存在")
|
||||
|
||||
values = {"order_id": order_id, "updated_at": datetime.utcnow()}
|
||||
if shipping_record_id:
|
||||
values["shipping_record_id"] = uuid.UUID(shipping_record_id)
|
||||
|
||||
await db.execute(
|
||||
sa_update(FinSalesInvoice)
|
||||
.where(FinSalesInvoice.id == uuid.UUID(invoice_id))
|
||||
.values(**values)
|
||||
)
|
||||
await db.commit()
|
||||
return ok(message="发票已关联到订单")
|
||||
|
||||
|
||||
@router.post("/{order_id}/invoices/create", summary="直接创建发票并关联到订单")
|
||||
async def create_and_link_invoice(
|
||||
order_id: uuid.UUID,
|
||||
body: dict,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
"""创建新的销项发票,同时关联到当前订单"""
|
||||
from sqlalchemy import select
|
||||
from app.models.finance import FinSalesInvoice
|
||||
from app.models.order import ErpOrder
|
||||
from app.core.exceptions import NotFoundException, BizException
|
||||
from datetime import date as dt_date
|
||||
|
||||
order = (await db.execute(
|
||||
select(ErpOrder).where(
|
||||
ErpOrder.id == order_id,
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if not order:
|
||||
raise NotFoundException("订单不存在")
|
||||
|
||||
invoice_number = body.get("invoice_number", "").strip()
|
||||
amount = float(body.get("amount", 0))
|
||||
issuer = body.get("issuer", "").strip()
|
||||
receiver_customer_id = body.get("receiver_customer_id") or str(order.customer_id)
|
||||
billing_date_str = body.get("billing_date")
|
||||
shipping_record_id = body.get("shipping_record_id")
|
||||
remark = body.get("remark")
|
||||
|
||||
if not invoice_number:
|
||||
raise BizException(message="请填写发票号")
|
||||
if amount <= 0:
|
||||
raise BizException(message="开票金额需大于0")
|
||||
if not issuer:
|
||||
raise BizException(message="请填写开票方名称")
|
||||
|
||||
# 检查唯一性
|
||||
from sqlalchemy import func as sa_func
|
||||
existing = (await db.execute(
|
||||
select(sa_func.count()).select_from(FinSalesInvoice).where(
|
||||
FinSalesInvoice.invoice_number == invoice_number,
|
||||
FinSalesInvoice.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar()
|
||||
if existing:
|
||||
raise BizException(message=f"发票号 {invoice_number} 已存在")
|
||||
|
||||
inv = FinSalesInvoice(
|
||||
issuer=issuer,
|
||||
receiver_customer_id=uuid.UUID(receiver_customer_id),
|
||||
invoice_number=invoice_number,
|
||||
amount=amount,
|
||||
billing_date=dt_date.fromisoformat(billing_date_str) if billing_date_str else dt_date.today(),
|
||||
remark=remark,
|
||||
order_id=order_id,
|
||||
shipping_record_id=uuid.UUID(shipping_record_id) if shipping_record_id else None,
|
||||
created_by=current_user.user_id,
|
||||
company_id=company_id,
|
||||
)
|
||||
db.add(inv)
|
||||
await db.commit()
|
||||
return ok(data={"id": str(inv.id), "invoice_number": invoice_number}, message="发票创建并关联成功")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.erp import CategoryCreate, CategoryUpdate, InventoryFlowCreate, SkuCreate, SkuUpdate
|
||||
@@ -64,8 +64,9 @@ async def list_skus(
|
||||
keyword: str | None = Query(None, description="模糊搜索 SKU 编码或名称"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_skus(db, page, size, category_id, keyword)
|
||||
result = await svc.list_skus(db, company_id, page, size, category_id, keyword)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -95,8 +96,9 @@ async def create_inventory_flow(
|
||||
body: InventoryFlowCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_inventory_flow(db, current_user, body)
|
||||
result = await svc.create_inventory_flow(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message="库存变更成功")
|
||||
|
||||
|
||||
@@ -107,6 +109,7 @@ async def get_inventory_flows(
|
||||
size: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_inventory_flows(db, sku_id, page, size)
|
||||
result = await svc.get_inventory_flows(db, sku_id, company_id, page, size)
|
||||
return ok(data=result)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
利润核算路由 —— /api/profit
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
from app.services import profit_service as svc
|
||||
|
||||
router = APIRouter(prefix="/profit", tags=["利润核算"])
|
||||
|
||||
|
||||
@router.get("/report", summary="利润报表(订单维度)")
|
||||
async def profit_report(
|
||||
start_date: str | None = Query(None, description="起始日期 YYYY-MM-DD"),
|
||||
end_date: str | None = Query(None, description="结束日期 YYYY-MM-DD"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_profit_report(db, company_id, start_date, end_date)
|
||||
return ok(data=result)
|
||||
|
||||
|
||||
@router.post("/snapshot/{order_id}", summary="为订单锚定成本快照")
|
||||
async def snapshot_costs(
|
||||
order_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.snapshot_order_item_costs(db, order_id, company_id)
|
||||
return ok(data=result, message=f"已为 {len(result)} 项明细锚定成本")
|
||||
+16
-10
@@ -14,7 +14,7 @@ from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
@@ -28,15 +28,16 @@ async def generate_report(
|
||||
end_date: date = Body(..., embed=True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
authorization: str | None = Header(None),
|
||||
):
|
||||
"""
|
||||
1. 聚合该用户在时间范围内的 sales_logs 内容
|
||||
1. 聚合该用户在时间范围内、涉及当前公司的 sales_logs 内容
|
||||
2. 调用 Dify Workflow (streaming) 生成复盘报告
|
||||
3. SSE 流式返回给前端
|
||||
"""
|
||||
return StreamingResponse(
|
||||
_report_sse_generator(db, current_user, start_date, end_date, authorization or ""),
|
||||
_report_sse_generator(db, current_user, start_date, end_date, authorization or "", company_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -47,20 +48,25 @@ async def _report_sse_generator(
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
authorization: str = "",
|
||||
company_id: uuid.UUID | None = None,
|
||||
):
|
||||
import httpx
|
||||
from app.core.config import settings
|
||||
from app.models.ai import SalesLog
|
||||
|
||||
# 1. 聚合日志
|
||||
# 1. 聚合日志 — 仅提取涉及当前公司的日志
|
||||
conditions = [
|
||||
SalesLog.salesperson_id == user.user_id,
|
||||
SalesLog.log_date >= start_date,
|
||||
SalesLog.log_date <= end_date,
|
||||
SalesLog.is_deleted.is_(False),
|
||||
]
|
||||
if company_id:
|
||||
conditions.append(SalesLog.involved_company_ids.any(company_id))
|
||||
|
||||
stmt = (
|
||||
select(SalesLog)
|
||||
.where(
|
||||
SalesLog.salesperson_id == user.user_id,
|
||||
SalesLog.log_date >= start_date,
|
||||
SalesLog.log_date <= end_date,
|
||||
SalesLog.is_deleted.is_(False),
|
||||
)
|
||||
.where(*conditions)
|
||||
.order_by(SalesLog.log_date)
|
||||
)
|
||||
logs = (await db.execute(stmt)).scalars().all()
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.sales_invoice import SalesInvoiceCreate, SalesInvoiceUpdate
|
||||
@@ -26,8 +26,9 @@ async def create_invoice(
|
||||
body: SalesInvoiceCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.create_invoice(db, current_user, body)
|
||||
result = await svc.create_invoice(db, current_user, body, company_id)
|
||||
return ok(data=result.model_dump(mode="json"), message="销项发票创建成功")
|
||||
|
||||
|
||||
@@ -42,10 +43,11 @@ async def list_invoices(
|
||||
end_date: date | None = Query(None, description="开票结束日期"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_invoices(
|
||||
db, page, size, customer_name, invoice_number,
|
||||
payment_status, start_date, end_date,
|
||||
payment_status, start_date, end_date, company_id,
|
||||
)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.response import ok
|
||||
@@ -26,6 +28,7 @@ async def list_logs(
|
||||
end_date: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
result = await sales_log_service.list_logs(
|
||||
db, current_user,
|
||||
@@ -34,18 +37,56 @@ async def list_logs(
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
company_id=company_id,
|
||||
)
|
||||
return ok(data=result)
|
||||
|
||||
|
||||
async def _resolve_company_ids(
|
||||
db: AsyncSession,
|
||||
company_id: uuid.UUID,
|
||||
customer_id: str | None,
|
||||
company_ids: list[str] | None,
|
||||
) -> list[uuid.UUID]:
|
||||
"""
|
||||
智能解析 involved_company_ids:
|
||||
1. 如果前端显式传了 company_ids,使用它
|
||||
2. 否则以当前视角公司为基础
|
||||
3. 如果选了客户,自动查客户 owner 所属的公司,合并进来
|
||||
"""
|
||||
if company_ids:
|
||||
resolved = set(uuid.UUID(cid) for cid in company_ids)
|
||||
else:
|
||||
resolved = {company_id}
|
||||
|
||||
# 自动关联客户 owner 所在公司
|
||||
if customer_id:
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.models.sys import SysUserCompany
|
||||
cust = await db.get(CrmCustomer, uuid.UUID(customer_id))
|
||||
if cust and cust.owner_id:
|
||||
stmt = select(SysUserCompany.company_id).where(
|
||||
SysUserCompany.user_id == cust.owner_id
|
||||
)
|
||||
rows = (await db.execute(stmt)).scalars().all()
|
||||
for cid in rows:
|
||||
resolved.add(cid)
|
||||
|
||||
# 确保当前公司始终在内
|
||||
resolved.add(company_id)
|
||||
return list(resolved)
|
||||
|
||||
|
||||
@router.post("", summary="创建销售日志")
|
||||
async def create_log(
|
||||
content: str = Body(..., embed=True),
|
||||
customer_id: str | None = Body(None, embed=True),
|
||||
contact_ids: list[str] | None = Body(None, embed=True),
|
||||
log_date: str | None = Body(None, embed=True),
|
||||
company_ids: list[str] | None = Body(None, embed=True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
from datetime import date as date_type
|
||||
|
||||
@@ -53,17 +94,20 @@ async def create_log(
|
||||
if log_date:
|
||||
parsed_date = date_type.fromisoformat(log_date)
|
||||
|
||||
# 智能解析公司关联
|
||||
resolved_company_ids = await _resolve_company_ids(db, company_id, customer_id, company_ids)
|
||||
|
||||
result = await sales_log_service.create_log(
|
||||
db, current_user,
|
||||
content=content,
|
||||
customer_id=customer_id,
|
||||
contact_ids=contact_ids,
|
||||
log_date=parsed_date,
|
||||
company_ids=resolved_company_ids,
|
||||
)
|
||||
|
||||
# 异步触发 Dify 画像提取工作流(仅当关联了客户时)
|
||||
if customer_id:
|
||||
import uuid
|
||||
asyncio.create_task(
|
||||
sales_log_service.trigger_persona_workflow(
|
||||
log_id=uuid.UUID(result["id"]),
|
||||
@@ -75,3 +119,35 @@ async def create_log(
|
||||
)
|
||||
|
||||
return ok(data=result, message="日志创建成功")
|
||||
|
||||
|
||||
@router.put("/{log_id}", summary="编辑销售日志")
|
||||
async def update_log(
|
||||
log_id: uuid.UUID,
|
||||
content: str | None = Body(None, embed=True),
|
||||
customer_id: str | None = Body(None, embed=True),
|
||||
contact_ids: list[str] | None = Body(None, embed=True),
|
||||
log_date: str | None = Body(None, embed=True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
):
|
||||
result = await sales_log_service.update_log(
|
||||
db, current_user, log_id,
|
||||
content=content,
|
||||
customer_id=customer_id,
|
||||
contact_ids=contact_ids,
|
||||
log_date=log_date,
|
||||
company_id=company_id,
|
||||
)
|
||||
return ok(data=result, message="日志更新成功")
|
||||
|
||||
|
||||
@router.delete("/{log_id}", summary="删除销售日志(软删除)")
|
||||
async def delete_log(
|
||||
log_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
):
|
||||
await sales_log_service.delete_log(db, current_user, log_id)
|
||||
return ok(message="日志已删除")
|
||||
|
||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_current_company_id
|
||||
from app.db.database import get_db
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.shipping import ShippingCreate
|
||||
@@ -21,8 +21,9 @@ async def create_shipping(
|
||||
body: ShippingCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
resp, new_state = await svc.create_shipping(db, current_user, body)
|
||||
resp, new_state = await svc.create_shipping(db, current_user, body, company_id)
|
||||
return ok(data=resp.model_dump(mode="json"), message=f"发货单 {resp.shipping_no} 创建成功,订单状态已更新为 {new_state}")
|
||||
|
||||
|
||||
@@ -34,8 +35,9 @@ async def list_shipping(
|
||||
tracking_no: str | None = Query(None, description="按物流单号搜索"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.list_shipping(db, current_user, page, size, order_no, tracking_no)
|
||||
result = await svc.list_shipping(db, current_user, page, size, order_no, tracking_no, company_id)
|
||||
return ok(data=result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@@ -44,6 +46,7 @@ async def get_shipping_by_order(
|
||||
order_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUserPayload = Depends(get_current_user),
|
||||
company_id: uuid.UUID = Depends(get_current_company_id),
|
||||
) -> dict:
|
||||
result = await svc.get_shipping_by_order(db, current_user, order_id)
|
||||
result = await svc.get_shipping_by_order(db, current_user, order_id, company_id)
|
||||
return ok(data=result)
|
||||
|
||||
@@ -26,6 +26,10 @@ from app.api.sales_invoice import router as sales_invoice_router
|
||||
from app.api.reports import router as reports_router
|
||||
from app.api.contacts import router as contacts_router
|
||||
from app.api.dashboard import router as dashboard_router
|
||||
from app.api.companies import router as companies_router
|
||||
from app.api.contracts import router as contracts_router
|
||||
from app.api.profit import router as profit_router
|
||||
from app.api.ai_coaching import router as ai_coaching_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -33,8 +37,11 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
"""应用生命周期:启动/关闭时的钩子"""
|
||||
# ── startup ──
|
||||
print(f"🚀 {settings.APP_NAME} v{settings.APP_VERSION} 启动中...")
|
||||
from app.services.ocr_worker import ocr_worker
|
||||
ocr_worker.start()
|
||||
yield
|
||||
# ── shutdown ──
|
||||
await ocr_worker.stop()
|
||||
print("👋 服务正在关闭...")
|
||||
|
||||
|
||||
@@ -81,6 +88,10 @@ app.include_router(sales_invoice_router, prefix="/api")
|
||||
app.include_router(reports_router, prefix="/api")
|
||||
app.include_router(contacts_router, prefix="/api")
|
||||
app.include_router(dashboard_router, prefix="/api")
|
||||
app.include_router(companies_router, prefix="/api")
|
||||
app.include_router(contracts_router, prefix="/api")
|
||||
app.include_router(profit_router, prefix="/api")
|
||||
app.include_router(ai_coaching_router, prefix="/api")
|
||||
|
||||
|
||||
# ── 健康检查 ──
|
||||
|
||||
+26
-1
@@ -7,7 +7,7 @@ import uuid
|
||||
from datetime import date, datetime
|
||||
|
||||
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, SmallInteger, String, Text, func
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB, ARRAY
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import Base
|
||||
@@ -30,11 +30,19 @@ class SalesLog(Base):
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
salesperson_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=False)
|
||||
involved_company_ids: Mapped[list] = mapped_column(
|
||||
ARRAY(UUID(as_uuid=True)), nullable=False, default=list,
|
||||
comment="该篇日志涉及的公司ID列表"
|
||||
)
|
||||
customer_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("crm_customers.id"), nullable=True)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
log_date: Mapped[date] = mapped_column(Date, default=date.today)
|
||||
contact_ids: Mapped[list | None] = mapped_column(JSONB, default=list, nullable=True)
|
||||
ai_processed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
ai_coaching_feedback: Mapped[dict | None] = mapped_column(
|
||||
JSONB, default=dict, nullable=True,
|
||||
comment="AI 教练引擎回写的指导反馈"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -53,3 +61,20 @@ class AiReportDraft(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
|
||||
class KbObsidianVector(Base):
|
||||
"""知识库向量表 —— pgvector 存储 Obsidian 文档分块向量"""
|
||||
__tablename__ = "kb_obsidian_vectors"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
|
||||
)
|
||||
source_path: Mapped[str] = mapped_column(String(500), nullable=False, comment="源文件路径")
|
||||
chunk_index: Mapped[int] = mapped_column(SmallInteger, default=0)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
metadata_: Mapped[dict | None] = mapped_column("metadata", JSONB, default=dict)
|
||||
# 向量字段使用 raw SQL 创建(vector(1536))因 SQLAlchemy 无原生 pgvector 类型
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
合同域 ORM 模型
|
||||
映射: erp_contracts / erp_contract_items / erp_contract_attachments
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Date,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base
|
||||
|
||||
|
||||
# ── 付款条件枚举 ─────────────────────────────────────────
|
||||
PAYMENT_TERMS = [
|
||||
"预付全款订货",
|
||||
"预付30%订货,到货前付清",
|
||||
"预付50%订货,到货前付清",
|
||||
"货到付全款",
|
||||
"开具发票后30天内付款",
|
||||
"开具发票45天付款",
|
||||
"开具发票60天付款",
|
||||
"开具发票90天付款",
|
||||
]
|
||||
|
||||
# ── 运费条款枚举 ─────────────────────────────────────────
|
||||
SHIPPING_TERMS = [
|
||||
"买方自提",
|
||||
"卖方免费送达天津指定地点",
|
||||
"卖方免费送达指定地点",
|
||||
"物流发货,运费买方承担",
|
||||
]
|
||||
|
||||
|
||||
class ErpContract(Base):
|
||||
"""合同主表 —— B2B 交易防线核心"""
|
||||
__tablename__ = "erp_contracts"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
contract_no: Mapped[str] = mapped_column(String(30), unique=True, nullable=False)
|
||||
buyer_customer_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("crm_customers.id"), nullable=False,
|
||||
comment="买方(CRM 客户)"
|
||||
)
|
||||
seller_company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False,
|
||||
comment="卖方(当前操作公司)"
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True,
|
||||
comment="多租户隔离"
|
||||
)
|
||||
total_amount_excl_tax: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
|
||||
total_amount_incl_tax: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
|
||||
total_amount_cn: Mapped[str | None] = mapped_column(
|
||||
String(100), nullable=True, comment="大写合计金额"
|
||||
)
|
||||
payment_terms: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False, default="货到付全款"
|
||||
)
|
||||
shipping_terms: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False, default="买方自提"
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, default="draft",
|
||||
comment="draft→active→completed→cancelled"
|
||||
)
|
||||
is_signed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
signed_file_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
linked_order_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_orders.id"), nullable=True,
|
||||
comment="一键推单后回填"
|
||||
)
|
||||
salesperson_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
|
||||
)
|
||||
sign_date: Mapped[date | None] = mapped_column(Date, nullable=True)
|
||||
remark: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
delivery_terms: Mapped[str | None] = mapped_column(
|
||||
String(200), nullable=True, comment="货期(手动输入)"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# 关系
|
||||
buyer_customer: Mapped["CrmCustomer"] = relationship( # noqa: F821
|
||||
"CrmCustomer", lazy="selectin"
|
||||
)
|
||||
seller_company: Mapped["SysCompany"] = relationship( # noqa: F821
|
||||
"SysCompany", foreign_keys=[seller_company_id], lazy="selectin"
|
||||
)
|
||||
salesperson: Mapped["SysUser | None"] = relationship("SysUser", foreign_keys=[salesperson_id], lazy="selectin") # noqa: F821
|
||||
linked_order: Mapped["ErpOrder | None"] = relationship("ErpOrder", foreign_keys=[linked_order_id], lazy="selectin") # noqa: F821
|
||||
items: Mapped[list["ErpContractItem"]] = relationship(
|
||||
"ErpContractItem", back_populates="contract", lazy="selectin"
|
||||
)
|
||||
attachments: Mapped[list["ErpContractAttachment"]] = relationship(
|
||||
"ErpContractAttachment", back_populates="contract", lazy="selectin"
|
||||
)
|
||||
|
||||
|
||||
class ErpContractItem(Base):
|
||||
"""合同明细行"""
|
||||
__tablename__ = "erp_contract_items"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
contract_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_contracts.id"), nullable=False
|
||||
)
|
||||
sku_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_product_skus.id"), nullable=False
|
||||
)
|
||||
qty: Mapped[float] = mapped_column(Numeric(12, 2), nullable=False)
|
||||
unit_price: Mapped[float] = mapped_column(Numeric(12, 2), nullable=False)
|
||||
sub_total: Mapped[float] = mapped_column(Numeric(14, 2), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# 关系
|
||||
contract: Mapped[ErpContract] = relationship("ErpContract", back_populates="items")
|
||||
sku: Mapped["ProductSku"] = relationship("ProductSku", lazy="selectin") # noqa: F821
|
||||
|
||||
|
||||
class ErpContractAttachment(Base):
|
||||
"""合同附件(双签盖章版等)"""
|
||||
__tablename__ = "erp_contract_attachments"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
contract_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_contracts.id"), nullable=False
|
||||
)
|
||||
file_name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
file_url: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
file_type: Mapped[str] = mapped_column(
|
||||
String(30), nullable=False, default="signed_copy",
|
||||
comment="signed_copy / supplement / other"
|
||||
)
|
||||
uploader_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# 关系
|
||||
contract: Mapped[ErpContract] = relationship("ErpContract", back_populates="attachments")
|
||||
uploader: Mapped["SysUser | None"] = relationship("SysUser", lazy="selectin") # noqa: F821
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
成本域 ORM 模型
|
||||
映射: erp_order_item_costs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Numeric, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base
|
||||
|
||||
|
||||
class ErpOrderItemCost(Base):
|
||||
"""订单明细成本快照表 —— 发货/确认瞬间锚定 MWA 成本"""
|
||||
__tablename__ = "erp_order_item_costs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
order_item_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_order_items.id"), nullable=False, unique=True,
|
||||
comment="关联订单明细"
|
||||
)
|
||||
purchase_unit_price: Mapped[float] = mapped_column(
|
||||
Numeric(12, 4), nullable=False, comment="MWA 成本快照"
|
||||
)
|
||||
profit_amount: Mapped[float] = mapped_column(
|
||||
Numeric(14, 2), default=0, comment="利润额 = (售价-成本)*数量"
|
||||
)
|
||||
profit_rate: Mapped[float] = mapped_column(
|
||||
Numeric(5, 4), default=0, comment="利润率"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
# 关系
|
||||
order_item: Mapped["ErpOrderItem"] = relationship("ErpOrderItem", lazy="selectin") # noqa: F821
|
||||
@@ -29,6 +29,18 @@ class CrmCustomer(Base):
|
||||
address: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
ai_score: Mapped[float] = mapped_column(Numeric(5, 2), default=0)
|
||||
ai_persona: Mapped[dict | None] = mapped_column(JSONB, default=dict, nullable=True)
|
||||
billing_info: Mapped[dict | None] = mapped_column(
|
||||
JSONB, default=dict, nullable=True,
|
||||
comment="客户开票信息: company_name/tax_id/address/phone/bank_name/bank_account"
|
||||
)
|
||||
health_score: Mapped[float] = mapped_column(
|
||||
Numeric(5, 2), default=0,
|
||||
comment="客户健康度评分 (AI 教练引擎计算)"
|
||||
)
|
||||
meddic_status: Mapped[dict | None] = mapped_column(
|
||||
JSONB, default=dict, nullable=True,
|
||||
comment="MEDDIC 六维评估状态"
|
||||
)
|
||||
owner_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
|
||||
)
|
||||
|
||||
@@ -12,11 +12,13 @@ from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
SmallInteger,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
@@ -56,8 +58,6 @@ class ProductSku(Base):
|
||||
name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
spec: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
standard_price: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
|
||||
stock_qty: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
|
||||
warning_threshold: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
|
||||
unit: Mapped[str] = mapped_column(String(20), default="桶")
|
||||
status: Mapped[int] = mapped_column(SmallInteger, default=1)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
@@ -80,9 +80,18 @@ class InventoryFlow(Base):
|
||||
sku_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_product_skus.id"), nullable=False
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
|
||||
)
|
||||
change_qty: Mapped[float] = mapped_column(Numeric(12, 2), nullable=False)
|
||||
reason: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
remark: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
purchase_unit_price: Mapped[float] = mapped_column(
|
||||
Numeric(12, 2), default=0, comment="入库采购单价"
|
||||
)
|
||||
is_special_zero_cost: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, comment="特殊零元入库标识,不参与 MWA 计算"
|
||||
)
|
||||
operator_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
|
||||
)
|
||||
@@ -94,3 +103,34 @@ class InventoryFlow(Base):
|
||||
|
||||
sku: Mapped[ProductSku | None] = relationship("ProductSku", lazy="selectin")
|
||||
operator: Mapped["SysUser | None"] = relationship("SysUser", lazy="selectin") # noqa: F821
|
||||
|
||||
|
||||
class ErpSkuInventory(Base):
|
||||
"""SKU 分公司库存表 —— 同一 SKU 在不同公司有独立库存"""
|
||||
__tablename__ = "erp_sku_inventory"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
sku_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_product_skus.id"), nullable=False
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
|
||||
)
|
||||
stock_qty: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
|
||||
warning_threshold: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
|
||||
mwa_unit_cost: Mapped[float] = mapped_column(
|
||||
Numeric(12, 4), default=0,
|
||||
comment="移动加权均价 (Moving Weighted Average)"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("sku_id", "company_id", name="uq_sku_company"),
|
||||
)
|
||||
|
||||
sku: Mapped[ProductSku | None] = relationship("ProductSku", lazy="selectin")
|
||||
|
||||
@@ -33,6 +33,9 @@ class FinInvoicePool(Base):
|
||||
uploader_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
|
||||
)
|
||||
file_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
merchant_name: Mapped[str | None] = mapped_column(String(200), nullable=True)
|
||||
amount: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
|
||||
@@ -59,6 +62,9 @@ class FinExpenseRecord(Base):
|
||||
applicant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=False
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
|
||||
)
|
||||
total_amount: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
|
||||
status: Mapped[str] = mapped_column(String(20), nullable=False, default="draft")
|
||||
remark: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
@@ -134,9 +140,23 @@ class FinSalesInvoice(Base):
|
||||
payment_date: Mapped[date | None] = mapped_column(Date, nullable=True)
|
||||
payment_amount: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
|
||||
remark: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
order_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_orders.id"), nullable=True,
|
||||
comment="关联订单"
|
||||
)
|
||||
shipping_record_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_shipping_records.id"), nullable=True,
|
||||
comment="关联发货单"
|
||||
)
|
||||
payment_due_date: Mapped[date | None] = mapped_column(
|
||||
Date, nullable=True, comment="回款截止日(根据合同付款条件自动推算)"
|
||||
)
|
||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), onupdate=func.now()
|
||||
@@ -150,3 +170,43 @@ class FinSalesInvoice(Base):
|
||||
creator: Mapped["SysUser | None"] = relationship( # noqa: F821
|
||||
"SysUser", lazy="selectin"
|
||||
)
|
||||
|
||||
|
||||
class FinOcrTask(Base):
|
||||
"""OCR 处理任务队列 — 持久化排队"""
|
||||
__tablename__ = "fin_ocr_tasks"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
file_url: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
file_ext: Mapped[str] = mapped_column(String(10), nullable=False, comment=".pdf/.png/.jpg")
|
||||
original_name: Mapped[str] = mapped_column(String(200), nullable=False, default="")
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, default="pending",
|
||||
comment="pending/processing/success/failed/manual",
|
||||
)
|
||||
priority: Mapped[int] = mapped_column(default=100, comment="值越小越优先")
|
||||
ocr_result: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
retry_count: Mapped[int] = mapped_column(default=0)
|
||||
max_retries: Mapped[int] = mapped_column(default=3)
|
||||
invoice_pool_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("fin_invoice_pool.id"), nullable=True,
|
||||
comment="成功入池后关联的发票 ID",
|
||||
)
|
||||
uploader_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True,
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True,
|
||||
)
|
||||
inv_type: Mapped[str] = mapped_column(String(30), nullable=False, default="expense")
|
||||
scheduled_after: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
uploader: Mapped["SysUser | None"] = relationship("SysUser", lazy="selectin") # noqa: F821
|
||||
|
||||
@@ -37,6 +37,13 @@ class ErpOrder(Base):
|
||||
salesperson_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
|
||||
)
|
||||
contract_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("erp_contracts.id"), nullable=True,
|
||||
comment="来源合同(一键推单后回填)"
|
||||
)
|
||||
total_amount: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
|
||||
shipping_state: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, default="pending"
|
||||
|
||||
@@ -42,6 +42,9 @@ class ErpShippingRecord(Base):
|
||||
operator_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), onupdate=func.now()
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, SmallInteger, String, Text, func
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, SmallInteger, String, Text, UniqueConstraint, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@@ -97,3 +97,44 @@ class SysUser(Base):
|
||||
"SysDepartment", lazy="selectin"
|
||||
)
|
||||
role: Mapped[SysRole | None] = relationship("SysRole", lazy="selectin")
|
||||
|
||||
|
||||
class SysCompany(Base):
|
||||
"""公司主体表 —— 多租户逻辑隔离核心"""
|
||||
__tablename__ = "sys_companies"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
code: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
||||
full_info: Mapped[dict | None] = mapped_column(
|
||||
JSONB, default=dict, nullable=True,
|
||||
comment="公司完整信息: full_name/address/phone/bank_name/bank_account/tax_id"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class SysUserCompany(Base):
|
||||
"""用户-公司多对多关联 —— IDOR 防护核心"""
|
||||
__tablename__ = "sys_user_companies"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=False
|
||||
)
|
||||
company_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False
|
||||
)
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "company_id", name="uq_user_company"),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
合同域 Pydantic V2 Schemas
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── 合同明细行 ────────────────────────────────────────────
|
||||
class ContractItemCreate(BaseModel):
|
||||
sku_id: uuid.UUID
|
||||
qty: float = Field(gt=0)
|
||||
unit_price: float = Field(ge=0)
|
||||
sub_total: float = Field(ge=0)
|
||||
|
||||
|
||||
class ContractItemResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
sku_id: uuid.UUID
|
||||
sku_code: str | None = None
|
||||
sku_name: str | None = None
|
||||
spec: str | None = None
|
||||
unit: str | None = None
|
||||
qty: float
|
||||
unit_price: float
|
||||
sub_total: float
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── 合同创建 ──────────────────────────────────────────────
|
||||
class ContractCreate(BaseModel):
|
||||
buyer_customer_id: uuid.UUID
|
||||
items: list[ContractItemCreate] = Field(min_length=1)
|
||||
payment_terms: str = "货到付全款"
|
||||
shipping_terms: str = "买方自提"
|
||||
remark: str | None = None
|
||||
delivery_terms: str | None = None
|
||||
sign_date: date | None = None
|
||||
|
||||
|
||||
# ── 合同更新 ──────────────────────────────────────────────
|
||||
class ContractUpdate(BaseModel):
|
||||
buyer_customer_id: uuid.UUID | None = None
|
||||
items: list[ContractItemCreate] | None = None
|
||||
payment_terms: str | None = None
|
||||
shipping_terms: str | None = None
|
||||
status: str | None = None
|
||||
is_signed: bool | None = None
|
||||
remark: str | None = None
|
||||
delivery_terms: str | None = None
|
||||
sign_date: date | None = None
|
||||
|
||||
|
||||
# ── 执行进度 ──────────────────────────────────────────────
|
||||
class ContractProgressResponse(BaseModel):
|
||||
is_signed: bool = False
|
||||
has_order: bool = False
|
||||
order_id: uuid.UUID | None = None
|
||||
has_shipped: bool = False
|
||||
has_invoice: bool = False
|
||||
is_paid: bool = False
|
||||
|
||||
|
||||
# ── 合同响应 ──────────────────────────────────────────────
|
||||
class ContractResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
contract_no: str
|
||||
buyer_customer_id: uuid.UUID
|
||||
buyer_customer_name: str | None = None
|
||||
seller_company_id: uuid.UUID
|
||||
seller_company_name: str | None = None
|
||||
company_id: uuid.UUID
|
||||
total_amount_excl_tax: float = 0
|
||||
total_amount_incl_tax: float = 0
|
||||
total_amount_cn: str | None = None
|
||||
payment_terms: str
|
||||
shipping_terms: str
|
||||
status: str
|
||||
is_signed: bool = False
|
||||
signed_file_url: str | None = None
|
||||
linked_order_id: uuid.UUID | None = None
|
||||
salesperson_id: uuid.UUID | None = None
|
||||
salesperson_name: str | None = None
|
||||
sign_date: date | None = None
|
||||
remark: str | None = None
|
||||
delivery_terms: str | None = None
|
||||
items: list[ContractItemResponse] = []
|
||||
progress: ContractProgressResponse | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── 分页列表 ──────────────────────────────────────────────
|
||||
class ContractListResponse(BaseModel):
|
||||
total: int
|
||||
items: list[ContractResponse]
|
||||
page: int
|
||||
size: int
|
||||
@@ -12,6 +12,16 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── 开票信息子结构 ─────────────────────────────────────────
|
||||
class BillingInfoSchema(BaseModel):
|
||||
company_name: str | None = Field(default=None, max_length=200, description="开票公司全称")
|
||||
tax_id: str | None = Field(default=None, max_length=50, description="纳税人识别号")
|
||||
address: str | None = Field(default=None, max_length=300, description="地址")
|
||||
phone: str | None = Field(default=None, max_length=30, description="电话")
|
||||
bank_name: str | None = Field(default=None, max_length=200, description="开户行")
|
||||
bank_account: str | None = Field(default=None, max_length=50, description="银行账号")
|
||||
|
||||
|
||||
# ── 创建 ──────────────────────────────────────────────────
|
||||
class CustomerCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=200, examples=["中石化润滑油公司"])
|
||||
@@ -21,6 +31,7 @@ class CustomerCreate(BaseModel):
|
||||
phone: str | None = Field(default=None, max_length=30)
|
||||
email: str | None = Field(default=None, max_length=100)
|
||||
address: str | None = None
|
||||
billing_info: BillingInfoSchema | None = None
|
||||
status: int = Field(default=1, ge=0, le=1)
|
||||
|
||||
|
||||
@@ -33,6 +44,7 @@ class CustomerUpdate(BaseModel):
|
||||
phone: str | None = Field(default=None, max_length=30)
|
||||
email: str | None = Field(default=None, max_length=100)
|
||||
address: str | None = None
|
||||
billing_info: BillingInfoSchema | None = None
|
||||
status: int | None = Field(default=None, ge=0, le=1)
|
||||
|
||||
|
||||
@@ -48,6 +60,7 @@ class CustomerResponse(BaseModel):
|
||||
address: str | None = None
|
||||
ai_score: float = 0
|
||||
ai_persona: dict[str, Any] | None = None
|
||||
billing_info: dict[str, Any] | None = None
|
||||
owner_id: uuid.UUID | None = None
|
||||
owner_name: str | None = None
|
||||
status: int = 1
|
||||
|
||||
@@ -104,6 +104,8 @@ class InventoryFlowCreate(BaseModel):
|
||||
examples=["purchase"],
|
||||
)
|
||||
remark: str | None = Field(default=None, description="备注")
|
||||
purchase_unit_price: float = Field(default=0, ge=0, description="采购单价(仅入库时有意义)")
|
||||
is_special_zero_cost: bool = Field(default=False, description="特殊零元入库标识,不参与 MWA 计算")
|
||||
|
||||
|
||||
class InventoryFlowResponse(BaseModel):
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
AI 教练引擎 — 事件总线 + Dify 回调
|
||||
CQRS 解耦模式:
|
||||
1. 业务端 POST /api/sales-logs → 立即 200 OK → 发消息到 Redis Streams
|
||||
2. Worker 消费消息 → 调用 Dify Workflow → 写回 ai_coaching_feedback
|
||||
3. 前端通过 SSE /api/notifications/stream 接收推送
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.ai import SalesLog
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
|
||||
|
||||
# ── Redis 事件发布 ───────────────────────────────────────
|
||||
async def publish_coaching_event(
|
||||
sales_log_id: uuid.UUID,
|
||||
content: str,
|
||||
customer_id: uuid.UUID | None = None,
|
||||
salesperson_id: uuid.UUID | None = None,
|
||||
) -> None:
|
||||
"""将销售日志推送到 Redis Streams,供 Worker 异步消费"""
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
import os
|
||||
|
||||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
r = aioredis.from_url(redis_url, decode_responses=True)
|
||||
await r.xadd(
|
||||
"coaching:sales_logs",
|
||||
{
|
||||
"sales_log_id": str(sales_log_id),
|
||||
"content": content[:2000], # 限长
|
||||
"customer_id": str(customer_id) if customer_id else "",
|
||||
"salesperson_id": str(salesperson_id) if salesperson_id else "",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
},
|
||||
)
|
||||
await r.aclose()
|
||||
except Exception as e:
|
||||
# Redis 不可用时降级——不阻塞主流程
|
||||
print(f"[AI EventBus] Redis 推送失败(降级): {e}")
|
||||
|
||||
|
||||
# ── Dify 回调处理 ───────────────────────────────────────
|
||||
async def handle_dify_coaching_callback(
|
||||
db: AsyncSession,
|
||||
sales_log_id: uuid.UUID,
|
||||
feedback: dict,
|
||||
) -> None:
|
||||
"""Dify Workflow 回调 → 写回 SalesLog.ai_coaching_feedback"""
|
||||
await db.execute(
|
||||
update(SalesLog)
|
||||
.where(SalesLog.id == sales_log_id)
|
||||
.values(
|
||||
ai_coaching_feedback=feedback,
|
||||
ai_processed=True,
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
# 如果反馈中包含客户健康评分,同步更新 CrmCustomer
|
||||
health_score = feedback.get("health_score")
|
||||
meddic_status = feedback.get("meddic_status")
|
||||
if health_score is not None or meddic_status is not None:
|
||||
log = (await db.execute(
|
||||
select(SalesLog).where(SalesLog.id == sales_log_id)
|
||||
)).scalar_one_or_none()
|
||||
if log and log.customer_id:
|
||||
update_vals: dict = {}
|
||||
if health_score is not None:
|
||||
update_vals["health_score"] = float(health_score)
|
||||
if meddic_status is not None:
|
||||
update_vals["meddic_status"] = meddic_status
|
||||
if update_vals:
|
||||
await db.execute(
|
||||
update(CrmCustomer)
|
||||
.where(CrmCustomer.id == log.customer_id)
|
||||
.values(**update_vals)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ── SSE 通知流 ──────────────────────────────────────────
|
||||
async def sse_notification_generator(user_id: uuid.UUID):
|
||||
"""服务端推送事件流(SSE)—— 监听 Redis PubSub 频道"""
|
||||
import asyncio
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
import os
|
||||
|
||||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
r = aioredis.from_url(redis_url, decode_responses=True)
|
||||
pubsub = r.pubsub()
|
||||
channel = f"notifications:{user_id}"
|
||||
await pubsub.subscribe(channel)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
yield f"data: {message['data']}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
@@ -0,0 +1,762 @@
|
||||
"""
|
||||
合同管理 Service 层
|
||||
核心逻辑:CRUD + 一键推单 + 账期引擎 + 执行进度聚合
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime, timedelta
|
||||
import re
|
||||
|
||||
from sqlalchemy import func, select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
|
||||
from app.models.contract import ErpContract, ErpContractItem, ErpContractAttachment
|
||||
from app.models.order import ErpOrder, ErpOrderItem
|
||||
from app.models.shipping import ErpShippingRecord
|
||||
from app.models.finance import FinSalesInvoice
|
||||
from app.models.erp import ProductSku
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.contract import (
|
||||
ContractCreate,
|
||||
ContractUpdate,
|
||||
ContractItemResponse,
|
||||
ContractListResponse,
|
||||
ContractProgressResponse,
|
||||
ContractResponse,
|
||||
)
|
||||
|
||||
|
||||
# ── 金额大写转换 ─────────────────────────────────────────
|
||||
_CN_DIGITS = "零壹贰叁肆伍陆柒捌玖"
|
||||
_CN_UNITS = ["", "拾", "佰", "仟"]
|
||||
_CN_BIG_UNITS = ["", "万", "亿", "兆"]
|
||||
|
||||
|
||||
def amount_to_cn(amount: float) -> str:
|
||||
"""将金额转为中文大写"""
|
||||
if amount == 0:
|
||||
return "零元整"
|
||||
neg = ""
|
||||
if amount < 0:
|
||||
neg = "负"
|
||||
amount = -amount
|
||||
|
||||
yuan = int(amount)
|
||||
jiao = int(amount * 10) % 10
|
||||
fen = int(amount * 100) % 10
|
||||
|
||||
parts = []
|
||||
if yuan > 0:
|
||||
yuan_str = str(yuan)
|
||||
n = len(yuan_str)
|
||||
zero_flag = False
|
||||
for i, ch in enumerate(yuan_str):
|
||||
d = int(ch)
|
||||
pos = n - 1 - i
|
||||
big_idx = pos // 4
|
||||
unit_idx = pos % 4
|
||||
if d == 0:
|
||||
zero_flag = True
|
||||
if unit_idx == 0 and big_idx > 0:
|
||||
parts.append(_CN_BIG_UNITS[big_idx])
|
||||
else:
|
||||
if zero_flag:
|
||||
parts.append("零")
|
||||
zero_flag = False
|
||||
parts.append(_CN_DIGITS[d] + _CN_UNITS[unit_idx])
|
||||
if unit_idx == 0 and big_idx > 0:
|
||||
parts.append(_CN_BIG_UNITS[big_idx])
|
||||
parts.append("元")
|
||||
else:
|
||||
parts.append("零元")
|
||||
|
||||
if jiao > 0:
|
||||
parts.append(_CN_DIGITS[jiao] + "角")
|
||||
if fen > 0:
|
||||
parts.append(_CN_DIGITS[fen] + "分")
|
||||
else:
|
||||
if jiao == 0:
|
||||
parts.append("整")
|
||||
|
||||
return neg + "".join(parts)
|
||||
|
||||
|
||||
# ── 生成合同编号 ─────────────────────────────────────────
|
||||
async def _gen_contract_no(db: AsyncSession) -> str:
|
||||
today_str = date.today().strftime("%Y%m%d")
|
||||
prefix = f"HT-{today_str}-"
|
||||
count_stmt = select(func.count()).select_from(ErpContract).where(
|
||||
ErpContract.contract_no.like(f"{prefix}%")
|
||||
)
|
||||
count = (await db.execute(count_stmt)).scalar() or 0
|
||||
return f"{prefix}{count + 1:03d}"
|
||||
|
||||
|
||||
# ── 账期引擎 ────────────────────────────────────────────
|
||||
def calc_payment_due_date(payment_terms: str, base_date: date) -> date | None:
|
||||
"""根据付款条件枚举和基准日期(开票/发货)推算回款截止日"""
|
||||
m = re.search(r"(\d+)天", payment_terms)
|
||||
if m:
|
||||
days = int(m.group(1))
|
||||
return base_date + timedelta(days=days)
|
||||
if "货到" in payment_terms or "全款" in payment_terms:
|
||||
return base_date # 当天
|
||||
return None
|
||||
|
||||
|
||||
# ── ORM → Response ──────────────────────────────────────
|
||||
def _item_to_response(item: ErpContractItem) -> ContractItemResponse:
|
||||
sku = item.sku
|
||||
return ContractItemResponse(
|
||||
id=item.id,
|
||||
sku_id=item.sku_id,
|
||||
sku_code=sku.sku_code if sku else None,
|
||||
sku_name=sku.name if sku else None,
|
||||
spec=sku.spec if sku else None,
|
||||
unit=sku.unit if sku else None,
|
||||
qty=float(item.qty),
|
||||
unit_price=float(item.unit_price),
|
||||
sub_total=float(item.sub_total),
|
||||
)
|
||||
|
||||
|
||||
def _to_response(c: ErpContract, progress: ContractProgressResponse | None = None) -> ContractResponse:
|
||||
return ContractResponse(
|
||||
id=c.id,
|
||||
contract_no=c.contract_no,
|
||||
buyer_customer_id=c.buyer_customer_id,
|
||||
buyer_customer_name=c.buyer_customer.name if c.buyer_customer else None,
|
||||
seller_company_id=c.seller_company_id,
|
||||
seller_company_name=c.seller_company.name if c.seller_company else None,
|
||||
company_id=c.company_id,
|
||||
total_amount_excl_tax=float(c.total_amount_excl_tax or 0),
|
||||
total_amount_incl_tax=float(c.total_amount_incl_tax or 0),
|
||||
total_amount_cn=c.total_amount_cn,
|
||||
payment_terms=c.payment_terms,
|
||||
shipping_terms=c.shipping_terms,
|
||||
status=c.status,
|
||||
is_signed=c.is_signed,
|
||||
signed_file_url=c.signed_file_url,
|
||||
linked_order_id=c.linked_order_id,
|
||||
salesperson_id=c.salesperson_id,
|
||||
salesperson_name=c.salesperson.real_name if c.salesperson else None,
|
||||
sign_date=c.sign_date,
|
||||
remark=c.remark,
|
||||
delivery_terms=c.delivery_terms,
|
||||
items=[_item_to_response(i) for i in (c.items or []) if not i.is_deleted],
|
||||
progress=progress,
|
||||
created_at=c.created_at,
|
||||
updated_at=c.updated_at,
|
||||
)
|
||||
|
||||
|
||||
# ── 执行进度聚合 ────────────────────────────────────────
|
||||
async def _get_progress(db: AsyncSession, contract: ErpContract) -> ContractProgressResponse:
|
||||
progress = ContractProgressResponse(is_signed=contract.is_signed)
|
||||
|
||||
if contract.linked_order_id:
|
||||
progress.has_order = True
|
||||
progress.order_id = contract.linked_order_id
|
||||
|
||||
# 是否有发货
|
||||
ship_count = (await db.execute(
|
||||
select(func.count()).select_from(ErpShippingRecord).where(
|
||||
ErpShippingRecord.order_id == contract.linked_order_id,
|
||||
ErpShippingRecord.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar() or 0
|
||||
progress.has_shipped = ship_count > 0
|
||||
|
||||
# 是否有销项发票
|
||||
inv_count = (await db.execute(
|
||||
select(func.count()).select_from(FinSalesInvoice).where(
|
||||
FinSalesInvoice.order_id == contract.linked_order_id,
|
||||
FinSalesInvoice.is_deleted.is_(False),
|
||||
)
|
||||
)).scalar() or 0
|
||||
progress.has_invoice = inv_count > 0
|
||||
|
||||
# 是否回款(检查订单回款状态)
|
||||
order = (await db.execute(
|
||||
select(ErpOrder).where(ErpOrder.id == contract.linked_order_id)
|
||||
)).scalar_one_or_none()
|
||||
if order and order.payment_state == "paid":
|
||||
progress.is_paid = True
|
||||
|
||||
return progress
|
||||
|
||||
|
||||
# ── 公共 eager-load 选项 ────────────────────────────────────
|
||||
def _contract_load_options():
|
||||
"""返回 selectinload 链,保证 commit 后仍可安全访问关系属性"""
|
||||
return [
|
||||
selectinload(ErpContract.buyer_customer),
|
||||
selectinload(ErpContract.seller_company),
|
||||
selectinload(ErpContract.salesperson),
|
||||
selectinload(ErpContract.items).selectinload(ErpContractItem.sku),
|
||||
]
|
||||
|
||||
|
||||
# ── Service Functions ────────────────────────────────────
|
||||
|
||||
async def create_contract(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
company_id: uuid.UUID,
|
||||
body: ContractCreate,
|
||||
) -> ContractResponse:
|
||||
contract_no = await _gen_contract_no(db)
|
||||
|
||||
# 计算合计
|
||||
total = sum(item.sub_total for item in body.items)
|
||||
|
||||
contract = ErpContract(
|
||||
contract_no=contract_no,
|
||||
buyer_customer_id=body.buyer_customer_id,
|
||||
seller_company_id=company_id,
|
||||
company_id=company_id,
|
||||
total_amount_excl_tax=total,
|
||||
total_amount_incl_tax=total, # 含税金额默认同不含税,可后续区分
|
||||
total_amount_cn=amount_to_cn(total),
|
||||
payment_terms=body.payment_terms,
|
||||
shipping_terms=body.shipping_terms,
|
||||
sign_date=body.sign_date,
|
||||
remark=body.remark,
|
||||
delivery_terms=body.delivery_terms,
|
||||
salesperson_id=user.user_id,
|
||||
status="draft",
|
||||
)
|
||||
db.add(contract)
|
||||
await db.flush()
|
||||
|
||||
# 添加明细行
|
||||
for item_data in body.items:
|
||||
item = ErpContractItem(
|
||||
contract_id=contract.id,
|
||||
sku_id=item_data.sku_id,
|
||||
qty=item_data.qty,
|
||||
unit_price=item_data.unit_price,
|
||||
sub_total=item_data.sub_total,
|
||||
)
|
||||
db.add(item)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 重新查询并 eager-load 所有关系,避免 commit 后隐式 lazy load
|
||||
fresh = (await db.execute(
|
||||
select(ErpContract)
|
||||
.where(ErpContract.id == contract.id)
|
||||
.options(*_contract_load_options())
|
||||
)).scalar_one()
|
||||
return _to_response(fresh)
|
||||
|
||||
|
||||
async def list_contracts(
|
||||
db: AsyncSession,
|
||||
company_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
keyword: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> ContractListResponse:
|
||||
base_where = [
|
||||
ErpContract.company_id == company_id,
|
||||
ErpContract.is_deleted.is_(False),
|
||||
]
|
||||
if keyword:
|
||||
base_where.append(ErpContract.contract_no.ilike(f"%{keyword}%"))
|
||||
if status:
|
||||
base_where.append(ErpContract.status == status)
|
||||
|
||||
total = (await db.execute(
|
||||
select(func.count()).select_from(ErpContract).where(*base_where)
|
||||
)).scalar() or 0
|
||||
|
||||
stmt = (
|
||||
select(ErpContract)
|
||||
.where(*base_where)
|
||||
.options(*_contract_load_options())
|
||||
.order_by(ErpContract.created_at.desc())
|
||||
.offset((page - 1) * size)
|
||||
.limit(size)
|
||||
)
|
||||
contracts = (await db.execute(stmt)).scalars().all()
|
||||
|
||||
return ContractListResponse(
|
||||
total=total,
|
||||
items=[_to_response(c) for c in contracts],
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
async def get_contract(
|
||||
db: AsyncSession,
|
||||
contract_id: uuid.UUID,
|
||||
company_id: uuid.UUID,
|
||||
) -> ContractResponse:
|
||||
stmt = (
|
||||
select(ErpContract)
|
||||
.where(
|
||||
ErpContract.id == contract_id,
|
||||
ErpContract.company_id == company_id,
|
||||
ErpContract.is_deleted.is_(False),
|
||||
)
|
||||
.options(*_contract_load_options())
|
||||
)
|
||||
contract = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if contract is None:
|
||||
raise NotFoundException("合同不存在")
|
||||
|
||||
progress = await _get_progress(db, contract)
|
||||
return _to_response(contract, progress)
|
||||
|
||||
|
||||
async def update_contract(
|
||||
db: AsyncSession,
|
||||
contract_id: uuid.UUID,
|
||||
company_id: uuid.UUID,
|
||||
body: ContractUpdate,
|
||||
) -> ContractResponse:
|
||||
stmt = select(ErpContract).where(
|
||||
ErpContract.id == contract_id,
|
||||
ErpContract.company_id == company_id,
|
||||
ErpContract.is_deleted.is_(False),
|
||||
)
|
||||
contract = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if contract is None:
|
||||
raise NotFoundException("合同不存在")
|
||||
|
||||
# 更新主表字段
|
||||
update_data = body.model_dump(exclude_unset=True, exclude={"items"})
|
||||
if update_data:
|
||||
update_data["updated_at"] = datetime.utcnow()
|
||||
await db.execute(
|
||||
update(ErpContract).where(ErpContract.id == contract_id).values(**update_data)
|
||||
)
|
||||
|
||||
# 如果有明细行更新,删旧增新
|
||||
if body.items is not None:
|
||||
await db.execute(
|
||||
update(ErpContractItem)
|
||||
.where(ErpContractItem.contract_id == contract_id)
|
||||
.values(is_deleted=True)
|
||||
)
|
||||
total = 0
|
||||
for item_data in body.items:
|
||||
item = ErpContractItem(
|
||||
contract_id=contract_id,
|
||||
sku_id=item_data.sku_id,
|
||||
qty=item_data.qty,
|
||||
unit_price=item_data.unit_price,
|
||||
sub_total=item_data.sub_total,
|
||||
)
|
||||
total += item_data.sub_total
|
||||
db.add(item)
|
||||
|
||||
await db.execute(
|
||||
update(ErpContract).where(ErpContract.id == contract_id).values(
|
||||
total_amount_excl_tax=total,
|
||||
total_amount_incl_tax=total,
|
||||
total_amount_cn=amount_to_cn(total),
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
updated = (await db.execute(
|
||||
select(ErpContract)
|
||||
.where(ErpContract.id == contract_id)
|
||||
.options(*_contract_load_options())
|
||||
)).scalar_one()
|
||||
return _to_response(updated)
|
||||
|
||||
|
||||
async def delete_contract(
|
||||
db: AsyncSession,
|
||||
contract_id: uuid.UUID,
|
||||
company_id: uuid.UUID,
|
||||
) -> None:
|
||||
stmt = select(ErpContract).where(
|
||||
ErpContract.id == contract_id,
|
||||
ErpContract.company_id == company_id,
|
||||
ErpContract.is_deleted.is_(False),
|
||||
)
|
||||
contract = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if contract is None:
|
||||
raise NotFoundException("合同不存在")
|
||||
|
||||
await db.execute(
|
||||
update(ErpContract)
|
||||
.where(ErpContract.id == contract_id)
|
||||
.values(is_deleted=True, updated_at=datetime.utcnow())
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def generate_order_from_contract(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
contract_id: uuid.UUID,
|
||||
company_id: uuid.UUID,
|
||||
) -> dict:
|
||||
"""一键从合同生成订单 —— 防篡改推单逻辑"""
|
||||
stmt = (
|
||||
select(ErpContract)
|
||||
.where(
|
||||
ErpContract.id == contract_id,
|
||||
ErpContract.company_id == company_id,
|
||||
ErpContract.is_deleted.is_(False),
|
||||
)
|
||||
.options(*_contract_load_options())
|
||||
)
|
||||
contract = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if contract is None:
|
||||
raise NotFoundException("合同不存在")
|
||||
|
||||
if contract.linked_order_id is not None:
|
||||
raise BizException(message="该合同已关联订单,不可重复生成")
|
||||
|
||||
# 生成订单编号
|
||||
today_str = date.today().strftime("%Y%m%d")
|
||||
prefix = f"SO-{today_str}-"
|
||||
count = (await db.execute(
|
||||
select(func.count()).select_from(ErpOrder).where(
|
||||
ErpOrder.order_no.like(f"{prefix}%")
|
||||
)
|
||||
)).scalar() or 0
|
||||
order_no = f"{prefix}{count + 1:03d}"
|
||||
|
||||
# 创建订单
|
||||
new_order = ErpOrder(
|
||||
order_no=order_no,
|
||||
customer_id=contract.buyer_customer_id,
|
||||
salesperson_id=user.user_id,
|
||||
company_id=company_id,
|
||||
contract_id=contract_id,
|
||||
total_amount=float(contract.total_amount_incl_tax or 0),
|
||||
order_date=date.today(),
|
||||
)
|
||||
db.add(new_order)
|
||||
await db.flush()
|
||||
|
||||
# 复制合同明细到订单明细
|
||||
active_items = [i for i in (contract.items or []) if not i.is_deleted]
|
||||
for ci in active_items:
|
||||
oi = ErpOrderItem(
|
||||
order_id=new_order.id,
|
||||
sku_id=ci.sku_id,
|
||||
qty=float(ci.qty),
|
||||
unit_price=float(ci.unit_price),
|
||||
sub_total=float(ci.sub_total),
|
||||
)
|
||||
db.add(oi)
|
||||
|
||||
# 回填合同 linked_order_id + 激活状态
|
||||
await db.execute(
|
||||
update(ErpContract)
|
||||
.where(ErpContract.id == contract_id)
|
||||
.values(
|
||||
linked_order_id=new_order.id,
|
||||
status="active",
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
return {"order_id": str(new_order.id), "order_no": order_no}
|
||||
|
||||
|
||||
# ── 数字转中文大写金额 ──────────────────────────────────────
|
||||
def _amount_to_cn(amount: float) -> str:
|
||||
"""将数字金额转换为中文大写"""
|
||||
digits = "零壹贰叁肆伍陆柒捌玖"
|
||||
units = ["", "拾", "佰", "仟"]
|
||||
big_units = ["", "万", "亿"]
|
||||
|
||||
if amount == 0:
|
||||
return "零元整"
|
||||
|
||||
yuan = int(round(amount * 100))
|
||||
jiao = (yuan % 100) // 10
|
||||
fen = yuan % 10
|
||||
yuan_part = yuan // 100
|
||||
|
||||
result = ""
|
||||
if yuan_part > 0:
|
||||
s = str(yuan_part)
|
||||
n = len(s)
|
||||
for i, ch in enumerate(s):
|
||||
d = int(ch)
|
||||
pos = n - i - 1
|
||||
big_pos = pos // 4
|
||||
unit_pos = pos % 4
|
||||
if d != 0:
|
||||
result += digits[d] + units[unit_pos]
|
||||
else:
|
||||
if result and not result.endswith("零"):
|
||||
result += "零"
|
||||
if unit_pos == 0 and big_pos > 0:
|
||||
result = result.rstrip("零") + big_units[big_pos]
|
||||
result = result.rstrip("零") + "元"
|
||||
else:
|
||||
result = ""
|
||||
|
||||
if jiao == 0 and fen == 0:
|
||||
result += "整"
|
||||
else:
|
||||
if jiao > 0:
|
||||
result += digits[jiao] + "角"
|
||||
if fen > 0:
|
||||
result += digits[fen] + "分"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def generate_contract_docx(
|
||||
db: AsyncSession,
|
||||
contract_id: uuid.UUID,
|
||||
company_id: uuid.UUID,
|
||||
) -> bytes:
|
||||
"""纯代码生成合同 Word 文档(紧凑排版,2 页以内)"""
|
||||
import io
|
||||
from docx import Document as DocxDocument
|
||||
from docx.shared import Pt, Cm, Emu, RGBColor
|
||||
from docx.enum.table import WD_TABLE_ALIGNMENT
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
from docx.oxml.ns import qn
|
||||
|
||||
from app.models.sys import SysCompany
|
||||
|
||||
# ── 1) 数据准备 ─────────────────────────────────────────
|
||||
contract = (await db.execute(
|
||||
select(ErpContract)
|
||||
.where(
|
||||
ErpContract.id == contract_id,
|
||||
ErpContract.company_id == company_id,
|
||||
ErpContract.is_deleted.is_(False),
|
||||
)
|
||||
.options(*_contract_load_options())
|
||||
)).scalar_one_or_none()
|
||||
if contract is None:
|
||||
raise NotFoundException("合同不存在")
|
||||
|
||||
seller = (await db.execute(
|
||||
select(SysCompany).where(SysCompany.id == contract.seller_company_id)
|
||||
)).scalar_one_or_none()
|
||||
seller_info = (seller.full_info or {}) if seller else {}
|
||||
|
||||
buyer = contract.buyer_customer
|
||||
buyer_billing = {}
|
||||
if buyer and hasattr(buyer, "billing_info") and buyer.billing_info:
|
||||
buyer_billing = buyer.billing_info
|
||||
|
||||
total_incl = float(contract.total_amount_incl_tax or 0)
|
||||
sign_date_str = (contract.sign_date or date.today()).strftime("%Y年%m月%d日")
|
||||
buyer_name = buyer_billing.get("company_name") or (buyer.name if buyer else "")
|
||||
seller_name = seller_info.get("company_name") or (seller.name if seller else "")
|
||||
items = [i for i in (contract.items or []) if not i.is_deleted]
|
||||
|
||||
# ── 2) 创建文档 ─────────────────────────────────────────
|
||||
doc = DocxDocument()
|
||||
|
||||
# 页边距:上下2cm 左右2.5cm(紧凑)
|
||||
for section in doc.sections:
|
||||
section.top_margin = Cm(2)
|
||||
section.bottom_margin = Cm(1.5)
|
||||
section.left_margin = Cm(2.5)
|
||||
section.right_margin = Cm(2.5)
|
||||
|
||||
# ── 辅助函数 ─────────────────────────────────────────────
|
||||
# 小四 = 12pt, 1.5倍行距 = 18pt
|
||||
def add_para(text: str, font_size: int = 12, bold: bool = False,
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT, space_before: int = 0,
|
||||
space_after: int = 0, font_name: str = "宋体"):
|
||||
p = doc.add_paragraph()
|
||||
p.alignment = align
|
||||
p.paragraph_format.space_before = Pt(space_before)
|
||||
p.paragraph_format.space_after = Pt(space_after)
|
||||
p.paragraph_format.line_spacing = Pt(18) # 1.5倍行距(12pt×1.5)
|
||||
run = p.add_run(text)
|
||||
run.font.size = Pt(font_size)
|
||||
run.font.bold = bold
|
||||
run.font.name = font_name
|
||||
run._element.rPr.rFonts.set(qn("w:eastAsia"), font_name)
|
||||
return p
|
||||
|
||||
def set_cell(cell, text: str, font_size: int = 12, bold: bool = False,
|
||||
align=WD_ALIGN_PARAGRAPH.CENTER):
|
||||
cell.text = ""
|
||||
p = cell.paragraphs[0]
|
||||
p.alignment = align
|
||||
p.paragraph_format.space_before = Pt(0)
|
||||
p.paragraph_format.space_after = Pt(0)
|
||||
p.paragraph_format.line_spacing = Pt(18) # 1.5倍行距
|
||||
run = p.add_run(text)
|
||||
run.font.size = Pt(font_size)
|
||||
run.font.bold = bold
|
||||
run.font.name = "宋体"
|
||||
run._element.rPr.rFonts.set(qn("w:eastAsia"), "宋体")
|
||||
|
||||
# ── 3) 标题 ──────────────────────────────────────────────
|
||||
add_para("产 品 购 销 合 同", font_size=18, bold=True,
|
||||
align=WD_ALIGN_PARAGRAPH.CENTER, space_after=4, font_name="黑体")
|
||||
|
||||
add_para(f"合同编号:{contract.contract_no}",
|
||||
align=WD_ALIGN_PARAGRAPH.RIGHT, space_after=4)
|
||||
|
||||
# ── 4) 甲乙方信息(紧凑表格) ────────────────────────────
|
||||
info_tbl = doc.add_table(rows=4, cols=4)
|
||||
info_tbl.alignment = WD_TABLE_ALIGNMENT.CENTER
|
||||
info_tbl.style = "Table Grid"
|
||||
|
||||
info_data = [
|
||||
("买方(甲方)", buyer_name,
|
||||
"卖方(乙方)", seller_name),
|
||||
("税号", buyer_billing.get("tax_id", "") or "",
|
||||
"税号", seller_info.get("tax_id", "") or ""),
|
||||
("地址", buyer_billing.get("address", "") or "",
|
||||
"地址", seller_info.get("address", "") or ""),
|
||||
("开户行 / 账号",
|
||||
f"{buyer_billing.get('bank_name', '') or ''} {buyer_billing.get('bank_account', '') or ''}".strip(),
|
||||
"开户行 / 账号",
|
||||
f"{seller_info.get('bank_name', '') or ''} {seller_info.get('bank_account', '') or ''}".strip()),
|
||||
]
|
||||
for ri, row_data in enumerate(info_data):
|
||||
for ci, val in enumerate(row_data):
|
||||
bold = ri == 0 and ci in (0, 2)
|
||||
set_cell(info_tbl.cell(ri, ci), val, bold=bold,
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
|
||||
# ── 5) 一、产品明细 ──────────────────────────────────────
|
||||
add_para("一、产品明细", bold=True, space_before=6, space_after=2)
|
||||
|
||||
cols = 6
|
||||
tbl = doc.add_table(rows=1 + len(items) + 1, cols=cols)
|
||||
tbl.alignment = WD_TABLE_ALIGNMENT.CENTER
|
||||
tbl.style = "Table Grid"
|
||||
|
||||
headers = ["序号", "产品名称", "规格", "数量", "单价(元)", "小计(元)"]
|
||||
for ci, h in enumerate(headers):
|
||||
set_cell(tbl.cell(0, ci), h, bold=True)
|
||||
|
||||
for ri, item in enumerate(items):
|
||||
sku_name = item.sku.name if item.sku else ""
|
||||
sku_spec = item.sku.spec if item.sku else ""
|
||||
set_cell(tbl.cell(ri + 1, 0), str(ri + 1))
|
||||
set_cell(tbl.cell(ri + 1, 1), sku_name, align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
set_cell(tbl.cell(ri + 1, 2), sku_spec or "-")
|
||||
set_cell(tbl.cell(ri + 1, 3), str(float(item.qty)))
|
||||
set_cell(tbl.cell(ri + 1, 4), f"{float(item.unit_price):,.2f}",
|
||||
align=WD_ALIGN_PARAGRAPH.RIGHT)
|
||||
set_cell(tbl.cell(ri + 1, 5), f"{float(item.sub_total):,.2f}",
|
||||
align=WD_ALIGN_PARAGRAPH.RIGHT)
|
||||
|
||||
# 合计行
|
||||
last_row = len(items) + 1
|
||||
set_cell(tbl.cell(last_row, 0), "合计", bold=True)
|
||||
# 合并序号~单价列
|
||||
for ci in range(1, 4):
|
||||
set_cell(tbl.cell(last_row, ci), "")
|
||||
set_cell(tbl.cell(last_row, 4), "", align=WD_ALIGN_PARAGRAPH.RIGHT)
|
||||
set_cell(tbl.cell(last_row, 5), f"{total_incl:,.2f}", bold=True,
|
||||
align=WD_ALIGN_PARAGRAPH.RIGHT)
|
||||
|
||||
# 大写金额
|
||||
add_para(f"合计金额(大写):{_amount_to_cn(total_incl)} (含13%增值税)",
|
||||
bold=True, space_before=2, space_after=2)
|
||||
|
||||
# ── 6) 二、交货及付款条件 ────────────────────────────────
|
||||
add_para("二、交货及付款条件", bold=True, space_before=4, space_after=2)
|
||||
delivery_text = contract.delivery_terms or "按双方约定"
|
||||
add_para(f"1. 货 期:{delivery_text}")
|
||||
add_para(f"2. 交货方式:{contract.shipping_terms or '买方自提'}")
|
||||
add_para(f"3. 付款条件:{contract.payment_terms or '货到付全款'}")
|
||||
|
||||
# ── 7) 三、发票信息 ──────────────────────────────────────
|
||||
add_para("三、发票信息", bold=True, space_before=4, space_after=2)
|
||||
add_para("卖方给买方开具合同金额增值税专用发票(13%增值税)。")
|
||||
|
||||
# ── 8) 四、合同细则 ──────────────────────────────────────
|
||||
add_para("四、合同细则", bold=True, space_before=4, space_after=2)
|
||||
|
||||
# 紧凑输出细则内容
|
||||
terms = [
|
||||
"第一条 质量标准:按照厂家标准执行,由于买方储存不当(如露天暴晒、混入杂质、超过保质期等)或未按产品说明书操作导致的质量问题,卖方不承担责任。",
|
||||
"第二条 卖方对质量负责的条件及期限:自货到12个月。",
|
||||
"第三条 包装标准包装物的供应与回收:产品包装均应采用国家或专业标准保护措施进行包装,以确保产品不受损害为原则,由于包装不善所引起的货物污染、损坏、损失均由卖方负担,采取装箱包装的应在包装箱内附一份详细装箱单和质量合格证,包装物不回收。",
|
||||
"第四条 合理损耗标准及计算方法:标的货物送至买方指定地点前的合理损耗由卖方负责。",
|
||||
"第五条 标的物所有权:在买方付清本合同项下全部货款之前,标的物的所有权仍属于卖方。",
|
||||
"第六条 检验标准、方法、地点及期限:按第二条标准检验。",
|
||||
"第七条 发票信息:卖方给买方开具合同金额增值税专用发票(13%增值税)。",
|
||||
"第八条 本合同解除条件:合同执行完毕。",
|
||||
(
|
||||
"第九条 违约责任:\n"
|
||||
"1、卖方应保证产品质量合格,买方有权在货到后7个工作日内且未开封状态下将卖方产品送质监局或第三方部门检验单位检验,"
|
||||
"送检样品的取样过程必须经卖方现场确认或双方共同封样,否则检验结果无效。检验结果不合格,则所发生的所有检验费用,"
|
||||
"均由卖方承担,买方可根据实际情况选择要求退货或更换。\n"
|
||||
"赔偿限额:卖方对本合同项下违约责任的赔偿总额,以本合同约定的总货款金额为限,"
|
||||
"且不承担任何间接损失(包括但不限于停工损失、利润损失等)。"
|
||||
),
|
||||
(
|
||||
"第十条 合同争议的解决方式:本合同在履行过程中发生的争执,由双方当事人协商解决,"
|
||||
"也可由当地工商行政管理部门调解;协商或调解不成的,按下列第二种方式解决。\n"
|
||||
"(一)提交当地仲裁委员会仲裁;(二)依法向卖方所在地的人民法院起诉。"
|
||||
),
|
||||
"第十一条 本合同一式两份,自双方签字盖章起生效。",
|
||||
(
|
||||
"第十二条 其他约定事项:\n"
|
||||
"1、卖方必须遵守国家有关能源管理的法律、法规;\n"
|
||||
"2、卖方必须执行买方对其提出的对能源控制进行改善的要求;\n"
|
||||
"3、卖方在运输途中和施工作业中的各种行为不应对能源造成浪费或负面影响;\n"
|
||||
"4、如卖方提供货物存在质量问题,买方书面(包括但不限于传真、邮件)通知对方,"
|
||||
"卖方在接到买方书面通知后3个工作日内要给与买方书面回复,否则将视为卖方已经认可买方提出的质量问题;"
|
||||
"如果双方意见产生争议,由卖方负责安排经买方同意的第三方进行检验,否则视为卖方质量问题;\n"
|
||||
"5、未经对方书面同意,不得将合同部分或者全部权利义务转给第三方。\n"
|
||||
"6、如遇战争、原材料短缺、工厂停产、物流管制等不可抗力因素导致货期延长,卖方不承担违约责任。"
|
||||
),
|
||||
]
|
||||
|
||||
for term in terms:
|
||||
add_para(term)
|
||||
|
||||
# ── 9) 签章区 ────────────────────────────────────────────
|
||||
add_para("", space_before=6, space_after=0) # 小间距
|
||||
|
||||
sig_tbl = doc.add_table(rows=4, cols=2)
|
||||
sig_tbl.alignment = WD_TABLE_ALIGNMENT.CENTER
|
||||
# 去边框
|
||||
for row in sig_tbl.rows:
|
||||
for cell in row.cells:
|
||||
for paragraph in cell.paragraphs:
|
||||
paragraph.paragraph_format.space_before = Pt(0)
|
||||
paragraph.paragraph_format.space_after = Pt(0)
|
||||
|
||||
set_cell(sig_tbl.cell(0, 0), "买方(盖章):", bold=True,
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
set_cell(sig_tbl.cell(0, 1), "卖方(盖章):", bold=True,
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
set_cell(sig_tbl.cell(1, 0), "授权代表签字:",
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
set_cell(sig_tbl.cell(1, 1), "授权代表签字:",
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
set_cell(sig_tbl.cell(2, 0), f"日期:{sign_date_str}",
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
set_cell(sig_tbl.cell(2, 1), f"日期:{sign_date_str}",
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
set_cell(sig_tbl.cell(3, 0), f"联系电话:{buyer_billing.get('phone', '') or ''}",
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
set_cell(sig_tbl.cell(3, 1), f"联系电话:{seller_info.get('phone', '') or ''}",
|
||||
align=WD_ALIGN_PARAGRAPH.LEFT)
|
||||
|
||||
# ── 10) 输出 ─────────────────────────────────────────────
|
||||
buffer = io.BytesIO()
|
||||
doc.save(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.models.sys import SysUser
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.crm import (
|
||||
CustomerCreate,
|
||||
@@ -35,6 +36,7 @@ def _to_response(c: CrmCustomer) -> CustomerResponse:
|
||||
address=c.address,
|
||||
ai_score=float(c.ai_score or 0),
|
||||
ai_persona=c.ai_persona,
|
||||
billing_info=c.billing_info,
|
||||
owner_id=c.owner_id,
|
||||
owner_name=c.owner.real_name if c.owner else None,
|
||||
status=c.status,
|
||||
@@ -44,12 +46,48 @@ def _to_response(c: CrmCustomer) -> CustomerResponse:
|
||||
)
|
||||
|
||||
|
||||
# ── 递归查询本部门 + 子部门所有用户 ID ────────────────────
|
||||
async def _get_dept_and_sub_user_ids(
|
||||
db: AsyncSession, dept_id: uuid.UUID
|
||||
) -> list[uuid.UUID]:
|
||||
"""递归获取指定部门及其所有子部门下的用户 ID 列表"""
|
||||
from app.models.sys import SysDepartment, SysUser
|
||||
|
||||
# 收集所有目标部门 ID(递归子部门)
|
||||
dept_ids: list[uuid.UUID] = [dept_id]
|
||||
queue = [dept_id]
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
children = (await db.execute(
|
||||
select(SysDepartment.id).where(
|
||||
SysDepartment.parent_id == current,
|
||||
SysDepartment.is_deleted.is_(False),
|
||||
)
|
||||
)).scalars().all()
|
||||
for child_id in children:
|
||||
dept_ids.append(child_id)
|
||||
queue.append(child_id)
|
||||
|
||||
# 查询这些部门下的所有用户 ID
|
||||
user_ids = (await db.execute(
|
||||
select(SysUser.id).where(
|
||||
SysUser.dept_id.in_(dept_ids),
|
||||
SysUser.is_deleted.is_(False),
|
||||
)
|
||||
)).scalars().all()
|
||||
return list(user_ids)
|
||||
|
||||
|
||||
# ── 权限校验 ─────────────────────────────────────────────
|
||||
def _check_access(customer: CrmCustomer, user: CurrentUserPayload) -> None:
|
||||
def _check_access(customer: CrmCustomer, user: CurrentUserPayload, *, dept_user_ids: list[uuid.UUID] | None = None) -> None:
|
||||
if user.data_scope == "all":
|
||||
return
|
||||
if user.data_scope == "dept_and_sub":
|
||||
return # 简化版:放通本部门
|
||||
# 如果有预查询的部门用户列表,校验 owner 是否在列表内
|
||||
if dept_user_ids is not None:
|
||||
if customer.owner_id not in dept_user_ids:
|
||||
raise ForbiddenException("无权访问该客户(数据权限:本部门及子部门)")
|
||||
return
|
||||
# data_scope == 'self'
|
||||
if customer.owner_id != user.user_id:
|
||||
raise ForbiddenException("无权访问该客户(数据权限:仅本人)")
|
||||
@@ -70,6 +108,7 @@ async def create_customer(
|
||||
phone=body.phone,
|
||||
email=body.email,
|
||||
address=body.address,
|
||||
billing_info=body.billing_info.model_dump() if body.billing_info else None,
|
||||
status=body.status,
|
||||
owner_id=user.user_id,
|
||||
)
|
||||
@@ -98,12 +137,12 @@ async def list_customers(
|
||||
base_where.append(CrmCustomer.owner_id == user.user_id)
|
||||
elif user.data_scope == "dept_and_sub":
|
||||
if user.dept_id is not None:
|
||||
from app.models.sys import SysUser
|
||||
sub = select(SysUser.id).where(
|
||||
SysUser.dept_id == user.dept_id,
|
||||
SysUser.is_deleted.is_(False),
|
||||
)
|
||||
base_where.append(CrmCustomer.owner_id.in_(sub))
|
||||
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
|
||||
if dept_user_ids:
|
||||
base_where.append(CrmCustomer.owner_id.in_(dept_user_ids))
|
||||
else:
|
||||
# 部门无用户 → 仅显示自己的
|
||||
base_where.append(CrmCustomer.owner_id == user.user_id)
|
||||
|
||||
if keyword:
|
||||
base_where.append(CrmCustomer.name.ilike(f"%{keyword}%"))
|
||||
@@ -144,7 +183,11 @@ async def get_customer(
|
||||
if customer is None:
|
||||
raise NotFoundException("客户不存在或已被删除")
|
||||
|
||||
_check_access(customer, user)
|
||||
# dept_and_sub 需要先查询部门用户列表
|
||||
dept_user_ids = None
|
||||
if user.data_scope == "dept_and_sub" and user.dept_id:
|
||||
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
|
||||
_check_access(customer, user, dept_user_ids=dept_user_ids)
|
||||
return _to_response(customer)
|
||||
|
||||
|
||||
@@ -162,7 +205,10 @@ async def update_customer(
|
||||
if customer is None:
|
||||
raise NotFoundException("客户不存在或已被删除")
|
||||
|
||||
_check_access(customer, user)
|
||||
dept_user_ids = None
|
||||
if user.data_scope == "dept_and_sub" and user.dept_id:
|
||||
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
|
||||
_check_access(customer, user, dept_user_ids=dept_user_ids)
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
if not update_data:
|
||||
@@ -193,7 +239,10 @@ async def delete_customer(
|
||||
if customer is None:
|
||||
raise NotFoundException("客户不存在或已被删除")
|
||||
|
||||
_check_access(customer, user)
|
||||
dept_user_ids = None
|
||||
if user.data_scope == "dept_and_sub" and user.dept_id:
|
||||
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
|
||||
_check_access(customer, user, dept_user_ids=dept_user_ids)
|
||||
|
||||
await db.execute(
|
||||
update(CrmCustomer)
|
||||
@@ -216,7 +265,10 @@ async def restore_customer(
|
||||
if customer is None:
|
||||
raise NotFoundException("客户不存在或未被归档")
|
||||
|
||||
_check_access(customer, user)
|
||||
dept_user_ids = None
|
||||
if user.data_scope == "dept_and_sub" and user.dept_id:
|
||||
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
|
||||
_check_access(customer, user, dept_user_ids=dept_user_ids)
|
||||
|
||||
await db.execute(
|
||||
update(CrmCustomer)
|
||||
@@ -226,6 +278,49 @@ async def restore_customer(
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def transfer_customer(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
customer_id: uuid.UUID,
|
||||
new_owner_id: uuid.UUID,
|
||||
) -> CustomerResponse:
|
||||
"""将客户转移至指定人员名下(仅管理员)"""
|
||||
if user.data_scope != "all":
|
||||
raise ForbiddenException("仅管理员可执行客户转移操作")
|
||||
|
||||
stmt = select(CrmCustomer).where(
|
||||
CrmCustomer.id == customer_id,
|
||||
CrmCustomer.is_deleted.is_(False),
|
||||
)
|
||||
customer = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if customer is None:
|
||||
raise NotFoundException("客户不存在或已被归档")
|
||||
|
||||
if customer.owner_id == new_owner_id:
|
||||
raise BizException(message="目标负责人与当前负责人相同,无需转移")
|
||||
|
||||
# 校验目标用户是否存在
|
||||
from app.models.sys import SysUser
|
||||
target = (await db.execute(
|
||||
select(SysUser).where(SysUser.id == new_owner_id)
|
||||
)).scalar_one_or_none()
|
||||
if target is None:
|
||||
raise NotFoundException("目标负责人不存在")
|
||||
|
||||
old_owner_name = customer.owner.real_name if customer.owner else "(无)"
|
||||
|
||||
await db.execute(
|
||||
update(CrmCustomer)
|
||||
.where(CrmCustomer.id == customer_id)
|
||||
.values(owner_id=new_owner_id, updated_at=datetime.utcnow())
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(customer)
|
||||
|
||||
print(f"[客户转移] {customer.name}: {old_owner_name} → {target.real_name} (操作人: {user.real_name})")
|
||||
return _to_response(customer)
|
||||
|
||||
|
||||
async def get_customer_products(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
@@ -241,7 +336,10 @@ async def get_customer_products(
|
||||
customer = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if customer is None:
|
||||
raise NotFoundException("客户不存在")
|
||||
_check_access(customer, user)
|
||||
dept_user_ids = None
|
||||
if user.data_scope == "dept_and_sub" and user.dept_id:
|
||||
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
|
||||
_check_access(customer, user, dept_user_ids=dept_user_ids)
|
||||
|
||||
# 聚合: 该客户所有订单中的 SKU,含总数量、最近下单时间
|
||||
agg_stmt = (
|
||||
@@ -299,12 +397,11 @@ async def search_customers(
|
||||
base_where.append(CrmCustomer.owner_id == user.user_id)
|
||||
elif user.data_scope == "dept_and_sub":
|
||||
if user.dept_id is not None:
|
||||
from app.models.sys import SysUser
|
||||
sub = select(SysUser.id).where(
|
||||
SysUser.dept_id == user.dept_id,
|
||||
SysUser.is_deleted.is_(False),
|
||||
)
|
||||
base_where.append(CrmCustomer.owner_id.in_(sub))
|
||||
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
|
||||
if dept_user_ids:
|
||||
base_where.append(CrmCustomer.owner_id.in_(dept_user_ids))
|
||||
else:
|
||||
base_where.append(CrmCustomer.owner_id == user.user_id)
|
||||
|
||||
# 模糊搜索(名称 / 联系人 / 电话)
|
||||
from sqlalchemy import or_
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
|
||||
from app.models.finance import FinExpenseDetail, FinExpenseRecord, FinInvoicePool
|
||||
from app.models.sys import SysUser
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.finance import (
|
||||
ExpenseBriefResponse, ExpenseCreate, ExpenseDetailResponse,
|
||||
@@ -84,12 +85,13 @@ async def _release_invoices(db: AsyncSession, expense_id: uuid.UUID, now: dateti
|
||||
|
||||
# ── Service Functions ────────────────────────────────────
|
||||
|
||||
async def create_invoice(db: AsyncSession, user: CurrentUserPayload, body: InvoiceCreate) -> InvoiceResponse:
|
||||
async def create_invoice(db: AsyncSession, user: CurrentUserPayload, body: InvoiceCreate, company_id: uuid.UUID) -> InvoiceResponse:
|
||||
invoice = FinInvoicePool(
|
||||
uploader_id=user.user_id, file_url=body.file_url,
|
||||
merchant_name=body.merchant_name, amount=body.amount,
|
||||
invoice_date=body.invoice_date, type=body.type,
|
||||
ai_extracted_data=body.ai_extracted_data, is_used=False,
|
||||
company_id=company_id,
|
||||
)
|
||||
db.add(invoice)
|
||||
await db.commit()
|
||||
@@ -101,8 +103,11 @@ async def list_invoices(
|
||||
db: AsyncSession, user: CurrentUserPayload,
|
||||
page: int = 1, size: int = 20,
|
||||
inv_type: str | None = None, is_used: bool | None = None,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> InvoiceListResponse:
|
||||
where = [FinInvoicePool.is_deleted.is_(False)]
|
||||
if company_id:
|
||||
where.append(FinInvoicePool.company_id == company_id)
|
||||
if user.data_scope == "self":
|
||||
where.append(FinInvoicePool.uploader_id == user.user_id)
|
||||
elif user.data_scope == "dept_and_sub":
|
||||
@@ -135,7 +140,7 @@ async def void_invoice(db: AsyncSession, user: CurrentUserPayload, invoice_id: u
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def create_expense(db: AsyncSession, user: CurrentUserPayload, body: ExpenseCreate) -> ExpenseResponse:
|
||||
async def create_expense(db: AsyncSession, user: CurrentUserPayload, body: ExpenseCreate, company_id: uuid.UUID) -> ExpenseResponse:
|
||||
invoice_ids = [item.invoice_id for item in body.items]
|
||||
try:
|
||||
async with db.begin_nested():
|
||||
@@ -154,6 +159,7 @@ async def create_expense(db: AsyncSession, user: CurrentUserPayload, body: Expen
|
||||
system_no = await _generate_expense_no(db)
|
||||
expense = FinExpenseRecord(
|
||||
system_no=system_no, applicant_id=user.user_id,
|
||||
company_id=company_id,
|
||||
total_amount=body.total_amount, status="submitted", remark=body.remark,
|
||||
)
|
||||
db.add(expense)
|
||||
@@ -184,8 +190,11 @@ async def list_expenses(
|
||||
db: AsyncSession, user: CurrentUserPayload,
|
||||
page: int = 1, size: int = 20,
|
||||
status: str | None = None, applicant_id: uuid.UUID | None = None,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> ExpenseListResponse:
|
||||
where = [FinExpenseRecord.is_deleted.is_(False)]
|
||||
if company_id:
|
||||
where.append(FinExpenseRecord.company_id == company_id)
|
||||
if user.data_scope == "self":
|
||||
where.append(FinExpenseRecord.applicant_id == user.user_id)
|
||||
elif user.data_scope == "dept_and_sub":
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
发票结构化解析器 — OFD / XML 零算力提取
|
||||
OFD 文件本质是 ZIP 包含 XML,直接解包提取发票字段。
|
||||
XML 电子发票(数电票)直接 XPath 提取。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from xml.etree import ElementTree as ET
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def parse_ofd_invoice(file_bytes: bytes) -> dict:
|
||||
"""
|
||||
解析 OFD 电子发票文件。
|
||||
OFD = ZIP 压缩包,内含 XML 描述文件。
|
||||
提取发票关键字段,返回结构化 dict。
|
||||
"""
|
||||
result: dict = {}
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
|
||||
# 收集所有 XML 内容
|
||||
all_text = ""
|
||||
for name in zf.namelist():
|
||||
if name.endswith(".xml"):
|
||||
try:
|
||||
xml_bytes = zf.read(name)
|
||||
xml_text = xml_bytes.decode("utf-8", errors="replace")
|
||||
all_text += xml_text + "\n"
|
||||
|
||||
# 尝试从 XML 标签中提取结构化数据
|
||||
extracted = _extract_from_xml_text(xml_text)
|
||||
if extracted:
|
||||
result.update(extracted)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 如果解析出了字段就直接返回
|
||||
if result.get("merchant") or result.get("amount"):
|
||||
return {"success": True, "data": result}
|
||||
|
||||
# 降级:把所有 XML 文本当纯文本返回,交给 LLM 处理
|
||||
if all_text.strip():
|
||||
return {"success": True, "data": {"raw_text": all_text[:8000]}, "needs_llm": True}
|
||||
|
||||
return {"success": False, "data": {}, "error": "OFD 文件中未找到有效 XML 内容"}
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
return {"success": False, "data": {}, "error": "OFD 文件格式损坏或不是有效的 OFD 文件"}
|
||||
except Exception as e:
|
||||
return {"success": False, "data": {}, "error": f"OFD 解析失败: {e}"}
|
||||
|
||||
|
||||
def parse_xml_invoice(file_bytes: bytes) -> dict:
|
||||
"""
|
||||
解析 XML 格式电子发票(数电票)。
|
||||
直接从 XML 标签提取所有发票字段。
|
||||
"""
|
||||
try:
|
||||
xml_text = file_bytes.decode("utf-8", errors="replace")
|
||||
result = _extract_from_xml_text(xml_text)
|
||||
|
||||
if result and (result.get("merchant") or result.get("amount")):
|
||||
return {"success": True, "data": result}
|
||||
|
||||
# 降级:XML 结构未匹配预设标签,交给 LLM
|
||||
if xml_text.strip():
|
||||
return {"success": True, "data": {"raw_text": xml_text[:8000]}, "needs_llm": True}
|
||||
|
||||
return {"success": False, "data": {}, "error": "XML 文件内容为空"}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "data": {}, "error": f"XML 解析失败: {e}"}
|
||||
|
||||
|
||||
def parse_zip_invoices(file_bytes: bytes) -> list[dict]:
|
||||
"""
|
||||
解析 ZIP 压缩包中的所有 XML 发票文件。
|
||||
返回列表,每个元素 = {"filename": str, "success": bool, "data": dict, ...}
|
||||
支持系统导出的 ZIP 格式(内含多个 XML 发票)。
|
||||
"""
|
||||
results: list[dict] = []
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
|
||||
xml_names = [n for n in zf.namelist() if n.lower().endswith(".xml")]
|
||||
if not xml_names:
|
||||
return [{"filename": "(zip)", "success": False, "data": {}, "error": "ZIP 包中未找到 XML 文件"}]
|
||||
|
||||
for name in xml_names:
|
||||
try:
|
||||
xml_bytes = zf.read(name)
|
||||
result = parse_xml_invoice(xml_bytes)
|
||||
result["filename"] = os.path.basename(name)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
results.append({"filename": os.path.basename(name), "success": False, "data": {}, "error": str(e)})
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
return [{"filename": "(zip)", "success": False, "data": {}, "error": "不是有效的 ZIP 文件"}]
|
||||
except Exception as e:
|
||||
return [{"filename": "(zip)", "success": False, "data": {}, "error": f"ZIP 解析失败: {e}"}]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── 内部工具函数 ──────────────────────────────────────
|
||||
|
||||
# 常见发票 XML 标签名映射(兼容多种数电票 XML 格式)
|
||||
_FIELD_PATTERNS = {
|
||||
"merchant": [
|
||||
"SalesName", "SellerName", "销售方名称", "销方名称",
|
||||
"开票方", "Seller", "salername", "xfmc",
|
||||
],
|
||||
"buyer": [
|
||||
"BuyerName", "PurchaserName", "购买方名称", "购方名称",
|
||||
"Buyer", "buyername", "gfmc",
|
||||
],
|
||||
"amount": [
|
||||
"TotalAmount", "Amount", "InvoiceAmount", "金额",
|
||||
"合计金额", "价税合计", "jshj", "hjje",
|
||||
],
|
||||
"tax_amount": [
|
||||
"TotalTax", "TaxAmount", "Tax", "税额",
|
||||
"合计税额", "hjse",
|
||||
],
|
||||
"date": [
|
||||
"IssueDate", "InvoiceDate", "BillingDate", "开票日期",
|
||||
"kprq",
|
||||
],
|
||||
"invoice_code": [
|
||||
"InvoiceCode", "发票代码", "fpdm",
|
||||
],
|
||||
"invoice_number": [
|
||||
"InvoiceNumber", "InvoiceNo", "发票号码", "fphm",
|
||||
],
|
||||
"items": [
|
||||
"GoodsName", "ItemName", "商品名称", "货物名称", "spmc",
|
||||
],
|
||||
"tax_rate": [
|
||||
"TaxRate", "税率", "sl",
|
||||
],
|
||||
"remark": [
|
||||
"Remark", "备注", "bz",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _extract_from_xml_text(xml_text: str) -> Optional[dict]:
|
||||
"""从 XML 文本中用多种策略提取发票字段。"""
|
||||
result: dict = {}
|
||||
|
||||
# 策略 1: 正则匹配 <TagName>Value</TagName> 格式
|
||||
for field, tag_names in _FIELD_PATTERNS.items():
|
||||
for tag in tag_names:
|
||||
# 匹配 <Tag>value</Tag> 或 <ns:Tag>value</ns:Tag>
|
||||
pattern = rf'<(?:\w+:)?{re.escape(tag)}[^>]*>([^<]+)</(?:\w+:)?{re.escape(tag)}>'
|
||||
match = re.search(pattern, xml_text, re.IGNORECASE)
|
||||
if match:
|
||||
value = match.group(1).strip()
|
||||
if value:
|
||||
# 数字字段转数值
|
||||
if field in ("amount", "tax_amount"):
|
||||
try:
|
||||
result[field] = float(value)
|
||||
except ValueError:
|
||||
result[field] = value
|
||||
else:
|
||||
result[field] = value
|
||||
break # 找到一个就跳到下一个字段
|
||||
|
||||
# 策略 2: 尝试 ElementTree 解析
|
||||
if not result:
|
||||
try:
|
||||
# 移除 XML 声明中可能的编码问题
|
||||
cleaned = re.sub(r'<\?xml[^?]*\?>', '', xml_text).strip()
|
||||
if cleaned:
|
||||
root = ET.fromstring(cleaned)
|
||||
_extract_from_element(root, result)
|
||||
except ET.ParseError:
|
||||
pass
|
||||
|
||||
return result if result else None
|
||||
|
||||
|
||||
def _extract_from_element(elem: ET.Element, result: dict, depth: int = 0):
|
||||
"""递归遍历 XML 元素树提取字段。"""
|
||||
if depth > 10:
|
||||
return
|
||||
|
||||
tag_local = elem.tag.split("}")[-1] if "}" in elem.tag else elem.tag
|
||||
|
||||
for field, tag_names in _FIELD_PATTERNS.items():
|
||||
if field not in result:
|
||||
for tn in tag_names:
|
||||
if tag_local.lower() == tn.lower():
|
||||
text = (elem.text or "").strip()
|
||||
if text:
|
||||
if field in ("amount", "tax_amount"):
|
||||
try:
|
||||
result[field] = float(text)
|
||||
except ValueError:
|
||||
result[field] = text
|
||||
else:
|
||||
result[field] = text
|
||||
break
|
||||
|
||||
for child in elem:
|
||||
_extract_from_element(child, result, depth + 1)
|
||||
@@ -72,11 +72,12 @@ async def ocr_image(
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "/no_think\n" + prompt,
|
||||
"content": prompt,
|
||||
"images": [image_base64], # Ollama vision 格式
|
||||
},
|
||||
],
|
||||
"stream": False,
|
||||
"think": False, # 关闭思考模式:稳定输出、避免死循环、提速 2-5x
|
||||
"options": {
|
||||
"temperature": 0.1,
|
||||
"num_predict": 2000,
|
||||
@@ -87,19 +88,18 @@ async def ocr_image(
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(url, json=payload)
|
||||
if resp.status_code != 200:
|
||||
print(f"[OCR] 3090 返回 {resp.status_code}: {resp.text[:200]}")
|
||||
return {"success": False, "data": {}, "error": f"VL 模型返回 {resp.status_code}"}
|
||||
detail = resp.text[:200]
|
||||
print(f"[OCR] 3090 返回 {resp.status_code}: {detail}")
|
||||
if "model runner" in detail:
|
||||
return {"success": False, "data": {}, "error": "AI OCR 模型进程崩溃,请联系管理员重启 Ollama 服务"}
|
||||
return {"success": False, "data": {}, "error": f"AI OCR 服务异常 (HTTP {resp.status_code}),请稍后重试"}
|
||||
|
||||
data = resp.json()
|
||||
# Qwen3.5 的 CoT 推理放在 message.thinking,最终结果在 message.content
|
||||
content = data.get("message", {}).get("content", "")
|
||||
thinking = data.get("message", {}).get("thinking", "")
|
||||
|
||||
# 优先从 content 提取 JSON,回退到 thinking
|
||||
for text_source in [content, thinking]:
|
||||
if not text_source:
|
||||
continue
|
||||
cleaned = re.sub(r'<think>.*?</think>', '', text_source, flags=re.DOTALL).strip()
|
||||
# 关闭思考模式后,结果直接在 content(无 thinking 字段)
|
||||
if content:
|
||||
cleaned = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||
json_match = re.search(r'\{[\s\S]*\}', cleaned)
|
||||
if json_match:
|
||||
try:
|
||||
@@ -107,16 +107,14 @@ async def ocr_image(
|
||||
print(f"[OCR] 解析成功: {list(result.keys())}")
|
||||
return {"success": True, "data": result}
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
pass
|
||||
|
||||
# 没有提取到 JSON,返回原始文本
|
||||
raw = content or thinking
|
||||
print(f"[OCR] 未能提取 JSON, 内容长度: content={len(content)}, thinking={len(thinking)}")
|
||||
return {"success": True, "data": {"raw_text": raw[:2000]}}
|
||||
print(f"[OCR] 未能提取 JSON, content 长度: {len(content)}")
|
||||
return {"success": True, "data": {"raw_text": content[:2000]}}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
print("[OCR] 3090 超时(60s)")
|
||||
return {"success": False, "data": {}, "error": "VL 模型响应超时"}
|
||||
print("[OCR] 3090 超时(120s)")
|
||||
return {"success": False, "data": {}, "error": "AI OCR 响应超时(120s),模型可能负载过高,请稍后重试"}
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[OCR] JSON 解析失败: {e}")
|
||||
return {"success": False, "data": {}, "error": f"JSON 解析失败: {e}"}
|
||||
@@ -172,11 +170,11 @@ async def extract_invoice_from_text(
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"/no_think\n{prompt}\n\n--- 以下是发票文本内容 ---\n\n{truncated}",
|
||||
# 不传 images —— 纯文本模式
|
||||
"content": f"{prompt}\n\n--- 以下是发票文本内容 ---\n\n{truncated}",
|
||||
},
|
||||
],
|
||||
"stream": False,
|
||||
"think": False, # 关闭思考模式
|
||||
"options": {
|
||||
"temperature": 0.1,
|
||||
"num_predict": 2000,
|
||||
@@ -192,12 +190,9 @@ async def extract_invoice_from_text(
|
||||
|
||||
data = resp.json()
|
||||
content = data.get("message", {}).get("content", "")
|
||||
thinking = data.get("message", {}).get("thinking", "")
|
||||
|
||||
for text_source in [content, thinking]:
|
||||
if not text_source:
|
||||
continue
|
||||
cleaned = re.sub(r'<think>.*?</think>', '', text_source, flags=re.DOTALL).strip()
|
||||
if content:
|
||||
cleaned = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||
json_match = re.search(r'\{[\s\S]*\}', cleaned)
|
||||
if json_match:
|
||||
try:
|
||||
@@ -205,11 +200,10 @@ async def extract_invoice_from_text(
|
||||
print(f"[TextExtract] AI 提取成功: {list(result.keys())}")
|
||||
return {"success": True, "data": result}
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
pass
|
||||
|
||||
raw = content or thinking
|
||||
print(f"[TextExtract] 未能提取 JSON, 内容: {raw[:200]}")
|
||||
return {"success": True, "data": {"raw_text": raw[:2000]}}
|
||||
print(f"[TextExtract] 未能提取 JSON, content: {content[:200]}")
|
||||
return {"success": True, "data": {"raw_text": content[:2000]}}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
print("[TextExtract] 3090 超时")
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
OCR 后台 Worker — asyncio 协程,FastAPI lifespan 启动
|
||||
策略 C: 工作时间限流(1并发 + 60s间隔),17:00-20:00 BJT 全速
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import async_session_factory
|
||||
from app.models.finance import FinInvoicePool, FinOcrTask
|
||||
|
||||
|
||||
class OcrWorker:
|
||||
"""后台 OCR 任务处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.running = False
|
||||
self.current_task_id: uuid.UUID | None = None
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
def start(self):
|
||||
self.running = True
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
print("[OcrWorker] 启动 — 策略 C: 工作时间限流, 17-20 BJT 全速")
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
print("[OcrWorker] 已停止")
|
||||
|
||||
async def _run_loop(self):
|
||||
"""主循环:每 10 秒检查一次队列"""
|
||||
while self.running:
|
||||
try:
|
||||
task = await self._pick_next_task()
|
||||
if task:
|
||||
await self._process_task(task)
|
||||
# 限流:非高峰期间隔 60s
|
||||
if not self._is_peak_time():
|
||||
await asyncio.sleep(60)
|
||||
else:
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"[OcrWorker] 循环异常: {e}")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
def _is_peak_time(self) -> bool:
|
||||
"""17:00-20:00 BJT = 09:00-12:00 UTC"""
|
||||
utc_hour = datetime.utcnow().hour
|
||||
return 9 <= utc_hour < 12
|
||||
|
||||
async def _pick_next_task(self) -> dict | None:
|
||||
"""从 DB 获取优先级最高的 pending 任务"""
|
||||
async with async_session_factory() as db:
|
||||
stmt = (
|
||||
select(FinOcrTask)
|
||||
.where(
|
||||
FinOcrTask.status == "pending",
|
||||
FinOcrTask.is_deleted.is_(False),
|
||||
FinOcrTask.retry_count < FinOcrTask.max_retries,
|
||||
)
|
||||
.order_by(FinOcrTask.priority, FinOcrTask.created_at)
|
||||
.limit(1)
|
||||
)
|
||||
task = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
# 标记为 processing
|
||||
task.status = "processing"
|
||||
task.updated_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
self.current_task_id = task.id
|
||||
return {
|
||||
"id": task.id,
|
||||
"file_url": task.file_url,
|
||||
"file_ext": task.file_ext,
|
||||
"original_name": task.original_name,
|
||||
"uploader_id": task.uploader_id,
|
||||
"company_id": task.company_id,
|
||||
"inv_type": task.inv_type,
|
||||
"retry_count": task.retry_count,
|
||||
}
|
||||
|
||||
async def _process_task(self, task_info: dict):
|
||||
"""执行 OCR 并更新"""
|
||||
task_id = task_info["id"]
|
||||
file_url = task_info["file_url"]
|
||||
file_ext = task_info["file_ext"]
|
||||
print(f"[OcrWorker] 处理任务 {task_id} ({task_info['original_name']}, {file_ext})")
|
||||
|
||||
try:
|
||||
# 读取文件
|
||||
file_path = file_url.lstrip("/")
|
||||
if not os.path.exists(file_path):
|
||||
await self._mark_failed(task_id, f"文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
ocr_data = {}
|
||||
message = ""
|
||||
|
||||
# PDF 处理
|
||||
if file_ext == ".pdf":
|
||||
ocr_data, message = await self._process_pdf(file_bytes)
|
||||
# 图片处理
|
||||
elif file_ext in (".png", ".jpg", ".jpeg"):
|
||||
ocr_data, message = await self._process_image(file_bytes)
|
||||
else:
|
||||
await self._mark_failed(task_id, f"不支持的文件格式: {file_ext}")
|
||||
return
|
||||
|
||||
if ocr_data and (ocr_data.get("merchant") or ocr_data.get("amount")):
|
||||
# OCR 成功 → 自动入池
|
||||
await self._mark_success_and_pool(task_id, task_info, ocr_data)
|
||||
print(f"[OcrWorker] ✅ {task_info['original_name']} 入池成功")
|
||||
else:
|
||||
# OCR 完成但没提取到关键字段
|
||||
await self._mark_failed(
|
||||
task_id,
|
||||
message or "AI 未能提取发票关键字段(开票方/金额),请手动录入",
|
||||
ocr_data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[OcrWorker] ❌ 任务 {task_id} 异常: {e}")
|
||||
await self._mark_failed(task_id, str(e))
|
||||
|
||||
self.current_task_id = None
|
||||
|
||||
async def _process_pdf(self, file_bytes: bytes) -> tuple[dict, str]:
|
||||
"""PDF: 先尝试文本提取,失败降级 Vision OCR"""
|
||||
try:
|
||||
import fitz
|
||||
doc = fitz.open(stream=file_bytes, filetype="pdf")
|
||||
text = ""
|
||||
for page in doc:
|
||||
text += page.get_text() + "\n"
|
||||
doc.close()
|
||||
text = text.strip()
|
||||
|
||||
if len(text) > 50:
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
result = await extract_invoice_from_text(text, "invoice")
|
||||
if result.get("success") and result.get("data"):
|
||||
return result["data"], "PDF 文本解析成功"
|
||||
|
||||
# 降级: 扫描件 → Vision OCR
|
||||
doc2 = fitz.open(stream=file_bytes, filetype="pdf")
|
||||
pix = doc2[0].get_pixmap(dpi=150)
|
||||
ocr_bytes = pix.tobytes("png")
|
||||
doc2.close()
|
||||
return await self._vision_ocr(ocr_bytes)
|
||||
|
||||
except Exception as e:
|
||||
return {}, f"PDF 处理失败: {e}"
|
||||
|
||||
async def _process_image(self, file_bytes: bytes) -> tuple[dict, str]:
|
||||
"""图片: Vision OCR"""
|
||||
return await self._vision_ocr(file_bytes)
|
||||
|
||||
async def _vision_ocr(self, image_bytes: bytes) -> tuple[dict, str]:
|
||||
"""调用 3090 Vision OCR"""
|
||||
import base64
|
||||
from app.services.ocr_service import ocr_image
|
||||
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
result = await ocr_image(image_b64, "invoice")
|
||||
if result.get("success"):
|
||||
return result.get("data", {}), "Vision OCR 成功"
|
||||
return {}, result.get("error", "OCR 失败")
|
||||
|
||||
async def _mark_success_and_pool(self, task_id: uuid.UUID, task_info: dict, ocr_data: dict):
|
||||
"""标记成功 + 自动入池"""
|
||||
async with async_session_factory() as db:
|
||||
merchant = ocr_data.get("merchant") or ocr_data.get("merchant_name") or "(AI 提取)"
|
||||
amount = 0
|
||||
try:
|
||||
amount = float(ocr_data.get("amount", 0))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
invoice_date_str = ocr_data.get("date")
|
||||
invoice_date = None
|
||||
if invoice_date_str:
|
||||
try:
|
||||
from datetime import date as dt_date
|
||||
invoice_date = dt_date.fromisoformat(invoice_date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=task_info["uploader_id"],
|
||||
company_id=task_info["company_id"],
|
||||
file_url=task_info["file_url"],
|
||||
merchant_name=merchant,
|
||||
amount=amount,
|
||||
invoice_date=invoice_date,
|
||||
type=task_info["inv_type"],
|
||||
ai_extracted_data=ocr_data,
|
||||
is_used=False,
|
||||
)
|
||||
db.add(inv)
|
||||
await db.flush()
|
||||
|
||||
await db.execute(
|
||||
update(FinOcrTask)
|
||||
.where(FinOcrTask.id == task_id)
|
||||
.values(
|
||||
status="success",
|
||||
ocr_result=ocr_data,
|
||||
invoice_pool_id=inv.id,
|
||||
error_message=None,
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def _mark_failed(self, task_id: uuid.UUID, error: str, partial_data: dict | None = None):
|
||||
"""标记失败 + retry_count+1"""
|
||||
async with async_session_factory() as db:
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id)
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
return
|
||||
|
||||
new_retry = task.retry_count + 1
|
||||
new_status = "failed" if new_retry >= task.max_retries else "pending"
|
||||
|
||||
await db.execute(
|
||||
update(FinOcrTask)
|
||||
.where(FinOcrTask.id == task_id)
|
||||
.values(
|
||||
status=new_status,
|
||||
retry_count=new_retry,
|
||||
error_message=error,
|
||||
ocr_result=partial_data or task.ocr_result,
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
if new_status == "pending":
|
||||
print(f"[OcrWorker] ⚠️ 任务 {task_id} 第 {new_retry} 次重试入队")
|
||||
else:
|
||||
print(f"[OcrWorker] ❌ 任务 {task_id} 已达最大重试次数,标记失败")
|
||||
|
||||
|
||||
# 单例
|
||||
ocr_worker = OcrWorker()
|
||||
@@ -16,6 +16,7 @@ from app.core.exceptions import BizException, ForbiddenException, NotFoundExcept
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.models.erp import ProductSku
|
||||
from app.models.order import ErpOrder, ErpOrderItem
|
||||
from app.models.sys import SysUser
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.order import (
|
||||
OrderBriefResponse,
|
||||
@@ -156,6 +157,7 @@ async def create_order(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
body: OrderCreate,
|
||||
company_id: uuid.UUID,
|
||||
) -> OrderResponse:
|
||||
# 校验客户存在
|
||||
cust = (
|
||||
@@ -193,6 +195,7 @@ async def create_order(
|
||||
order_no=order_no,
|
||||
customer_id=body.customer_id,
|
||||
salesperson_id=user.user_id,
|
||||
company_id=company_id,
|
||||
total_amount=total,
|
||||
shipping_state="pending",
|
||||
payment_state="unpaid",
|
||||
@@ -236,8 +239,11 @@ async def list_orders(
|
||||
shipping_state: str | None = None,
|
||||
payment_state: str | None = None,
|
||||
keyword: str | None = None,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> OrderListResponse:
|
||||
where: list[Any] = [ErpOrder.is_deleted.is_(False)]
|
||||
if company_id:
|
||||
where.append(ErpOrder.company_id == company_id)
|
||||
|
||||
if user.data_scope == "self":
|
||||
where.append(ErpOrder.salesperson_id == user.user_id)
|
||||
@@ -284,13 +290,17 @@ async def get_order(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
order_id: uuid.UUID,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> OrderResponse:
|
||||
where_clause = [
|
||||
ErpOrder.id == order_id,
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
]
|
||||
if company_id:
|
||||
where_clause.append(ErpOrder.company_id == company_id)
|
||||
order = (
|
||||
await db.execute(
|
||||
select(ErpOrder).where(
|
||||
ErpOrder.id == order_id,
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
)
|
||||
select(ErpOrder).where(*where_clause)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if order is None:
|
||||
|
||||
@@ -14,7 +14,8 @@ from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import BizException, NotFoundException
|
||||
from app.models.erp import InventoryFlow, ProductCategory, ProductSku
|
||||
from app.models.erp import ErpSkuInventory, InventoryFlow, ProductCategory, ProductSku
|
||||
from app.models.sys import SysUser
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.erp import (
|
||||
CategoryCreate,
|
||||
@@ -31,7 +32,10 @@ from app.schemas.erp import (
|
||||
|
||||
# ── ORM → Response ───────────────────────────────────────
|
||||
|
||||
def _sku_to_response(s: ProductSku) -> SkuResponse:
|
||||
def _sku_to_response(
|
||||
s: ProductSku,
|
||||
inv: ErpSkuInventory | None = None,
|
||||
) -> SkuResponse:
|
||||
return SkuResponse(
|
||||
id=s.id,
|
||||
sku_code=s.sku_code,
|
||||
@@ -40,8 +44,8 @@ def _sku_to_response(s: ProductSku) -> SkuResponse:
|
||||
category_name=s.category.name if s.category else None,
|
||||
spec=s.spec,
|
||||
standard_price=float(s.standard_price or 0),
|
||||
stock_qty=float(s.stock_qty or 0),
|
||||
warning_threshold=float(s.warning_threshold or 0),
|
||||
stock_qty=float(inv.stock_qty) if inv else 0.0,
|
||||
warning_threshold=float(inv.warning_threshold) if inv else 0.0,
|
||||
unit=s.unit,
|
||||
status=s.status,
|
||||
created_at=s.created_at,
|
||||
@@ -200,11 +204,13 @@ async def delete_category(db: AsyncSession, cat_id: uuid.UUID) -> None:
|
||||
|
||||
async def list_skus(
|
||||
db: AsyncSession,
|
||||
company_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
category_id: uuid.UUID | None = None,
|
||||
keyword: str | None = None,
|
||||
) -> SkuListResponse:
|
||||
"""LEFT JOIN erp_sku_inventory 获取当前公司库存,COALESCE 兜底为 0"""
|
||||
where: list[Any] = [ProductSku.is_deleted.is_(False)]
|
||||
if category_id:
|
||||
where.append(ProductSku.category_id == category_id)
|
||||
@@ -218,24 +224,31 @@ async def list_skus(
|
||||
await db.execute(select(func.count()).select_from(ProductSku).where(*where))
|
||||
).scalar() or 0
|
||||
|
||||
# LEFT JOIN erp_sku_inventory 带出当前公司库存
|
||||
stmt = (
|
||||
select(ProductSku)
|
||||
select(ProductSku, ErpSkuInventory)
|
||||
.outerjoin(
|
||||
ErpSkuInventory,
|
||||
(ErpSkuInventory.sku_id == ProductSku.id)
|
||||
& (ErpSkuInventory.company_id == company_id),
|
||||
)
|
||||
.where(*where)
|
||||
.order_by(ProductSku.created_at.desc())
|
||||
.offset((page - 1) * size)
|
||||
.limit(size)
|
||||
)
|
||||
rows = (await db.execute(stmt)).scalars().all()
|
||||
rows = (await db.execute(stmt)).all()
|
||||
|
||||
return SkuListResponse(
|
||||
total=total,
|
||||
items=[_sku_to_response(s) for s in rows],
|
||||
items=[_sku_to_response(sku, inv) for sku, inv in rows],
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
async def create_sku(db: AsyncSession, body: SkuCreate) -> SkuResponse:
|
||||
"""创建 SKU(不创建库存行,LEFT JOIN 查询自动兜底为 0)"""
|
||||
exists = (
|
||||
await db.execute(
|
||||
select(ProductSku.id).where(
|
||||
@@ -253,8 +266,6 @@ async def create_sku(db: AsyncSession, body: SkuCreate) -> SkuResponse:
|
||||
category_id=body.category_id,
|
||||
spec=body.spec,
|
||||
standard_price=body.standard_price,
|
||||
stock_qty=body.stock_qty,
|
||||
warning_threshold=body.warning_threshold,
|
||||
unit=body.unit,
|
||||
status=body.status,
|
||||
)
|
||||
@@ -299,7 +310,9 @@ async def create_inventory_flow(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
body: InventoryFlowCreate,
|
||||
company_id: uuid.UUID,
|
||||
) -> InventoryFlowResponse:
|
||||
"""库存变更(upsert erp_sku_inventory + 写流水)"""
|
||||
sku = (
|
||||
await db.execute(
|
||||
select(ProductSku).where(
|
||||
@@ -310,35 +323,74 @@ async def create_inventory_flow(
|
||||
if sku is None:
|
||||
raise NotFoundException("产品 SKU 不存在")
|
||||
|
||||
if body.change_qty < 0:
|
||||
current_stock = float(sku.stock_qty or 0)
|
||||
if current_stock + body.change_qty < 0:
|
||||
raise BizException(
|
||||
message=f"库存不足:当前库存 {current_stock},请求出库 {abs(body.change_qty)}"
|
||||
)
|
||||
|
||||
try:
|
||||
async with db.begin_nested():
|
||||
# ── upsert: 查找或创建当前公司的库存行 ──
|
||||
inv = (
|
||||
await db.execute(
|
||||
select(ErpSkuInventory)
|
||||
.where(
|
||||
ErpSkuInventory.sku_id == body.sku_id,
|
||||
ErpSkuInventory.company_id == company_id,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if inv is None:
|
||||
# 首次操作该 SKU:自动创建 0 库存行
|
||||
inv = ErpSkuInventory(
|
||||
sku_id=body.sku_id,
|
||||
company_id=company_id,
|
||||
stock_qty=0,
|
||||
warning_threshold=0,
|
||||
)
|
||||
db.add(inv)
|
||||
await db.flush()
|
||||
# 重新锁行
|
||||
inv = (
|
||||
await db.execute(
|
||||
select(ErpSkuInventory)
|
||||
.where(ErpSkuInventory.id == inv.id)
|
||||
.with_for_update()
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
# ── 校验库存 ──
|
||||
current_stock = float(inv.stock_qty or 0)
|
||||
if body.change_qty < 0 and current_stock + body.change_qty < 0:
|
||||
raise BizException(
|
||||
message=f"库存不足:当前库存 {current_stock},请求出库 {abs(body.change_qty)}"
|
||||
)
|
||||
|
||||
# ── 更新库存 ──
|
||||
await db.execute(
|
||||
update(ErpSkuInventory)
|
||||
.where(ErpSkuInventory.id == inv.id)
|
||||
.values(
|
||||
stock_qty=ErpSkuInventory.stock_qty + Decimal(str(body.change_qty)),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
# ── 写流水 ──
|
||||
flow = InventoryFlow(
|
||||
sku_id=body.sku_id,
|
||||
company_id=company_id,
|
||||
change_qty=body.change_qty,
|
||||
reason=body.reason,
|
||||
remark=body.remark,
|
||||
purchase_unit_price=body.purchase_unit_price if body.change_qty > 0 else 0,
|
||||
is_special_zero_cost=body.is_special_zero_cost if body.change_qty > 0 else False,
|
||||
operator_id=user.user_id,
|
||||
)
|
||||
db.add(flow)
|
||||
await db.flush()
|
||||
|
||||
await db.execute(
|
||||
update(ProductSku)
|
||||
.where(ProductSku.id == body.sku_id)
|
||||
.values(
|
||||
stock_qty=ProductSku.stock_qty + Decimal(str(body.change_qty)),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
except BizException:
|
||||
await db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise BizException(code=500, message=f"库存变更事务失败: {e!s}") from e
|
||||
@@ -352,9 +404,11 @@ async def create_inventory_flow(
|
||||
async def get_inventory_flows(
|
||||
db: AsyncSession,
|
||||
sku_id: uuid.UUID,
|
||||
company_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
size: int = 50,
|
||||
) -> dict[str, Any]:
|
||||
"""获取单个 SKU 在当前公司的库存流水"""
|
||||
sku = (
|
||||
await db.execute(
|
||||
select(ProductSku).where(
|
||||
@@ -365,8 +419,19 @@ async def get_inventory_flows(
|
||||
if sku is None:
|
||||
raise NotFoundException("产品 SKU 不存在")
|
||||
|
||||
# 查当前公司库存
|
||||
inv = (
|
||||
await db.execute(
|
||||
select(ErpSkuInventory).where(
|
||||
ErpSkuInventory.sku_id == sku_id,
|
||||
ErpSkuInventory.company_id == company_id,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
where: list[Any] = [
|
||||
InventoryFlow.sku_id == sku_id,
|
||||
InventoryFlow.company_id == company_id,
|
||||
InventoryFlow.is_deleted.is_(False),
|
||||
]
|
||||
|
||||
@@ -389,7 +454,7 @@ async def get_inventory_flows(
|
||||
"total": total,
|
||||
"sku_code": sku.sku_code,
|
||||
"sku_name": sku.name,
|
||||
"current_stock": float(sku.stock_qty or 0),
|
||||
"current_stock": float(inv.stock_qty) if inv else 0.0,
|
||||
"items": [_flow_to_response(f).model_dump(mode="json") for f in flows],
|
||||
"page": page,
|
||||
"size": size,
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
库存与利润核算 Service 层
|
||||
- MWA 入库事务(悲观锁 FOR UPDATE + 零元隔离)
|
||||
- 订单利润快照
|
||||
- 利润报表聚合
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func, select, update, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import BizException, NotFoundException
|
||||
from app.models.erp import ErpSkuInventory, InventoryFlow, ProductSku
|
||||
from app.models.cost import ErpOrderItemCost
|
||||
from app.models.order import ErpOrder, ErpOrderItem
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
|
||||
|
||||
# ── MWA 入库事务 ────────────────────────────────────────
|
||||
async def process_inbound_with_mwa(
|
||||
db: AsyncSession,
|
||||
sku_id: uuid.UUID,
|
||||
company_id: uuid.UUID,
|
||||
qty: float,
|
||||
purchase_unit_price: float,
|
||||
operator_id: uuid.UUID | None = None,
|
||||
remark: str | None = None,
|
||||
is_special_zero_cost: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
入库事务(悲观锁 + MWA)
|
||||
1. SELECT ... FOR UPDATE 锁定库存行
|
||||
2. 如果非零元特殊,计算新 MWA
|
||||
3. 更新库存 + 记录流水
|
||||
"""
|
||||
# 悲观锁获取库存记录
|
||||
inv_stmt = (
|
||||
select(ErpSkuInventory)
|
||||
.where(
|
||||
ErpSkuInventory.sku_id == sku_id,
|
||||
ErpSkuInventory.company_id == company_id,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
inv = (await db.execute(inv_stmt)).scalar_one_or_none()
|
||||
|
||||
if inv is None:
|
||||
# 首次入库,创建库存记录
|
||||
inv = ErpSkuInventory(
|
||||
sku_id=sku_id,
|
||||
company_id=company_id,
|
||||
stock_qty=0,
|
||||
mwa_unit_cost=0,
|
||||
)
|
||||
db.add(inv)
|
||||
await db.flush()
|
||||
# 重新锁定
|
||||
inv = (await db.execute(inv_stmt)).scalar_one()
|
||||
|
||||
old_qty = float(inv.stock_qty or 0)
|
||||
old_mwa = float(inv.mwa_unit_cost or 0)
|
||||
new_qty = old_qty + qty
|
||||
|
||||
# MWA 计算(零元特殊入库不参与)
|
||||
if is_special_zero_cost or purchase_unit_price == 0:
|
||||
new_mwa = old_mwa # 保持原有 MWA
|
||||
else:
|
||||
if new_qty > 0:
|
||||
new_mwa = (old_qty * old_mwa + qty * purchase_unit_price) / new_qty
|
||||
else:
|
||||
new_mwa = purchase_unit_price
|
||||
|
||||
# 更新库存
|
||||
inv.stock_qty = new_qty
|
||||
inv.mwa_unit_cost = round(new_mwa, 4)
|
||||
inv.updated_at = datetime.utcnow()
|
||||
|
||||
# 记录流水
|
||||
flow = InventoryFlow(
|
||||
sku_id=sku_id,
|
||||
company_id=company_id,
|
||||
flow_type="in",
|
||||
change_qty=qty,
|
||||
reason="purchase_in",
|
||||
purchase_unit_price=purchase_unit_price,
|
||||
is_special_zero_cost=is_special_zero_cost,
|
||||
operator_id=operator_id,
|
||||
remark=remark or f"入库 {qty} 件 @ ¥{purchase_unit_price}",
|
||||
)
|
||||
db.add(flow)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"sku_id": str(sku_id),
|
||||
"old_qty": old_qty,
|
||||
"new_qty": new_qty,
|
||||
"old_mwa": old_mwa,
|
||||
"new_mwa": round(new_mwa, 4),
|
||||
"is_special_zero_cost": is_special_zero_cost,
|
||||
}
|
||||
|
||||
|
||||
# ── 订单明细成本快照 ────────────────────────────────────
|
||||
async def snapshot_order_item_costs(
|
||||
db: AsyncSession,
|
||||
order_id: uuid.UUID,
|
||||
company_id: uuid.UUID,
|
||||
) -> list[dict]:
|
||||
"""为订单的所有明细行锚定 MWA 成本快照"""
|
||||
items_stmt = select(ErpOrderItem).where(
|
||||
ErpOrderItem.order_id == order_id,
|
||||
ErpOrderItem.is_deleted.is_(False),
|
||||
)
|
||||
items = (await db.execute(items_stmt)).scalars().all()
|
||||
|
||||
results = []
|
||||
for item in items:
|
||||
# 查当前 MWA
|
||||
inv = (await db.execute(
|
||||
select(ErpSkuInventory).where(
|
||||
ErpSkuInventory.sku_id == item.sku_id,
|
||||
ErpSkuInventory.company_id == company_id,
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
mwa_cost = float(inv.mwa_unit_cost or 0) if inv else 0
|
||||
sell_price = float(item.unit_price or 0)
|
||||
qty = float(item.qty or 0)
|
||||
profit = (sell_price - mwa_cost) * qty
|
||||
profit_rate = (sell_price - mwa_cost) / sell_price if sell_price > 0 else 0
|
||||
|
||||
# 检查是否已有快照
|
||||
existing = (await db.execute(
|
||||
select(ErpOrderItemCost).where(
|
||||
ErpOrderItemCost.order_item_id == item.id
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
existing.purchase_unit_price = mwa_cost
|
||||
existing.profit_amount = round(profit, 2)
|
||||
existing.profit_rate = round(profit_rate, 4)
|
||||
else:
|
||||
cost_snap = ErpOrderItemCost(
|
||||
order_item_id=item.id,
|
||||
purchase_unit_price=mwa_cost,
|
||||
profit_amount=round(profit, 2),
|
||||
profit_rate=round(profit_rate, 4),
|
||||
)
|
||||
db.add(cost_snap)
|
||||
|
||||
results.append({
|
||||
"sku_id": str(item.sku_id),
|
||||
"qty": qty,
|
||||
"sell_price": sell_price,
|
||||
"mwa_cost": mwa_cost,
|
||||
"profit": round(profit, 2),
|
||||
"profit_rate": round(profit_rate * 100, 2),
|
||||
})
|
||||
|
||||
await db.commit()
|
||||
return results
|
||||
|
||||
|
||||
# ── 利润报表 ────────────────────────────────────────────
|
||||
async def get_profit_report(
|
||||
db: AsyncSession,
|
||||
company_id: uuid.UUID,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> dict:
|
||||
"""聚合利润报表"""
|
||||
base_where = [
|
||||
ErpOrder.company_id == company_id,
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
]
|
||||
if start_date:
|
||||
base_where.append(ErpOrder.order_date >= start_date)
|
||||
if end_date:
|
||||
base_where.append(ErpOrder.order_date <= end_date)
|
||||
|
||||
# 聚合:每笔订单的利润
|
||||
stmt = (
|
||||
select(
|
||||
ErpOrder.id.label("order_id"),
|
||||
ErpOrder.order_no,
|
||||
ErpOrder.order_date,
|
||||
ErpOrder.total_amount,
|
||||
func.sum(ErpOrderItemCost.profit_amount).label("total_profit"),
|
||||
)
|
||||
.join(ErpOrderItem, ErpOrderItem.order_id == ErpOrder.id)
|
||||
.join(ErpOrderItemCost, ErpOrderItemCost.order_item_id == ErpOrderItem.id)
|
||||
.where(*base_where)
|
||||
.group_by(ErpOrder.id, ErpOrder.order_no, ErpOrder.order_date, ErpOrder.total_amount)
|
||||
.order_by(ErpOrder.order_date.desc())
|
||||
)
|
||||
rows = (await db.execute(stmt)).all()
|
||||
|
||||
orders = []
|
||||
total_revenue = 0
|
||||
total_profit = 0
|
||||
for r in rows:
|
||||
revenue = float(r.total_amount or 0)
|
||||
profit = float(r.total_profit or 0)
|
||||
total_revenue += revenue
|
||||
total_profit += profit
|
||||
orders.append({
|
||||
"order_id": str(r.order_id),
|
||||
"order_no": r.order_no,
|
||||
"order_date": r.order_date.isoformat() if r.order_date else None,
|
||||
"revenue": revenue,
|
||||
"profit": profit,
|
||||
"profit_rate": round(profit / revenue * 100, 2) if revenue > 0 else 0,
|
||||
})
|
||||
|
||||
return {
|
||||
"total_revenue": round(total_revenue, 2),
|
||||
"total_profit": round(total_profit, 2),
|
||||
"overall_profit_rate": round(total_profit / total_revenue * 100, 2) if total_revenue > 0 else 0,
|
||||
"orders": orders,
|
||||
}
|
||||
@@ -12,6 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import BizException, NotFoundException
|
||||
from app.models.finance import FinSalesInvoice
|
||||
from app.models.sys import SysUser
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.sales_invoice import (
|
||||
SalesInvoiceCreate,
|
||||
@@ -45,6 +47,7 @@ async def create_invoice(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
body: SalesInvoiceCreate,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> SalesInvoiceResponse:
|
||||
# 检查发票号唯一性
|
||||
existing = (await db.execute(
|
||||
@@ -56,7 +59,7 @@ async def create_invoice(
|
||||
if existing:
|
||||
raise BizException(message=f"发票号 {body.invoice_number} 已存在")
|
||||
|
||||
inv = FinSalesInvoice(
|
||||
kwargs: dict = dict(
|
||||
issuer=body.issuer,
|
||||
receiver_customer_id=body.receiver_customer_id,
|
||||
invoice_number=body.invoice_number,
|
||||
@@ -65,6 +68,9 @@ async def create_invoice(
|
||||
remark=body.remark,
|
||||
created_by=user.user_id,
|
||||
)
|
||||
if company_id is not None:
|
||||
kwargs["company_id"] = company_id
|
||||
inv = FinSalesInvoice(**kwargs)
|
||||
db.add(inv)
|
||||
await db.commit()
|
||||
await db.refresh(inv)
|
||||
@@ -80,8 +86,11 @@ async def list_invoices(
|
||||
payment_status: str | None = None,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> SalesInvoiceListResponse:
|
||||
conditions = [FinSalesInvoice.is_deleted.is_(False)]
|
||||
if company_id:
|
||||
conditions.append(FinSalesInvoice.company_id == company_id)
|
||||
|
||||
if invoice_number:
|
||||
conditions.append(FinSalesInvoice.invoice_number.ilike(f"%{invoice_number}%"))
|
||||
|
||||
@@ -22,6 +22,7 @@ async def create_log(
|
||||
customer_id: str | None = None,
|
||||
contact_ids: list[str] | None = None,
|
||||
log_date: date | None = None,
|
||||
company_ids: list[uuid.UUID] | None = None,
|
||||
) -> dict:
|
||||
"""创建销售日志"""
|
||||
log = SalesLog(
|
||||
@@ -30,6 +31,7 @@ async def create_log(
|
||||
contact_ids=contact_ids or [],
|
||||
content=content,
|
||||
log_date=log_date or date.today(),
|
||||
involved_company_ids=company_ids or [],
|
||||
)
|
||||
db.add(log)
|
||||
await db.commit()
|
||||
@@ -46,9 +48,17 @@ async def list_logs(
|
||||
user_id: str | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> dict:
|
||||
"""查询销售日志列表"""
|
||||
"""查询销售日志列表(按 involved_company_ids 包含过滤)"""
|
||||
from sqlalchemy.orm import aliased
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.models.sys import SysUser
|
||||
|
||||
conditions = [SalesLog.is_deleted.is_(False)]
|
||||
if company_id:
|
||||
# ARRAY contains: 过滤涉及当前公司的日志
|
||||
conditions.append(SalesLog.involved_company_ids.any(company_id))
|
||||
|
||||
# 数据权限
|
||||
if user.data_scope == "self":
|
||||
@@ -69,24 +79,107 @@ async def list_logs(
|
||||
count_stmt = select(func.count()).select_from(SalesLog).where(where)
|
||||
total = (await db.execute(count_stmt)).scalar() or 0
|
||||
|
||||
# data
|
||||
# data — LEFT JOIN customer + user to get names
|
||||
Author = aliased(SysUser)
|
||||
stmt = (
|
||||
select(SalesLog)
|
||||
select(
|
||||
SalesLog,
|
||||
CrmCustomer.name.label("customer_name"),
|
||||
Author.real_name.label("author_name"),
|
||||
)
|
||||
.outerjoin(CrmCustomer, SalesLog.customer_id == CrmCustomer.id)
|
||||
.outerjoin(Author, SalesLog.salesperson_id == Author.id)
|
||||
.where(where)
|
||||
.order_by(desc(SalesLog.created_at))
|
||||
.offset((page - 1) * size)
|
||||
.limit(size)
|
||||
)
|
||||
rows = (await db.execute(stmt)).scalars().all()
|
||||
rows = (await db.execute(stmt)).all()
|
||||
|
||||
items = []
|
||||
for log, cust_name, auth_name in rows:
|
||||
d = _to_dict(log)
|
||||
d["customer_name"] = cust_name
|
||||
d["author_name"] = auth_name
|
||||
items.append(d)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"size": size,
|
||||
"items": [_to_dict(r) for r in rows],
|
||||
"items": items,
|
||||
}
|
||||
|
||||
|
||||
async def update_log(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
log_id: uuid.UUID,
|
||||
content: str | None = None,
|
||||
customer_id: str | None = None,
|
||||
contact_ids: list[str] | None = None,
|
||||
log_date: str | None = None,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> dict:
|
||||
"""编辑销售日志 — 员工只能改自己的,管理员可改所有"""
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.models.sys import SysUserCompany
|
||||
|
||||
log = await db.get(SalesLog, log_id)
|
||||
if not log or log.is_deleted:
|
||||
raise Exception("日志不存在")
|
||||
|
||||
# 权限检查
|
||||
if user.data_scope != "all" and log.salesperson_id != user.user_id:
|
||||
raise Exception("您无权编辑此日志")
|
||||
|
||||
if content is not None:
|
||||
log.content = content
|
||||
if contact_ids is not None:
|
||||
log.contact_ids = contact_ids
|
||||
if log_date is not None:
|
||||
log.log_date = date.fromisoformat(log_date)
|
||||
|
||||
# 更新客户关联 + 自动重算 involved_company_ids
|
||||
if customer_id is not None:
|
||||
log.customer_id = uuid.UUID(customer_id) if customer_id else None
|
||||
# 重新关联公司
|
||||
resolved = set(log.involved_company_ids or [])
|
||||
if company_id:
|
||||
resolved.add(company_id)
|
||||
if customer_id:
|
||||
cust = await db.get(CrmCustomer, uuid.UUID(customer_id))
|
||||
if cust and cust.owner_id:
|
||||
stmt = select(SysUserCompany.company_id).where(
|
||||
SysUserCompany.user_id == cust.owner_id
|
||||
)
|
||||
rows = (await db.execute(stmt)).scalars().all()
|
||||
for cid in rows:
|
||||
resolved.add(cid)
|
||||
log.involved_company_ids = list(resolved)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(log)
|
||||
return _to_dict(log)
|
||||
|
||||
|
||||
async def delete_log(
|
||||
db: AsyncSession,
|
||||
user: CurrentUserPayload,
|
||||
log_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""软删除销售日志 — 员工只能删自己的,管理员可删所有"""
|
||||
log = await db.get(SalesLog, log_id)
|
||||
if not log or log.is_deleted:
|
||||
raise Exception("日志不存在")
|
||||
|
||||
if user.data_scope != "all" and log.salesperson_id != user.user_id:
|
||||
raise Exception("您无权删除此日志")
|
||||
|
||||
log.is_deleted = True
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def trigger_persona_workflow(
|
||||
log_id: uuid.UUID,
|
||||
customer_id: uuid.UUID,
|
||||
@@ -157,6 +250,7 @@ def _to_dict(log: SalesLog) -> dict:
|
||||
"salesperson_id": str(log.salesperson_id),
|
||||
"customer_id": str(log.customer_id) if log.customer_id else None,
|
||||
"contact_ids": log.contact_ids or [],
|
||||
"involved_company_ids": [str(c) for c in (log.involved_company_ids or [])],
|
||||
"content": log.content,
|
||||
"log_date": log.log_date.isoformat() if log.log_date else None,
|
||||
"ai_processed": log.ai_processed,
|
||||
|
||||
@@ -10,9 +10,11 @@ from typing import Any
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
|
||||
from app.models.erp import InventoryFlow, ProductSku
|
||||
from app.models.erp import ErpSkuInventory, InventoryFlow, ProductSku
|
||||
from app.models.order import ErpOrder, ErpOrderItem
|
||||
from app.models.shipping import ErpShippingItem, ErpShippingRecord
|
||||
from app.models.sys import SysUser
|
||||
from app.models.crm import CrmCustomer
|
||||
from app.schemas.auth import CurrentUserPayload
|
||||
from app.schemas.shipping import (
|
||||
ShippingBriefResponse, ShippingCreate, ShippingItemResponse,
|
||||
@@ -75,10 +77,15 @@ def _check_shipping_access(order: ErpOrder, user: CurrentUserPayload) -> None:
|
||||
|
||||
async def create_shipping(
|
||||
db: AsyncSession, user: CurrentUserPayload, body: ShippingCreate,
|
||||
company_id: uuid.UUID,
|
||||
) -> tuple[ShippingResponse, str]:
|
||||
"""返回 (response, new_shipping_state)"""
|
||||
"""返回 (response, new_shipping_state)。库存从 erp_sku_inventory 扣减"""
|
||||
order = (await db.execute(
|
||||
select(ErpOrder).where(ErpOrder.id == body.order_id, ErpOrder.is_deleted.is_(False))
|
||||
select(ErpOrder).where(
|
||||
ErpOrder.id == body.order_id,
|
||||
ErpOrder.is_deleted.is_(False),
|
||||
ErpOrder.company_id == company_id,
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
if order is None:
|
||||
raise NotFoundException("订单不存在")
|
||||
@@ -114,6 +121,7 @@ async def create_shipping(
|
||||
carrier=body.carrier, tracking_no=body.tracking_no,
|
||||
status="transit", ship_date=body.ship_date or date.today(),
|
||||
remark=body.remark, operator_id=user.user_id,
|
||||
company_id=company_id,
|
||||
)
|
||||
db.add(record)
|
||||
await db.flush()
|
||||
@@ -125,22 +133,41 @@ async def create_shipping(
|
||||
)
|
||||
db.add(si)
|
||||
|
||||
result = await db.execute(
|
||||
update(ProductSku).where(
|
||||
ProductSku.id == item.sku_id,
|
||||
ProductSku.stock_qty >= item.shipped_qty,
|
||||
).values(
|
||||
stock_qty=ProductSku.stock_qty - Decimal(str(item.shipped_qty)),
|
||||
# ── 从 erp_sku_inventory 扣减库存(行锁) ──
|
||||
inv = (
|
||||
await db.execute(
|
||||
select(ErpSkuInventory)
|
||||
.where(
|
||||
ErpSkuInventory.sku_id == item.sku_id,
|
||||
ErpSkuInventory.company_id == company_id,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
current_stock = float(inv.stock_qty) if inv else 0
|
||||
if current_stock < item.shipped_qty:
|
||||
raise BizException(
|
||||
message=f"库存不足无法发货: SKU {item.sku_id},"
|
||||
f"当前库存 {current_stock},请求出库 {item.shipped_qty}"
|
||||
)
|
||||
|
||||
if inv is None:
|
||||
# 不应出现此情况,但防御性处理
|
||||
raise BizException(message=f"SKU {item.sku_id} 在当前公司无库存记录")
|
||||
|
||||
await db.execute(
|
||||
update(ErpSkuInventory)
|
||||
.where(ErpSkuInventory.id == inv.id)
|
||||
.values(
|
||||
stock_qty=ErpSkuInventory.stock_qty - Decimal(str(item.shipped_qty)),
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
if result.rowcount == 0:
|
||||
sku = (await db.execute(select(ProductSku).where(ProductSku.id == item.sku_id))).scalar_one_or_none()
|
||||
current_stock = float(sku.stock_qty) if sku else 0
|
||||
raise BizException(message=f"库存不足无法发货: SKU {item.sku_id},当前库存 {current_stock},请求出库 {item.shipped_qty}")
|
||||
|
||||
db.add(InventoryFlow(
|
||||
sku_id=item.sku_id, change_qty=-item.shipped_qty,
|
||||
sku_id=item.sku_id, company_id=company_id,
|
||||
change_qty=-item.shipped_qty,
|
||||
reason="shipment", remark=f"订单发货出库 - 发货单 {shipping_no}",
|
||||
operator_id=user.user_id,
|
||||
))
|
||||
@@ -178,8 +205,11 @@ async def list_shipping(
|
||||
db: AsyncSession, user: CurrentUserPayload,
|
||||
page: int = 1, size: int = 20,
|
||||
order_no: str | None = None, tracking_no: str | None = None,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> ShippingListResponse:
|
||||
where: list[Any] = [ErpShippingRecord.is_deleted.is_(False)]
|
||||
if company_id:
|
||||
where.append(ErpShippingRecord.company_id == company_id)
|
||||
if user.data_scope == "self":
|
||||
my_orders = select(ErpOrder.id).where(ErpOrder.salesperson_id == user.user_id, ErpOrder.is_deleted.is_(False))
|
||||
where.append(ErpShippingRecord.order_id.in_(my_orders))
|
||||
@@ -203,9 +233,13 @@ async def list_shipping(
|
||||
|
||||
async def get_shipping_by_order(
|
||||
db: AsyncSession, user: CurrentUserPayload, order_id: uuid.UUID,
|
||||
company_id: uuid.UUID | None = None,
|
||||
) -> dict[str, Any]:
|
||||
where_clause = [ErpOrder.id == order_id, ErpOrder.is_deleted.is_(False)]
|
||||
if company_id:
|
||||
where_clause.append(ErpOrder.company_id == company_id)
|
||||
order = (await db.execute(
|
||||
select(ErpOrder).where(ErpOrder.id == order_id, ErpOrder.is_deleted.is_(False))
|
||||
select(ErpOrder).where(*where_clause)
|
||||
)).scalar_one_or_none()
|
||||
if order is None:
|
||||
raise NotFoundException("订单不存在")
|
||||
|
||||
Reference in New Issue
Block a user