v0.2.0: CRM/ERP 系统升级 - 清理 .gitignore 并移除误提交的 venv/env/db 文件

- 更新 .gitignore:全面覆盖环境变量、数据库、日志、缓存、上传文件
- 移除误跟踪的 server/venv/、crm_data.db、.env 文件
- 新增 server/.env.example 模板
- 新增合同管理、利润核算、AI教练等功能模块
- 新增 Playwright e2e 测试套件
- 前后端多项功能升级和 bug 修复
This commit is contained in:
hankin
2026-05-11 07:24:19 +00:00
parent 0f4c6b7924
commit 815cbf9d8c
2526 changed files with 11875 additions and 804148 deletions
+45
View File
@@ -0,0 +1,45 @@
"""
AI 教练引擎路由 —— /api/ai-coaching
Dify 回调 + SSE 通知流
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.response import ok
from app.services import ai_coaching_service as svc
router = APIRouter(prefix="/ai-coaching", tags=["AI教练引擎"])
@router.post("/dify-callback/{sales_log_id}", summary="Dify Workflow 回调端点")
async def dify_coaching_callback(
sales_log_id: uuid.UUID,
request: Request,
db: AsyncSession = Depends(get_db),
) -> dict:
"""接收 Dify Workflow 的异步回调,写回教练反馈"""
import json
body = await request.json()
await svc.handle_dify_coaching_callback(db, sales_log_id, body)
return ok(message="教练反馈已回写")
@router.get("/notifications/stream", summary="SSE 通知流")
async def sse_notifications(
current_user: CurrentUserPayload = Depends(get_current_user),
):
"""Server-Sent Events 推送通知(AI 教练反馈、系统通知等)"""
return StreamingResponse(
svc.sse_notification_generator(current_user.user_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
+119
View File
@@ -0,0 +1,119 @@
"""
公司管理路由 —— /api/companies
"""
from __future__ import annotations
import uuid
from datetime import datetime
from fastapi import APIRouter, Body, Depends
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.models.sys import SysCompany, SysUserCompany
from app.schemas.auth import CurrentUserPayload
from app.schemas.response import ok
router = APIRouter(prefix="/companies", tags=["公司管理"])
@router.get("", summary="获取当前用户可访问的公司列表")
async def list_companies(
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
) -> dict:
"""返回当前登录用户所关联的所有激活公司"""
stmt = (
select(SysCompany)
.join(SysUserCompany, SysUserCompany.company_id == SysCompany.id)
.where(
SysUserCompany.user_id == current_user.user_id,
SysCompany.is_active.is_(True),
)
.order_by(SysCompany.created_at)
)
companies = (await db.execute(stmt)).scalars().all()
# 查该用户的默认公司
default_stmt = (
select(SysUserCompany.company_id)
.where(
SysUserCompany.user_id == current_user.user_id,
SysUserCompany.is_default.is_(True),
)
)
default_id = (await db.execute(default_stmt)).scalar_one_or_none()
return ok(data={
"companies": [
{
"id": str(c.id),
"name": c.name,
"code": c.code,
"is_active": c.is_active,
}
for c in companies
],
"default_company_id": str(default_id) if default_id else (
str(companies[0].id) if companies else None
),
})
@router.get("/current", summary="获取当前公司详情(含 full_info)")
async def get_current_company(
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
company = (await db.execute(
select(SysCompany).where(SysCompany.id == company_id)
)).scalar_one_or_none()
if company is None:
return ok(data=None)
return ok(data={
"id": str(company.id),
"name": company.name,
"code": company.code,
"full_info": company.full_info or {},
"is_active": company.is_active,
})
@router.put("/current", summary="更新当前公司信息(含 full_info)")
async def update_current_company(
body: dict = Body(...),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
# 仅管理员可编辑
if current_user.data_scope != "all":
from app.core.exceptions import ForbiddenException
raise ForbiddenException("仅管理员可编辑公司信息")
values: dict = {}
if "name" in body:
values["name"] = body["name"]
if "full_info" in body:
values["full_info"] = body["full_info"]
if values:
values["updated_at"] = datetime.utcnow()
await db.execute(
update(SysCompany).where(SysCompany.id == company_id).values(**values)
)
await db.commit()
# 返回更新后的数据
company = (await db.execute(
select(SysCompany).where(SysCompany.id == company_id)
)).scalar_one()
return ok(data={
"id": str(company.id),
"name": company.name,
"code": company.code,
"full_info": company.full_info or {},
"is_active": company.is_active,
}, message="公司信息已更新")
+156
View File
@@ -0,0 +1,156 @@
"""
合同管理路由 —— /api/contracts
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Body, Depends, Query, UploadFile, File
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.contract import ContractCreate, ContractUpdate
from app.schemas.response import ok
from app.services import contract_service as svc
router = APIRouter(prefix="/contracts", tags=["合同管理"])
@router.post("", summary="新增合同")
async def create_contract(
body: ContractCreate,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.create_contract(db, current_user, company_id, body)
return ok(data=result.model_dump(mode="json"), message="合同创建成功")
@router.get("", summary="合同列表(分页)")
async def list_contracts(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
keyword: str | None = Query(None, description="合同编号搜索"),
status: str | None = Query(None, description="状态筛选"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.list_contracts(db, company_id, page, size, keyword, status)
return ok(data=result.model_dump(mode="json"))
@router.get("/{contract_id}", summary="合同详情(含执行进度)")
async def get_contract(
contract_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.get_contract(db, contract_id, company_id)
return ok(data=result.model_dump(mode="json"))
@router.put("/{contract_id}", summary="编辑合同")
async def update_contract(
contract_id: uuid.UUID,
body: ContractUpdate,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.update_contract(db, contract_id, company_id, body)
return ok(data=result.model_dump(mode="json"), message="合同已更新")
@router.delete("/{contract_id}", summary="删除合同")
async def delete_contract(
contract_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
await svc.delete_contract(db, contract_id, company_id)
return ok(message="合同已删除")
@router.post("/{contract_id}/generate-order", summary="一键从合同生成订单")
async def generate_order_from_contract(
contract_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.generate_order_from_contract(db, current_user, contract_id, company_id)
return ok(data=result, message="订单生成成功")
@router.get("/{contract_id}/generate", summary="生成合同 Word 文档下载")
async def generate_contract_document(
contract_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
):
from fastapi.responses import Response
docx_bytes = await svc.generate_contract_docx(db, contract_id, company_id)
return Response(
content=docx_bytes,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f"attachment; filename=contract_{contract_id}.docx"},
)
@router.post("/{contract_id}/upload-signed", summary="上传双签盖章版")
async def upload_signed_copy(
contract_id: uuid.UUID,
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
import os
from app.models.contract import ErpContract, ErpContractAttachment
from sqlalchemy import update as sa_update
# 验证合同存在
from sqlalchemy import select as sa_select
contract = (await db.execute(
sa_select(ErpContract).where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
)).scalar_one_or_none()
if contract is None:
raise Exception("合同不存在")
# 保存文件
upload_dir = f"uploads/contracts/{contract_id}"
os.makedirs(upload_dir, exist_ok=True)
file_path = f"{upload_dir}/{file.filename}"
with open(file_path, "wb") as f:
content = await file.read()
f.write(content)
file_url = f"/{file_path}"
# 记录附件
attachment = ErpContractAttachment(
contract_id=contract_id,
file_name=file.filename or "signed_copy",
file_url=file_url,
file_type="signed_copy",
uploader_id=current_user.user_id,
)
db.add(attachment)
# 更新合同签署状态
await db.execute(
sa_update(ErpContract)
.where(ErpContract.id == contract_id)
.values(is_signed=True, signed_file_url=file_url)
)
await db.commit()
return ok(message="双签盖章版上传成功", data={"file_url": file_url})
+15 -1
View File
@@ -4,7 +4,7 @@ CRM 客户模块路由 —— /api/customers
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, Query, Request
from fastapi import APIRouter, Body, Depends, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db.database import get_db
@@ -91,6 +91,20 @@ async def restore_customer(
return ok(message="客户已恢复")
@router.put("/{customer_id}/transfer", summary="转移客户负责人(仅管理员)")
async def transfer_customer(
customer_id: uuid.UUID,
body: dict = Body(..., examples=[{"new_owner_id": "uuid-here"}]),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
) -> dict:
new_owner_id = body.get("new_owner_id")
if not new_owner_id:
raise Exception("缺少 new_owner_id 参数")
result = await svc.transfer_customer(db, current_user, customer_id, uuid.UUID(str(new_owner_id)))
return ok(data=result.model_dump(mode="json"), message="客户转移成功")
@router.get("/{customer_id}/products", summary="获取客户关联产品(通过订单反查)")
async def get_customer_products(
customer_id: uuid.UUID,
+15 -10
View File
@@ -3,20 +3,21 @@ Dashboard 统计 API — /api/dashboard
"""
from __future__ import annotations
import uuid
from datetime import date, datetime
from fastapi import APIRouter, Depends
from sqlalchemy import func, select, and_, extract
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.response import ok
from app.models.order import ErpOrder
from app.models.shipping import ErpShippingRecord
from app.models.erp import ProductSku
from app.models.erp import ErpSkuInventory
router = APIRouter(prefix="/dashboard", tags=["Dashboard"])
@@ -25,42 +26,46 @@ router = APIRouter(prefix="/dashboard", tags=["Dashboard"])
async def get_stats(
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
):
today = date.today()
month_start = today.replace(day=1)
# 本月新增订单数
# 本月新增订单数(按公司隔离)
orders_count_q = select(func.count()).select_from(ErpOrder).where(
and_(
ErpOrder.is_deleted.is_(False),
ErpOrder.company_id == company_id,
ErpOrder.order_date >= month_start,
)
)
orders_count = (await db.execute(orders_count_q)).scalar() or 0
# 待出库发货数(状态为 pending
# 待出库发货数(按公司隔离
pending_shipping_q = select(func.count()).select_from(ErpOrder).where(
and_(
ErpOrder.is_deleted.is_(False),
ErpOrder.company_id == company_id,
ErpOrder.shipping_state == "pending",
)
)
pending_shipping = (await db.execute(pending_shipping_q)).scalar() or 0
# 库存预警 SKU 数(stock_qty <= warning_threshold 且 warning_threshold > 0
warning_skus_q = select(func.count()).select_from(ProductSku).where(
# 库存预警 SKU 数(从 erp_sku_inventory 查,按公司隔离
warning_skus_q = select(func.count()).select_from(ErpSkuInventory).where(
and_(
ProductSku.is_deleted.is_(False),
ProductSku.warning_threshold > 0,
ProductSku.stock_qty <= ProductSku.warning_threshold,
ErpSkuInventory.company_id == company_id,
ErpSkuInventory.warning_threshold > 0,
ErpSkuInventory.stock_qty <= ErpSkuInventory.warning_threshold,
)
)
warning_skus = (await db.execute(warning_skus_q)).scalar() or 0
# 本月预计营收(本月订单总金额
# 本月预计营收(按公司隔离
revenue_q = select(func.coalesce(func.sum(ErpOrder.total_amount), 0)).where(
and_(
ErpOrder.is_deleted.is_(False),
ErpOrder.company_id == company_id,
ErpOrder.order_date >= month_start,
)
)
+47 -2
View File
@@ -1,18 +1,21 @@
"""
FastAPI 依赖注入 —— 权限拦截核心
get_current_user: 解析 JWT → 查表获取完整权限上下文
get_current_company_id: 从 X-Company-Id Header 提取公司 ID + IDOR 校验
"""
from __future__ import annotations
import uuid
from fastapi import Depends, Header
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import UnauthorizedException
from app.core.exceptions import ForbiddenException, UnauthorizedException
from app.core.security import decode_access_token
from app.db.database import get_db
from app.models.sys import SysUser
from app.models.sys import SysCompany, SysUser, SysUserCompany
from app.schemas.auth import CurrentUserPayload
@@ -65,3 +68,45 @@ async def get_current_user(
data_scope=user.role.data_scope if user.role else "self",
menu_keys=user.role.menu_keys if user.role else [],
)
async def get_current_company_id(
x_company_id: str = Header(..., alias="X-Company-Id", description="当前工作台的公司 ID"),
current_user: CurrentUserPayload = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> uuid.UUID:
"""
公司视角依赖(IDOR 防护核心):
1. 从 X-Company-Id Header 提取公司 UUID
2. 校验当前用户是否归属于该公司(查 sys_user_companies
3. 校验公司是否启用
"""
# ── 解析 company_id ──
try:
company_uuid = uuid.UUID(x_company_id)
except ValueError:
raise UnauthorizedException("X-Company-Id 格式错误,需为合法 UUID")
# ── IDOR 防护:校验用户-公司归属 ──
assoc = (await db.execute(
select(SysUserCompany).where(
SysUserCompany.user_id == current_user.user_id,
SysUserCompany.company_id == company_uuid,
)
)).scalar_one_or_none()
if assoc is None:
raise ForbiddenException("您无权访问该公司数据")
# ── 校验公司是否启用 ──
company = (await db.execute(
select(SysCompany).where(
SysCompany.id == company_uuid,
SysCompany.is_active.is_(True),
)
)).scalar_one_or_none()
if company is None:
raise ForbiddenException("公司不存在或已停用")
return company_uuid
+407 -21
View File
@@ -9,7 +9,7 @@ import time
import base64
from fastapi import APIRouter, Depends, Query, Body, File, UploadFile, Form
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.finance import ExpenseCreate, ExpenseStatusUpdate, InvoiceCreate
@@ -43,34 +43,96 @@ async def ocr_recognize(
file_url = f"/uploads/finance/{safe_filename}"
# 支持图片(png/jpg/jpeg)和 PDF,不再支持 MD/TXT
supported = {".png", ".jpg", ".jpeg", ".pdf"}
# 支持的格式:结构化零算力 > 文本 LLM > 图片 Vision
supported = {".png", ".jpg", ".jpeg", ".pdf", ".md", ".ofd", ".xml", ".zip"}
if ext not in supported:
raise BizException(message=f"不支持的文件格式 {ext},仅支持: {', '.join(supported)}")
# 如果是 PDF,转成 PNG 再做 OCR
ocr_bytes = file_bytes
raise BizException(message=f"不支持的文件格式 {ext},仅支持: {', '.join(sorted(supported))}")
# ── 策略 A0: ZIP → 解包所有 XML 并逐个解析 ──
if ext == ".zip":
from app.services.invoice_parser import parse_zip_invoices
results = parse_zip_invoices(file_bytes)
return ok(data={"zip_results": [
{"filename": r.get("filename", ""), "success": r.get("success", False),
"ocr_data": r.get("data", {}), "needs_llm": r.get("needs_llm", False),
"error": r.get("error")}
for r in results
], "file_url": file_url}, message=f"ZIP 解析完成:{sum(1 for r in results if r.get('success'))}/{len(results)} 成功")
# ── 策略 A: OFD / XML → 结构化零算力提取(最快最准)──
if ext in (".ofd", ".xml"):
from app.services.invoice_parser import parse_ofd_invoice, parse_xml_invoice
parser = parse_ofd_invoice if ext == ".ofd" else parse_xml_invoice
result = parser(file_bytes)
print(f"[OCR] {ext.upper()} 解析: success={result.get('success')}")
if result.get("success"):
# 如果解析器提取到 raw_text 且标记 needs_llm,交给 LLM 做字段提取
if result.get("needs_llm") and result["data"].get("raw_text"):
from app.services.ocr_service import extract_invoice_from_text
llm_result = await extract_invoice_from_text(result["data"]["raw_text"], scene)
if llm_result.get("success"):
return ok(data={"ocr_data": llm_result["data"], "file_url": file_url}, message=f"AI 发票识别成功({ext.upper()} → LLM")
return ok(data={"ocr_data": llm_result.get("data", {}), "file_url": file_url}, message=llm_result.get("error", "LLM 解析失败"))
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message=f"发票识别成功({ext.upper()} 结构化提取)")
return ok(data={"ocr_data": {}, "file_url": file_url}, message=result.get("error", f"{ext.upper()} 解析失败"))
# ── 策略 B: MD → 纯文本 LLM 理解(零 GPU Vision)──
if ext == ".md":
text = file_bytes.decode("utf-8", errors="replace").strip()
print(f"[OCR] MD 文本: {len(text)} 字符")
if len(text) < 20:
return ok(data={"ocr_data": {}, "file_url": file_url}, message="MD 文件内容过少,无法识别")
from app.services.ocr_service import extract_invoice_from_text
result = await extract_invoice_from_text(text, scene)
if result.get("success"):
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message="AI 发票识别成功(MD 文本解析)")
return ok(data={"ocr_data": result.get("data", {}), "file_url": file_url}, message=result.get("error", "MD 文本解析失败"))
# ── 策略 C: PDF → PyMuPDF 提取文本 → LLM(零 GPU Vision)──
if ext == ".pdf":
try:
import fitz # PyMuPDF
doc = fitz.open(stream=file_bytes, filetype="pdf")
page = doc[0] # 取第一页
# 中等分辨率渲染(150 DPI,平衡质量与大小)
pix = page.get_pixmap(dpi=150)
ocr_bytes = pix.tobytes("png")
text = ""
for page in doc:
text += page.get_text() + "\n"
doc.close()
print(f"[OCR] PDF 转 PNG 成功: {len(ocr_bytes)} bytes")
text = text.strip()
print(f"[OCR] PDF 文本提取: {len(text)} 字符")
if len(text) > 50: # 有足够文本内容
from app.services.ocr_service import extract_invoice_from_text
result = await extract_invoice_from_text(text, scene)
if result.get("success"):
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message="AI 发票识别成功(PDF 文本解析)")
return ok(data={"ocr_data": result.get("data", {}), "file_url": file_url}, message=result.get("error", "PDF 文本提取失败"))
else:
# PDF 是扫描件(无文字层),降级到图片 OCR
print(f"[OCR] PDF 无文本层(仅 {len(text)} 字符),降级到图片 OCR")
page = fitz.open(stream=file_bytes, filetype="pdf")[0]
pix = page.get_pixmap(dpi=150)
ocr_bytes = pix.tobytes("png")
print(f"[OCR] PDF 转 PNG 成功: {len(ocr_bytes)} bytes")
except Exception as e:
print(f"[OCR] PDF 转换失败: {e}")
return ok(data={"ocr_data": {}, "file_url": file_url}, message=f"PDF 转换失败: {e}")
# 转换为纯 base64 传给 OCR
print(f"[OCR] PDF 处理失败: {e}")
return ok(data={"ocr_data": {}, "file_url": file_url}, message=f"PDF 处理失败: {e}")
else:
ocr_bytes = file_bytes
# ── 策略 D: 图片/扫描PDF → Vision OCR(需要视觉模型)──
from app.services.ocr_service import ocr_image
image_base64 = base64.b64encode(ocr_bytes).decode("utf-8")
result = await ocr_image(image_base64, scene)
if result.get("success"):
return ok(data={"ocr_data": result["data"], "file_url": file_url}, message="AI OCR 识别成功")
return ok(data={"ocr_data": result.get("data", {}), "file_url": file_url}, message=result.get("error", "OCR 识别失败"))
# Vision 失败时友好提示
error_msg = result.get("error", "OCR 识别失败")
if "模型进程崩溃" in error_msg or "unexpectedly stopped" in error_msg or "服务异常" in error_msg:
error_msg += "。建议:请上传电子版 PDF/OFD/XML 发票,系统可零算力直接提取数据"
return ok(data={"ocr_data": {}, "file_url": file_url}, message=error_msg)
@router.post("/invoices", summary="上传票据入池(含 AI/OCR JSONB 数据)")
@@ -78,8 +140,9 @@ async def create_invoice(
body: InvoiceCreate,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.create_invoice(db, current_user, body)
result = await svc.create_invoice(db, current_user, body, company_id)
return ok(data=result.model_dump(mode="json"), message="票据入池成功")
@@ -91,8 +154,9 @@ async def list_invoices(
is_used: bool | None = Query(None),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.list_invoices(db, current_user, page, size, type, is_used)
result = await svc.list_invoices(db, current_user, page, size, type, is_used, company_id)
return ok(data=result.model_dump(mode="json"))
@@ -111,8 +175,9 @@ async def create_expense(
body: ExpenseCreate,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.create_expense(db, current_user, body)
result = await svc.create_expense(db, current_user, body, company_id)
return ok(data=result.model_dump(mode="json"), message=f"报销单 {result.system_no} 提交成功")
@@ -124,8 +189,9 @@ async def list_expenses(
applicant_id: uuid.UUID | None = Query(None, description="按申请人过滤(管理员用)"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.list_expenses(db, current_user, page, size, status, applicant_id)
result = await svc.list_expenses(db, current_user, page, size, status, applicant_id, company_id)
return ok(data=result.model_dump(mode="json"))
@@ -148,3 +214,323 @@ async def update_expense_status(
) -> dict:
msg = await svc.update_expense_status(db, current_user, expense_id, body)
return ok(message=msg)
# ════════════════════════════════════════════════════════════
# 批量上传 + OCR 任务队列 API
# ════════════════════════════════════════════════════════════
@router.post("/upload-batch", summary="批量上传发票(ZIP/XML 即时入池,图片PDF 入队列)")
async def upload_batch(
files: list[UploadFile] = File(...),
scene: str = Form("invoice"),
inv_type: str = Form("expense"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
from app.services.invoice_parser import parse_xml_invoice, parse_ofd_invoice, parse_zip_invoices
from app.services.ocr_service import extract_invoice_from_text
from app.models.finance import FinInvoicePool, FinOcrTask
upload_dir = "uploads/finance"
os.makedirs(upload_dir, exist_ok=True)
results = [] # 返回给前端
for file in files:
file_bytes = await file.read()
ext = os.path.splitext(file.filename or "")[1].lower() or ".bin"
ts = int(time.time())
safe_fn = f"{ts}_{uuid.uuid4().hex[:8]}{ext}"
file_path = os.path.join(upload_dir, safe_fn)
with open(file_path, "wb") as f:
f.write(file_bytes)
file_url = f"/uploads/finance/{safe_fn}"
# ── ZIP: 解压内部 XML,逐个即时入池 ──
if ext == ".zip":
zip_results = parse_zip_invoices(file_bytes)
for zr in zip_results:
if zr.get("success") and not zr.get("needs_llm"):
ai_data = zr.get("data", {})
# 需要 LLM 的 zip 中的 xml 也立刻处理
merchant = ai_data.get("merchant") or ai_data.get("merchant_name") or "(ZIP)"
amount = float(ai_data.get("amount", 0) or 0)
inv_date_str = ai_data.get("date")
inv_date = None
if inv_date_str:
try:
from datetime import date as d
inv_date = d.fromisoformat(inv_date_str)
except ValueError:
pass
inv = FinInvoicePool(
uploader_id=current_user.user_id, company_id=company_id,
file_url=file_url, merchant_name=merchant, amount=amount,
invoice_date=inv_date, type=inv_type, ai_extracted_data=ai_data,
)
db.add(inv)
results.append({"filename": zr.get("filename", file.filename), "action": "pooled",
"status": "success", "message": f"{merchant} ¥{amount}"})
elif zr.get("needs_llm") and zr.get("data", {}).get("raw_text"):
# LLM 文本理解(即时,<5s
try:
llm_r = await extract_invoice_from_text(zr["data"]["raw_text"], scene)
if llm_r.get("success"):
ai_data = llm_r["data"]
merchant = ai_data.get("merchant") or "(LLM)"
amount = float(ai_data.get("amount", 0) or 0)
inv = FinInvoicePool(
uploader_id=current_user.user_id, company_id=company_id,
file_url=file_url, merchant_name=merchant, amount=amount,
type=inv_type, ai_extracted_data=ai_data,
)
db.add(inv)
results.append({"filename": zr.get("filename"), "action": "pooled",
"status": "success", "message": f"{merchant} ¥{amount} (LLM)"})
else:
results.append({"filename": zr.get("filename"), "action": "failed",
"status": "error", "message": llm_r.get("error", "LLM 解析失败")})
except Exception as e:
results.append({"filename": zr.get("filename"), "action": "failed",
"status": "error", "message": str(e)})
else:
results.append({"filename": zr.get("filename", file.filename), "action": "failed",
"status": "error", "message": zr.get("error", "解析失败")})
continue
# ── XML / OFD: 零算力即时入池 ──
if ext in (".xml", ".ofd"):
parser = parse_xml_invoice if ext == ".xml" else parse_ofd_invoice
r = parser(file_bytes)
if r.get("success") and not r.get("needs_llm"):
ai_data = r.get("data", {})
merchant = ai_data.get("merchant") or ai_data.get("merchant_name") or "(解析)"
amount = float(ai_data.get("amount", 0) or 0)
inv_date_str = ai_data.get("date")
inv_date = None
if inv_date_str:
try:
from datetime import date as d
inv_date = d.fromisoformat(inv_date_str)
except ValueError:
pass
inv = FinInvoicePool(
uploader_id=current_user.user_id, company_id=company_id,
file_url=file_url, merchant_name=merchant, amount=amount,
invoice_date=inv_date, type=inv_type, ai_extracted_data=ai_data,
)
db.add(inv)
results.append({"filename": file.filename, "action": "pooled",
"status": "success", "message": f"{merchant} ¥{amount}"})
elif r.get("needs_llm") and r.get("data", {}).get("raw_text"):
try:
llm_r = await extract_invoice_from_text(r["data"]["raw_text"], scene)
if llm_r.get("success"):
ai_data = llm_r["data"]
merchant = ai_data.get("merchant") or "(LLM)"
amount = float(ai_data.get("amount", 0) or 0)
inv = FinInvoicePool(
uploader_id=current_user.user_id, company_id=company_id,
file_url=file_url, merchant_name=merchant, amount=amount,
type=inv_type, ai_extracted_data=ai_data,
)
db.add(inv)
results.append({"filename": file.filename, "action": "pooled",
"status": "success", "message": f"{merchant} ¥{amount} (LLM)"})
else:
results.append({"filename": file.filename, "action": "failed",
"status": "error", "message": llm_r.get("error", "LLM 失败")})
except Exception as e:
results.append({"filename": file.filename, "action": "failed",
"status": "error", "message": str(e)})
else:
results.append({"filename": file.filename, "action": "failed",
"status": "error", "message": r.get("error", "解析失败")})
continue
# ── 图片 / PDF : 写入 DB 任务队列 ──
task = FinOcrTask(
file_url=file_url, file_ext=ext,
original_name=file.filename or "unknown",
uploader_id=current_user.user_id,
company_id=company_id,
inv_type=inv_type,
priority=50 if ext == ".pdf" else 100, # PDF 优先(可能有文字层)
)
db.add(task)
await db.flush()
results.append({"filename": file.filename, "action": "queued",
"status": "pending", "task_id": str(task.id),
"message": "🕐 已加入 OCR 处理队列"})
await db.commit()
pooled = sum(1 for r in results if r["action"] == "pooled")
queued = sum(1 for r in results if r["action"] == "queued")
failed = sum(1 for r in results if r["action"] == "failed")
return ok(data={"results": results},
message=f"批量处理完成:{pooled} 即时入池,{queued} 排队中,{failed} 失败")
@router.get("/ocr-tasks", summary="OCR 任务队列列表")
async def list_ocr_tasks(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
status: str | None = Query(None, description="pending/processing/success/failed/manual"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
from sqlalchemy import func, select
from app.models.finance import FinOcrTask
where = [FinOcrTask.company_id == company_id, FinOcrTask.is_deleted.is_(False)]
if current_user.data_scope == "self":
where.append(FinOcrTask.uploader_id == current_user.user_id)
if status:
where.append(FinOcrTask.status == status)
total = (await db.execute(select(func.count()).select_from(FinOcrTask).where(*where))).scalar() or 0
stmt = (
select(FinOcrTask).where(*where)
.order_by(FinOcrTask.priority, FinOcrTask.created_at.desc())
.offset((page - 1) * size).limit(size)
)
tasks = (await db.execute(stmt)).scalars().all()
return ok(data={
"total": total, "page": page, "size": size,
"items": [{
"id": str(t.id),
"original_name": t.original_name,
"file_ext": t.file_ext,
"file_url": t.file_url,
"status": t.status,
"priority": t.priority,
"retry_count": t.retry_count,
"max_retries": t.max_retries,
"error_message": t.error_message,
"ocr_result": t.ocr_result,
"invoice_pool_id": str(t.invoice_pool_id) if t.invoice_pool_id else None,
"uploader_name": t.uploader.real_name if t.uploader else None,
"inv_type": t.inv_type,
"created_at": str(t.created_at),
"updated_at": str(t.updated_at),
} for t in tasks],
})
@router.post("/ocr-tasks/{task_id}/retry", summary="重试失败的 OCR 任务")
async def retry_ocr_task(
task_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
) -> dict:
from sqlalchemy import select, update
from app.models.finance import FinOcrTask
task = (await db.execute(
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
)).scalar_one_or_none()
if not task:
raise BizException(message="任务不存在")
if task.status not in ("failed", "manual"):
raise BizException(message=f"当前状态 [{task.status}] 不允许重试")
task.status = "pending"
task.retry_count = 0
task.error_message = None
await db.commit()
return ok(message="任务已重新入队")
@router.post("/ocr-tasks/{task_id}/manual", summary="手动录入 OCR 结果并入池")
async def manual_ocr_task(
task_id: uuid.UUID,
body: dict,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
) -> dict:
from sqlalchemy import select
from app.models.finance import FinOcrTask, FinInvoicePool
task = (await db.execute(
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
)).scalar_one_or_none()
if not task:
raise BizException(message="任务不存在")
merchant = body.get("merchant_name", "手动录入")
amount = float(body.get("amount", 0))
inv_date_str = body.get("invoice_date")
inv_date = None
if inv_date_str:
try:
from datetime import date as d
inv_date = d.fromisoformat(inv_date_str)
except ValueError:
pass
inv = FinInvoicePool(
uploader_id=task.uploader_id, company_id=task.company_id,
file_url=task.file_url, merchant_name=merchant, amount=amount,
invoice_date=inv_date, type=task.inv_type, ai_extracted_data=body,
)
db.add(inv)
await db.flush()
task.status = "manual"
task.invoice_pool_id = inv.id
task.ocr_result = body
task.error_message = None
await db.commit()
return ok(data={"invoice_pool_id": str(inv.id)}, message="手动录入成功,发票已入池")
@router.put("/ocr-tasks/{task_id}/priority", summary="调整 OCR 任务优先级")
async def update_ocr_task_priority(
task_id: uuid.UUID,
body: dict,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
) -> dict:
from sqlalchemy import select
from app.models.finance import FinOcrTask
task = (await db.execute(
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
)).scalar_one_or_none()
if not task:
raise BizException(message="任务不存在")
if task.status not in ("pending",):
raise BizException(message="仅待处理任务可调整优先级")
new_priority = body.get("priority", task.priority)
task.priority = int(new_priority)
await db.commit()
return ok(message=f"优先级已调整为 {task.priority}")
@router.delete("/ocr-tasks/{task_id}", summary="取消/删除 OCR 任务")
async def delete_ocr_task(
task_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
) -> dict:
from sqlalchemy import select
from app.models.finance import FinOcrTask
task = (await db.execute(
select(FinOcrTask).where(FinOcrTask.id == task_id, FinOcrTask.is_deleted.is_(False))
)).scalar_one_or_none()
if not task:
raise BizException(message="任务不存在")
if task.status == "processing":
raise BizException(message="正在处理中的任务无法取消")
task.is_deleted = True
await db.commit()
return ok(message="任务已取消")
+5 -7
View File
@@ -56,7 +56,7 @@ async def import_products(
current_user: CurrentUserPayload = Depends(get_current_user),
):
from openpyxl import load_workbook
from app.models.erp import ErpProductSku
from app.models.erp import ProductSku
content = await file.read()
wb = load_workbook(io.BytesIO(content))
@@ -79,7 +79,6 @@ async def import_products(
spec = str(row[2] or "").strip() or None
standard_price = float(row[3] or 0)
unit = str(row[4] or "").strip()
warning_threshold = float(row[5] or 0)
if not sku_code or not name:
skipped += 1
@@ -87,22 +86,21 @@ async def import_products(
# 检查 sku_code 是否已存在
exists = (await db.execute(
select(func.count()).select_from(ErpProductSku).where(
ErpProductSku.sku_code == sku_code,
ErpProductSku.is_deleted.is_(False),
select(func.count()).select_from(ProductSku).where(
ProductSku.sku_code == sku_code,
ProductSku.is_deleted.is_(False),
)
)).scalar()
if exists:
skipped += 1
continue
sku = ErpProductSku(
sku = ProductSku(
sku_code=sku_code,
name=name,
spec=spec,
standard_price=standard_price,
unit=unit,
warning_threshold=warning_threshold,
)
db.add(sku)
created += 1
+338 -4
View File
@@ -6,7 +6,7 @@ from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.order import OrderCreate
@@ -32,8 +32,9 @@ async def create_order(
body: OrderCreate,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.create_order(db, current_user, body)
result = await svc.create_order(db, current_user, body, company_id)
return ok(data=result.model_dump(mode="json"), message=f"订单 {result.order_no} 创建成功")
@@ -47,16 +48,349 @@ async def list_orders(
keyword: str | None = Query(None, description="模糊搜索订单号"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.list_orders(db, current_user, page, size, customer_id, shipping_state, payment_state, keyword)
result = await svc.list_orders(db, current_user, page, size, customer_id, shipping_state, payment_state, keyword, company_id)
return ok(data=result.model_dump(mode="json"))
@router.get("/unlinked-invoices", summary="查询未关联订单的发票列表")
async def list_unlinked_invoices(
keyword: str | None = Query(None, description="发票号模糊搜索"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
from sqlalchemy import select
from app.models.finance import FinSalesInvoice
conditions = [
FinSalesInvoice.company_id == company_id,
FinSalesInvoice.is_deleted.is_(False),
FinSalesInvoice.order_id.is_(None),
]
if keyword:
conditions.append(FinSalesInvoice.invoice_number.ilike(f"%{keyword}%"))
stmt = (
select(FinSalesInvoice)
.where(*conditions)
.order_by(FinSalesInvoice.created_at.desc())
.limit(50)
)
invoices = (await db.execute(stmt)).scalars().all()
return ok(data=[
{
"id": str(inv.id),
"invoice_number": inv.invoice_number,
"issuer": inv.issuer,
"receiver_name": inv.receiver_customer.name if inv.receiver_customer else None,
"amount": float(inv.amount),
"billing_date": str(inv.billing_date),
}
for inv in invoices
])
@router.get("/{order_id}", summary="订单全景详情(关系预加载 items + customer")
async def get_order(
order_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.get_order(db, current_user, order_id)
result = await svc.get_order(db, current_user, order_id, company_id)
return ok(data=result.model_dump(mode="json"))
@router.get("/{order_id}/invoices", summary="获取订单关联的销项发票")
async def get_order_invoices(
order_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
from sqlalchemy import select
from app.models.finance import FinSalesInvoice
stmt = (
select(FinSalesInvoice)
.where(
FinSalesInvoice.order_id == order_id,
FinSalesInvoice.company_id == company_id,
FinSalesInvoice.is_deleted.is_(False),
)
.order_by(FinSalesInvoice.created_at.desc())
)
invoices = (await db.execute(stmt)).scalars().all()
return ok(data=[
{
"id": str(inv.id),
"invoice_number": inv.invoice_number,
"issuer": inv.issuer,
"receiver_name": inv.receiver_customer.name if inv.receiver_customer else None,
"amount": float(inv.amount),
"billing_date": str(inv.billing_date),
"payment_status": inv.payment_status,
"payment_date": str(inv.payment_date) if inv.payment_date else None,
"payment_amount": float(inv.payment_amount or 0),
"payment_due_date": str(inv.payment_due_date) if inv.payment_due_date else None,
}
for inv in invoices
])
@router.put("/{order_id}/payment", summary="更新订单收款状态")
async def update_order_payment(
order_id: uuid.UUID,
body: dict,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
from sqlalchemy import select, update as sa_update
from app.models.order import ErpOrder
from datetime import datetime
order = (await db.execute(
select(ErpOrder).where(
ErpOrder.id == order_id,
ErpOrder.company_id == company_id,
ErpOrder.is_deleted.is_(False),
)
)).scalar_one_or_none()
if order is None:
from app.core.exceptions import NotFoundException
raise NotFoundException("订单不存在")
values = {}
if "paid_amount" in body:
paid = float(body["paid_amount"])
values["paid_amount"] = paid
total = float(order.total_amount)
if paid >= total:
values["payment_state"] = "cleared"
elif paid > 0:
values["payment_state"] = "partial"
else:
values["payment_state"] = "unpaid"
if "payment_state" in body:
values["payment_state"] = body["payment_state"]
if values:
values["updated_at"] = datetime.utcnow()
await db.execute(
sa_update(ErpOrder).where(ErpOrder.id == order_id).values(**values)
)
await db.commit()
return ok(message="收款状态已更新")
@router.get("/{order_id}/invoice-detail-preview", summary="生成开票明细预览")
async def invoice_detail_preview(
order_id: uuid.UUID,
mode: str = Query("full", pattern=r"^(full|batch)$", description="full=整体开票, batch=按发货批次"),
shipping_id: uuid.UUID | None = Query(None, description="batch模式下必传发货单ID"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
"""根据模式生成开票明细: 整体=订单全部商品, 批次=指定发货单商品"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from app.models.order import ErpOrder, ErpOrderItem
from app.models.shipping import ErpShippingRecord, ErpShippingItem
from app.models.crm import CrmCustomer
from app.models.sys import SysCompany
from app.core.exceptions import NotFoundException, BizException
# 查订单
order = (await db.execute(
select(ErpOrder)
.where(ErpOrder.id == order_id, ErpOrder.company_id == company_id, ErpOrder.is_deleted.is_(False))
.options(
selectinload(ErpOrder.items),
selectinload(ErpOrder.customer),
selectinload(ErpOrder.salesperson),
)
)).scalar_one_or_none()
if not order:
raise NotFoundException("订单不存在")
# 买方名称
buyer_name = order.customer.name if order.customer else ""
# 卖方名称
company = (await db.execute(
select(SysCompany).where(SysCompany.id == company_id)
)).scalar_one_or_none()
seller_name = company.name if company else ""
items_data = []
total_amount = 0.0
if mode == "full":
# 整体开票: 聚合全部订单明细
for oi in (order.items or []):
sub = float(oi.sub_total or 0)
items_data.append({
"sku_code": oi.sku.sku_code if oi.sku else "",
"sku_name": oi.sku.name if oi.sku else "",
"spec": oi.sku.spec if oi.sku else "",
"unit": oi.sku.unit if oi.sku else "",
"qty": float(oi.qty),
"unit_price": float(oi.unit_price),
"sub_total": sub,
})
total_amount += sub
else:
# 按发货批次
if not shipping_id:
raise BizException(message="batch模式需指定shipping_id")
ship = (await db.execute(
select(ErpShippingRecord)
.where(
ErpShippingRecord.id == shipping_id,
ErpShippingRecord.order_id == order_id,
ErpShippingRecord.is_deleted.is_(False),
)
.options(selectinload(ErpShippingRecord.items).selectinload(ErpShippingItem.sku))
)).scalar_one_or_none()
if not ship:
raise NotFoundException("发货单不存在")
# 查对应的订单明细来获取单价
order_item_map = {str(oi.id): oi for oi in (order.items or [])}
for si in (ship.items or []):
oi = order_item_map.get(str(si.order_item_id))
unit_price = float(oi.unit_price) if oi else 0
qty = float(si.shipped_qty)
sub = round(qty * unit_price, 2)
items_data.append({
"sku_code": si.sku.sku_code if si.sku else "",
"sku_name": si.sku.name if si.sku else "",
"spec": si.sku.spec if si.sku else "",
"unit": si.sku.unit if si.sku else "",
"qty": qty,
"unit_price": unit_price,
"sub_total": sub,
})
total_amount += sub
return ok(data={
"order_no": order.order_no,
"buyer_name": buyer_name,
"seller_name": seller_name,
"customer_id": str(order.customer_id),
"items": items_data,
"total_amount": round(total_amount, 2),
"shipping_id": str(shipping_id) if shipping_id else None,
})
@router.post("/{order_id}/invoices/link", summary="关联已有发票到订单")
async def link_existing_invoice(
order_id: uuid.UUID,
body: dict,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
"""将已存在的销项发票关联到该订单"""
from sqlalchemy import select, update as sa_update
from app.models.finance import FinSalesInvoice
from app.core.exceptions import NotFoundException, BizException
from datetime import datetime
invoice_id = body.get("invoice_id")
shipping_record_id = body.get("shipping_record_id")
if not invoice_id:
raise BizException(message="请提供 invoice_id")
inv = (await db.execute(
select(FinSalesInvoice).where(
FinSalesInvoice.id == uuid.UUID(invoice_id),
FinSalesInvoice.company_id == company_id,
FinSalesInvoice.is_deleted.is_(False),
)
)).scalar_one_or_none()
if not inv:
raise NotFoundException("发票不存在")
values = {"order_id": order_id, "updated_at": datetime.utcnow()}
if shipping_record_id:
values["shipping_record_id"] = uuid.UUID(shipping_record_id)
await db.execute(
sa_update(FinSalesInvoice)
.where(FinSalesInvoice.id == uuid.UUID(invoice_id))
.values(**values)
)
await db.commit()
return ok(message="发票已关联到订单")
@router.post("/{order_id}/invoices/create", summary="直接创建发票并关联到订单")
async def create_and_link_invoice(
order_id: uuid.UUID,
body: dict,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
"""创建新的销项发票,同时关联到当前订单"""
from sqlalchemy import select
from app.models.finance import FinSalesInvoice
from app.models.order import ErpOrder
from app.core.exceptions import NotFoundException, BizException
from datetime import date as dt_date
order = (await db.execute(
select(ErpOrder).where(
ErpOrder.id == order_id,
ErpOrder.company_id == company_id,
ErpOrder.is_deleted.is_(False),
)
)).scalar_one_or_none()
if not order:
raise NotFoundException("订单不存在")
invoice_number = body.get("invoice_number", "").strip()
amount = float(body.get("amount", 0))
issuer = body.get("issuer", "").strip()
receiver_customer_id = body.get("receiver_customer_id") or str(order.customer_id)
billing_date_str = body.get("billing_date")
shipping_record_id = body.get("shipping_record_id")
remark = body.get("remark")
if not invoice_number:
raise BizException(message="请填写发票号")
if amount <= 0:
raise BizException(message="开票金额需大于0")
if not issuer:
raise BizException(message="请填写开票方名称")
# 检查唯一性
from sqlalchemy import func as sa_func
existing = (await db.execute(
select(sa_func.count()).select_from(FinSalesInvoice).where(
FinSalesInvoice.invoice_number == invoice_number,
FinSalesInvoice.is_deleted.is_(False),
)
)).scalar()
if existing:
raise BizException(message=f"发票号 {invoice_number} 已存在")
inv = FinSalesInvoice(
issuer=issuer,
receiver_customer_id=uuid.UUID(receiver_customer_id),
invoice_number=invoice_number,
amount=amount,
billing_date=dt_date.fromisoformat(billing_date_str) if billing_date_str else dt_date.today(),
remark=remark,
order_id=order_id,
shipping_record_id=uuid.UUID(shipping_record_id) if shipping_record_id else None,
created_by=current_user.user_id,
company_id=company_id,
)
db.add(inv)
await db.commit()
return ok(data={"id": str(inv.id), "invoice_number": invoice_number}, message="发票创建并关联成功")
+7 -4
View File
@@ -6,7 +6,7 @@ from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.erp import CategoryCreate, CategoryUpdate, InventoryFlowCreate, SkuCreate, SkuUpdate
@@ -64,8 +64,9 @@ async def list_skus(
keyword: str | None = Query(None, description="模糊搜索 SKU 编码或名称"),
db: AsyncSession = Depends(get_db),
_: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.list_skus(db, page, size, category_id, keyword)
result = await svc.list_skus(db, company_id, page, size, category_id, keyword)
return ok(data=result.model_dump(mode="json"))
@@ -95,8 +96,9 @@ async def create_inventory_flow(
body: InventoryFlowCreate,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.create_inventory_flow(db, current_user, body)
result = await svc.create_inventory_flow(db, current_user, body, company_id)
return ok(data=result.model_dump(mode="json"), message="库存变更成功")
@@ -107,6 +109,7 @@ async def get_inventory_flows(
size: int = Query(50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
_: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.get_inventory_flows(db, sku_id, page, size)
result = await svc.get_inventory_flows(db, sku_id, company_id, page, size)
return ok(data=result)
+37
View File
@@ -0,0 +1,37 @@
"""
利润核算路由 —— /api/profit
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.response import ok
from app.services import profit_service as svc
router = APIRouter(prefix="/profit", tags=["利润核算"])
@router.get("/report", summary="利润报表(订单维度)")
async def profit_report(
start_date: str | None = Query(None, description="起始日期 YYYY-MM-DD"),
end_date: str | None = Query(None, description="结束日期 YYYY-MM-DD"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.get_profit_report(db, company_id, start_date, end_date)
return ok(data=result)
@router.post("/snapshot/{order_id}", summary="为订单锚定成本快照")
async def snapshot_costs(
order_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.snapshot_order_item_costs(db, order_id, company_id)
return ok(data=result, message=f"已为 {len(result)} 项明细锚定成本")
+16 -10
View File
@@ -14,7 +14,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.response import ok
@@ -28,15 +28,16 @@ async def generate_report(
end_date: date = Body(..., embed=True),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
authorization: str | None = Header(None),
):
"""
1. 聚合该用户在时间范围内的 sales_logs 内容
1. 聚合该用户在时间范围内、涉及当前公司的 sales_logs 内容
2. 调用 Dify Workflow (streaming) 生成复盘报告
3. SSE 流式返回给前端
"""
return StreamingResponse(
_report_sse_generator(db, current_user, start_date, end_date, authorization or ""),
_report_sse_generator(db, current_user, start_date, end_date, authorization or "", company_id),
media_type="text/event-stream",
)
@@ -47,20 +48,25 @@ async def _report_sse_generator(
start_date: date,
end_date: date,
authorization: str = "",
company_id: uuid.UUID | None = None,
):
import httpx
from app.core.config import settings
from app.models.ai import SalesLog
# 1. 聚合日志
# 1. 聚合日志 — 仅提取涉及当前公司的日志
conditions = [
SalesLog.salesperson_id == user.user_id,
SalesLog.log_date >= start_date,
SalesLog.log_date <= end_date,
SalesLog.is_deleted.is_(False),
]
if company_id:
conditions.append(SalesLog.involved_company_ids.any(company_id))
stmt = (
select(SalesLog)
.where(
SalesLog.salesperson_id == user.user_id,
SalesLog.log_date >= start_date,
SalesLog.log_date <= end_date,
SalesLog.is_deleted.is_(False),
)
.where(*conditions)
.order_by(SalesLog.log_date)
)
logs = (await db.execute(stmt)).scalars().all()
+5 -3
View File
@@ -10,7 +10,7 @@ from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.sales_invoice import SalesInvoiceCreate, SalesInvoiceUpdate
@@ -26,8 +26,9 @@ async def create_invoice(
body: SalesInvoiceCreate,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.create_invoice(db, current_user, body)
result = await svc.create_invoice(db, current_user, body, company_id)
return ok(data=result.model_dump(mode="json"), message="销项发票创建成功")
@@ -42,10 +43,11 @@ async def list_invoices(
end_date: date | None = Query(None, description="开票结束日期"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.list_invoices(
db, page, size, customer_name, invoice_number,
payment_status, start_date, end_date,
payment_status, start_date, end_date, company_id,
)
return ok(data=result.model_dump(mode="json"))
+79 -3
View File
@@ -3,11 +3,13 @@
"""
from __future__ import annotations
import uuid
import asyncio
from fastapi import APIRouter, Depends, Body
from fastapi import APIRouter, Depends, Body, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.response import ok
@@ -26,6 +28,7 @@ async def list_logs(
end_date: str | None = None,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
):
result = await sales_log_service.list_logs(
db, current_user,
@@ -34,18 +37,56 @@ async def list_logs(
user_id=user_id,
start_date=start_date,
end_date=end_date,
company_id=company_id,
)
return ok(data=result)
async def _resolve_company_ids(
db: AsyncSession,
company_id: uuid.UUID,
customer_id: str | None,
company_ids: list[str] | None,
) -> list[uuid.UUID]:
"""
智能解析 involved_company_ids
1. 如果前端显式传了 company_ids,使用它
2. 否则以当前视角公司为基础
3. 如果选了客户,自动查客户 owner 所属的公司,合并进来
"""
if company_ids:
resolved = set(uuid.UUID(cid) for cid in company_ids)
else:
resolved = {company_id}
# 自动关联客户 owner 所在公司
if customer_id:
from app.models.crm import CrmCustomer
from app.models.sys import SysUserCompany
cust = await db.get(CrmCustomer, uuid.UUID(customer_id))
if cust and cust.owner_id:
stmt = select(SysUserCompany.company_id).where(
SysUserCompany.user_id == cust.owner_id
)
rows = (await db.execute(stmt)).scalars().all()
for cid in rows:
resolved.add(cid)
# 确保当前公司始终在内
resolved.add(company_id)
return list(resolved)
@router.post("", summary="创建销售日志")
async def create_log(
content: str = Body(..., embed=True),
customer_id: str | None = Body(None, embed=True),
contact_ids: list[str] | None = Body(None, embed=True),
log_date: str | None = Body(None, embed=True),
company_ids: list[str] | None = Body(None, embed=True),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
):
from datetime import date as date_type
@@ -53,17 +94,20 @@ async def create_log(
if log_date:
parsed_date = date_type.fromisoformat(log_date)
# 智能解析公司关联
resolved_company_ids = await _resolve_company_ids(db, company_id, customer_id, company_ids)
result = await sales_log_service.create_log(
db, current_user,
content=content,
customer_id=customer_id,
contact_ids=contact_ids,
log_date=parsed_date,
company_ids=resolved_company_ids,
)
# 异步触发 Dify 画像提取工作流(仅当关联了客户时)
if customer_id:
import uuid
asyncio.create_task(
sales_log_service.trigger_persona_workflow(
log_id=uuid.UUID(result["id"]),
@@ -75,3 +119,35 @@ async def create_log(
)
return ok(data=result, message="日志创建成功")
@router.put("/{log_id}", summary="编辑销售日志")
async def update_log(
log_id: uuid.UUID,
content: str | None = Body(None, embed=True),
customer_id: str | None = Body(None, embed=True),
contact_ids: list[str] | None = Body(None, embed=True),
log_date: str | None = Body(None, embed=True),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
):
result = await sales_log_service.update_log(
db, current_user, log_id,
content=content,
customer_id=customer_id,
contact_ids=contact_ids,
log_date=log_date,
company_id=company_id,
)
return ok(data=result, message="日志更新成功")
@router.delete("/{log_id}", summary="删除销售日志(软删除)")
async def delete_log(
log_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
):
await sales_log_service.delete_log(db, current_user, log_id)
return ok(message="日志已删除")
+7 -4
View File
@@ -6,7 +6,7 @@ from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.api.deps import get_current_user, get_current_company_id
from app.db.database import get_db
from app.schemas.auth import CurrentUserPayload
from app.schemas.shipping import ShippingCreate
@@ -21,8 +21,9 @@ async def create_shipping(
body: ShippingCreate,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
resp, new_state = await svc.create_shipping(db, current_user, body)
resp, new_state = await svc.create_shipping(db, current_user, body, company_id)
return ok(data=resp.model_dump(mode="json"), message=f"发货单 {resp.shipping_no} 创建成功,订单状态已更新为 {new_state}")
@@ -34,8 +35,9 @@ async def list_shipping(
tracking_no: str | None = Query(None, description="按物流单号搜索"),
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.list_shipping(db, current_user, page, size, order_no, tracking_no)
result = await svc.list_shipping(db, current_user, page, size, order_no, tracking_no, company_id)
return ok(data=result.model_dump(mode="json"))
@@ -44,6 +46,7 @@ async def get_shipping_by_order(
order_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUserPayload = Depends(get_current_user),
company_id: uuid.UUID = Depends(get_current_company_id),
) -> dict:
result = await svc.get_shipping_by_order(db, current_user, order_id)
result = await svc.get_shipping_by_order(db, current_user, order_id, company_id)
return ok(data=result)
+11
View File
@@ -26,6 +26,10 @@ from app.api.sales_invoice import router as sales_invoice_router
from app.api.reports import router as reports_router
from app.api.contacts import router as contacts_router
from app.api.dashboard import router as dashboard_router
from app.api.companies import router as companies_router
from app.api.contracts import router as contracts_router
from app.api.profit import router as profit_router
from app.api.ai_coaching import router as ai_coaching_router
@asynccontextmanager
@@ -33,8 +37,11 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""应用生命周期:启动/关闭时的钩子"""
# ── startup ──
print(f"🚀 {settings.APP_NAME} v{settings.APP_VERSION} 启动中...")
from app.services.ocr_worker import ocr_worker
ocr_worker.start()
yield
# ── shutdown ──
await ocr_worker.stop()
print("👋 服务正在关闭...")
@@ -81,6 +88,10 @@ app.include_router(sales_invoice_router, prefix="/api")
app.include_router(reports_router, prefix="/api")
app.include_router(contacts_router, prefix="/api")
app.include_router(dashboard_router, prefix="/api")
app.include_router(companies_router, prefix="/api")
app.include_router(contracts_router, prefix="/api")
app.include_router(profit_router, prefix="/api")
app.include_router(ai_coaching_router, prefix="/api")
# ── 健康检查 ──
+26 -1
View File
@@ -7,7 +7,7 @@ import uuid
from datetime import date, datetime
from sqlalchemy import Boolean, Date, DateTime, ForeignKey, SmallInteger, String, Text, func
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.dialects.postgresql import UUID, JSONB, ARRAY
from sqlalchemy.orm import Mapped, mapped_column
from app.models.base import Base
@@ -30,11 +30,19 @@ class SalesLog(Base):
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
salesperson_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=False)
involved_company_ids: Mapped[list] = mapped_column(
ARRAY(UUID(as_uuid=True)), nullable=False, default=list,
comment="该篇日志涉及的公司ID列表"
)
customer_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("crm_customers.id"), nullable=True)
content: Mapped[str] = mapped_column(Text, nullable=False)
log_date: Mapped[date] = mapped_column(Date, default=date.today)
contact_ids: Mapped[list | None] = mapped_column(JSONB, default=list, nullable=True)
ai_processed: Mapped[bool] = mapped_column(Boolean, default=False)
ai_coaching_feedback: Mapped[dict | None] = mapped_column(
JSONB, default=dict, nullable=True,
comment="AI 教练引擎回写的指导反馈"
)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
@@ -53,3 +61,20 @@ class AiReportDraft(Base):
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
class KbObsidianVector(Base):
"""知识库向量表 —— pgvector 存储 Obsidian 文档分块向量"""
__tablename__ = "kb_obsidian_vectors"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
)
source_path: Mapped[str] = mapped_column(String(500), nullable=False, comment="源文件路径")
chunk_index: Mapped[int] = mapped_column(SmallInteger, default=0)
content: Mapped[str] = mapped_column(Text, nullable=False)
metadata_: Mapped[dict | None] = mapped_column("metadata", JSONB, default=dict)
# 向量字段使用 raw SQL 创建(vector(1536))因 SQLAlchemy 无原生 pgvector 类型
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
+171
View File
@@ -0,0 +1,171 @@
"""
合同域 ORM 模型
映射: erp_contracts / erp_contract_items / erp_contract_attachments
"""
from __future__ import annotations
import uuid
from datetime import date, datetime
from sqlalchemy import (
Boolean,
Date,
DateTime,
ForeignKey,
Numeric,
String,
Text,
func,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base
# ── 付款条件枚举 ─────────────────────────────────────────
PAYMENT_TERMS = [
"预付全款订货",
"预付30%订货,到货前付清",
"预付50%订货,到货前付清",
"货到付全款",
"开具发票后30天内付款",
"开具发票45天付款",
"开具发票60天付款",
"开具发票90天付款",
]
# ── 运费条款枚举 ─────────────────────────────────────────
SHIPPING_TERMS = [
"买方自提",
"卖方免费送达天津指定地点",
"卖方免费送达指定地点",
"物流发货,运费买方承担",
]
class ErpContract(Base):
"""合同主表 —— B2B 交易防线核心"""
__tablename__ = "erp_contracts"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
contract_no: Mapped[str] = mapped_column(String(30), unique=True, nullable=False)
buyer_customer_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("crm_customers.id"), nullable=False,
comment="买方(CRM 客户)"
)
seller_company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False,
comment="卖方(当前操作公司)"
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True,
comment="多租户隔离"
)
total_amount_excl_tax: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
total_amount_incl_tax: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
total_amount_cn: Mapped[str | None] = mapped_column(
String(100), nullable=True, comment="大写合计金额"
)
payment_terms: Mapped[str] = mapped_column(
String(50), nullable=False, default="货到付全款"
)
shipping_terms: Mapped[str] = mapped_column(
String(50), nullable=False, default="买方自提"
)
status: Mapped[str] = mapped_column(
String(20), nullable=False, default="draft",
comment="draft→active→completed→cancelled"
)
is_signed: Mapped[bool] = mapped_column(Boolean, default=False)
signed_file_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
linked_order_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_orders.id"), nullable=True,
comment="一键推单后回填"
)
salesperson_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
)
sign_date: Mapped[date | None] = mapped_column(Date, nullable=True)
remark: Mapped[str | None] = mapped_column(Text, nullable=True)
delivery_terms: Mapped[str | None] = mapped_column(
String(200), nullable=True, comment="货期(手动输入)"
)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now()
)
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
# 关系
buyer_customer: Mapped["CrmCustomer"] = relationship( # noqa: F821
"CrmCustomer", lazy="selectin"
)
seller_company: Mapped["SysCompany"] = relationship( # noqa: F821
"SysCompany", foreign_keys=[seller_company_id], lazy="selectin"
)
salesperson: Mapped["SysUser | None"] = relationship("SysUser", foreign_keys=[salesperson_id], lazy="selectin") # noqa: F821
linked_order: Mapped["ErpOrder | None"] = relationship("ErpOrder", foreign_keys=[linked_order_id], lazy="selectin") # noqa: F821
items: Mapped[list["ErpContractItem"]] = relationship(
"ErpContractItem", back_populates="contract", lazy="selectin"
)
attachments: Mapped[list["ErpContractAttachment"]] = relationship(
"ErpContractAttachment", back_populates="contract", lazy="selectin"
)
class ErpContractItem(Base):
"""合同明细行"""
__tablename__ = "erp_contract_items"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
contract_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_contracts.id"), nullable=False
)
sku_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_product_skus.id"), nullable=False
)
qty: Mapped[float] = mapped_column(Numeric(12, 2), nullable=False)
unit_price: Mapped[float] = mapped_column(Numeric(12, 2), nullable=False)
sub_total: Mapped[float] = mapped_column(Numeric(14, 2), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now()
)
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
# 关系
contract: Mapped[ErpContract] = relationship("ErpContract", back_populates="items")
sku: Mapped["ProductSku"] = relationship("ProductSku", lazy="selectin") # noqa: F821
class ErpContractAttachment(Base):
"""合同附件(双签盖章版等)"""
__tablename__ = "erp_contract_attachments"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
contract_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_contracts.id"), nullable=False
)
file_name: Mapped[str] = mapped_column(String(200), nullable=False)
file_url: Mapped[str] = mapped_column(String(500), nullable=False)
file_type: Mapped[str] = mapped_column(
String(30), nullable=False, default="signed_copy",
comment="signed_copy / supplement / other"
)
uploader_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
# 关系
contract: Mapped[ErpContract] = relationship("ErpContract", back_populates="attachments")
uploader: Mapped["SysUser | None"] = relationship("SysUser", lazy="selectin") # noqa: F821
+41
View File
@@ -0,0 +1,41 @@
"""
成本域 ORM 模型
映射: erp_order_item_costs
"""
from __future__ import annotations
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Numeric, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base
class ErpOrderItemCost(Base):
"""订单明细成本快照表 —— 发货/确认瞬间锚定 MWA 成本"""
__tablename__ = "erp_order_item_costs"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
order_item_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_order_items.id"), nullable=False, unique=True,
comment="关联订单明细"
)
purchase_unit_price: Mapped[float] = mapped_column(
Numeric(12, 4), nullable=False, comment="MWA 成本快照"
)
profit_amount: Mapped[float] = mapped_column(
Numeric(14, 2), default=0, comment="利润额 = (售价-成本)*数量"
)
profit_rate: Mapped[float] = mapped_column(
Numeric(5, 4), default=0, comment="利润率"
)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
# 关系
order_item: Mapped["ErpOrderItem"] = relationship("ErpOrderItem", lazy="selectin") # noqa: F821
+12
View File
@@ -29,6 +29,18 @@ class CrmCustomer(Base):
address: Mapped[str | None] = mapped_column(Text, nullable=True)
ai_score: Mapped[float] = mapped_column(Numeric(5, 2), default=0)
ai_persona: Mapped[dict | None] = mapped_column(JSONB, default=dict, nullable=True)
billing_info: Mapped[dict | None] = mapped_column(
JSONB, default=dict, nullable=True,
comment="客户开票信息: company_name/tax_id/address/phone/bank_name/bank_account"
)
health_score: Mapped[float] = mapped_column(
Numeric(5, 2), default=0,
comment="客户健康度评分 (AI 教练引擎计算)"
)
meddic_status: Mapped[dict | None] = mapped_column(
JSONB, default=dict, nullable=True,
comment="MEDDIC 六维评估状态"
)
owner_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
)
+42 -2
View File
@@ -12,11 +12,13 @@ from sqlalchemy import (
Boolean,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
SmallInteger,
String,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.dialects.postgresql import UUID
@@ -56,8 +58,6 @@ class ProductSku(Base):
name: Mapped[str] = mapped_column(String(200), nullable=False)
spec: Mapped[str | None] = mapped_column(String(100), nullable=True)
standard_price: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
stock_qty: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
warning_threshold: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
unit: Mapped[str] = mapped_column(String(20), default="")
status: Mapped[int] = mapped_column(SmallInteger, default=1)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
@@ -80,9 +80,18 @@ class InventoryFlow(Base):
sku_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_product_skus.id"), nullable=False
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
)
change_qty: Mapped[float] = mapped_column(Numeric(12, 2), nullable=False)
reason: Mapped[str] = mapped_column(String(50), nullable=False)
remark: Mapped[str | None] = mapped_column(Text, nullable=True)
purchase_unit_price: Mapped[float] = mapped_column(
Numeric(12, 2), default=0, comment="入库采购单价"
)
is_special_zero_cost: Mapped[bool] = mapped_column(
Boolean, default=False, comment="特殊零元入库标识,不参与 MWA 计算"
)
operator_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
)
@@ -94,3 +103,34 @@ class InventoryFlow(Base):
sku: Mapped[ProductSku | None] = relationship("ProductSku", lazy="selectin")
operator: Mapped["SysUser | None"] = relationship("SysUser", lazy="selectin") # noqa: F821
class ErpSkuInventory(Base):
"""SKU 分公司库存表 —— 同一 SKU 在不同公司有独立库存"""
__tablename__ = "erp_sku_inventory"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
sku_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_product_skus.id"), nullable=False
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
)
stock_qty: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
warning_threshold: Mapped[float] = mapped_column(Numeric(12, 2), default=0)
mwa_unit_cost: Mapped[float] = mapped_column(
Numeric(12, 4), default=0,
comment="移动加权均价 (Moving Weighted Average)"
)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now()
)
__table_args__ = (
UniqueConstraint("sku_id", "company_id", name="uq_sku_company"),
)
sku: Mapped[ProductSku | None] = relationship("ProductSku", lazy="selectin")
+60
View File
@@ -33,6 +33,9 @@ class FinInvoicePool(Base):
uploader_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
)
file_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
merchant_name: Mapped[str | None] = mapped_column(String(200), nullable=True)
amount: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
@@ -59,6 +62,9 @@ class FinExpenseRecord(Base):
applicant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=False
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
)
total_amount: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
status: Mapped[str] = mapped_column(String(20), nullable=False, default="draft")
remark: Mapped[str | None] = mapped_column(Text, nullable=True)
@@ -134,9 +140,23 @@ class FinSalesInvoice(Base):
payment_date: Mapped[date | None] = mapped_column(Date, nullable=True)
payment_amount: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
remark: Mapped[str | None] = mapped_column(Text, nullable=True)
order_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_orders.id"), nullable=True,
comment="关联订单"
)
shipping_record_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_shipping_records.id"), nullable=True,
comment="关联发货单"
)
payment_due_date: Mapped[date | None] = mapped_column(
Date, nullable=True, comment="回款截止日(根据合同付款条件自动推算)"
)
created_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now()
@@ -150,3 +170,43 @@ class FinSalesInvoice(Base):
creator: Mapped["SysUser | None"] = relationship( # noqa: F821
"SysUser", lazy="selectin"
)
class FinOcrTask(Base):
"""OCR 处理任务队列 — 持久化排队"""
__tablename__ = "fin_ocr_tasks"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
file_url: Mapped[str] = mapped_column(String(500), nullable=False)
file_ext: Mapped[str] = mapped_column(String(10), nullable=False, comment=".pdf/.png/.jpg")
original_name: Mapped[str] = mapped_column(String(200), nullable=False, default="")
status: Mapped[str] = mapped_column(
String(20), nullable=False, default="pending",
comment="pending/processing/success/failed/manual",
)
priority: Mapped[int] = mapped_column(default=100, comment="值越小越优先")
ocr_result: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
retry_count: Mapped[int] = mapped_column(default=0)
max_retries: Mapped[int] = mapped_column(default=3)
invoice_pool_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("fin_invoice_pool.id"), nullable=True,
comment="成功入池后关联的发票 ID",
)
uploader_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True,
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True,
)
inv_type: Mapped[str] = mapped_column(String(30), nullable=False, default="expense")
scheduled_after: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now()
)
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)
uploader: Mapped["SysUser | None"] = relationship("SysUser", lazy="selectin") # noqa: F821
+7
View File
@@ -37,6 +37,13 @@ class ErpOrder(Base):
salesperson_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
)
contract_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("erp_contracts.id"), nullable=True,
comment="来源合同(一键推单后回填)"
)
total_amount: Mapped[float] = mapped_column(Numeric(14, 2), default=0)
shipping_state: Mapped[str] = mapped_column(
String(20), nullable=False, default="pending"
+3
View File
@@ -42,6 +42,9 @@ class ErpShippingRecord(Base):
operator_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=True
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False, index=True
)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now()
+42 -1
View File
@@ -8,7 +8,7 @@ from __future__ import annotations
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, ForeignKey, SmallInteger, String, Text, func
from sqlalchemy import Boolean, DateTime, ForeignKey, SmallInteger, String, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -97,3 +97,44 @@ class SysUser(Base):
"SysDepartment", lazy="selectin"
)
role: Mapped[SysRole | None] = relationship("SysRole", lazy="selectin")
class SysCompany(Base):
"""公司主体表 —— 多租户逻辑隔离核心"""
__tablename__ = "sys_companies"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(String(200), nullable=False)
code: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
full_info: Mapped[dict | None] = mapped_column(
JSONB, default=dict, nullable=True,
comment="公司完整信息: full_name/address/phone/bank_name/bank_account/tax_id"
)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now()
)
class SysUserCompany(Base):
"""用户-公司多对多关联 —— IDOR 防护核心"""
__tablename__ = "sys_user_companies"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_users.id"), nullable=False
)
company_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("sys_companies.id"), nullable=False
)
is_default: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
__table_args__ = (
UniqueConstraint("user_id", "company_id", name="uq_user_company"),
)
+106
View File
@@ -0,0 +1,106 @@
"""
合同域 Pydantic V2 Schemas
"""
from __future__ import annotations
import uuid
from datetime import date, datetime
from typing import Any
from pydantic import BaseModel, Field
# ── 合同明细行 ────────────────────────────────────────────
class ContractItemCreate(BaseModel):
sku_id: uuid.UUID
qty: float = Field(gt=0)
unit_price: float = Field(ge=0)
sub_total: float = Field(ge=0)
class ContractItemResponse(BaseModel):
id: uuid.UUID
sku_id: uuid.UUID
sku_code: str | None = None
sku_name: str | None = None
spec: str | None = None
unit: str | None = None
qty: float
unit_price: float
sub_total: float
model_config = {"from_attributes": True}
# ── 合同创建 ──────────────────────────────────────────────
class ContractCreate(BaseModel):
buyer_customer_id: uuid.UUID
items: list[ContractItemCreate] = Field(min_length=1)
payment_terms: str = "货到付全款"
shipping_terms: str = "买方自提"
remark: str | None = None
delivery_terms: str | None = None
sign_date: date | None = None
# ── 合同更新 ──────────────────────────────────────────────
class ContractUpdate(BaseModel):
buyer_customer_id: uuid.UUID | None = None
items: list[ContractItemCreate] | None = None
payment_terms: str | None = None
shipping_terms: str | None = None
status: str | None = None
is_signed: bool | None = None
remark: str | None = None
delivery_terms: str | None = None
sign_date: date | None = None
# ── 执行进度 ──────────────────────────────────────────────
class ContractProgressResponse(BaseModel):
is_signed: bool = False
has_order: bool = False
order_id: uuid.UUID | None = None
has_shipped: bool = False
has_invoice: bool = False
is_paid: bool = False
# ── 合同响应 ──────────────────────────────────────────────
class ContractResponse(BaseModel):
id: uuid.UUID
contract_no: str
buyer_customer_id: uuid.UUID
buyer_customer_name: str | None = None
seller_company_id: uuid.UUID
seller_company_name: str | None = None
company_id: uuid.UUID
total_amount_excl_tax: float = 0
total_amount_incl_tax: float = 0
total_amount_cn: str | None = None
payment_terms: str
shipping_terms: str
status: str
is_signed: bool = False
signed_file_url: str | None = None
linked_order_id: uuid.UUID | None = None
salesperson_id: uuid.UUID | None = None
salesperson_name: str | None = None
sign_date: date | None = None
remark: str | None = None
delivery_terms: str | None = None
items: list[ContractItemResponse] = []
progress: ContractProgressResponse | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ── 分页列表 ──────────────────────────────────────────────
class ContractListResponse(BaseModel):
total: int
items: list[ContractResponse]
page: int
size: int
+13
View File
@@ -12,6 +12,16 @@ from typing import Any
from pydantic import BaseModel, Field
# ── 开票信息子结构 ─────────────────────────────────────────
class BillingInfoSchema(BaseModel):
company_name: str | None = Field(default=None, max_length=200, description="开票公司全称")
tax_id: str | None = Field(default=None, max_length=50, description="纳税人识别号")
address: str | None = Field(default=None, max_length=300, description="地址")
phone: str | None = Field(default=None, max_length=30, description="电话")
bank_name: str | None = Field(default=None, max_length=200, description="开户行")
bank_account: str | None = Field(default=None, max_length=50, description="银行账号")
# ── 创建 ──────────────────────────────────────────────────
class CustomerCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=200, examples=["中石化润滑油公司"])
@@ -21,6 +31,7 @@ class CustomerCreate(BaseModel):
phone: str | None = Field(default=None, max_length=30)
email: str | None = Field(default=None, max_length=100)
address: str | None = None
billing_info: BillingInfoSchema | None = None
status: int = Field(default=1, ge=0, le=1)
@@ -33,6 +44,7 @@ class CustomerUpdate(BaseModel):
phone: str | None = Field(default=None, max_length=30)
email: str | None = Field(default=None, max_length=100)
address: str | None = None
billing_info: BillingInfoSchema | None = None
status: int | None = Field(default=None, ge=0, le=1)
@@ -48,6 +60,7 @@ class CustomerResponse(BaseModel):
address: str | None = None
ai_score: float = 0
ai_persona: dict[str, Any] | None = None
billing_info: dict[str, Any] | None = None
owner_id: uuid.UUID | None = None
owner_name: str | None = None
status: int = 1
+2
View File
@@ -104,6 +104,8 @@ class InventoryFlowCreate(BaseModel):
examples=["purchase"],
)
remark: str | None = Field(default=None, description="备注")
purchase_unit_price: float = Field(default=0, ge=0, description="采购单价(仅入库时有意义)")
is_special_zero_cost: bool = Field(default=False, description="特殊零元入库标识,不参与 MWA 计算")
class InventoryFlowResponse(BaseModel):
+111
View File
@@ -0,0 +1,111 @@
"""
AI 教练引擎 事件总线 + Dify 回调
CQRS 解耦模式
1. 业务端 POST /api/sales-logs 立即 200 OK 发消息到 Redis Streams
2. Worker 消费消息 调用 Dify Workflow 写回 ai_coaching_feedback
3. 前端通过 SSE /api/notifications/stream 接收推送
"""
from __future__ import annotations
import json
import uuid
from datetime import datetime
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ai import SalesLog
from app.models.crm import CrmCustomer
from app.schemas.auth import CurrentUserPayload
# ── Redis 事件发布 ───────────────────────────────────────
async def publish_coaching_event(
sales_log_id: uuid.UUID,
content: str,
customer_id: uuid.UUID | None = None,
salesperson_id: uuid.UUID | None = None,
) -> None:
"""将销售日志推送到 Redis Streams,供 Worker 异步消费"""
try:
import redis.asyncio as aioredis
import os
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
r = aioredis.from_url(redis_url, decode_responses=True)
await r.xadd(
"coaching:sales_logs",
{
"sales_log_id": str(sales_log_id),
"content": content[:2000], # 限长
"customer_id": str(customer_id) if customer_id else "",
"salesperson_id": str(salesperson_id) if salesperson_id else "",
"timestamp": datetime.utcnow().isoformat(),
},
)
await r.aclose()
except Exception as e:
# Redis 不可用时降级——不阻塞主流程
print(f"[AI EventBus] Redis 推送失败(降级): {e}")
# ── Dify 回调处理 ───────────────────────────────────────
async def handle_dify_coaching_callback(
db: AsyncSession,
sales_log_id: uuid.UUID,
feedback: dict,
) -> None:
"""Dify Workflow 回调 → 写回 SalesLog.ai_coaching_feedback"""
await db.execute(
update(SalesLog)
.where(SalesLog.id == sales_log_id)
.values(
ai_coaching_feedback=feedback,
ai_processed=True,
updated_at=datetime.utcnow(),
)
)
# 如果反馈中包含客户健康评分,同步更新 CrmCustomer
health_score = feedback.get("health_score")
meddic_status = feedback.get("meddic_status")
if health_score is not None or meddic_status is not None:
log = (await db.execute(
select(SalesLog).where(SalesLog.id == sales_log_id)
)).scalar_one_or_none()
if log and log.customer_id:
update_vals: dict = {}
if health_score is not None:
update_vals["health_score"] = float(health_score)
if meddic_status is not None:
update_vals["meddic_status"] = meddic_status
if update_vals:
await db.execute(
update(CrmCustomer)
.where(CrmCustomer.id == log.customer_id)
.values(**update_vals)
)
await db.commit()
# ── SSE 通知流 ──────────────────────────────────────────
async def sse_notification_generator(user_id: uuid.UUID):
"""服务端推送事件流(SSE)—— 监听 Redis PubSub 频道"""
import asyncio
try:
import redis.asyncio as aioredis
import os
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
r = aioredis.from_url(redis_url, decode_responses=True)
pubsub = r.pubsub()
channel = f"notifications:{user_id}"
await pubsub.subscribe(channel)
async for message in pubsub.listen():
if message["type"] == "message":
yield f"data: {message['data']}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
+762
View File
@@ -0,0 +1,762 @@
"""
合同管理 Service
核心逻辑CRUD + 一键推单 + 账期引擎 + 执行进度聚合
"""
from __future__ import annotations
import uuid
from datetime import date, datetime, timedelta
import re
from sqlalchemy import func, select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
from app.models.contract import ErpContract, ErpContractItem, ErpContractAttachment
from app.models.order import ErpOrder, ErpOrderItem
from app.models.shipping import ErpShippingRecord
from app.models.finance import FinSalesInvoice
from app.models.erp import ProductSku
from app.models.crm import CrmCustomer
from app.schemas.auth import CurrentUserPayload
from app.schemas.contract import (
ContractCreate,
ContractUpdate,
ContractItemResponse,
ContractListResponse,
ContractProgressResponse,
ContractResponse,
)
# ── 金额大写转换 ─────────────────────────────────────────
_CN_DIGITS = "零壹贰叁肆伍陆柒捌玖"
_CN_UNITS = ["", "", "", ""]
_CN_BIG_UNITS = ["", "", "亿", ""]
def amount_to_cn(amount: float) -> str:
"""将金额转为中文大写"""
if amount == 0:
return "零元整"
neg = ""
if amount < 0:
neg = ""
amount = -amount
yuan = int(amount)
jiao = int(amount * 10) % 10
fen = int(amount * 100) % 10
parts = []
if yuan > 0:
yuan_str = str(yuan)
n = len(yuan_str)
zero_flag = False
for i, ch in enumerate(yuan_str):
d = int(ch)
pos = n - 1 - i
big_idx = pos // 4
unit_idx = pos % 4
if d == 0:
zero_flag = True
if unit_idx == 0 and big_idx > 0:
parts.append(_CN_BIG_UNITS[big_idx])
else:
if zero_flag:
parts.append("")
zero_flag = False
parts.append(_CN_DIGITS[d] + _CN_UNITS[unit_idx])
if unit_idx == 0 and big_idx > 0:
parts.append(_CN_BIG_UNITS[big_idx])
parts.append("")
else:
parts.append("零元")
if jiao > 0:
parts.append(_CN_DIGITS[jiao] + "")
if fen > 0:
parts.append(_CN_DIGITS[fen] + "")
else:
if jiao == 0:
parts.append("")
return neg + "".join(parts)
# ── 生成合同编号 ─────────────────────────────────────────
async def _gen_contract_no(db: AsyncSession) -> str:
today_str = date.today().strftime("%Y%m%d")
prefix = f"HT-{today_str}-"
count_stmt = select(func.count()).select_from(ErpContract).where(
ErpContract.contract_no.like(f"{prefix}%")
)
count = (await db.execute(count_stmt)).scalar() or 0
return f"{prefix}{count + 1:03d}"
# ── 账期引擎 ────────────────────────────────────────────
def calc_payment_due_date(payment_terms: str, base_date: date) -> date | None:
"""根据付款条件枚举和基准日期(开票/发货)推算回款截止日"""
m = re.search(r"(\d+)天", payment_terms)
if m:
days = int(m.group(1))
return base_date + timedelta(days=days)
if "货到" in payment_terms or "全款" in payment_terms:
return base_date # 当天
return None
# ── ORM → Response ──────────────────────────────────────
def _item_to_response(item: ErpContractItem) -> ContractItemResponse:
sku = item.sku
return ContractItemResponse(
id=item.id,
sku_id=item.sku_id,
sku_code=sku.sku_code if sku else None,
sku_name=sku.name if sku else None,
spec=sku.spec if sku else None,
unit=sku.unit if sku else None,
qty=float(item.qty),
unit_price=float(item.unit_price),
sub_total=float(item.sub_total),
)
def _to_response(c: ErpContract, progress: ContractProgressResponse | None = None) -> ContractResponse:
return ContractResponse(
id=c.id,
contract_no=c.contract_no,
buyer_customer_id=c.buyer_customer_id,
buyer_customer_name=c.buyer_customer.name if c.buyer_customer else None,
seller_company_id=c.seller_company_id,
seller_company_name=c.seller_company.name if c.seller_company else None,
company_id=c.company_id,
total_amount_excl_tax=float(c.total_amount_excl_tax or 0),
total_amount_incl_tax=float(c.total_amount_incl_tax or 0),
total_amount_cn=c.total_amount_cn,
payment_terms=c.payment_terms,
shipping_terms=c.shipping_terms,
status=c.status,
is_signed=c.is_signed,
signed_file_url=c.signed_file_url,
linked_order_id=c.linked_order_id,
salesperson_id=c.salesperson_id,
salesperson_name=c.salesperson.real_name if c.salesperson else None,
sign_date=c.sign_date,
remark=c.remark,
delivery_terms=c.delivery_terms,
items=[_item_to_response(i) for i in (c.items or []) if not i.is_deleted],
progress=progress,
created_at=c.created_at,
updated_at=c.updated_at,
)
# ── 执行进度聚合 ────────────────────────────────────────
async def _get_progress(db: AsyncSession, contract: ErpContract) -> ContractProgressResponse:
progress = ContractProgressResponse(is_signed=contract.is_signed)
if contract.linked_order_id:
progress.has_order = True
progress.order_id = contract.linked_order_id
# 是否有发货
ship_count = (await db.execute(
select(func.count()).select_from(ErpShippingRecord).where(
ErpShippingRecord.order_id == contract.linked_order_id,
ErpShippingRecord.is_deleted.is_(False),
)
)).scalar() or 0
progress.has_shipped = ship_count > 0
# 是否有销项发票
inv_count = (await db.execute(
select(func.count()).select_from(FinSalesInvoice).where(
FinSalesInvoice.order_id == contract.linked_order_id,
FinSalesInvoice.is_deleted.is_(False),
)
)).scalar() or 0
progress.has_invoice = inv_count > 0
# 是否回款(检查订单回款状态)
order = (await db.execute(
select(ErpOrder).where(ErpOrder.id == contract.linked_order_id)
)).scalar_one_or_none()
if order and order.payment_state == "paid":
progress.is_paid = True
return progress
# ── 公共 eager-load 选项 ────────────────────────────────────
def _contract_load_options():
"""返回 selectinload 链,保证 commit 后仍可安全访问关系属性"""
return [
selectinload(ErpContract.buyer_customer),
selectinload(ErpContract.seller_company),
selectinload(ErpContract.salesperson),
selectinload(ErpContract.items).selectinload(ErpContractItem.sku),
]
# ── Service Functions ────────────────────────────────────
async def create_contract(
db: AsyncSession,
user: CurrentUserPayload,
company_id: uuid.UUID,
body: ContractCreate,
) -> ContractResponse:
contract_no = await _gen_contract_no(db)
# 计算合计
total = sum(item.sub_total for item in body.items)
contract = ErpContract(
contract_no=contract_no,
buyer_customer_id=body.buyer_customer_id,
seller_company_id=company_id,
company_id=company_id,
total_amount_excl_tax=total,
total_amount_incl_tax=total, # 含税金额默认同不含税,可后续区分
total_amount_cn=amount_to_cn(total),
payment_terms=body.payment_terms,
shipping_terms=body.shipping_terms,
sign_date=body.sign_date,
remark=body.remark,
delivery_terms=body.delivery_terms,
salesperson_id=user.user_id,
status="draft",
)
db.add(contract)
await db.flush()
# 添加明细行
for item_data in body.items:
item = ErpContractItem(
contract_id=contract.id,
sku_id=item_data.sku_id,
qty=item_data.qty,
unit_price=item_data.unit_price,
sub_total=item_data.sub_total,
)
db.add(item)
await db.commit()
# 重新查询并 eager-load 所有关系,避免 commit 后隐式 lazy load
fresh = (await db.execute(
select(ErpContract)
.where(ErpContract.id == contract.id)
.options(*_contract_load_options())
)).scalar_one()
return _to_response(fresh)
async def list_contracts(
db: AsyncSession,
company_id: uuid.UUID,
page: int = 1,
size: int = 20,
keyword: str | None = None,
status: str | None = None,
) -> ContractListResponse:
base_where = [
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
]
if keyword:
base_where.append(ErpContract.contract_no.ilike(f"%{keyword}%"))
if status:
base_where.append(ErpContract.status == status)
total = (await db.execute(
select(func.count()).select_from(ErpContract).where(*base_where)
)).scalar() or 0
stmt = (
select(ErpContract)
.where(*base_where)
.options(*_contract_load_options())
.order_by(ErpContract.created_at.desc())
.offset((page - 1) * size)
.limit(size)
)
contracts = (await db.execute(stmt)).scalars().all()
return ContractListResponse(
total=total,
items=[_to_response(c) for c in contracts],
page=page,
size=size,
)
async def get_contract(
db: AsyncSession,
contract_id: uuid.UUID,
company_id: uuid.UUID,
) -> ContractResponse:
stmt = (
select(ErpContract)
.where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
.options(*_contract_load_options())
)
contract = (await db.execute(stmt)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
progress = await _get_progress(db, contract)
return _to_response(contract, progress)
async def update_contract(
db: AsyncSession,
contract_id: uuid.UUID,
company_id: uuid.UUID,
body: ContractUpdate,
) -> ContractResponse:
stmt = select(ErpContract).where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
contract = (await db.execute(stmt)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
# 更新主表字段
update_data = body.model_dump(exclude_unset=True, exclude={"items"})
if update_data:
update_data["updated_at"] = datetime.utcnow()
await db.execute(
update(ErpContract).where(ErpContract.id == contract_id).values(**update_data)
)
# 如果有明细行更新,删旧增新
if body.items is not None:
await db.execute(
update(ErpContractItem)
.where(ErpContractItem.contract_id == contract_id)
.values(is_deleted=True)
)
total = 0
for item_data in body.items:
item = ErpContractItem(
contract_id=contract_id,
sku_id=item_data.sku_id,
qty=item_data.qty,
unit_price=item_data.unit_price,
sub_total=item_data.sub_total,
)
total += item_data.sub_total
db.add(item)
await db.execute(
update(ErpContract).where(ErpContract.id == contract_id).values(
total_amount_excl_tax=total,
total_amount_incl_tax=total,
total_amount_cn=amount_to_cn(total),
)
)
await db.commit()
updated = (await db.execute(
select(ErpContract)
.where(ErpContract.id == contract_id)
.options(*_contract_load_options())
)).scalar_one()
return _to_response(updated)
async def delete_contract(
db: AsyncSession,
contract_id: uuid.UUID,
company_id: uuid.UUID,
) -> None:
stmt = select(ErpContract).where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
contract = (await db.execute(stmt)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
await db.execute(
update(ErpContract)
.where(ErpContract.id == contract_id)
.values(is_deleted=True, updated_at=datetime.utcnow())
)
await db.commit()
async def generate_order_from_contract(
db: AsyncSession,
user: CurrentUserPayload,
contract_id: uuid.UUID,
company_id: uuid.UUID,
) -> dict:
"""一键从合同生成订单 —— 防篡改推单逻辑"""
stmt = (
select(ErpContract)
.where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
.options(*_contract_load_options())
)
contract = (await db.execute(stmt)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
if contract.linked_order_id is not None:
raise BizException(message="该合同已关联订单,不可重复生成")
# 生成订单编号
today_str = date.today().strftime("%Y%m%d")
prefix = f"SO-{today_str}-"
count = (await db.execute(
select(func.count()).select_from(ErpOrder).where(
ErpOrder.order_no.like(f"{prefix}%")
)
)).scalar() or 0
order_no = f"{prefix}{count + 1:03d}"
# 创建订单
new_order = ErpOrder(
order_no=order_no,
customer_id=contract.buyer_customer_id,
salesperson_id=user.user_id,
company_id=company_id,
contract_id=contract_id,
total_amount=float(contract.total_amount_incl_tax or 0),
order_date=date.today(),
)
db.add(new_order)
await db.flush()
# 复制合同明细到订单明细
active_items = [i for i in (contract.items or []) if not i.is_deleted]
for ci in active_items:
oi = ErpOrderItem(
order_id=new_order.id,
sku_id=ci.sku_id,
qty=float(ci.qty),
unit_price=float(ci.unit_price),
sub_total=float(ci.sub_total),
)
db.add(oi)
# 回填合同 linked_order_id + 激活状态
await db.execute(
update(ErpContract)
.where(ErpContract.id == contract_id)
.values(
linked_order_id=new_order.id,
status="active",
updated_at=datetime.utcnow(),
)
)
await db.commit()
return {"order_id": str(new_order.id), "order_no": order_no}
# ── 数字转中文大写金额 ──────────────────────────────────────
def _amount_to_cn(amount: float) -> str:
"""将数字金额转换为中文大写"""
digits = "零壹贰叁肆伍陆柒捌玖"
units = ["", "", "", ""]
big_units = ["", "", "亿"]
if amount == 0:
return "零元整"
yuan = int(round(amount * 100))
jiao = (yuan % 100) // 10
fen = yuan % 10
yuan_part = yuan // 100
result = ""
if yuan_part > 0:
s = str(yuan_part)
n = len(s)
for i, ch in enumerate(s):
d = int(ch)
pos = n - i - 1
big_pos = pos // 4
unit_pos = pos % 4
if d != 0:
result += digits[d] + units[unit_pos]
else:
if result and not result.endswith(""):
result += ""
if unit_pos == 0 and big_pos > 0:
result = result.rstrip("") + big_units[big_pos]
result = result.rstrip("") + ""
else:
result = ""
if jiao == 0 and fen == 0:
result += ""
else:
if jiao > 0:
result += digits[jiao] + ""
if fen > 0:
result += digits[fen] + ""
return result
async def generate_contract_docx(
db: AsyncSession,
contract_id: uuid.UUID,
company_id: uuid.UUID,
) -> bytes:
"""纯代码生成合同 Word 文档(紧凑排版,2 页以内)"""
import io
from docx import Document as DocxDocument
from docx.shared import Pt, Cm, Emu, RGBColor
from docx.enum.table import WD_TABLE_ALIGNMENT
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.oxml.ns import qn
from app.models.sys import SysCompany
# ── 1) 数据准备 ─────────────────────────────────────────
contract = (await db.execute(
select(ErpContract)
.where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
.options(*_contract_load_options())
)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
seller = (await db.execute(
select(SysCompany).where(SysCompany.id == contract.seller_company_id)
)).scalar_one_or_none()
seller_info = (seller.full_info or {}) if seller else {}
buyer = contract.buyer_customer
buyer_billing = {}
if buyer and hasattr(buyer, "billing_info") and buyer.billing_info:
buyer_billing = buyer.billing_info
total_incl = float(contract.total_amount_incl_tax or 0)
sign_date_str = (contract.sign_date or date.today()).strftime("%Y年%m月%d")
buyer_name = buyer_billing.get("company_name") or (buyer.name if buyer else "")
seller_name = seller_info.get("company_name") or (seller.name if seller else "")
items = [i for i in (contract.items or []) if not i.is_deleted]
# ── 2) 创建文档 ─────────────────────────────────────────
doc = DocxDocument()
# 页边距:上下2cm 左右2.5cm(紧凑)
for section in doc.sections:
section.top_margin = Cm(2)
section.bottom_margin = Cm(1.5)
section.left_margin = Cm(2.5)
section.right_margin = Cm(2.5)
# ── 辅助函数 ─────────────────────────────────────────────
# 小四 = 12pt, 1.5倍行距 = 18pt
def add_para(text: str, font_size: int = 12, bold: bool = False,
align=WD_ALIGN_PARAGRAPH.LEFT, space_before: int = 0,
space_after: int = 0, font_name: str = "宋体"):
p = doc.add_paragraph()
p.alignment = align
p.paragraph_format.space_before = Pt(space_before)
p.paragraph_format.space_after = Pt(space_after)
p.paragraph_format.line_spacing = Pt(18) # 1.5倍行距(12pt×1.5)
run = p.add_run(text)
run.font.size = Pt(font_size)
run.font.bold = bold
run.font.name = font_name
run._element.rPr.rFonts.set(qn("w:eastAsia"), font_name)
return p
def set_cell(cell, text: str, font_size: int = 12, bold: bool = False,
align=WD_ALIGN_PARAGRAPH.CENTER):
cell.text = ""
p = cell.paragraphs[0]
p.alignment = align
p.paragraph_format.space_before = Pt(0)
p.paragraph_format.space_after = Pt(0)
p.paragraph_format.line_spacing = Pt(18) # 1.5倍行距
run = p.add_run(text)
run.font.size = Pt(font_size)
run.font.bold = bold
run.font.name = "宋体"
run._element.rPr.rFonts.set(qn("w:eastAsia"), "宋体")
# ── 3) 标题 ──────────────────────────────────────────────
add_para("产 品 购 销 合 同", font_size=18, bold=True,
align=WD_ALIGN_PARAGRAPH.CENTER, space_after=4, font_name="黑体")
add_para(f"合同编号:{contract.contract_no}",
align=WD_ALIGN_PARAGRAPH.RIGHT, space_after=4)
# ── 4) 甲乙方信息(紧凑表格) ────────────────────────────
info_tbl = doc.add_table(rows=4, cols=4)
info_tbl.alignment = WD_TABLE_ALIGNMENT.CENTER
info_tbl.style = "Table Grid"
info_data = [
("买方(甲方)", buyer_name,
"卖方(乙方)", seller_name),
("税号", buyer_billing.get("tax_id", "") or "",
"税号", seller_info.get("tax_id", "") or ""),
("地址", buyer_billing.get("address", "") or "",
"地址", seller_info.get("address", "") or ""),
("开户行 / 账号",
f"{buyer_billing.get('bank_name', '') or ''} {buyer_billing.get('bank_account', '') or ''}".strip(),
"开户行 / 账号",
f"{seller_info.get('bank_name', '') or ''} {seller_info.get('bank_account', '') or ''}".strip()),
]
for ri, row_data in enumerate(info_data):
for ci, val in enumerate(row_data):
bold = ri == 0 and ci in (0, 2)
set_cell(info_tbl.cell(ri, ci), val, bold=bold,
align=WD_ALIGN_PARAGRAPH.LEFT)
# ── 5) 一、产品明细 ──────────────────────────────────────
add_para("一、产品明细", bold=True, space_before=6, space_after=2)
cols = 6
tbl = doc.add_table(rows=1 + len(items) + 1, cols=cols)
tbl.alignment = WD_TABLE_ALIGNMENT.CENTER
tbl.style = "Table Grid"
headers = ["序号", "产品名称", "规格", "数量", "单价(元)", "小计(元)"]
for ci, h in enumerate(headers):
set_cell(tbl.cell(0, ci), h, bold=True)
for ri, item in enumerate(items):
sku_name = item.sku.name if item.sku else ""
sku_spec = item.sku.spec if item.sku else ""
set_cell(tbl.cell(ri + 1, 0), str(ri + 1))
set_cell(tbl.cell(ri + 1, 1), sku_name, align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(tbl.cell(ri + 1, 2), sku_spec or "-")
set_cell(tbl.cell(ri + 1, 3), str(float(item.qty)))
set_cell(tbl.cell(ri + 1, 4), f"{float(item.unit_price):,.2f}",
align=WD_ALIGN_PARAGRAPH.RIGHT)
set_cell(tbl.cell(ri + 1, 5), f"{float(item.sub_total):,.2f}",
align=WD_ALIGN_PARAGRAPH.RIGHT)
# 合计行
last_row = len(items) + 1
set_cell(tbl.cell(last_row, 0), "合计", bold=True)
# 合并序号~单价列
for ci in range(1, 4):
set_cell(tbl.cell(last_row, ci), "")
set_cell(tbl.cell(last_row, 4), "", align=WD_ALIGN_PARAGRAPH.RIGHT)
set_cell(tbl.cell(last_row, 5), f"{total_incl:,.2f}", bold=True,
align=WD_ALIGN_PARAGRAPH.RIGHT)
# 大写金额
add_para(f"合计金额(大写):{_amount_to_cn(total_incl)} (含13%增值税)",
bold=True, space_before=2, space_after=2)
# ── 6) 二、交货及付款条件 ────────────────────────────────
add_para("二、交货及付款条件", bold=True, space_before=4, space_after=2)
delivery_text = contract.delivery_terms or "按双方约定"
add_para(f"1. 货 期:{delivery_text}")
add_para(f"2. 交货方式:{contract.shipping_terms or '买方自提'}")
add_para(f"3. 付款条件:{contract.payment_terms or '货到付全款'}")
# ── 7) 三、发票信息 ──────────────────────────────────────
add_para("三、发票信息", bold=True, space_before=4, space_after=2)
add_para("卖方给买方开具合同金额增值税专用发票(13%增值税)。")
# ── 8) 四、合同细则 ──────────────────────────────────────
add_para("四、合同细则", bold=True, space_before=4, space_after=2)
# 紧凑输出细则内容
terms = [
"第一条 质量标准:按照厂家标准执行,由于买方储存不当(如露天暴晒、混入杂质、超过保质期等)或未按产品说明书操作导致的质量问题,卖方不承担责任。",
"第二条 卖方对质量负责的条件及期限:自货到12个月。",
"第三条 包装标准包装物的供应与回收:产品包装均应采用国家或专业标准保护措施进行包装,以确保产品不受损害为原则,由于包装不善所引起的货物污染、损坏、损失均由卖方负担,采取装箱包装的应在包装箱内附一份详细装箱单和质量合格证,包装物不回收。",
"第四条 合理损耗标准及计算方法:标的货物送至买方指定地点前的合理损耗由卖方负责。",
"第五条 标的物所有权:在买方付清本合同项下全部货款之前,标的物的所有权仍属于卖方。",
"第六条 检验标准、方法、地点及期限:按第二条标准检验。",
"第七条 发票信息:卖方给买方开具合同金额增值税专用发票(13%增值税)。",
"第八条 本合同解除条件:合同执行完毕。",
(
"第九条 违约责任:\n"
"1、卖方应保证产品质量合格,买方有权在货到后7个工作日内且未开封状态下将卖方产品送质监局或第三方部门检验单位检验,"
"送检样品的取样过程必须经卖方现场确认或双方共同封样,否则检验结果无效。检验结果不合格,则所发生的所有检验费用,"
"均由卖方承担,买方可根据实际情况选择要求退货或更换。\n"
"赔偿限额:卖方对本合同项下违约责任的赔偿总额,以本合同约定的总货款金额为限,"
"且不承担任何间接损失(包括但不限于停工损失、利润损失等)。"
),
(
"第十条 合同争议的解决方式:本合同在履行过程中发生的争执,由双方当事人协商解决,"
"也可由当地工商行政管理部门调解;协商或调解不成的,按下列第二种方式解决。\n"
"(一)提交当地仲裁委员会仲裁;(二)依法向卖方所在地的人民法院起诉。"
),
"第十一条 本合同一式两份,自双方签字盖章起生效。",
(
"第十二条 其他约定事项:\n"
"1、卖方必须遵守国家有关能源管理的法律、法规;\n"
"2、卖方必须执行买方对其提出的对能源控制进行改善的要求;\n"
"3、卖方在运输途中和施工作业中的各种行为不应对能源造成浪费或负面影响;\n"
"4、如卖方提供货物存在质量问题,买方书面(包括但不限于传真、邮件)通知对方,"
"卖方在接到买方书面通知后3个工作日内要给与买方书面回复,否则将视为卖方已经认可买方提出的质量问题;"
"如果双方意见产生争议,由卖方负责安排经买方同意的第三方进行检验,否则视为卖方质量问题;\n"
"5、未经对方书面同意,不得将合同部分或者全部权利义务转给第三方。\n"
"6、如遇战争、原材料短缺、工厂停产、物流管制等不可抗力因素导致货期延长,卖方不承担违约责任。"
),
]
for term in terms:
add_para(term)
# ── 9) 签章区 ────────────────────────────────────────────
add_para("", space_before=6, space_after=0) # 小间距
sig_tbl = doc.add_table(rows=4, cols=2)
sig_tbl.alignment = WD_TABLE_ALIGNMENT.CENTER
# 去边框
for row in sig_tbl.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
paragraph.paragraph_format.space_before = Pt(0)
paragraph.paragraph_format.space_after = Pt(0)
set_cell(sig_tbl.cell(0, 0), "买方(盖章):", bold=True,
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(0, 1), "卖方(盖章):", bold=True,
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(1, 0), "授权代表签字:",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(1, 1), "授权代表签字:",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(2, 0), f"日期:{sign_date_str}",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(2, 1), f"日期:{sign_date_str}",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(3, 0), f"联系电话:{buyer_billing.get('phone', '') or ''}",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(3, 1), f"联系电话:{seller_info.get('phone', '') or ''}",
align=WD_ALIGN_PARAGRAPH.LEFT)
# ── 10) 输出 ─────────────────────────────────────────────
buffer = io.BytesIO()
doc.save(buffer)
buffer.seek(0)
return buffer.getvalue()
+116 -19
View File
@@ -13,6 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
from app.models.crm import CrmCustomer
from app.models.sys import SysUser
from app.schemas.auth import CurrentUserPayload
from app.schemas.crm import (
CustomerCreate,
@@ -35,6 +36,7 @@ def _to_response(c: CrmCustomer) -> CustomerResponse:
address=c.address,
ai_score=float(c.ai_score or 0),
ai_persona=c.ai_persona,
billing_info=c.billing_info,
owner_id=c.owner_id,
owner_name=c.owner.real_name if c.owner else None,
status=c.status,
@@ -44,12 +46,48 @@ def _to_response(c: CrmCustomer) -> CustomerResponse:
)
# ── 递归查询本部门 + 子部门所有用户 ID ────────────────────
async def _get_dept_and_sub_user_ids(
db: AsyncSession, dept_id: uuid.UUID
) -> list[uuid.UUID]:
"""递归获取指定部门及其所有子部门下的用户 ID 列表"""
from app.models.sys import SysDepartment, SysUser
# 收集所有目标部门 ID(递归子部门)
dept_ids: list[uuid.UUID] = [dept_id]
queue = [dept_id]
while queue:
current = queue.pop(0)
children = (await db.execute(
select(SysDepartment.id).where(
SysDepartment.parent_id == current,
SysDepartment.is_deleted.is_(False),
)
)).scalars().all()
for child_id in children:
dept_ids.append(child_id)
queue.append(child_id)
# 查询这些部门下的所有用户 ID
user_ids = (await db.execute(
select(SysUser.id).where(
SysUser.dept_id.in_(dept_ids),
SysUser.is_deleted.is_(False),
)
)).scalars().all()
return list(user_ids)
# ── 权限校验 ─────────────────────────────────────────────
def _check_access(customer: CrmCustomer, user: CurrentUserPayload) -> None:
def _check_access(customer: CrmCustomer, user: CurrentUserPayload, *, dept_user_ids: list[uuid.UUID] | None = None) -> None:
if user.data_scope == "all":
return
if user.data_scope == "dept_and_sub":
return # 简化版:放通本部门
# 如果有预查询的部门用户列表,校验 owner 是否在列表内
if dept_user_ids is not None:
if customer.owner_id not in dept_user_ids:
raise ForbiddenException("无权访问该客户(数据权限:本部门及子部门)")
return
# data_scope == 'self'
if customer.owner_id != user.user_id:
raise ForbiddenException("无权访问该客户(数据权限:仅本人)")
@@ -70,6 +108,7 @@ async def create_customer(
phone=body.phone,
email=body.email,
address=body.address,
billing_info=body.billing_info.model_dump() if body.billing_info else None,
status=body.status,
owner_id=user.user_id,
)
@@ -98,12 +137,12 @@ async def list_customers(
base_where.append(CrmCustomer.owner_id == user.user_id)
elif user.data_scope == "dept_and_sub":
if user.dept_id is not None:
from app.models.sys import SysUser
sub = select(SysUser.id).where(
SysUser.dept_id == user.dept_id,
SysUser.is_deleted.is_(False),
)
base_where.append(CrmCustomer.owner_id.in_(sub))
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
if dept_user_ids:
base_where.append(CrmCustomer.owner_id.in_(dept_user_ids))
else:
# 部门无用户 → 仅显示自己的
base_where.append(CrmCustomer.owner_id == user.user_id)
if keyword:
base_where.append(CrmCustomer.name.ilike(f"%{keyword}%"))
@@ -144,7 +183,11 @@ async def get_customer(
if customer is None:
raise NotFoundException("客户不存在或已被删除")
_check_access(customer, user)
# dept_and_sub 需要先查询部门用户列表
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
return _to_response(customer)
@@ -162,7 +205,10 @@ async def update_customer(
if customer is None:
raise NotFoundException("客户不存在或已被删除")
_check_access(customer, user)
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
update_data = body.model_dump(exclude_unset=True)
if not update_data:
@@ -193,7 +239,10 @@ async def delete_customer(
if customer is None:
raise NotFoundException("客户不存在或已被删除")
_check_access(customer, user)
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
await db.execute(
update(CrmCustomer)
@@ -216,7 +265,10 @@ async def restore_customer(
if customer is None:
raise NotFoundException("客户不存在或未被归档")
_check_access(customer, user)
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
await db.execute(
update(CrmCustomer)
@@ -226,6 +278,49 @@ async def restore_customer(
await db.commit()
async def transfer_customer(
db: AsyncSession,
user: CurrentUserPayload,
customer_id: uuid.UUID,
new_owner_id: uuid.UUID,
) -> CustomerResponse:
"""将客户转移至指定人员名下(仅管理员)"""
if user.data_scope != "all":
raise ForbiddenException("仅管理员可执行客户转移操作")
stmt = select(CrmCustomer).where(
CrmCustomer.id == customer_id,
CrmCustomer.is_deleted.is_(False),
)
customer = (await db.execute(stmt)).scalar_one_or_none()
if customer is None:
raise NotFoundException("客户不存在或已被归档")
if customer.owner_id == new_owner_id:
raise BizException(message="目标负责人与当前负责人相同,无需转移")
# 校验目标用户是否存在
from app.models.sys import SysUser
target = (await db.execute(
select(SysUser).where(SysUser.id == new_owner_id)
)).scalar_one_or_none()
if target is None:
raise NotFoundException("目标负责人不存在")
old_owner_name = customer.owner.real_name if customer.owner else "(无)"
await db.execute(
update(CrmCustomer)
.where(CrmCustomer.id == customer_id)
.values(owner_id=new_owner_id, updated_at=datetime.utcnow())
)
await db.commit()
await db.refresh(customer)
print(f"[客户转移] {customer.name}: {old_owner_name}{target.real_name} (操作人: {user.real_name})")
return _to_response(customer)
async def get_customer_products(
db: AsyncSession,
user: CurrentUserPayload,
@@ -241,7 +336,10 @@ async def get_customer_products(
customer = (await db.execute(stmt)).scalar_one_or_none()
if customer is None:
raise NotFoundException("客户不存在")
_check_access(customer, user)
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
# 聚合: 该客户所有订单中的 SKU,含总数量、最近下单时间
agg_stmt = (
@@ -299,12 +397,11 @@ async def search_customers(
base_where.append(CrmCustomer.owner_id == user.user_id)
elif user.data_scope == "dept_and_sub":
if user.dept_id is not None:
from app.models.sys import SysUser
sub = select(SysUser.id).where(
SysUser.dept_id == user.dept_id,
SysUser.is_deleted.is_(False),
)
base_where.append(CrmCustomer.owner_id.in_(sub))
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
if dept_user_ids:
base_where.append(CrmCustomer.owner_id.in_(dept_user_ids))
else:
base_where.append(CrmCustomer.owner_id == user.user_id)
# 模糊搜索(名称 / 联系人 / 电话)
from sqlalchemy import or_
+11 -2
View File
@@ -9,6 +9,7 @@ from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
from app.models.finance import FinExpenseDetail, FinExpenseRecord, FinInvoicePool
from app.models.sys import SysUser
from app.schemas.auth import CurrentUserPayload
from app.schemas.finance import (
ExpenseBriefResponse, ExpenseCreate, ExpenseDetailResponse,
@@ -84,12 +85,13 @@ async def _release_invoices(db: AsyncSession, expense_id: uuid.UUID, now: dateti
# ── Service Functions ────────────────────────────────────
async def create_invoice(db: AsyncSession, user: CurrentUserPayload, body: InvoiceCreate) -> InvoiceResponse:
async def create_invoice(db: AsyncSession, user: CurrentUserPayload, body: InvoiceCreate, company_id: uuid.UUID) -> InvoiceResponse:
invoice = FinInvoicePool(
uploader_id=user.user_id, file_url=body.file_url,
merchant_name=body.merchant_name, amount=body.amount,
invoice_date=body.invoice_date, type=body.type,
ai_extracted_data=body.ai_extracted_data, is_used=False,
company_id=company_id,
)
db.add(invoice)
await db.commit()
@@ -101,8 +103,11 @@ async def list_invoices(
db: AsyncSession, user: CurrentUserPayload,
page: int = 1, size: int = 20,
inv_type: str | None = None, is_used: bool | None = None,
company_id: uuid.UUID | None = None,
) -> InvoiceListResponse:
where = [FinInvoicePool.is_deleted.is_(False)]
if company_id:
where.append(FinInvoicePool.company_id == company_id)
if user.data_scope == "self":
where.append(FinInvoicePool.uploader_id == user.user_id)
elif user.data_scope == "dept_and_sub":
@@ -135,7 +140,7 @@ async def void_invoice(db: AsyncSession, user: CurrentUserPayload, invoice_id: u
await db.commit()
async def create_expense(db: AsyncSession, user: CurrentUserPayload, body: ExpenseCreate) -> ExpenseResponse:
async def create_expense(db: AsyncSession, user: CurrentUserPayload, body: ExpenseCreate, company_id: uuid.UUID) -> ExpenseResponse:
invoice_ids = [item.invoice_id for item in body.items]
try:
async with db.begin_nested():
@@ -154,6 +159,7 @@ async def create_expense(db: AsyncSession, user: CurrentUserPayload, body: Expen
system_no = await _generate_expense_no(db)
expense = FinExpenseRecord(
system_no=system_no, applicant_id=user.user_id,
company_id=company_id,
total_amount=body.total_amount, status="submitted", remark=body.remark,
)
db.add(expense)
@@ -184,8 +190,11 @@ async def list_expenses(
db: AsyncSession, user: CurrentUserPayload,
page: int = 1, size: int = 20,
status: str | None = None, applicant_id: uuid.UUID | None = None,
company_id: uuid.UUID | None = None,
) -> ExpenseListResponse:
where = [FinExpenseRecord.is_deleted.is_(False)]
if company_id:
where.append(FinExpenseRecord.company_id == company_id)
if user.data_scope == "self":
where.append(FinExpenseRecord.applicant_id == user.user_id)
elif user.data_scope == "dept_and_sub":
+211
View File
@@ -0,0 +1,211 @@
"""
发票结构化解析器 OFD / XML 零算力提取
OFD 文件本质是 ZIP 包含 XML直接解包提取发票字段
XML 电子发票数电票直接 XPath 提取
"""
from __future__ import annotations
import io
import os
import re
import zipfile
from xml.etree import ElementTree as ET
from typing import Optional
def parse_ofd_invoice(file_bytes: bytes) -> dict:
"""
解析 OFD 电子发票文件
OFD = ZIP 压缩包内含 XML 描述文件
提取发票关键字段返回结构化 dict
"""
result: dict = {}
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
# 收集所有 XML 内容
all_text = ""
for name in zf.namelist():
if name.endswith(".xml"):
try:
xml_bytes = zf.read(name)
xml_text = xml_bytes.decode("utf-8", errors="replace")
all_text += xml_text + "\n"
# 尝试从 XML 标签中提取结构化数据
extracted = _extract_from_xml_text(xml_text)
if extracted:
result.update(extracted)
except Exception:
continue
# 如果解析出了字段就直接返回
if result.get("merchant") or result.get("amount"):
return {"success": True, "data": result}
# 降级:把所有 XML 文本当纯文本返回,交给 LLM 处理
if all_text.strip():
return {"success": True, "data": {"raw_text": all_text[:8000]}, "needs_llm": True}
return {"success": False, "data": {}, "error": "OFD 文件中未找到有效 XML 内容"}
except zipfile.BadZipFile:
return {"success": False, "data": {}, "error": "OFD 文件格式损坏或不是有效的 OFD 文件"}
except Exception as e:
return {"success": False, "data": {}, "error": f"OFD 解析失败: {e}"}
def parse_xml_invoice(file_bytes: bytes) -> dict:
"""
解析 XML 格式电子发票数电票
直接从 XML 标签提取所有发票字段
"""
try:
xml_text = file_bytes.decode("utf-8", errors="replace")
result = _extract_from_xml_text(xml_text)
if result and (result.get("merchant") or result.get("amount")):
return {"success": True, "data": result}
# 降级:XML 结构未匹配预设标签,交给 LLM
if xml_text.strip():
return {"success": True, "data": {"raw_text": xml_text[:8000]}, "needs_llm": True}
return {"success": False, "data": {}, "error": "XML 文件内容为空"}
except Exception as e:
return {"success": False, "data": {}, "error": f"XML 解析失败: {e}"}
def parse_zip_invoices(file_bytes: bytes) -> list[dict]:
"""
解析 ZIP 压缩包中的所有 XML 发票文件
返回列表每个元素 = {"filename": str, "success": bool, "data": dict, ...}
支持系统导出的 ZIP 格式内含多个 XML 发票
"""
results: list[dict] = []
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
xml_names = [n for n in zf.namelist() if n.lower().endswith(".xml")]
if not xml_names:
return [{"filename": "(zip)", "success": False, "data": {}, "error": "ZIP 包中未找到 XML 文件"}]
for name in xml_names:
try:
xml_bytes = zf.read(name)
result = parse_xml_invoice(xml_bytes)
result["filename"] = os.path.basename(name)
results.append(result)
except Exception as e:
results.append({"filename": os.path.basename(name), "success": False, "data": {}, "error": str(e)})
except zipfile.BadZipFile:
return [{"filename": "(zip)", "success": False, "data": {}, "error": "不是有效的 ZIP 文件"}]
except Exception as e:
return [{"filename": "(zip)", "success": False, "data": {}, "error": f"ZIP 解析失败: {e}"}]
return results
# ── 内部工具函数 ──────────────────────────────────────
# 常见发票 XML 标签名映射(兼容多种数电票 XML 格式)
_FIELD_PATTERNS = {
"merchant": [
"SalesName", "SellerName", "销售方名称", "销方名称",
"开票方", "Seller", "salername", "xfmc",
],
"buyer": [
"BuyerName", "PurchaserName", "购买方名称", "购方名称",
"Buyer", "buyername", "gfmc",
],
"amount": [
"TotalAmount", "Amount", "InvoiceAmount", "金额",
"合计金额", "价税合计", "jshj", "hjje",
],
"tax_amount": [
"TotalTax", "TaxAmount", "Tax", "税额",
"合计税额", "hjse",
],
"date": [
"IssueDate", "InvoiceDate", "BillingDate", "开票日期",
"kprq",
],
"invoice_code": [
"InvoiceCode", "发票代码", "fpdm",
],
"invoice_number": [
"InvoiceNumber", "InvoiceNo", "发票号码", "fphm",
],
"items": [
"GoodsName", "ItemName", "商品名称", "货物名称", "spmc",
],
"tax_rate": [
"TaxRate", "税率", "sl",
],
"remark": [
"Remark", "备注", "bz",
],
}
def _extract_from_xml_text(xml_text: str) -> Optional[dict]:
"""从 XML 文本中用多种策略提取发票字段。"""
result: dict = {}
# 策略 1: 正则匹配 <TagName>Value</TagName> 格式
for field, tag_names in _FIELD_PATTERNS.items():
for tag in tag_names:
# 匹配 <Tag>value</Tag> 或 <ns:Tag>value</ns:Tag>
pattern = rf'<(?:\w+:)?{re.escape(tag)}[^>]*>([^<]+)</(?:\w+:)?{re.escape(tag)}>'
match = re.search(pattern, xml_text, re.IGNORECASE)
if match:
value = match.group(1).strip()
if value:
# 数字字段转数值
if field in ("amount", "tax_amount"):
try:
result[field] = float(value)
except ValueError:
result[field] = value
else:
result[field] = value
break # 找到一个就跳到下一个字段
# 策略 2: 尝试 ElementTree 解析
if not result:
try:
# 移除 XML 声明中可能的编码问题
cleaned = re.sub(r'<\?xml[^?]*\?>', '', xml_text).strip()
if cleaned:
root = ET.fromstring(cleaned)
_extract_from_element(root, result)
except ET.ParseError:
pass
return result if result else None
def _extract_from_element(elem: ET.Element, result: dict, depth: int = 0):
"""递归遍历 XML 元素树提取字段。"""
if depth > 10:
return
tag_local = elem.tag.split("}")[-1] if "}" in elem.tag else elem.tag
for field, tag_names in _FIELD_PATTERNS.items():
if field not in result:
for tn in tag_names:
if tag_local.lower() == tn.lower():
text = (elem.text or "").strip()
if text:
if field in ("amount", "tax_amount"):
try:
result[field] = float(text)
except ValueError:
result[field] = text
else:
result[field] = text
break
for child in elem:
_extract_from_element(child, result, depth + 1)
+22 -28
View File
@@ -72,11 +72,12 @@ async def ocr_image(
"messages": [
{
"role": "user",
"content": "/no_think\n" + prompt,
"content": prompt,
"images": [image_base64], # Ollama vision 格式
},
],
"stream": False,
"think": False, # 关闭思考模式:稳定输出、避免死循环、提速 2-5x
"options": {
"temperature": 0.1,
"num_predict": 2000,
@@ -87,19 +88,18 @@ async def ocr_image(
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(url, json=payload)
if resp.status_code != 200:
print(f"[OCR] 3090 返回 {resp.status_code}: {resp.text[:200]}")
return {"success": False, "data": {}, "error": f"VL 模型返回 {resp.status_code}"}
detail = resp.text[:200]
print(f"[OCR] 3090 返回 {resp.status_code}: {detail}")
if "model runner" in detail:
return {"success": False, "data": {}, "error": "AI OCR 模型进程崩溃,请联系管理员重启 Ollama 服务"}
return {"success": False, "data": {}, "error": f"AI OCR 服务异常 (HTTP {resp.status_code}),请稍后重试"}
data = resp.json()
# Qwen3.5 的 CoT 推理放在 message.thinking,最终结果在 message.content
content = data.get("message", {}).get("content", "")
thinking = data.get("message", {}).get("thinking", "")
# 优先从 content 提取 JSON,回退到 thinking
for text_source in [content, thinking]:
if not text_source:
continue
cleaned = re.sub(r'<think>.*?</think>', '', text_source, flags=re.DOTALL).strip()
# 关闭思考模式后,结果直接在 content(无 thinking 字段)
if content:
cleaned = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
json_match = re.search(r'\{[\s\S]*\}', cleaned)
if json_match:
try:
@@ -107,16 +107,14 @@ async def ocr_image(
print(f"[OCR] 解析成功: {list(result.keys())}")
return {"success": True, "data": result}
except json.JSONDecodeError:
continue
pass
# 没有提取 JSON,返回原始文本
raw = content or thinking
print(f"[OCR] 未能提取 JSON, 内容长度: content={len(content)}, thinking={len(thinking)}")
return {"success": True, "data": {"raw_text": raw[:2000]}}
print(f"[OCR] 未能提取 JSON, content 长度: {len(content)}")
return {"success": True, "data": {"raw_text": content[:2000]}}
except httpx.TimeoutException:
print("[OCR] 3090 超时(60s")
return {"success": False, "data": {}, "error": "VL 模型响应超时"}
print("[OCR] 3090 超时(120s")
return {"success": False, "data": {}, "error": "AI OCR 响应超时(120s),模型可能负载过高,请稍后重试"}
except json.JSONDecodeError as e:
print(f"[OCR] JSON 解析失败: {e}")
return {"success": False, "data": {}, "error": f"JSON 解析失败: {e}"}
@@ -172,11 +170,11 @@ async def extract_invoice_from_text(
"messages": [
{
"role": "user",
"content": f"/no_think\n{prompt}\n\n--- 以下是发票文本内容 ---\n\n{truncated}",
# 不传 images —— 纯文本模式
"content": f"{prompt}\n\n--- 以下是发票文本内容 ---\n\n{truncated}",
},
],
"stream": False,
"think": False, # 关闭思考模式
"options": {
"temperature": 0.1,
"num_predict": 2000,
@@ -192,12 +190,9 @@ async def extract_invoice_from_text(
data = resp.json()
content = data.get("message", {}).get("content", "")
thinking = data.get("message", {}).get("thinking", "")
for text_source in [content, thinking]:
if not text_source:
continue
cleaned = re.sub(r'<think>.*?</think>', '', text_source, flags=re.DOTALL).strip()
if content:
cleaned = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
json_match = re.search(r'\{[\s\S]*\}', cleaned)
if json_match:
try:
@@ -205,11 +200,10 @@ async def extract_invoice_from_text(
print(f"[TextExtract] AI 提取成功: {list(result.keys())}")
return {"success": True, "data": result}
except json.JSONDecodeError:
continue
pass
raw = content or thinking
print(f"[TextExtract] 未能提取 JSON, 内容: {raw[:200]}")
return {"success": True, "data": {"raw_text": raw[:2000]}}
print(f"[TextExtract] 未能提取 JSON, content: {content[:200]}")
return {"success": True, "data": {"raw_text": content[:2000]}}
except httpx.TimeoutException:
print("[TextExtract] 3090 超时")
+266
View File
@@ -0,0 +1,266 @@
"""
OCR 后台 Worker asyncio 协程FastAPI lifespan 启动
策略 C: 工作时间限流(1并发 + 60s间隔)17:00-20:00 BJT 全速
"""
from __future__ import annotations
import asyncio
import os
import uuid
from datetime import datetime, timedelta
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import async_session_factory
from app.models.finance import FinInvoicePool, FinOcrTask
class OcrWorker:
"""后台 OCR 任务处理器"""
def __init__(self):
self.running = False
self.current_task_id: uuid.UUID | None = None
self._task: asyncio.Task | None = None
def start(self):
self.running = True
self._task = asyncio.create_task(self._run_loop())
print("[OcrWorker] 启动 — 策略 C: 工作时间限流, 17-20 BJT 全速")
async def stop(self):
self.running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
print("[OcrWorker] 已停止")
async def _run_loop(self):
"""主循环:每 10 秒检查一次队列"""
while self.running:
try:
task = await self._pick_next_task()
if task:
await self._process_task(task)
# 限流:非高峰期间隔 60s
if not self._is_peak_time():
await asyncio.sleep(60)
else:
await asyncio.sleep(5)
else:
await asyncio.sleep(10)
except asyncio.CancelledError:
break
except Exception as e:
print(f"[OcrWorker] 循环异常: {e}")
await asyncio.sleep(30)
def _is_peak_time(self) -> bool:
"""17:00-20:00 BJT = 09:00-12:00 UTC"""
utc_hour = datetime.utcnow().hour
return 9 <= utc_hour < 12
async def _pick_next_task(self) -> dict | None:
"""从 DB 获取优先级最高的 pending 任务"""
async with async_session_factory() as db:
stmt = (
select(FinOcrTask)
.where(
FinOcrTask.status == "pending",
FinOcrTask.is_deleted.is_(False),
FinOcrTask.retry_count < FinOcrTask.max_retries,
)
.order_by(FinOcrTask.priority, FinOcrTask.created_at)
.limit(1)
)
task = (await db.execute(stmt)).scalar_one_or_none()
if not task:
return None
# 标记为 processing
task.status = "processing"
task.updated_at = datetime.utcnow()
await db.commit()
self.current_task_id = task.id
return {
"id": task.id,
"file_url": task.file_url,
"file_ext": task.file_ext,
"original_name": task.original_name,
"uploader_id": task.uploader_id,
"company_id": task.company_id,
"inv_type": task.inv_type,
"retry_count": task.retry_count,
}
async def _process_task(self, task_info: dict):
"""执行 OCR 并更新"""
task_id = task_info["id"]
file_url = task_info["file_url"]
file_ext = task_info["file_ext"]
print(f"[OcrWorker] 处理任务 {task_id} ({task_info['original_name']}, {file_ext})")
try:
# 读取文件
file_path = file_url.lstrip("/")
if not os.path.exists(file_path):
await self._mark_failed(task_id, f"文件不存在: {file_path}")
return
with open(file_path, "rb") as f:
file_bytes = f.read()
ocr_data = {}
message = ""
# PDF 处理
if file_ext == ".pdf":
ocr_data, message = await self._process_pdf(file_bytes)
# 图片处理
elif file_ext in (".png", ".jpg", ".jpeg"):
ocr_data, message = await self._process_image(file_bytes)
else:
await self._mark_failed(task_id, f"不支持的文件格式: {file_ext}")
return
if ocr_data and (ocr_data.get("merchant") or ocr_data.get("amount")):
# OCR 成功 → 自动入池
await self._mark_success_and_pool(task_id, task_info, ocr_data)
print(f"[OcrWorker] ✅ {task_info['original_name']} 入池成功")
else:
# OCR 完成但没提取到关键字段
await self._mark_failed(
task_id,
message or "AI 未能提取发票关键字段(开票方/金额),请手动录入",
ocr_data,
)
except Exception as e:
print(f"[OcrWorker] ❌ 任务 {task_id} 异常: {e}")
await self._mark_failed(task_id, str(e))
self.current_task_id = None
async def _process_pdf(self, file_bytes: bytes) -> tuple[dict, str]:
"""PDF: 先尝试文本提取,失败降级 Vision OCR"""
try:
import fitz
doc = fitz.open(stream=file_bytes, filetype="pdf")
text = ""
for page in doc:
text += page.get_text() + "\n"
doc.close()
text = text.strip()
if len(text) > 50:
from app.services.ocr_service import extract_invoice_from_text
result = await extract_invoice_from_text(text, "invoice")
if result.get("success") and result.get("data"):
return result["data"], "PDF 文本解析成功"
# 降级: 扫描件 → Vision OCR
doc2 = fitz.open(stream=file_bytes, filetype="pdf")
pix = doc2[0].get_pixmap(dpi=150)
ocr_bytes = pix.tobytes("png")
doc2.close()
return await self._vision_ocr(ocr_bytes)
except Exception as e:
return {}, f"PDF 处理失败: {e}"
async def _process_image(self, file_bytes: bytes) -> tuple[dict, str]:
"""图片: Vision OCR"""
return await self._vision_ocr(file_bytes)
async def _vision_ocr(self, image_bytes: bytes) -> tuple[dict, str]:
"""调用 3090 Vision OCR"""
import base64
from app.services.ocr_service import ocr_image
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
result = await ocr_image(image_b64, "invoice")
if result.get("success"):
return result.get("data", {}), "Vision OCR 成功"
return {}, result.get("error", "OCR 失败")
async def _mark_success_and_pool(self, task_id: uuid.UUID, task_info: dict, ocr_data: dict):
"""标记成功 + 自动入池"""
async with async_session_factory() as db:
merchant = ocr_data.get("merchant") or ocr_data.get("merchant_name") or "AI 提取)"
amount = 0
try:
amount = float(ocr_data.get("amount", 0))
except (ValueError, TypeError):
pass
invoice_date_str = ocr_data.get("date")
invoice_date = None
if invoice_date_str:
try:
from datetime import date as dt_date
invoice_date = dt_date.fromisoformat(invoice_date_str)
except ValueError:
pass
inv = FinInvoicePool(
uploader_id=task_info["uploader_id"],
company_id=task_info["company_id"],
file_url=task_info["file_url"],
merchant_name=merchant,
amount=amount,
invoice_date=invoice_date,
type=task_info["inv_type"],
ai_extracted_data=ocr_data,
is_used=False,
)
db.add(inv)
await db.flush()
await db.execute(
update(FinOcrTask)
.where(FinOcrTask.id == task_id)
.values(
status="success",
ocr_result=ocr_data,
invoice_pool_id=inv.id,
error_message=None,
updated_at=datetime.utcnow(),
)
)
await db.commit()
async def _mark_failed(self, task_id: uuid.UUID, error: str, partial_data: dict | None = None):
"""标记失败 + retry_count+1"""
async with async_session_factory() as db:
task = (await db.execute(
select(FinOcrTask).where(FinOcrTask.id == task_id)
)).scalar_one_or_none()
if not task:
return
new_retry = task.retry_count + 1
new_status = "failed" if new_retry >= task.max_retries else "pending"
await db.execute(
update(FinOcrTask)
.where(FinOcrTask.id == task_id)
.values(
status=new_status,
retry_count=new_retry,
error_message=error,
ocr_result=partial_data or task.ocr_result,
updated_at=datetime.utcnow(),
)
)
await db.commit()
if new_status == "pending":
print(f"[OcrWorker] ⚠️ 任务 {task_id}{new_retry} 次重试入队")
else:
print(f"[OcrWorker] ❌ 任务 {task_id} 已达最大重试次数,标记失败")
# 单例
ocr_worker = OcrWorker()
+14 -4
View File
@@ -16,6 +16,7 @@ from app.core.exceptions import BizException, ForbiddenException, NotFoundExcept
from app.models.crm import CrmCustomer
from app.models.erp import ProductSku
from app.models.order import ErpOrder, ErpOrderItem
from app.models.sys import SysUser
from app.schemas.auth import CurrentUserPayload
from app.schemas.order import (
OrderBriefResponse,
@@ -156,6 +157,7 @@ async def create_order(
db: AsyncSession,
user: CurrentUserPayload,
body: OrderCreate,
company_id: uuid.UUID,
) -> OrderResponse:
# 校验客户存在
cust = (
@@ -193,6 +195,7 @@ async def create_order(
order_no=order_no,
customer_id=body.customer_id,
salesperson_id=user.user_id,
company_id=company_id,
total_amount=total,
shipping_state="pending",
payment_state="unpaid",
@@ -236,8 +239,11 @@ async def list_orders(
shipping_state: str | None = None,
payment_state: str | None = None,
keyword: str | None = None,
company_id: uuid.UUID | None = None,
) -> OrderListResponse:
where: list[Any] = [ErpOrder.is_deleted.is_(False)]
if company_id:
where.append(ErpOrder.company_id == company_id)
if user.data_scope == "self":
where.append(ErpOrder.salesperson_id == user.user_id)
@@ -284,13 +290,17 @@ async def get_order(
db: AsyncSession,
user: CurrentUserPayload,
order_id: uuid.UUID,
company_id: uuid.UUID | None = None,
) -> OrderResponse:
where_clause = [
ErpOrder.id == order_id,
ErpOrder.is_deleted.is_(False),
]
if company_id:
where_clause.append(ErpOrder.company_id == company_id)
order = (
await db.execute(
select(ErpOrder).where(
ErpOrder.id == order_id,
ErpOrder.is_deleted.is_(False),
)
select(ErpOrder).where(*where_clause)
)
).scalar_one_or_none()
if order is None:
+91 -26
View File
@@ -14,7 +14,8 @@ from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, NotFoundException
from app.models.erp import InventoryFlow, ProductCategory, ProductSku
from app.models.erp import ErpSkuInventory, InventoryFlow, ProductCategory, ProductSku
from app.models.sys import SysUser
from app.schemas.auth import CurrentUserPayload
from app.schemas.erp import (
CategoryCreate,
@@ -31,7 +32,10 @@ from app.schemas.erp import (
# ── ORM → Response ───────────────────────────────────────
def _sku_to_response(s: ProductSku) -> SkuResponse:
def _sku_to_response(
s: ProductSku,
inv: ErpSkuInventory | None = None,
) -> SkuResponse:
return SkuResponse(
id=s.id,
sku_code=s.sku_code,
@@ -40,8 +44,8 @@ def _sku_to_response(s: ProductSku) -> SkuResponse:
category_name=s.category.name if s.category else None,
spec=s.spec,
standard_price=float(s.standard_price or 0),
stock_qty=float(s.stock_qty or 0),
warning_threshold=float(s.warning_threshold or 0),
stock_qty=float(inv.stock_qty) if inv else 0.0,
warning_threshold=float(inv.warning_threshold) if inv else 0.0,
unit=s.unit,
status=s.status,
created_at=s.created_at,
@@ -200,11 +204,13 @@ async def delete_category(db: AsyncSession, cat_id: uuid.UUID) -> None:
async def list_skus(
db: AsyncSession,
company_id: uuid.UUID,
page: int = 1,
size: int = 20,
category_id: uuid.UUID | None = None,
keyword: str | None = None,
) -> SkuListResponse:
"""LEFT JOIN erp_sku_inventory 获取当前公司库存,COALESCE 兜底为 0"""
where: list[Any] = [ProductSku.is_deleted.is_(False)]
if category_id:
where.append(ProductSku.category_id == category_id)
@@ -218,24 +224,31 @@ async def list_skus(
await db.execute(select(func.count()).select_from(ProductSku).where(*where))
).scalar() or 0
# LEFT JOIN erp_sku_inventory 带出当前公司库存
stmt = (
select(ProductSku)
select(ProductSku, ErpSkuInventory)
.outerjoin(
ErpSkuInventory,
(ErpSkuInventory.sku_id == ProductSku.id)
& (ErpSkuInventory.company_id == company_id),
)
.where(*where)
.order_by(ProductSku.created_at.desc())
.offset((page - 1) * size)
.limit(size)
)
rows = (await db.execute(stmt)).scalars().all()
rows = (await db.execute(stmt)).all()
return SkuListResponse(
total=total,
items=[_sku_to_response(s) for s in rows],
items=[_sku_to_response(sku, inv) for sku, inv in rows],
page=page,
size=size,
)
async def create_sku(db: AsyncSession, body: SkuCreate) -> SkuResponse:
"""创建 SKU(不创建库存行,LEFT JOIN 查询自动兜底为 0)"""
exists = (
await db.execute(
select(ProductSku.id).where(
@@ -253,8 +266,6 @@ async def create_sku(db: AsyncSession, body: SkuCreate) -> SkuResponse:
category_id=body.category_id,
spec=body.spec,
standard_price=body.standard_price,
stock_qty=body.stock_qty,
warning_threshold=body.warning_threshold,
unit=body.unit,
status=body.status,
)
@@ -299,7 +310,9 @@ async def create_inventory_flow(
db: AsyncSession,
user: CurrentUserPayload,
body: InventoryFlowCreate,
company_id: uuid.UUID,
) -> InventoryFlowResponse:
"""库存变更(upsert erp_sku_inventory + 写流水)"""
sku = (
await db.execute(
select(ProductSku).where(
@@ -310,35 +323,74 @@ async def create_inventory_flow(
if sku is None:
raise NotFoundException("产品 SKU 不存在")
if body.change_qty < 0:
current_stock = float(sku.stock_qty or 0)
if current_stock + body.change_qty < 0:
raise BizException(
message=f"库存不足:当前库存 {current_stock},请求出库 {abs(body.change_qty)}"
)
try:
async with db.begin_nested():
# ── upsert: 查找或创建当前公司的库存行 ──
inv = (
await db.execute(
select(ErpSkuInventory)
.where(
ErpSkuInventory.sku_id == body.sku_id,
ErpSkuInventory.company_id == company_id,
)
.with_for_update()
)
).scalar_one_or_none()
if inv is None:
# 首次操作该 SKU:自动创建 0 库存行
inv = ErpSkuInventory(
sku_id=body.sku_id,
company_id=company_id,
stock_qty=0,
warning_threshold=0,
)
db.add(inv)
await db.flush()
# 重新锁行
inv = (
await db.execute(
select(ErpSkuInventory)
.where(ErpSkuInventory.id == inv.id)
.with_for_update()
)
).scalar_one()
# ── 校验库存 ──
current_stock = float(inv.stock_qty or 0)
if body.change_qty < 0 and current_stock + body.change_qty < 0:
raise BizException(
message=f"库存不足:当前库存 {current_stock},请求出库 {abs(body.change_qty)}"
)
# ── 更新库存 ──
await db.execute(
update(ErpSkuInventory)
.where(ErpSkuInventory.id == inv.id)
.values(
stock_qty=ErpSkuInventory.stock_qty + Decimal(str(body.change_qty)),
updated_at=datetime.utcnow(),
)
)
# ── 写流水 ──
flow = InventoryFlow(
sku_id=body.sku_id,
company_id=company_id,
change_qty=body.change_qty,
reason=body.reason,
remark=body.remark,
purchase_unit_price=body.purchase_unit_price if body.change_qty > 0 else 0,
is_special_zero_cost=body.is_special_zero_cost if body.change_qty > 0 else False,
operator_id=user.user_id,
)
db.add(flow)
await db.flush()
await db.execute(
update(ProductSku)
.where(ProductSku.id == body.sku_id)
.values(
stock_qty=ProductSku.stock_qty + Decimal(str(body.change_qty)),
updated_at=datetime.utcnow(),
)
)
await db.commit()
except BizException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
raise BizException(code=500, message=f"库存变更事务失败: {e!s}") from e
@@ -352,9 +404,11 @@ async def create_inventory_flow(
async def get_inventory_flows(
db: AsyncSession,
sku_id: uuid.UUID,
company_id: uuid.UUID,
page: int = 1,
size: int = 50,
) -> dict[str, Any]:
"""获取单个 SKU 在当前公司的库存流水"""
sku = (
await db.execute(
select(ProductSku).where(
@@ -365,8 +419,19 @@ async def get_inventory_flows(
if sku is None:
raise NotFoundException("产品 SKU 不存在")
# 查当前公司库存
inv = (
await db.execute(
select(ErpSkuInventory).where(
ErpSkuInventory.sku_id == sku_id,
ErpSkuInventory.company_id == company_id,
)
)
).scalar_one_or_none()
where: list[Any] = [
InventoryFlow.sku_id == sku_id,
InventoryFlow.company_id == company_id,
InventoryFlow.is_deleted.is_(False),
]
@@ -389,7 +454,7 @@ async def get_inventory_flows(
"total": total,
"sku_code": sku.sku_code,
"sku_name": sku.name,
"current_stock": float(sku.stock_qty or 0),
"current_stock": float(inv.stock_qty) if inv else 0.0,
"items": [_flow_to_response(f).model_dump(mode="json") for f in flows],
"page": page,
"size": size,
+226
View File
@@ -0,0 +1,226 @@
"""
库存与利润核算 Service
- MWA 入库事务悲观锁 FOR UPDATE + 零元隔离
- 订单利润快照
- 利润报表聚合
"""
from __future__ import annotations
import uuid
from datetime import datetime
from sqlalchemy import func, select, update, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, NotFoundException
from app.models.erp import ErpSkuInventory, InventoryFlow, ProductSku
from app.models.cost import ErpOrderItemCost
from app.models.order import ErpOrder, ErpOrderItem
from app.schemas.auth import CurrentUserPayload
# ── MWA 入库事务 ────────────────────────────────────────
async def process_inbound_with_mwa(
db: AsyncSession,
sku_id: uuid.UUID,
company_id: uuid.UUID,
qty: float,
purchase_unit_price: float,
operator_id: uuid.UUID | None = None,
remark: str | None = None,
is_special_zero_cost: bool = False,
) -> dict:
"""
入库事务悲观锁 + MWA
1. SELECT ... FOR UPDATE 锁定库存行
2. 如果非零元特殊计算新 MWA
3. 更新库存 + 记录流水
"""
# 悲观锁获取库存记录
inv_stmt = (
select(ErpSkuInventory)
.where(
ErpSkuInventory.sku_id == sku_id,
ErpSkuInventory.company_id == company_id,
)
.with_for_update()
)
inv = (await db.execute(inv_stmt)).scalar_one_or_none()
if inv is None:
# 首次入库,创建库存记录
inv = ErpSkuInventory(
sku_id=sku_id,
company_id=company_id,
stock_qty=0,
mwa_unit_cost=0,
)
db.add(inv)
await db.flush()
# 重新锁定
inv = (await db.execute(inv_stmt)).scalar_one()
old_qty = float(inv.stock_qty or 0)
old_mwa = float(inv.mwa_unit_cost or 0)
new_qty = old_qty + qty
# MWA 计算(零元特殊入库不参与)
if is_special_zero_cost or purchase_unit_price == 0:
new_mwa = old_mwa # 保持原有 MWA
else:
if new_qty > 0:
new_mwa = (old_qty * old_mwa + qty * purchase_unit_price) / new_qty
else:
new_mwa = purchase_unit_price
# 更新库存
inv.stock_qty = new_qty
inv.mwa_unit_cost = round(new_mwa, 4)
inv.updated_at = datetime.utcnow()
# 记录流水
flow = InventoryFlow(
sku_id=sku_id,
company_id=company_id,
flow_type="in",
change_qty=qty,
reason="purchase_in",
purchase_unit_price=purchase_unit_price,
is_special_zero_cost=is_special_zero_cost,
operator_id=operator_id,
remark=remark or f"入库 {qty} 件 @ ¥{purchase_unit_price}",
)
db.add(flow)
await db.commit()
return {
"sku_id": str(sku_id),
"old_qty": old_qty,
"new_qty": new_qty,
"old_mwa": old_mwa,
"new_mwa": round(new_mwa, 4),
"is_special_zero_cost": is_special_zero_cost,
}
# ── 订单明细成本快照 ────────────────────────────────────
async def snapshot_order_item_costs(
db: AsyncSession,
order_id: uuid.UUID,
company_id: uuid.UUID,
) -> list[dict]:
"""为订单的所有明细行锚定 MWA 成本快照"""
items_stmt = select(ErpOrderItem).where(
ErpOrderItem.order_id == order_id,
ErpOrderItem.is_deleted.is_(False),
)
items = (await db.execute(items_stmt)).scalars().all()
results = []
for item in items:
# 查当前 MWA
inv = (await db.execute(
select(ErpSkuInventory).where(
ErpSkuInventory.sku_id == item.sku_id,
ErpSkuInventory.company_id == company_id,
)
)).scalar_one_or_none()
mwa_cost = float(inv.mwa_unit_cost or 0) if inv else 0
sell_price = float(item.unit_price or 0)
qty = float(item.qty or 0)
profit = (sell_price - mwa_cost) * qty
profit_rate = (sell_price - mwa_cost) / sell_price if sell_price > 0 else 0
# 检查是否已有快照
existing = (await db.execute(
select(ErpOrderItemCost).where(
ErpOrderItemCost.order_item_id == item.id
)
)).scalar_one_or_none()
if existing:
existing.purchase_unit_price = mwa_cost
existing.profit_amount = round(profit, 2)
existing.profit_rate = round(profit_rate, 4)
else:
cost_snap = ErpOrderItemCost(
order_item_id=item.id,
purchase_unit_price=mwa_cost,
profit_amount=round(profit, 2),
profit_rate=round(profit_rate, 4),
)
db.add(cost_snap)
results.append({
"sku_id": str(item.sku_id),
"qty": qty,
"sell_price": sell_price,
"mwa_cost": mwa_cost,
"profit": round(profit, 2),
"profit_rate": round(profit_rate * 100, 2),
})
await db.commit()
return results
# ── 利润报表 ────────────────────────────────────────────
async def get_profit_report(
db: AsyncSession,
company_id: uuid.UUID,
start_date: str | None = None,
end_date: str | None = None,
) -> dict:
"""聚合利润报表"""
base_where = [
ErpOrder.company_id == company_id,
ErpOrder.is_deleted.is_(False),
]
if start_date:
base_where.append(ErpOrder.order_date >= start_date)
if end_date:
base_where.append(ErpOrder.order_date <= end_date)
# 聚合:每笔订单的利润
stmt = (
select(
ErpOrder.id.label("order_id"),
ErpOrder.order_no,
ErpOrder.order_date,
ErpOrder.total_amount,
func.sum(ErpOrderItemCost.profit_amount).label("total_profit"),
)
.join(ErpOrderItem, ErpOrderItem.order_id == ErpOrder.id)
.join(ErpOrderItemCost, ErpOrderItemCost.order_item_id == ErpOrderItem.id)
.where(*base_where)
.group_by(ErpOrder.id, ErpOrder.order_no, ErpOrder.order_date, ErpOrder.total_amount)
.order_by(ErpOrder.order_date.desc())
)
rows = (await db.execute(stmt)).all()
orders = []
total_revenue = 0
total_profit = 0
for r in rows:
revenue = float(r.total_amount or 0)
profit = float(r.total_profit or 0)
total_revenue += revenue
total_profit += profit
orders.append({
"order_id": str(r.order_id),
"order_no": r.order_no,
"order_date": r.order_date.isoformat() if r.order_date else None,
"revenue": revenue,
"profit": profit,
"profit_rate": round(profit / revenue * 100, 2) if revenue > 0 else 0,
})
return {
"total_revenue": round(total_revenue, 2),
"total_profit": round(total_profit, 2),
"overall_profit_rate": round(total_profit / total_revenue * 100, 2) if total_revenue > 0 else 0,
"orders": orders,
}
+10 -1
View File
@@ -12,6 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, NotFoundException
from app.models.finance import FinSalesInvoice
from app.models.sys import SysUser
from app.models.crm import CrmCustomer
from app.schemas.auth import CurrentUserPayload
from app.schemas.sales_invoice import (
SalesInvoiceCreate,
@@ -45,6 +47,7 @@ async def create_invoice(
db: AsyncSession,
user: CurrentUserPayload,
body: SalesInvoiceCreate,
company_id: uuid.UUID | None = None,
) -> SalesInvoiceResponse:
# 检查发票号唯一性
existing = (await db.execute(
@@ -56,7 +59,7 @@ async def create_invoice(
if existing:
raise BizException(message=f"发票号 {body.invoice_number} 已存在")
inv = FinSalesInvoice(
kwargs: dict = dict(
issuer=body.issuer,
receiver_customer_id=body.receiver_customer_id,
invoice_number=body.invoice_number,
@@ -65,6 +68,9 @@ async def create_invoice(
remark=body.remark,
created_by=user.user_id,
)
if company_id is not None:
kwargs["company_id"] = company_id
inv = FinSalesInvoice(**kwargs)
db.add(inv)
await db.commit()
await db.refresh(inv)
@@ -80,8 +86,11 @@ async def list_invoices(
payment_status: str | None = None,
start_date: date | None = None,
end_date: date | None = None,
company_id: uuid.UUID | None = None,
) -> SalesInvoiceListResponse:
conditions = [FinSalesInvoice.is_deleted.is_(False)]
if company_id:
conditions.append(FinSalesInvoice.company_id == company_id)
if invoice_number:
conditions.append(FinSalesInvoice.invoice_number.ilike(f"%{invoice_number}%"))
+99 -5
View File
@@ -22,6 +22,7 @@ async def create_log(
customer_id: str | None = None,
contact_ids: list[str] | None = None,
log_date: date | None = None,
company_ids: list[uuid.UUID] | None = None,
) -> dict:
"""创建销售日志"""
log = SalesLog(
@@ -30,6 +31,7 @@ async def create_log(
contact_ids=contact_ids or [],
content=content,
log_date=log_date or date.today(),
involved_company_ids=company_ids or [],
)
db.add(log)
await db.commit()
@@ -46,9 +48,17 @@ async def list_logs(
user_id: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
company_id: uuid.UUID | None = None,
) -> dict:
"""查询销售日志列表"""
"""查询销售日志列表(按 involved_company_ids 包含过滤)"""
from sqlalchemy.orm import aliased
from app.models.crm import CrmCustomer
from app.models.sys import SysUser
conditions = [SalesLog.is_deleted.is_(False)]
if company_id:
# ARRAY contains: 过滤涉及当前公司的日志
conditions.append(SalesLog.involved_company_ids.any(company_id))
# 数据权限
if user.data_scope == "self":
@@ -69,24 +79,107 @@ async def list_logs(
count_stmt = select(func.count()).select_from(SalesLog).where(where)
total = (await db.execute(count_stmt)).scalar() or 0
# data
# data — LEFT JOIN customer + user to get names
Author = aliased(SysUser)
stmt = (
select(SalesLog)
select(
SalesLog,
CrmCustomer.name.label("customer_name"),
Author.real_name.label("author_name"),
)
.outerjoin(CrmCustomer, SalesLog.customer_id == CrmCustomer.id)
.outerjoin(Author, SalesLog.salesperson_id == Author.id)
.where(where)
.order_by(desc(SalesLog.created_at))
.offset((page - 1) * size)
.limit(size)
)
rows = (await db.execute(stmt)).scalars().all()
rows = (await db.execute(stmt)).all()
items = []
for log, cust_name, auth_name in rows:
d = _to_dict(log)
d["customer_name"] = cust_name
d["author_name"] = auth_name
items.append(d)
return {
"total": total,
"page": page,
"size": size,
"items": [_to_dict(r) for r in rows],
"items": items,
}
async def update_log(
db: AsyncSession,
user: CurrentUserPayload,
log_id: uuid.UUID,
content: str | None = None,
customer_id: str | None = None,
contact_ids: list[str] | None = None,
log_date: str | None = None,
company_id: uuid.UUID | None = None,
) -> dict:
"""编辑销售日志 — 员工只能改自己的,管理员可改所有"""
from app.models.crm import CrmCustomer
from app.models.sys import SysUserCompany
log = await db.get(SalesLog, log_id)
if not log or log.is_deleted:
raise Exception("日志不存在")
# 权限检查
if user.data_scope != "all" and log.salesperson_id != user.user_id:
raise Exception("您无权编辑此日志")
if content is not None:
log.content = content
if contact_ids is not None:
log.contact_ids = contact_ids
if log_date is not None:
log.log_date = date.fromisoformat(log_date)
# 更新客户关联 + 自动重算 involved_company_ids
if customer_id is not None:
log.customer_id = uuid.UUID(customer_id) if customer_id else None
# 重新关联公司
resolved = set(log.involved_company_ids or [])
if company_id:
resolved.add(company_id)
if customer_id:
cust = await db.get(CrmCustomer, uuid.UUID(customer_id))
if cust and cust.owner_id:
stmt = select(SysUserCompany.company_id).where(
SysUserCompany.user_id == cust.owner_id
)
rows = (await db.execute(stmt)).scalars().all()
for cid in rows:
resolved.add(cid)
log.involved_company_ids = list(resolved)
await db.commit()
await db.refresh(log)
return _to_dict(log)
async def delete_log(
db: AsyncSession,
user: CurrentUserPayload,
log_id: uuid.UUID,
) -> None:
"""软删除销售日志 — 员工只能删自己的,管理员可删所有"""
log = await db.get(SalesLog, log_id)
if not log or log.is_deleted:
raise Exception("日志不存在")
if user.data_scope != "all" and log.salesperson_id != user.user_id:
raise Exception("您无权删除此日志")
log.is_deleted = True
await db.commit()
async def trigger_persona_workflow(
log_id: uuid.UUID,
customer_id: uuid.UUID,
@@ -157,6 +250,7 @@ def _to_dict(log: SalesLog) -> dict:
"salesperson_id": str(log.salesperson_id),
"customer_id": str(log.customer_id) if log.customer_id else None,
"contact_ids": log.contact_ids or [],
"involved_company_ids": [str(c) for c in (log.involved_company_ids or [])],
"content": log.content,
"log_date": log.log_date.isoformat() if log.log_date else None,
"ai_processed": log.ai_processed,
+49 -15
View File
@@ -10,9 +10,11 @@ from typing import Any
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
from app.models.erp import InventoryFlow, ProductSku
from app.models.erp import ErpSkuInventory, InventoryFlow, ProductSku
from app.models.order import ErpOrder, ErpOrderItem
from app.models.shipping import ErpShippingItem, ErpShippingRecord
from app.models.sys import SysUser
from app.models.crm import CrmCustomer
from app.schemas.auth import CurrentUserPayload
from app.schemas.shipping import (
ShippingBriefResponse, ShippingCreate, ShippingItemResponse,
@@ -75,10 +77,15 @@ def _check_shipping_access(order: ErpOrder, user: CurrentUserPayload) -> None:
async def create_shipping(
db: AsyncSession, user: CurrentUserPayload, body: ShippingCreate,
company_id: uuid.UUID,
) -> tuple[ShippingResponse, str]:
"""返回 (response, new_shipping_state)"""
"""返回 (response, new_shipping_state)。库存从 erp_sku_inventory 扣减"""
order = (await db.execute(
select(ErpOrder).where(ErpOrder.id == body.order_id, ErpOrder.is_deleted.is_(False))
select(ErpOrder).where(
ErpOrder.id == body.order_id,
ErpOrder.is_deleted.is_(False),
ErpOrder.company_id == company_id,
)
)).scalar_one_or_none()
if order is None:
raise NotFoundException("订单不存在")
@@ -114,6 +121,7 @@ async def create_shipping(
carrier=body.carrier, tracking_no=body.tracking_no,
status="transit", ship_date=body.ship_date or date.today(),
remark=body.remark, operator_id=user.user_id,
company_id=company_id,
)
db.add(record)
await db.flush()
@@ -125,22 +133,41 @@ async def create_shipping(
)
db.add(si)
result = await db.execute(
update(ProductSku).where(
ProductSku.id == item.sku_id,
ProductSku.stock_qty >= item.shipped_qty,
).values(
stock_qty=ProductSku.stock_qty - Decimal(str(item.shipped_qty)),
# ── 从 erp_sku_inventory 扣减库存(行锁) ──
inv = (
await db.execute(
select(ErpSkuInventory)
.where(
ErpSkuInventory.sku_id == item.sku_id,
ErpSkuInventory.company_id == company_id,
)
.with_for_update()
)
).scalar_one_or_none()
current_stock = float(inv.stock_qty) if inv else 0
if current_stock < item.shipped_qty:
raise BizException(
message=f"库存不足无法发货: SKU {item.sku_id}"
f"当前库存 {current_stock},请求出库 {item.shipped_qty}"
)
if inv is None:
# 不应出现此情况,但防御性处理
raise BizException(message=f"SKU {item.sku_id} 在当前公司无库存记录")
await db.execute(
update(ErpSkuInventory)
.where(ErpSkuInventory.id == inv.id)
.values(
stock_qty=ErpSkuInventory.stock_qty - Decimal(str(item.shipped_qty)),
updated_at=now,
)
)
if result.rowcount == 0:
sku = (await db.execute(select(ProductSku).where(ProductSku.id == item.sku_id))).scalar_one_or_none()
current_stock = float(sku.stock_qty) if sku else 0
raise BizException(message=f"库存不足无法发货: SKU {item.sku_id},当前库存 {current_stock},请求出库 {item.shipped_qty}")
db.add(InventoryFlow(
sku_id=item.sku_id, change_qty=-item.shipped_qty,
sku_id=item.sku_id, company_id=company_id,
change_qty=-item.shipped_qty,
reason="shipment", remark=f"订单发货出库 - 发货单 {shipping_no}",
operator_id=user.user_id,
))
@@ -178,8 +205,11 @@ async def list_shipping(
db: AsyncSession, user: CurrentUserPayload,
page: int = 1, size: int = 20,
order_no: str | None = None, tracking_no: str | None = None,
company_id: uuid.UUID | None = None,
) -> ShippingListResponse:
where: list[Any] = [ErpShippingRecord.is_deleted.is_(False)]
if company_id:
where.append(ErpShippingRecord.company_id == company_id)
if user.data_scope == "self":
my_orders = select(ErpOrder.id).where(ErpOrder.salesperson_id == user.user_id, ErpOrder.is_deleted.is_(False))
where.append(ErpShippingRecord.order_id.in_(my_orders))
@@ -203,9 +233,13 @@ async def list_shipping(
async def get_shipping_by_order(
db: AsyncSession, user: CurrentUserPayload, order_id: uuid.UUID,
company_id: uuid.UUID | None = None,
) -> dict[str, Any]:
where_clause = [ErpOrder.id == order_id, ErpOrder.is_deleted.is_(False)]
if company_id:
where_clause.append(ErpOrder.company_id == company_id)
order = (await db.execute(
select(ErpOrder).where(ErpOrder.id == order_id, ErpOrder.is_deleted.is_(False))
select(ErpOrder).where(*where_clause)
)).scalar_one_or_none()
if order is None:
raise NotFoundException("订单不存在")