v0.2.0: CRM/ERP 系统升级 - 清理 .gitignore 并移除误提交的 venv/env/db 文件

- 更新 .gitignore:全面覆盖环境变量、数据库、日志、缓存、上传文件
- 移除误跟踪的 server/venv/、crm_data.db、.env 文件
- 新增 server/.env.example 模板
- 新增合同管理、利润核算、AI教练等功能模块
- 新增 Playwright e2e 测试套件
- 前后端多项功能升级和 bug 修复
This commit is contained in:
hankin
2026-05-11 07:24:19 +00:00
parent 0f4c6b7924
commit 815cbf9d8c
2526 changed files with 11875 additions and 804148 deletions
+111
View File
@@ -0,0 +1,111 @@
"""
AI 教练引擎 — 事件总线 + Dify 回调
CQRS 解耦模式:
1. 业务端 POST /api/sales-logs → 立即 200 OK → 发消息到 Redis Streams
2. Worker 消费消息 → 调用 Dify Workflow → 写回 ai_coaching_feedback
3. 前端通过 SSE /api/notifications/stream 接收推送
"""
from __future__ import annotations
import json
import uuid
from datetime import datetime
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ai import SalesLog
from app.models.crm import CrmCustomer
from app.schemas.auth import CurrentUserPayload
# ── Redis 事件发布 ───────────────────────────────────────
async def publish_coaching_event(
sales_log_id: uuid.UUID,
content: str,
customer_id: uuid.UUID | None = None,
salesperson_id: uuid.UUID | None = None,
) -> None:
"""将销售日志推送到 Redis Streams,供 Worker 异步消费"""
try:
import redis.asyncio as aioredis
import os
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
r = aioredis.from_url(redis_url, decode_responses=True)
await r.xadd(
"coaching:sales_logs",
{
"sales_log_id": str(sales_log_id),
"content": content[:2000], # 限长
"customer_id": str(customer_id) if customer_id else "",
"salesperson_id": str(salesperson_id) if salesperson_id else "",
"timestamp": datetime.utcnow().isoformat(),
},
)
await r.aclose()
except Exception as e:
# Redis 不可用时降级——不阻塞主流程
print(f"[AI EventBus] Redis 推送失败(降级): {e}")
# ── Dify 回调处理 ───────────────────────────────────────
async def handle_dify_coaching_callback(
db: AsyncSession,
sales_log_id: uuid.UUID,
feedback: dict,
) -> None:
"""Dify Workflow 回调 → 写回 SalesLog.ai_coaching_feedback"""
await db.execute(
update(SalesLog)
.where(SalesLog.id == sales_log_id)
.values(
ai_coaching_feedback=feedback,
ai_processed=True,
updated_at=datetime.utcnow(),
)
)
# 如果反馈中包含客户健康评分,同步更新 CrmCustomer
health_score = feedback.get("health_score")
meddic_status = feedback.get("meddic_status")
if health_score is not None or meddic_status is not None:
log = (await db.execute(
select(SalesLog).where(SalesLog.id == sales_log_id)
)).scalar_one_or_none()
if log and log.customer_id:
update_vals: dict = {}
if health_score is not None:
update_vals["health_score"] = float(health_score)
if meddic_status is not None:
update_vals["meddic_status"] = meddic_status
if update_vals:
await db.execute(
update(CrmCustomer)
.where(CrmCustomer.id == log.customer_id)
.values(**update_vals)
)
await db.commit()
# ── SSE 通知流 ──────────────────────────────────────────
async def sse_notification_generator(user_id: uuid.UUID):
"""服务端推送事件流(SSE)—— 监听 Redis PubSub 频道"""
import asyncio
try:
import redis.asyncio as aioredis
import os
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
r = aioredis.from_url(redis_url, decode_responses=True)
pubsub = r.pubsub()
channel = f"notifications:{user_id}"
await pubsub.subscribe(channel)
async for message in pubsub.listen():
if message["type"] == "message":
yield f"data: {message['data']}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
+762
View File
@@ -0,0 +1,762 @@
"""
合同管理 Service 层
核心逻辑:CRUD + 一键推单 + 账期引擎 + 执行进度聚合
"""
from __future__ import annotations
import uuid
from datetime import date, datetime, timedelta
import re
from sqlalchemy import func, select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
from app.models.contract import ErpContract, ErpContractItem, ErpContractAttachment
from app.models.order import ErpOrder, ErpOrderItem
from app.models.shipping import ErpShippingRecord
from app.models.finance import FinSalesInvoice
from app.models.erp import ProductSku
from app.models.crm import CrmCustomer
from app.schemas.auth import CurrentUserPayload
from app.schemas.contract import (
ContractCreate,
ContractUpdate,
ContractItemResponse,
ContractListResponse,
ContractProgressResponse,
ContractResponse,
)
# ── 金额大写转换 ─────────────────────────────────────────
_CN_DIGITS = "零壹贰叁肆伍陆柒捌玖"
_CN_UNITS = ["", "", "", ""]
_CN_BIG_UNITS = ["", "", "亿", ""]
def amount_to_cn(amount: float) -> str:
"""将金额转为中文大写"""
if amount == 0:
return "零元整"
neg = ""
if amount < 0:
neg = ""
amount = -amount
yuan = int(amount)
jiao = int(amount * 10) % 10
fen = int(amount * 100) % 10
parts = []
if yuan > 0:
yuan_str = str(yuan)
n = len(yuan_str)
zero_flag = False
for i, ch in enumerate(yuan_str):
d = int(ch)
pos = n - 1 - i
big_idx = pos // 4
unit_idx = pos % 4
if d == 0:
zero_flag = True
if unit_idx == 0 and big_idx > 0:
parts.append(_CN_BIG_UNITS[big_idx])
else:
if zero_flag:
parts.append("")
zero_flag = False
parts.append(_CN_DIGITS[d] + _CN_UNITS[unit_idx])
if unit_idx == 0 and big_idx > 0:
parts.append(_CN_BIG_UNITS[big_idx])
parts.append("")
else:
parts.append("零元")
if jiao > 0:
parts.append(_CN_DIGITS[jiao] + "")
if fen > 0:
parts.append(_CN_DIGITS[fen] + "")
else:
if jiao == 0:
parts.append("")
return neg + "".join(parts)
# ── 生成合同编号 ─────────────────────────────────────────
async def _gen_contract_no(db: AsyncSession) -> str:
today_str = date.today().strftime("%Y%m%d")
prefix = f"HT-{today_str}-"
count_stmt = select(func.count()).select_from(ErpContract).where(
ErpContract.contract_no.like(f"{prefix}%")
)
count = (await db.execute(count_stmt)).scalar() or 0
return f"{prefix}{count + 1:03d}"
# ── 账期引擎 ────────────────────────────────────────────
def calc_payment_due_date(payment_terms: str, base_date: date) -> date | None:
"""根据付款条件枚举和基准日期(开票/发货)推算回款截止日"""
m = re.search(r"(\d+)天", payment_terms)
if m:
days = int(m.group(1))
return base_date + timedelta(days=days)
if "货到" in payment_terms or "全款" in payment_terms:
return base_date # 当天
return None
# ── ORM → Response ──────────────────────────────────────
def _item_to_response(item: ErpContractItem) -> ContractItemResponse:
sku = item.sku
return ContractItemResponse(
id=item.id,
sku_id=item.sku_id,
sku_code=sku.sku_code if sku else None,
sku_name=sku.name if sku else None,
spec=sku.spec if sku else None,
unit=sku.unit if sku else None,
qty=float(item.qty),
unit_price=float(item.unit_price),
sub_total=float(item.sub_total),
)
def _to_response(c: ErpContract, progress: ContractProgressResponse | None = None) -> ContractResponse:
return ContractResponse(
id=c.id,
contract_no=c.contract_no,
buyer_customer_id=c.buyer_customer_id,
buyer_customer_name=c.buyer_customer.name if c.buyer_customer else None,
seller_company_id=c.seller_company_id,
seller_company_name=c.seller_company.name if c.seller_company else None,
company_id=c.company_id,
total_amount_excl_tax=float(c.total_amount_excl_tax or 0),
total_amount_incl_tax=float(c.total_amount_incl_tax or 0),
total_amount_cn=c.total_amount_cn,
payment_terms=c.payment_terms,
shipping_terms=c.shipping_terms,
status=c.status,
is_signed=c.is_signed,
signed_file_url=c.signed_file_url,
linked_order_id=c.linked_order_id,
salesperson_id=c.salesperson_id,
salesperson_name=c.salesperson.real_name if c.salesperson else None,
sign_date=c.sign_date,
remark=c.remark,
delivery_terms=c.delivery_terms,
items=[_item_to_response(i) for i in (c.items or []) if not i.is_deleted],
progress=progress,
created_at=c.created_at,
updated_at=c.updated_at,
)
# ── 执行进度聚合 ────────────────────────────────────────
async def _get_progress(db: AsyncSession, contract: ErpContract) -> ContractProgressResponse:
progress = ContractProgressResponse(is_signed=contract.is_signed)
if contract.linked_order_id:
progress.has_order = True
progress.order_id = contract.linked_order_id
# 是否有发货
ship_count = (await db.execute(
select(func.count()).select_from(ErpShippingRecord).where(
ErpShippingRecord.order_id == contract.linked_order_id,
ErpShippingRecord.is_deleted.is_(False),
)
)).scalar() or 0
progress.has_shipped = ship_count > 0
# 是否有销项发票
inv_count = (await db.execute(
select(func.count()).select_from(FinSalesInvoice).where(
FinSalesInvoice.order_id == contract.linked_order_id,
FinSalesInvoice.is_deleted.is_(False),
)
)).scalar() or 0
progress.has_invoice = inv_count > 0
# 是否回款(检查订单回款状态)
order = (await db.execute(
select(ErpOrder).where(ErpOrder.id == contract.linked_order_id)
)).scalar_one_or_none()
if order and order.payment_state == "paid":
progress.is_paid = True
return progress
# ── 公共 eager-load 选项 ────────────────────────────────────
def _contract_load_options():
"""返回 selectinload 链,保证 commit 后仍可安全访问关系属性"""
return [
selectinload(ErpContract.buyer_customer),
selectinload(ErpContract.seller_company),
selectinload(ErpContract.salesperson),
selectinload(ErpContract.items).selectinload(ErpContractItem.sku),
]
# ── Service Functions ────────────────────────────────────
async def create_contract(
db: AsyncSession,
user: CurrentUserPayload,
company_id: uuid.UUID,
body: ContractCreate,
) -> ContractResponse:
contract_no = await _gen_contract_no(db)
# 计算合计
total = sum(item.sub_total for item in body.items)
contract = ErpContract(
contract_no=contract_no,
buyer_customer_id=body.buyer_customer_id,
seller_company_id=company_id,
company_id=company_id,
total_amount_excl_tax=total,
total_amount_incl_tax=total, # 含税金额默认同不含税,可后续区分
total_amount_cn=amount_to_cn(total),
payment_terms=body.payment_terms,
shipping_terms=body.shipping_terms,
sign_date=body.sign_date,
remark=body.remark,
delivery_terms=body.delivery_terms,
salesperson_id=user.user_id,
status="draft",
)
db.add(contract)
await db.flush()
# 添加明细行
for item_data in body.items:
item = ErpContractItem(
contract_id=contract.id,
sku_id=item_data.sku_id,
qty=item_data.qty,
unit_price=item_data.unit_price,
sub_total=item_data.sub_total,
)
db.add(item)
await db.commit()
# 重新查询并 eager-load 所有关系,避免 commit 后隐式 lazy load
fresh = (await db.execute(
select(ErpContract)
.where(ErpContract.id == contract.id)
.options(*_contract_load_options())
)).scalar_one()
return _to_response(fresh)
async def list_contracts(
db: AsyncSession,
company_id: uuid.UUID,
page: int = 1,
size: int = 20,
keyword: str | None = None,
status: str | None = None,
) -> ContractListResponse:
base_where = [
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
]
if keyword:
base_where.append(ErpContract.contract_no.ilike(f"%{keyword}%"))
if status:
base_where.append(ErpContract.status == status)
total = (await db.execute(
select(func.count()).select_from(ErpContract).where(*base_where)
)).scalar() or 0
stmt = (
select(ErpContract)
.where(*base_where)
.options(*_contract_load_options())
.order_by(ErpContract.created_at.desc())
.offset((page - 1) * size)
.limit(size)
)
contracts = (await db.execute(stmt)).scalars().all()
return ContractListResponse(
total=total,
items=[_to_response(c) for c in contracts],
page=page,
size=size,
)
async def get_contract(
db: AsyncSession,
contract_id: uuid.UUID,
company_id: uuid.UUID,
) -> ContractResponse:
stmt = (
select(ErpContract)
.where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
.options(*_contract_load_options())
)
contract = (await db.execute(stmt)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
progress = await _get_progress(db, contract)
return _to_response(contract, progress)
async def update_contract(
db: AsyncSession,
contract_id: uuid.UUID,
company_id: uuid.UUID,
body: ContractUpdate,
) -> ContractResponse:
stmt = select(ErpContract).where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
contract = (await db.execute(stmt)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
# 更新主表字段
update_data = body.model_dump(exclude_unset=True, exclude={"items"})
if update_data:
update_data["updated_at"] = datetime.utcnow()
await db.execute(
update(ErpContract).where(ErpContract.id == contract_id).values(**update_data)
)
# 如果有明细行更新,删旧增新
if body.items is not None:
await db.execute(
update(ErpContractItem)
.where(ErpContractItem.contract_id == contract_id)
.values(is_deleted=True)
)
total = 0
for item_data in body.items:
item = ErpContractItem(
contract_id=contract_id,
sku_id=item_data.sku_id,
qty=item_data.qty,
unit_price=item_data.unit_price,
sub_total=item_data.sub_total,
)
total += item_data.sub_total
db.add(item)
await db.execute(
update(ErpContract).where(ErpContract.id == contract_id).values(
total_amount_excl_tax=total,
total_amount_incl_tax=total,
total_amount_cn=amount_to_cn(total),
)
)
await db.commit()
updated = (await db.execute(
select(ErpContract)
.where(ErpContract.id == contract_id)
.options(*_contract_load_options())
)).scalar_one()
return _to_response(updated)
async def delete_contract(
db: AsyncSession,
contract_id: uuid.UUID,
company_id: uuid.UUID,
) -> None:
stmt = select(ErpContract).where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
contract = (await db.execute(stmt)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
await db.execute(
update(ErpContract)
.where(ErpContract.id == contract_id)
.values(is_deleted=True, updated_at=datetime.utcnow())
)
await db.commit()
async def generate_order_from_contract(
db: AsyncSession,
user: CurrentUserPayload,
contract_id: uuid.UUID,
company_id: uuid.UUID,
) -> dict:
"""一键从合同生成订单 —— 防篡改推单逻辑"""
stmt = (
select(ErpContract)
.where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
.options(*_contract_load_options())
)
contract = (await db.execute(stmt)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
if contract.linked_order_id is not None:
raise BizException(message="该合同已关联订单,不可重复生成")
# 生成订单编号
today_str = date.today().strftime("%Y%m%d")
prefix = f"SO-{today_str}-"
count = (await db.execute(
select(func.count()).select_from(ErpOrder).where(
ErpOrder.order_no.like(f"{prefix}%")
)
)).scalar() or 0
order_no = f"{prefix}{count + 1:03d}"
# 创建订单
new_order = ErpOrder(
order_no=order_no,
customer_id=contract.buyer_customer_id,
salesperson_id=user.user_id,
company_id=company_id,
contract_id=contract_id,
total_amount=float(contract.total_amount_incl_tax or 0),
order_date=date.today(),
)
db.add(new_order)
await db.flush()
# 复制合同明细到订单明细
active_items = [i for i in (contract.items or []) if not i.is_deleted]
for ci in active_items:
oi = ErpOrderItem(
order_id=new_order.id,
sku_id=ci.sku_id,
qty=float(ci.qty),
unit_price=float(ci.unit_price),
sub_total=float(ci.sub_total),
)
db.add(oi)
# 回填合同 linked_order_id + 激活状态
await db.execute(
update(ErpContract)
.where(ErpContract.id == contract_id)
.values(
linked_order_id=new_order.id,
status="active",
updated_at=datetime.utcnow(),
)
)
await db.commit()
return {"order_id": str(new_order.id), "order_no": order_no}
# ── 数字转中文大写金额 ──────────────────────────────────────
def _amount_to_cn(amount: float) -> str:
"""将数字金额转换为中文大写"""
digits = "零壹贰叁肆伍陆柒捌玖"
units = ["", "", "", ""]
big_units = ["", "", "亿"]
if amount == 0:
return "零元整"
yuan = int(round(amount * 100))
jiao = (yuan % 100) // 10
fen = yuan % 10
yuan_part = yuan // 100
result = ""
if yuan_part > 0:
s = str(yuan_part)
n = len(s)
for i, ch in enumerate(s):
d = int(ch)
pos = n - i - 1
big_pos = pos // 4
unit_pos = pos % 4
if d != 0:
result += digits[d] + units[unit_pos]
else:
if result and not result.endswith(""):
result += ""
if unit_pos == 0 and big_pos > 0:
result = result.rstrip("") + big_units[big_pos]
result = result.rstrip("") + ""
else:
result = ""
if jiao == 0 and fen == 0:
result += ""
else:
if jiao > 0:
result += digits[jiao] + ""
if fen > 0:
result += digits[fen] + ""
return result
async def generate_contract_docx(
db: AsyncSession,
contract_id: uuid.UUID,
company_id: uuid.UUID,
) -> bytes:
"""纯代码生成合同 Word 文档(紧凑排版,2 页以内)"""
import io
from docx import Document as DocxDocument
from docx.shared import Pt, Cm, Emu, RGBColor
from docx.enum.table import WD_TABLE_ALIGNMENT
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.oxml.ns import qn
from app.models.sys import SysCompany
# ── 1) 数据准备 ─────────────────────────────────────────
contract = (await db.execute(
select(ErpContract)
.where(
ErpContract.id == contract_id,
ErpContract.company_id == company_id,
ErpContract.is_deleted.is_(False),
)
.options(*_contract_load_options())
)).scalar_one_or_none()
if contract is None:
raise NotFoundException("合同不存在")
seller = (await db.execute(
select(SysCompany).where(SysCompany.id == contract.seller_company_id)
)).scalar_one_or_none()
seller_info = (seller.full_info or {}) if seller else {}
buyer = contract.buyer_customer
buyer_billing = {}
if buyer and hasattr(buyer, "billing_info") and buyer.billing_info:
buyer_billing = buyer.billing_info
total_incl = float(contract.total_amount_incl_tax or 0)
sign_date_str = (contract.sign_date or date.today()).strftime("%Y年%m月%d")
buyer_name = buyer_billing.get("company_name") or (buyer.name if buyer else "")
seller_name = seller_info.get("company_name") or (seller.name if seller else "")
items = [i for i in (contract.items or []) if not i.is_deleted]
# ── 2) 创建文档 ─────────────────────────────────────────
doc = DocxDocument()
# 页边距:上下2cm 左右2.5cm(紧凑)
for section in doc.sections:
section.top_margin = Cm(2)
section.bottom_margin = Cm(1.5)
section.left_margin = Cm(2.5)
section.right_margin = Cm(2.5)
# ── 辅助函数 ─────────────────────────────────────────────
# 小四 = 12pt, 1.5倍行距 = 18pt
def add_para(text: str, font_size: int = 12, bold: bool = False,
align=WD_ALIGN_PARAGRAPH.LEFT, space_before: int = 0,
space_after: int = 0, font_name: str = "宋体"):
p = doc.add_paragraph()
p.alignment = align
p.paragraph_format.space_before = Pt(space_before)
p.paragraph_format.space_after = Pt(space_after)
p.paragraph_format.line_spacing = Pt(18) # 1.5倍行距(12pt×1.5)
run = p.add_run(text)
run.font.size = Pt(font_size)
run.font.bold = bold
run.font.name = font_name
run._element.rPr.rFonts.set(qn("w:eastAsia"), font_name)
return p
def set_cell(cell, text: str, font_size: int = 12, bold: bool = False,
align=WD_ALIGN_PARAGRAPH.CENTER):
cell.text = ""
p = cell.paragraphs[0]
p.alignment = align
p.paragraph_format.space_before = Pt(0)
p.paragraph_format.space_after = Pt(0)
p.paragraph_format.line_spacing = Pt(18) # 1.5倍行距
run = p.add_run(text)
run.font.size = Pt(font_size)
run.font.bold = bold
run.font.name = "宋体"
run._element.rPr.rFonts.set(qn("w:eastAsia"), "宋体")
# ── 3) 标题 ──────────────────────────────────────────────
add_para("产 品 购 销 合 同", font_size=18, bold=True,
align=WD_ALIGN_PARAGRAPH.CENTER, space_after=4, font_name="黑体")
add_para(f"合同编号:{contract.contract_no}",
align=WD_ALIGN_PARAGRAPH.RIGHT, space_after=4)
# ── 4) 甲乙方信息(紧凑表格) ────────────────────────────
info_tbl = doc.add_table(rows=4, cols=4)
info_tbl.alignment = WD_TABLE_ALIGNMENT.CENTER
info_tbl.style = "Table Grid"
info_data = [
("买方(甲方)", buyer_name,
"卖方(乙方)", seller_name),
("税号", buyer_billing.get("tax_id", "") or "",
"税号", seller_info.get("tax_id", "") or ""),
("地址", buyer_billing.get("address", "") or "",
"地址", seller_info.get("address", "") or ""),
("开户行 / 账号",
f"{buyer_billing.get('bank_name', '') or ''} {buyer_billing.get('bank_account', '') or ''}".strip(),
"开户行 / 账号",
f"{seller_info.get('bank_name', '') or ''} {seller_info.get('bank_account', '') or ''}".strip()),
]
for ri, row_data in enumerate(info_data):
for ci, val in enumerate(row_data):
bold = ri == 0 and ci in (0, 2)
set_cell(info_tbl.cell(ri, ci), val, bold=bold,
align=WD_ALIGN_PARAGRAPH.LEFT)
# ── 5) 一、产品明细 ──────────────────────────────────────
add_para("一、产品明细", bold=True, space_before=6, space_after=2)
cols = 6
tbl = doc.add_table(rows=1 + len(items) + 1, cols=cols)
tbl.alignment = WD_TABLE_ALIGNMENT.CENTER
tbl.style = "Table Grid"
headers = ["序号", "产品名称", "规格", "数量", "单价(元)", "小计(元)"]
for ci, h in enumerate(headers):
set_cell(tbl.cell(0, ci), h, bold=True)
for ri, item in enumerate(items):
sku_name = item.sku.name if item.sku else ""
sku_spec = item.sku.spec if item.sku else ""
set_cell(tbl.cell(ri + 1, 0), str(ri + 1))
set_cell(tbl.cell(ri + 1, 1), sku_name, align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(tbl.cell(ri + 1, 2), sku_spec or "-")
set_cell(tbl.cell(ri + 1, 3), str(float(item.qty)))
set_cell(tbl.cell(ri + 1, 4), f"{float(item.unit_price):,.2f}",
align=WD_ALIGN_PARAGRAPH.RIGHT)
set_cell(tbl.cell(ri + 1, 5), f"{float(item.sub_total):,.2f}",
align=WD_ALIGN_PARAGRAPH.RIGHT)
# 合计行
last_row = len(items) + 1
set_cell(tbl.cell(last_row, 0), "合计", bold=True)
# 合并序号~单价列
for ci in range(1, 4):
set_cell(tbl.cell(last_row, ci), "")
set_cell(tbl.cell(last_row, 4), "", align=WD_ALIGN_PARAGRAPH.RIGHT)
set_cell(tbl.cell(last_row, 5), f"{total_incl:,.2f}", bold=True,
align=WD_ALIGN_PARAGRAPH.RIGHT)
# 大写金额
add_para(f"合计金额(大写):{_amount_to_cn(total_incl)} (含13%增值税)",
bold=True, space_before=2, space_after=2)
# ── 6) 二、交货及付款条件 ────────────────────────────────
add_para("二、交货及付款条件", bold=True, space_before=4, space_after=2)
delivery_text = contract.delivery_terms or "按双方约定"
add_para(f"1. 货 期:{delivery_text}")
add_para(f"2. 交货方式:{contract.shipping_terms or '买方自提'}")
add_para(f"3. 付款条件:{contract.payment_terms or '货到付全款'}")
# ── 7) 三、发票信息 ──────────────────────────────────────
add_para("三、发票信息", bold=True, space_before=4, space_after=2)
add_para("卖方给买方开具合同金额增值税专用发票(13%增值税)。")
# ── 8) 四、合同细则 ──────────────────────────────────────
add_para("四、合同细则", bold=True, space_before=4, space_after=2)
# 紧凑输出细则内容
terms = [
"第一条 质量标准:按照厂家标准执行,由于买方储存不当(如露天暴晒、混入杂质、超过保质期等)或未按产品说明书操作导致的质量问题,卖方不承担责任。",
"第二条 卖方对质量负责的条件及期限:自货到12个月。",
"第三条 包装标准包装物的供应与回收:产品包装均应采用国家或专业标准保护措施进行包装,以确保产品不受损害为原则,由于包装不善所引起的货物污染、损坏、损失均由卖方负担,采取装箱包装的应在包装箱内附一份详细装箱单和质量合格证,包装物不回收。",
"第四条 合理损耗标准及计算方法:标的货物送至买方指定地点前的合理损耗由卖方负责。",
"第五条 标的物所有权:在买方付清本合同项下全部货款之前,标的物的所有权仍属于卖方。",
"第六条 检验标准、方法、地点及期限:按第二条标准检验。",
"第七条 发票信息:卖方给买方开具合同金额增值税专用发票(13%增值税)。",
"第八条 本合同解除条件:合同执行完毕。",
(
"第九条 违约责任:\n"
"1、卖方应保证产品质量合格,买方有权在货到后7个工作日内且未开封状态下将卖方产品送质监局或第三方部门检验单位检验,"
"送检样品的取样过程必须经卖方现场确认或双方共同封样,否则检验结果无效。检验结果不合格,则所发生的所有检验费用,"
"均由卖方承担,买方可根据实际情况选择要求退货或更换。\n"
"赔偿限额:卖方对本合同项下违约责任的赔偿总额,以本合同约定的总货款金额为限,"
"且不承担任何间接损失(包括但不限于停工损失、利润损失等)。"
),
(
"第十条 合同争议的解决方式:本合同在履行过程中发生的争执,由双方当事人协商解决,"
"也可由当地工商行政管理部门调解;协商或调解不成的,按下列第二种方式解决。\n"
"(一)提交当地仲裁委员会仲裁;(二)依法向卖方所在地的人民法院起诉。"
),
"第十一条 本合同一式两份,自双方签字盖章起生效。",
(
"第十二条 其他约定事项:\n"
"1、卖方必须遵守国家有关能源管理的法律、法规;\n"
"2、卖方必须执行买方对其提出的对能源控制进行改善的要求;\n"
"3、卖方在运输途中和施工作业中的各种行为不应对能源造成浪费或负面影响;\n"
"4、如卖方提供货物存在质量问题,买方书面(包括但不限于传真、邮件)通知对方,"
"卖方在接到买方书面通知后3个工作日内要给与买方书面回复,否则将视为卖方已经认可买方提出的质量问题;"
"如果双方意见产生争议,由卖方负责安排经买方同意的第三方进行检验,否则视为卖方质量问题;\n"
"5、未经对方书面同意,不得将合同部分或者全部权利义务转给第三方。\n"
"6、如遇战争、原材料短缺、工厂停产、物流管制等不可抗力因素导致货期延长,卖方不承担违约责任。"
),
]
for term in terms:
add_para(term)
# ── 9) 签章区 ────────────────────────────────────────────
add_para("", space_before=6, space_after=0) # 小间距
sig_tbl = doc.add_table(rows=4, cols=2)
sig_tbl.alignment = WD_TABLE_ALIGNMENT.CENTER
# 去边框
for row in sig_tbl.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
paragraph.paragraph_format.space_before = Pt(0)
paragraph.paragraph_format.space_after = Pt(0)
set_cell(sig_tbl.cell(0, 0), "买方(盖章):", bold=True,
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(0, 1), "卖方(盖章):", bold=True,
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(1, 0), "授权代表签字:",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(1, 1), "授权代表签字:",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(2, 0), f"日期:{sign_date_str}",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(2, 1), f"日期:{sign_date_str}",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(3, 0), f"联系电话:{buyer_billing.get('phone', '') or ''}",
align=WD_ALIGN_PARAGRAPH.LEFT)
set_cell(sig_tbl.cell(3, 1), f"联系电话:{seller_info.get('phone', '') or ''}",
align=WD_ALIGN_PARAGRAPH.LEFT)
# ── 10) 输出 ─────────────────────────────────────────────
buffer = io.BytesIO()
doc.save(buffer)
buffer.seek(0)
return buffer.getvalue()
+116 -19
View File
@@ -13,6 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
from app.models.crm import CrmCustomer
from app.models.sys import SysUser
from app.schemas.auth import CurrentUserPayload
from app.schemas.crm import (
CustomerCreate,
@@ -35,6 +36,7 @@ def _to_response(c: CrmCustomer) -> CustomerResponse:
address=c.address,
ai_score=float(c.ai_score or 0),
ai_persona=c.ai_persona,
billing_info=c.billing_info,
owner_id=c.owner_id,
owner_name=c.owner.real_name if c.owner else None,
status=c.status,
@@ -44,12 +46,48 @@ def _to_response(c: CrmCustomer) -> CustomerResponse:
)
# ── 递归查询本部门 + 子部门所有用户 ID ────────────────────
async def _get_dept_and_sub_user_ids(
db: AsyncSession, dept_id: uuid.UUID
) -> list[uuid.UUID]:
"""递归获取指定部门及其所有子部门下的用户 ID 列表"""
from app.models.sys import SysDepartment, SysUser
# 收集所有目标部门 ID(递归子部门)
dept_ids: list[uuid.UUID] = [dept_id]
queue = [dept_id]
while queue:
current = queue.pop(0)
children = (await db.execute(
select(SysDepartment.id).where(
SysDepartment.parent_id == current,
SysDepartment.is_deleted.is_(False),
)
)).scalars().all()
for child_id in children:
dept_ids.append(child_id)
queue.append(child_id)
# 查询这些部门下的所有用户 ID
user_ids = (await db.execute(
select(SysUser.id).where(
SysUser.dept_id.in_(dept_ids),
SysUser.is_deleted.is_(False),
)
)).scalars().all()
return list(user_ids)
# ── 权限校验 ─────────────────────────────────────────────
def _check_access(customer: CrmCustomer, user: CurrentUserPayload) -> None:
def _check_access(customer: CrmCustomer, user: CurrentUserPayload, *, dept_user_ids: list[uuid.UUID] | None = None) -> None:
if user.data_scope == "all":
return
if user.data_scope == "dept_and_sub":
return # 简化版:放通本部门
# 如果有预查询的部门用户列表,校验 owner 是否在列表内
if dept_user_ids is not None:
if customer.owner_id not in dept_user_ids:
raise ForbiddenException("无权访问该客户(数据权限:本部门及子部门)")
return
# data_scope == 'self'
if customer.owner_id != user.user_id:
raise ForbiddenException("无权访问该客户(数据权限:仅本人)")
@@ -70,6 +108,7 @@ async def create_customer(
phone=body.phone,
email=body.email,
address=body.address,
billing_info=body.billing_info.model_dump() if body.billing_info else None,
status=body.status,
owner_id=user.user_id,
)
@@ -98,12 +137,12 @@ async def list_customers(
base_where.append(CrmCustomer.owner_id == user.user_id)
elif user.data_scope == "dept_and_sub":
if user.dept_id is not None:
from app.models.sys import SysUser
sub = select(SysUser.id).where(
SysUser.dept_id == user.dept_id,
SysUser.is_deleted.is_(False),
)
base_where.append(CrmCustomer.owner_id.in_(sub))
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
if dept_user_ids:
base_where.append(CrmCustomer.owner_id.in_(dept_user_ids))
else:
# 部门无用户 → 仅显示自己的
base_where.append(CrmCustomer.owner_id == user.user_id)
if keyword:
base_where.append(CrmCustomer.name.ilike(f"%{keyword}%"))
@@ -144,7 +183,11 @@ async def get_customer(
if customer is None:
raise NotFoundException("客户不存在或已被删除")
_check_access(customer, user)
# dept_and_sub 需要先查询部门用户列表
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
return _to_response(customer)
@@ -162,7 +205,10 @@ async def update_customer(
if customer is None:
raise NotFoundException("客户不存在或已被删除")
_check_access(customer, user)
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
update_data = body.model_dump(exclude_unset=True)
if not update_data:
@@ -193,7 +239,10 @@ async def delete_customer(
if customer is None:
raise NotFoundException("客户不存在或已被删除")
_check_access(customer, user)
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
await db.execute(
update(CrmCustomer)
@@ -216,7 +265,10 @@ async def restore_customer(
if customer is None:
raise NotFoundException("客户不存在或未被归档")
_check_access(customer, user)
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
await db.execute(
update(CrmCustomer)
@@ -226,6 +278,49 @@ async def restore_customer(
await db.commit()
async def transfer_customer(
db: AsyncSession,
user: CurrentUserPayload,
customer_id: uuid.UUID,
new_owner_id: uuid.UUID,
) -> CustomerResponse:
"""将客户转移至指定人员名下(仅管理员)"""
if user.data_scope != "all":
raise ForbiddenException("仅管理员可执行客户转移操作")
stmt = select(CrmCustomer).where(
CrmCustomer.id == customer_id,
CrmCustomer.is_deleted.is_(False),
)
customer = (await db.execute(stmt)).scalar_one_or_none()
if customer is None:
raise NotFoundException("客户不存在或已被归档")
if customer.owner_id == new_owner_id:
raise BizException(message="目标负责人与当前负责人相同,无需转移")
# 校验目标用户是否存在
from app.models.sys import SysUser
target = (await db.execute(
select(SysUser).where(SysUser.id == new_owner_id)
)).scalar_one_or_none()
if target is None:
raise NotFoundException("目标负责人不存在")
old_owner_name = customer.owner.real_name if customer.owner else "(无)"
await db.execute(
update(CrmCustomer)
.where(CrmCustomer.id == customer_id)
.values(owner_id=new_owner_id, updated_at=datetime.utcnow())
)
await db.commit()
await db.refresh(customer)
print(f"[客户转移] {customer.name}: {old_owner_name}{target.real_name} (操作人: {user.real_name})")
return _to_response(customer)
async def get_customer_products(
db: AsyncSession,
user: CurrentUserPayload,
@@ -241,7 +336,10 @@ async def get_customer_products(
customer = (await db.execute(stmt)).scalar_one_or_none()
if customer is None:
raise NotFoundException("客户不存在")
_check_access(customer, user)
dept_user_ids = None
if user.data_scope == "dept_and_sub" and user.dept_id:
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
_check_access(customer, user, dept_user_ids=dept_user_ids)
# 聚合: 该客户所有订单中的 SKU,含总数量、最近下单时间
agg_stmt = (
@@ -299,12 +397,11 @@ async def search_customers(
base_where.append(CrmCustomer.owner_id == user.user_id)
elif user.data_scope == "dept_and_sub":
if user.dept_id is not None:
from app.models.sys import SysUser
sub = select(SysUser.id).where(
SysUser.dept_id == user.dept_id,
SysUser.is_deleted.is_(False),
)
base_where.append(CrmCustomer.owner_id.in_(sub))
dept_user_ids = await _get_dept_and_sub_user_ids(db, user.dept_id)
if dept_user_ids:
base_where.append(CrmCustomer.owner_id.in_(dept_user_ids))
else:
base_where.append(CrmCustomer.owner_id == user.user_id)
# 模糊搜索(名称 / 联系人 / 电话)
from sqlalchemy import or_
+11 -2
View File
@@ -9,6 +9,7 @@ from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
from app.models.finance import FinExpenseDetail, FinExpenseRecord, FinInvoicePool
from app.models.sys import SysUser
from app.schemas.auth import CurrentUserPayload
from app.schemas.finance import (
ExpenseBriefResponse, ExpenseCreate, ExpenseDetailResponse,
@@ -84,12 +85,13 @@ async def _release_invoices(db: AsyncSession, expense_id: uuid.UUID, now: dateti
# ── Service Functions ────────────────────────────────────
async def create_invoice(db: AsyncSession, user: CurrentUserPayload, body: InvoiceCreate) -> InvoiceResponse:
async def create_invoice(db: AsyncSession, user: CurrentUserPayload, body: InvoiceCreate, company_id: uuid.UUID) -> InvoiceResponse:
invoice = FinInvoicePool(
uploader_id=user.user_id, file_url=body.file_url,
merchant_name=body.merchant_name, amount=body.amount,
invoice_date=body.invoice_date, type=body.type,
ai_extracted_data=body.ai_extracted_data, is_used=False,
company_id=company_id,
)
db.add(invoice)
await db.commit()
@@ -101,8 +103,11 @@ async def list_invoices(
db: AsyncSession, user: CurrentUserPayload,
page: int = 1, size: int = 20,
inv_type: str | None = None, is_used: bool | None = None,
company_id: uuid.UUID | None = None,
) -> InvoiceListResponse:
where = [FinInvoicePool.is_deleted.is_(False)]
if company_id:
where.append(FinInvoicePool.company_id == company_id)
if user.data_scope == "self":
where.append(FinInvoicePool.uploader_id == user.user_id)
elif user.data_scope == "dept_and_sub":
@@ -135,7 +140,7 @@ async def void_invoice(db: AsyncSession, user: CurrentUserPayload, invoice_id: u
await db.commit()
async def create_expense(db: AsyncSession, user: CurrentUserPayload, body: ExpenseCreate) -> ExpenseResponse:
async def create_expense(db: AsyncSession, user: CurrentUserPayload, body: ExpenseCreate, company_id: uuid.UUID) -> ExpenseResponse:
invoice_ids = [item.invoice_id for item in body.items]
try:
async with db.begin_nested():
@@ -154,6 +159,7 @@ async def create_expense(db: AsyncSession, user: CurrentUserPayload, body: Expen
system_no = await _generate_expense_no(db)
expense = FinExpenseRecord(
system_no=system_no, applicant_id=user.user_id,
company_id=company_id,
total_amount=body.total_amount, status="submitted", remark=body.remark,
)
db.add(expense)
@@ -184,8 +190,11 @@ async def list_expenses(
db: AsyncSession, user: CurrentUserPayload,
page: int = 1, size: int = 20,
status: str | None = None, applicant_id: uuid.UUID | None = None,
company_id: uuid.UUID | None = None,
) -> ExpenseListResponse:
where = [FinExpenseRecord.is_deleted.is_(False)]
if company_id:
where.append(FinExpenseRecord.company_id == company_id)
if user.data_scope == "self":
where.append(FinExpenseRecord.applicant_id == user.user_id)
elif user.data_scope == "dept_and_sub":
+211
View File
@@ -0,0 +1,211 @@
"""
发票结构化解析器 — OFD / XML 零算力提取
OFD 文件本质是 ZIP 包含 XML,直接解包提取发票字段。
XML 电子发票(数电票)直接 XPath 提取。
"""
from __future__ import annotations
import io
import os
import re
import zipfile
from xml.etree import ElementTree as ET
from typing import Optional
def parse_ofd_invoice(file_bytes: bytes) -> dict:
"""
解析 OFD 电子发票文件。
OFD = ZIP 压缩包,内含 XML 描述文件。
提取发票关键字段,返回结构化 dict。
"""
result: dict = {}
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
# 收集所有 XML 内容
all_text = ""
for name in zf.namelist():
if name.endswith(".xml"):
try:
xml_bytes = zf.read(name)
xml_text = xml_bytes.decode("utf-8", errors="replace")
all_text += xml_text + "\n"
# 尝试从 XML 标签中提取结构化数据
extracted = _extract_from_xml_text(xml_text)
if extracted:
result.update(extracted)
except Exception:
continue
# 如果解析出了字段就直接返回
if result.get("merchant") or result.get("amount"):
return {"success": True, "data": result}
# 降级:把所有 XML 文本当纯文本返回,交给 LLM 处理
if all_text.strip():
return {"success": True, "data": {"raw_text": all_text[:8000]}, "needs_llm": True}
return {"success": False, "data": {}, "error": "OFD 文件中未找到有效 XML 内容"}
except zipfile.BadZipFile:
return {"success": False, "data": {}, "error": "OFD 文件格式损坏或不是有效的 OFD 文件"}
except Exception as e:
return {"success": False, "data": {}, "error": f"OFD 解析失败: {e}"}
def parse_xml_invoice(file_bytes: bytes) -> dict:
"""
解析 XML 格式电子发票(数电票)。
直接从 XML 标签提取所有发票字段。
"""
try:
xml_text = file_bytes.decode("utf-8", errors="replace")
result = _extract_from_xml_text(xml_text)
if result and (result.get("merchant") or result.get("amount")):
return {"success": True, "data": result}
# 降级:XML 结构未匹配预设标签,交给 LLM
if xml_text.strip():
return {"success": True, "data": {"raw_text": xml_text[:8000]}, "needs_llm": True}
return {"success": False, "data": {}, "error": "XML 文件内容为空"}
except Exception as e:
return {"success": False, "data": {}, "error": f"XML 解析失败: {e}"}
def parse_zip_invoices(file_bytes: bytes) -> list[dict]:
"""
解析 ZIP 压缩包中的所有 XML 发票文件。
返回列表,每个元素 = {"filename": str, "success": bool, "data": dict, ...}
支持系统导出的 ZIP 格式(内含多个 XML 发票)。
"""
results: list[dict] = []
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
xml_names = [n for n in zf.namelist() if n.lower().endswith(".xml")]
if not xml_names:
return [{"filename": "(zip)", "success": False, "data": {}, "error": "ZIP 包中未找到 XML 文件"}]
for name in xml_names:
try:
xml_bytes = zf.read(name)
result = parse_xml_invoice(xml_bytes)
result["filename"] = os.path.basename(name)
results.append(result)
except Exception as e:
results.append({"filename": os.path.basename(name), "success": False, "data": {}, "error": str(e)})
except zipfile.BadZipFile:
return [{"filename": "(zip)", "success": False, "data": {}, "error": "不是有效的 ZIP 文件"}]
except Exception as e:
return [{"filename": "(zip)", "success": False, "data": {}, "error": f"ZIP 解析失败: {e}"}]
return results
# ── 内部工具函数 ──────────────────────────────────────
# 常见发票 XML 标签名映射(兼容多种数电票 XML 格式)
_FIELD_PATTERNS = {
"merchant": [
"SalesName", "SellerName", "销售方名称", "销方名称",
"开票方", "Seller", "salername", "xfmc",
],
"buyer": [
"BuyerName", "PurchaserName", "购买方名称", "购方名称",
"Buyer", "buyername", "gfmc",
],
"amount": [
"TotalAmount", "Amount", "InvoiceAmount", "金额",
"合计金额", "价税合计", "jshj", "hjje",
],
"tax_amount": [
"TotalTax", "TaxAmount", "Tax", "税额",
"合计税额", "hjse",
],
"date": [
"IssueDate", "InvoiceDate", "BillingDate", "开票日期",
"kprq",
],
"invoice_code": [
"InvoiceCode", "发票代码", "fpdm",
],
"invoice_number": [
"InvoiceNumber", "InvoiceNo", "发票号码", "fphm",
],
"items": [
"GoodsName", "ItemName", "商品名称", "货物名称", "spmc",
],
"tax_rate": [
"TaxRate", "税率", "sl",
],
"remark": [
"Remark", "备注", "bz",
],
}
def _extract_from_xml_text(xml_text: str) -> Optional[dict]:
"""从 XML 文本中用多种策略提取发票字段。"""
result: dict = {}
# 策略 1: 正则匹配 <TagName>Value</TagName> 格式
for field, tag_names in _FIELD_PATTERNS.items():
for tag in tag_names:
# 匹配 <Tag>value</Tag> 或 <ns:Tag>value</ns:Tag>
pattern = rf'<(?:\w+:)?{re.escape(tag)}[^>]*>([^<]+)</(?:\w+:)?{re.escape(tag)}>'
match = re.search(pattern, xml_text, re.IGNORECASE)
if match:
value = match.group(1).strip()
if value:
# 数字字段转数值
if field in ("amount", "tax_amount"):
try:
result[field] = float(value)
except ValueError:
result[field] = value
else:
result[field] = value
break # 找到一个就跳到下一个字段
# 策略 2: 尝试 ElementTree 解析
if not result:
try:
# 移除 XML 声明中可能的编码问题
cleaned = re.sub(r'<\?xml[^?]*\?>', '', xml_text).strip()
if cleaned:
root = ET.fromstring(cleaned)
_extract_from_element(root, result)
except ET.ParseError:
pass
return result if result else None
def _extract_from_element(elem: ET.Element, result: dict, depth: int = 0):
"""递归遍历 XML 元素树提取字段。"""
if depth > 10:
return
tag_local = elem.tag.split("}")[-1] if "}" in elem.tag else elem.tag
for field, tag_names in _FIELD_PATTERNS.items():
if field not in result:
for tn in tag_names:
if tag_local.lower() == tn.lower():
text = (elem.text or "").strip()
if text:
if field in ("amount", "tax_amount"):
try:
result[field] = float(text)
except ValueError:
result[field] = text
else:
result[field] = text
break
for child in elem:
_extract_from_element(child, result, depth + 1)
+22 -28
View File
@@ -72,11 +72,12 @@ async def ocr_image(
"messages": [
{
"role": "user",
"content": "/no_think\n" + prompt,
"content": prompt,
"images": [image_base64], # Ollama vision 格式
},
],
"stream": False,
"think": False, # 关闭思考模式:稳定输出、避免死循环、提速 2-5x
"options": {
"temperature": 0.1,
"num_predict": 2000,
@@ -87,19 +88,18 @@ async def ocr_image(
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(url, json=payload)
if resp.status_code != 200:
print(f"[OCR] 3090 返回 {resp.status_code}: {resp.text[:200]}")
return {"success": False, "data": {}, "error": f"VL 模型返回 {resp.status_code}"}
detail = resp.text[:200]
print(f"[OCR] 3090 返回 {resp.status_code}: {detail}")
if "model runner" in detail:
return {"success": False, "data": {}, "error": "AI OCR 模型进程崩溃,请联系管理员重启 Ollama 服务"}
return {"success": False, "data": {}, "error": f"AI OCR 服务异常 (HTTP {resp.status_code}),请稍后重试"}
data = resp.json()
# Qwen3.5 的 CoT 推理放在 message.thinking,最终结果在 message.content
content = data.get("message", {}).get("content", "")
thinking = data.get("message", {}).get("thinking", "")
# 优先从 content 提取 JSON,回退到 thinking
for text_source in [content, thinking]:
if not text_source:
continue
cleaned = re.sub(r'<think>.*?</think>', '', text_source, flags=re.DOTALL).strip()
# 关闭思考模式后,结果直接在 content(无 thinking 字段)
if content:
cleaned = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
json_match = re.search(r'\{[\s\S]*\}', cleaned)
if json_match:
try:
@@ -107,16 +107,14 @@ async def ocr_image(
print(f"[OCR] 解析成功: {list(result.keys())}")
return {"success": True, "data": result}
except json.JSONDecodeError:
continue
pass
# 没有提取 JSON,返回原始文本
raw = content or thinking
print(f"[OCR] 未能提取 JSON, 内容长度: content={len(content)}, thinking={len(thinking)}")
return {"success": True, "data": {"raw_text": raw[:2000]}}
print(f"[OCR] 未能提取 JSON, content 长度: {len(content)}")
return {"success": True, "data": {"raw_text": content[:2000]}}
except httpx.TimeoutException:
print("[OCR] 3090 超时(60s")
return {"success": False, "data": {}, "error": "VL 模型响应超时"}
print("[OCR] 3090 超时(120s")
return {"success": False, "data": {}, "error": "AI OCR 响应超时(120s),模型可能负载过高,请稍后重试"}
except json.JSONDecodeError as e:
print(f"[OCR] JSON 解析失败: {e}")
return {"success": False, "data": {}, "error": f"JSON 解析失败: {e}"}
@@ -172,11 +170,11 @@ async def extract_invoice_from_text(
"messages": [
{
"role": "user",
"content": f"/no_think\n{prompt}\n\n--- 以下是发票文本内容 ---\n\n{truncated}",
# 不传 images —— 纯文本模式
"content": f"{prompt}\n\n--- 以下是发票文本内容 ---\n\n{truncated}",
},
],
"stream": False,
"think": False, # 关闭思考模式
"options": {
"temperature": 0.1,
"num_predict": 2000,
@@ -192,12 +190,9 @@ async def extract_invoice_from_text(
data = resp.json()
content = data.get("message", {}).get("content", "")
thinking = data.get("message", {}).get("thinking", "")
for text_source in [content, thinking]:
if not text_source:
continue
cleaned = re.sub(r'<think>.*?</think>', '', text_source, flags=re.DOTALL).strip()
if content:
cleaned = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
json_match = re.search(r'\{[\s\S]*\}', cleaned)
if json_match:
try:
@@ -205,11 +200,10 @@ async def extract_invoice_from_text(
print(f"[TextExtract] AI 提取成功: {list(result.keys())}")
return {"success": True, "data": result}
except json.JSONDecodeError:
continue
pass
raw = content or thinking
print(f"[TextExtract] 未能提取 JSON, 内容: {raw[:200]}")
return {"success": True, "data": {"raw_text": raw[:2000]}}
print(f"[TextExtract] 未能提取 JSON, content: {content[:200]}")
return {"success": True, "data": {"raw_text": content[:2000]}}
except httpx.TimeoutException:
print("[TextExtract] 3090 超时")
+266
View File
@@ -0,0 +1,266 @@
"""
OCR 后台 Worker — asyncio 协程,FastAPI lifespan 启动
策略 C: 工作时间限流(1并发 + 60s间隔)17:00-20:00 BJT 全速
"""
from __future__ import annotations
import asyncio
import os
import uuid
from datetime import datetime, timedelta
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import async_session_factory
from app.models.finance import FinInvoicePool, FinOcrTask
class OcrWorker:
"""后台 OCR 任务处理器"""
def __init__(self):
self.running = False
self.current_task_id: uuid.UUID | None = None
self._task: asyncio.Task | None = None
def start(self):
self.running = True
self._task = asyncio.create_task(self._run_loop())
print("[OcrWorker] 启动 — 策略 C: 工作时间限流, 17-20 BJT 全速")
async def stop(self):
self.running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
print("[OcrWorker] 已停止")
async def _run_loop(self):
"""主循环:每 10 秒检查一次队列"""
while self.running:
try:
task = await self._pick_next_task()
if task:
await self._process_task(task)
# 限流:非高峰期间隔 60s
if not self._is_peak_time():
await asyncio.sleep(60)
else:
await asyncio.sleep(5)
else:
await asyncio.sleep(10)
except asyncio.CancelledError:
break
except Exception as e:
print(f"[OcrWorker] 循环异常: {e}")
await asyncio.sleep(30)
def _is_peak_time(self) -> bool:
"""17:00-20:00 BJT = 09:00-12:00 UTC"""
utc_hour = datetime.utcnow().hour
return 9 <= utc_hour < 12
async def _pick_next_task(self) -> dict | None:
"""从 DB 获取优先级最高的 pending 任务"""
async with async_session_factory() as db:
stmt = (
select(FinOcrTask)
.where(
FinOcrTask.status == "pending",
FinOcrTask.is_deleted.is_(False),
FinOcrTask.retry_count < FinOcrTask.max_retries,
)
.order_by(FinOcrTask.priority, FinOcrTask.created_at)
.limit(1)
)
task = (await db.execute(stmt)).scalar_one_or_none()
if not task:
return None
# 标记为 processing
task.status = "processing"
task.updated_at = datetime.utcnow()
await db.commit()
self.current_task_id = task.id
return {
"id": task.id,
"file_url": task.file_url,
"file_ext": task.file_ext,
"original_name": task.original_name,
"uploader_id": task.uploader_id,
"company_id": task.company_id,
"inv_type": task.inv_type,
"retry_count": task.retry_count,
}
async def _process_task(self, task_info: dict):
"""执行 OCR 并更新"""
task_id = task_info["id"]
file_url = task_info["file_url"]
file_ext = task_info["file_ext"]
print(f"[OcrWorker] 处理任务 {task_id} ({task_info['original_name']}, {file_ext})")
try:
# 读取文件
file_path = file_url.lstrip("/")
if not os.path.exists(file_path):
await self._mark_failed(task_id, f"文件不存在: {file_path}")
return
with open(file_path, "rb") as f:
file_bytes = f.read()
ocr_data = {}
message = ""
# PDF 处理
if file_ext == ".pdf":
ocr_data, message = await self._process_pdf(file_bytes)
# 图片处理
elif file_ext in (".png", ".jpg", ".jpeg"):
ocr_data, message = await self._process_image(file_bytes)
else:
await self._mark_failed(task_id, f"不支持的文件格式: {file_ext}")
return
if ocr_data and (ocr_data.get("merchant") or ocr_data.get("amount")):
# OCR 成功 → 自动入池
await self._mark_success_and_pool(task_id, task_info, ocr_data)
print(f"[OcrWorker] ✅ {task_info['original_name']} 入池成功")
else:
# OCR 完成但没提取到关键字段
await self._mark_failed(
task_id,
message or "AI 未能提取发票关键字段(开票方/金额),请手动录入",
ocr_data,
)
except Exception as e:
print(f"[OcrWorker] ❌ 任务 {task_id} 异常: {e}")
await self._mark_failed(task_id, str(e))
self.current_task_id = None
async def _process_pdf(self, file_bytes: bytes) -> tuple[dict, str]:
"""PDF: 先尝试文本提取,失败降级 Vision OCR"""
try:
import fitz
doc = fitz.open(stream=file_bytes, filetype="pdf")
text = ""
for page in doc:
text += page.get_text() + "\n"
doc.close()
text = text.strip()
if len(text) > 50:
from app.services.ocr_service import extract_invoice_from_text
result = await extract_invoice_from_text(text, "invoice")
if result.get("success") and result.get("data"):
return result["data"], "PDF 文本解析成功"
# 降级: 扫描件 → Vision OCR
doc2 = fitz.open(stream=file_bytes, filetype="pdf")
pix = doc2[0].get_pixmap(dpi=150)
ocr_bytes = pix.tobytes("png")
doc2.close()
return await self._vision_ocr(ocr_bytes)
except Exception as e:
return {}, f"PDF 处理失败: {e}"
async def _process_image(self, file_bytes: bytes) -> tuple[dict, str]:
"""图片: Vision OCR"""
return await self._vision_ocr(file_bytes)
async def _vision_ocr(self, image_bytes: bytes) -> tuple[dict, str]:
"""调用 3090 Vision OCR"""
import base64
from app.services.ocr_service import ocr_image
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
result = await ocr_image(image_b64, "invoice")
if result.get("success"):
return result.get("data", {}), "Vision OCR 成功"
return {}, result.get("error", "OCR 失败")
async def _mark_success_and_pool(self, task_id: uuid.UUID, task_info: dict, ocr_data: dict):
"""标记成功 + 自动入池"""
async with async_session_factory() as db:
merchant = ocr_data.get("merchant") or ocr_data.get("merchant_name") or "AI 提取)"
amount = 0
try:
amount = float(ocr_data.get("amount", 0))
except (ValueError, TypeError):
pass
invoice_date_str = ocr_data.get("date")
invoice_date = None
if invoice_date_str:
try:
from datetime import date as dt_date
invoice_date = dt_date.fromisoformat(invoice_date_str)
except ValueError:
pass
inv = FinInvoicePool(
uploader_id=task_info["uploader_id"],
company_id=task_info["company_id"],
file_url=task_info["file_url"],
merchant_name=merchant,
amount=amount,
invoice_date=invoice_date,
type=task_info["inv_type"],
ai_extracted_data=ocr_data,
is_used=False,
)
db.add(inv)
await db.flush()
await db.execute(
update(FinOcrTask)
.where(FinOcrTask.id == task_id)
.values(
status="success",
ocr_result=ocr_data,
invoice_pool_id=inv.id,
error_message=None,
updated_at=datetime.utcnow(),
)
)
await db.commit()
async def _mark_failed(self, task_id: uuid.UUID, error: str, partial_data: dict | None = None):
"""标记失败 + retry_count+1"""
async with async_session_factory() as db:
task = (await db.execute(
select(FinOcrTask).where(FinOcrTask.id == task_id)
)).scalar_one_or_none()
if not task:
return
new_retry = task.retry_count + 1
new_status = "failed" if new_retry >= task.max_retries else "pending"
await db.execute(
update(FinOcrTask)
.where(FinOcrTask.id == task_id)
.values(
status=new_status,
retry_count=new_retry,
error_message=error,
ocr_result=partial_data or task.ocr_result,
updated_at=datetime.utcnow(),
)
)
await db.commit()
if new_status == "pending":
print(f"[OcrWorker] ⚠️ 任务 {task_id}{new_retry} 次重试入队")
else:
print(f"[OcrWorker] ❌ 任务 {task_id} 已达最大重试次数,标记失败")
# 单例
ocr_worker = OcrWorker()
+14 -4
View File
@@ -16,6 +16,7 @@ from app.core.exceptions import BizException, ForbiddenException, NotFoundExcept
from app.models.crm import CrmCustomer
from app.models.erp import ProductSku
from app.models.order import ErpOrder, ErpOrderItem
from app.models.sys import SysUser
from app.schemas.auth import CurrentUserPayload
from app.schemas.order import (
OrderBriefResponse,
@@ -156,6 +157,7 @@ async def create_order(
db: AsyncSession,
user: CurrentUserPayload,
body: OrderCreate,
company_id: uuid.UUID,
) -> OrderResponse:
# 校验客户存在
cust = (
@@ -193,6 +195,7 @@ async def create_order(
order_no=order_no,
customer_id=body.customer_id,
salesperson_id=user.user_id,
company_id=company_id,
total_amount=total,
shipping_state="pending",
payment_state="unpaid",
@@ -236,8 +239,11 @@ async def list_orders(
shipping_state: str | None = None,
payment_state: str | None = None,
keyword: str | None = None,
company_id: uuid.UUID | None = None,
) -> OrderListResponse:
where: list[Any] = [ErpOrder.is_deleted.is_(False)]
if company_id:
where.append(ErpOrder.company_id == company_id)
if user.data_scope == "self":
where.append(ErpOrder.salesperson_id == user.user_id)
@@ -284,13 +290,17 @@ async def get_order(
db: AsyncSession,
user: CurrentUserPayload,
order_id: uuid.UUID,
company_id: uuid.UUID | None = None,
) -> OrderResponse:
where_clause = [
ErpOrder.id == order_id,
ErpOrder.is_deleted.is_(False),
]
if company_id:
where_clause.append(ErpOrder.company_id == company_id)
order = (
await db.execute(
select(ErpOrder).where(
ErpOrder.id == order_id,
ErpOrder.is_deleted.is_(False),
)
select(ErpOrder).where(*where_clause)
)
).scalar_one_or_none()
if order is None:
+91 -26
View File
@@ -14,7 +14,8 @@ from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, NotFoundException
from app.models.erp import InventoryFlow, ProductCategory, ProductSku
from app.models.erp import ErpSkuInventory, InventoryFlow, ProductCategory, ProductSku
from app.models.sys import SysUser
from app.schemas.auth import CurrentUserPayload
from app.schemas.erp import (
CategoryCreate,
@@ -31,7 +32,10 @@ from app.schemas.erp import (
# ── ORM → Response ───────────────────────────────────────
def _sku_to_response(s: ProductSku) -> SkuResponse:
def _sku_to_response(
s: ProductSku,
inv: ErpSkuInventory | None = None,
) -> SkuResponse:
return SkuResponse(
id=s.id,
sku_code=s.sku_code,
@@ -40,8 +44,8 @@ def _sku_to_response(s: ProductSku) -> SkuResponse:
category_name=s.category.name if s.category else None,
spec=s.spec,
standard_price=float(s.standard_price or 0),
stock_qty=float(s.stock_qty or 0),
warning_threshold=float(s.warning_threshold or 0),
stock_qty=float(inv.stock_qty) if inv else 0.0,
warning_threshold=float(inv.warning_threshold) if inv else 0.0,
unit=s.unit,
status=s.status,
created_at=s.created_at,
@@ -200,11 +204,13 @@ async def delete_category(db: AsyncSession, cat_id: uuid.UUID) -> None:
async def list_skus(
db: AsyncSession,
company_id: uuid.UUID,
page: int = 1,
size: int = 20,
category_id: uuid.UUID | None = None,
keyword: str | None = None,
) -> SkuListResponse:
"""LEFT JOIN erp_sku_inventory 获取当前公司库存,COALESCE 兜底为 0"""
where: list[Any] = [ProductSku.is_deleted.is_(False)]
if category_id:
where.append(ProductSku.category_id == category_id)
@@ -218,24 +224,31 @@ async def list_skus(
await db.execute(select(func.count()).select_from(ProductSku).where(*where))
).scalar() or 0
# LEFT JOIN erp_sku_inventory 带出当前公司库存
stmt = (
select(ProductSku)
select(ProductSku, ErpSkuInventory)
.outerjoin(
ErpSkuInventory,
(ErpSkuInventory.sku_id == ProductSku.id)
& (ErpSkuInventory.company_id == company_id),
)
.where(*where)
.order_by(ProductSku.created_at.desc())
.offset((page - 1) * size)
.limit(size)
)
rows = (await db.execute(stmt)).scalars().all()
rows = (await db.execute(stmt)).all()
return SkuListResponse(
total=total,
items=[_sku_to_response(s) for s in rows],
items=[_sku_to_response(sku, inv) for sku, inv in rows],
page=page,
size=size,
)
async def create_sku(db: AsyncSession, body: SkuCreate) -> SkuResponse:
"""创建 SKU(不创建库存行,LEFT JOIN 查询自动兜底为 0)"""
exists = (
await db.execute(
select(ProductSku.id).where(
@@ -253,8 +266,6 @@ async def create_sku(db: AsyncSession, body: SkuCreate) -> SkuResponse:
category_id=body.category_id,
spec=body.spec,
standard_price=body.standard_price,
stock_qty=body.stock_qty,
warning_threshold=body.warning_threshold,
unit=body.unit,
status=body.status,
)
@@ -299,7 +310,9 @@ async def create_inventory_flow(
db: AsyncSession,
user: CurrentUserPayload,
body: InventoryFlowCreate,
company_id: uuid.UUID,
) -> InventoryFlowResponse:
"""库存变更(upsert erp_sku_inventory + 写流水)"""
sku = (
await db.execute(
select(ProductSku).where(
@@ -310,35 +323,74 @@ async def create_inventory_flow(
if sku is None:
raise NotFoundException("产品 SKU 不存在")
if body.change_qty < 0:
current_stock = float(sku.stock_qty or 0)
if current_stock + body.change_qty < 0:
raise BizException(
message=f"库存不足:当前库存 {current_stock},请求出库 {abs(body.change_qty)}"
)
try:
async with db.begin_nested():
# ── upsert: 查找或创建当前公司的库存行 ──
inv = (
await db.execute(
select(ErpSkuInventory)
.where(
ErpSkuInventory.sku_id == body.sku_id,
ErpSkuInventory.company_id == company_id,
)
.with_for_update()
)
).scalar_one_or_none()
if inv is None:
# 首次操作该 SKU:自动创建 0 库存行
inv = ErpSkuInventory(
sku_id=body.sku_id,
company_id=company_id,
stock_qty=0,
warning_threshold=0,
)
db.add(inv)
await db.flush()
# 重新锁行
inv = (
await db.execute(
select(ErpSkuInventory)
.where(ErpSkuInventory.id == inv.id)
.with_for_update()
)
).scalar_one()
# ── 校验库存 ──
current_stock = float(inv.stock_qty or 0)
if body.change_qty < 0 and current_stock + body.change_qty < 0:
raise BizException(
message=f"库存不足:当前库存 {current_stock},请求出库 {abs(body.change_qty)}"
)
# ── 更新库存 ──
await db.execute(
update(ErpSkuInventory)
.where(ErpSkuInventory.id == inv.id)
.values(
stock_qty=ErpSkuInventory.stock_qty + Decimal(str(body.change_qty)),
updated_at=datetime.utcnow(),
)
)
# ── 写流水 ──
flow = InventoryFlow(
sku_id=body.sku_id,
company_id=company_id,
change_qty=body.change_qty,
reason=body.reason,
remark=body.remark,
purchase_unit_price=body.purchase_unit_price if body.change_qty > 0 else 0,
is_special_zero_cost=body.is_special_zero_cost if body.change_qty > 0 else False,
operator_id=user.user_id,
)
db.add(flow)
await db.flush()
await db.execute(
update(ProductSku)
.where(ProductSku.id == body.sku_id)
.values(
stock_qty=ProductSku.stock_qty + Decimal(str(body.change_qty)),
updated_at=datetime.utcnow(),
)
)
await db.commit()
except BizException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
raise BizException(code=500, message=f"库存变更事务失败: {e!s}") from e
@@ -352,9 +404,11 @@ async def create_inventory_flow(
async def get_inventory_flows(
db: AsyncSession,
sku_id: uuid.UUID,
company_id: uuid.UUID,
page: int = 1,
size: int = 50,
) -> dict[str, Any]:
"""获取单个 SKU 在当前公司的库存流水"""
sku = (
await db.execute(
select(ProductSku).where(
@@ -365,8 +419,19 @@ async def get_inventory_flows(
if sku is None:
raise NotFoundException("产品 SKU 不存在")
# 查当前公司库存
inv = (
await db.execute(
select(ErpSkuInventory).where(
ErpSkuInventory.sku_id == sku_id,
ErpSkuInventory.company_id == company_id,
)
)
).scalar_one_or_none()
where: list[Any] = [
InventoryFlow.sku_id == sku_id,
InventoryFlow.company_id == company_id,
InventoryFlow.is_deleted.is_(False),
]
@@ -389,7 +454,7 @@ async def get_inventory_flows(
"total": total,
"sku_code": sku.sku_code,
"sku_name": sku.name,
"current_stock": float(sku.stock_qty or 0),
"current_stock": float(inv.stock_qty) if inv else 0.0,
"items": [_flow_to_response(f).model_dump(mode="json") for f in flows],
"page": page,
"size": size,
+226
View File
@@ -0,0 +1,226 @@
"""
库存与利润核算 Service 层
- MWA 入库事务(悲观锁 FOR UPDATE + 零元隔离)
- 订单利润快照
- 利润报表聚合
"""
from __future__ import annotations
import uuid
from datetime import datetime
from sqlalchemy import func, select, update, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, NotFoundException
from app.models.erp import ErpSkuInventory, InventoryFlow, ProductSku
from app.models.cost import ErpOrderItemCost
from app.models.order import ErpOrder, ErpOrderItem
from app.schemas.auth import CurrentUserPayload
# ── MWA 入库事务 ────────────────────────────────────────
async def process_inbound_with_mwa(
db: AsyncSession,
sku_id: uuid.UUID,
company_id: uuid.UUID,
qty: float,
purchase_unit_price: float,
operator_id: uuid.UUID | None = None,
remark: str | None = None,
is_special_zero_cost: bool = False,
) -> dict:
"""
入库事务(悲观锁 + MWA
1. SELECT ... FOR UPDATE 锁定库存行
2. 如果非零元特殊,计算新 MWA
3. 更新库存 + 记录流水
"""
# 悲观锁获取库存记录
inv_stmt = (
select(ErpSkuInventory)
.where(
ErpSkuInventory.sku_id == sku_id,
ErpSkuInventory.company_id == company_id,
)
.with_for_update()
)
inv = (await db.execute(inv_stmt)).scalar_one_or_none()
if inv is None:
# 首次入库,创建库存记录
inv = ErpSkuInventory(
sku_id=sku_id,
company_id=company_id,
stock_qty=0,
mwa_unit_cost=0,
)
db.add(inv)
await db.flush()
# 重新锁定
inv = (await db.execute(inv_stmt)).scalar_one()
old_qty = float(inv.stock_qty or 0)
old_mwa = float(inv.mwa_unit_cost or 0)
new_qty = old_qty + qty
# MWA 计算(零元特殊入库不参与)
if is_special_zero_cost or purchase_unit_price == 0:
new_mwa = old_mwa # 保持原有 MWA
else:
if new_qty > 0:
new_mwa = (old_qty * old_mwa + qty * purchase_unit_price) / new_qty
else:
new_mwa = purchase_unit_price
# 更新库存
inv.stock_qty = new_qty
inv.mwa_unit_cost = round(new_mwa, 4)
inv.updated_at = datetime.utcnow()
# 记录流水
flow = InventoryFlow(
sku_id=sku_id,
company_id=company_id,
flow_type="in",
change_qty=qty,
reason="purchase_in",
purchase_unit_price=purchase_unit_price,
is_special_zero_cost=is_special_zero_cost,
operator_id=operator_id,
remark=remark or f"入库 {qty} 件 @ ¥{purchase_unit_price}",
)
db.add(flow)
await db.commit()
return {
"sku_id": str(sku_id),
"old_qty": old_qty,
"new_qty": new_qty,
"old_mwa": old_mwa,
"new_mwa": round(new_mwa, 4),
"is_special_zero_cost": is_special_zero_cost,
}
# ── 订单明细成本快照 ────────────────────────────────────
async def snapshot_order_item_costs(
db: AsyncSession,
order_id: uuid.UUID,
company_id: uuid.UUID,
) -> list[dict]:
"""为订单的所有明细行锚定 MWA 成本快照"""
items_stmt = select(ErpOrderItem).where(
ErpOrderItem.order_id == order_id,
ErpOrderItem.is_deleted.is_(False),
)
items = (await db.execute(items_stmt)).scalars().all()
results = []
for item in items:
# 查当前 MWA
inv = (await db.execute(
select(ErpSkuInventory).where(
ErpSkuInventory.sku_id == item.sku_id,
ErpSkuInventory.company_id == company_id,
)
)).scalar_one_or_none()
mwa_cost = float(inv.mwa_unit_cost or 0) if inv else 0
sell_price = float(item.unit_price or 0)
qty = float(item.qty or 0)
profit = (sell_price - mwa_cost) * qty
profit_rate = (sell_price - mwa_cost) / sell_price if sell_price > 0 else 0
# 检查是否已有快照
existing = (await db.execute(
select(ErpOrderItemCost).where(
ErpOrderItemCost.order_item_id == item.id
)
)).scalar_one_or_none()
if existing:
existing.purchase_unit_price = mwa_cost
existing.profit_amount = round(profit, 2)
existing.profit_rate = round(profit_rate, 4)
else:
cost_snap = ErpOrderItemCost(
order_item_id=item.id,
purchase_unit_price=mwa_cost,
profit_amount=round(profit, 2),
profit_rate=round(profit_rate, 4),
)
db.add(cost_snap)
results.append({
"sku_id": str(item.sku_id),
"qty": qty,
"sell_price": sell_price,
"mwa_cost": mwa_cost,
"profit": round(profit, 2),
"profit_rate": round(profit_rate * 100, 2),
})
await db.commit()
return results
# ── 利润报表 ────────────────────────────────────────────
async def get_profit_report(
db: AsyncSession,
company_id: uuid.UUID,
start_date: str | None = None,
end_date: str | None = None,
) -> dict:
"""聚合利润报表"""
base_where = [
ErpOrder.company_id == company_id,
ErpOrder.is_deleted.is_(False),
]
if start_date:
base_where.append(ErpOrder.order_date >= start_date)
if end_date:
base_where.append(ErpOrder.order_date <= end_date)
# 聚合:每笔订单的利润
stmt = (
select(
ErpOrder.id.label("order_id"),
ErpOrder.order_no,
ErpOrder.order_date,
ErpOrder.total_amount,
func.sum(ErpOrderItemCost.profit_amount).label("total_profit"),
)
.join(ErpOrderItem, ErpOrderItem.order_id == ErpOrder.id)
.join(ErpOrderItemCost, ErpOrderItemCost.order_item_id == ErpOrderItem.id)
.where(*base_where)
.group_by(ErpOrder.id, ErpOrder.order_no, ErpOrder.order_date, ErpOrder.total_amount)
.order_by(ErpOrder.order_date.desc())
)
rows = (await db.execute(stmt)).all()
orders = []
total_revenue = 0
total_profit = 0
for r in rows:
revenue = float(r.total_amount or 0)
profit = float(r.total_profit or 0)
total_revenue += revenue
total_profit += profit
orders.append({
"order_id": str(r.order_id),
"order_no": r.order_no,
"order_date": r.order_date.isoformat() if r.order_date else None,
"revenue": revenue,
"profit": profit,
"profit_rate": round(profit / revenue * 100, 2) if revenue > 0 else 0,
})
return {
"total_revenue": round(total_revenue, 2),
"total_profit": round(total_profit, 2),
"overall_profit_rate": round(total_profit / total_revenue * 100, 2) if total_revenue > 0 else 0,
"orders": orders,
}
+10 -1
View File
@@ -12,6 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, NotFoundException
from app.models.finance import FinSalesInvoice
from app.models.sys import SysUser
from app.models.crm import CrmCustomer
from app.schemas.auth import CurrentUserPayload
from app.schemas.sales_invoice import (
SalesInvoiceCreate,
@@ -45,6 +47,7 @@ async def create_invoice(
db: AsyncSession,
user: CurrentUserPayload,
body: SalesInvoiceCreate,
company_id: uuid.UUID | None = None,
) -> SalesInvoiceResponse:
# 检查发票号唯一性
existing = (await db.execute(
@@ -56,7 +59,7 @@ async def create_invoice(
if existing:
raise BizException(message=f"发票号 {body.invoice_number} 已存在")
inv = FinSalesInvoice(
kwargs: dict = dict(
issuer=body.issuer,
receiver_customer_id=body.receiver_customer_id,
invoice_number=body.invoice_number,
@@ -65,6 +68,9 @@ async def create_invoice(
remark=body.remark,
created_by=user.user_id,
)
if company_id is not None:
kwargs["company_id"] = company_id
inv = FinSalesInvoice(**kwargs)
db.add(inv)
await db.commit()
await db.refresh(inv)
@@ -80,8 +86,11 @@ async def list_invoices(
payment_status: str | None = None,
start_date: date | None = None,
end_date: date | None = None,
company_id: uuid.UUID | None = None,
) -> SalesInvoiceListResponse:
conditions = [FinSalesInvoice.is_deleted.is_(False)]
if company_id:
conditions.append(FinSalesInvoice.company_id == company_id)
if invoice_number:
conditions.append(FinSalesInvoice.invoice_number.ilike(f"%{invoice_number}%"))
+99 -5
View File
@@ -22,6 +22,7 @@ async def create_log(
customer_id: str | None = None,
contact_ids: list[str] | None = None,
log_date: date | None = None,
company_ids: list[uuid.UUID] | None = None,
) -> dict:
"""创建销售日志"""
log = SalesLog(
@@ -30,6 +31,7 @@ async def create_log(
contact_ids=contact_ids or [],
content=content,
log_date=log_date or date.today(),
involved_company_ids=company_ids or [],
)
db.add(log)
await db.commit()
@@ -46,9 +48,17 @@ async def list_logs(
user_id: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
company_id: uuid.UUID | None = None,
) -> dict:
"""查询销售日志列表"""
"""查询销售日志列表(按 involved_company_ids 包含过滤)"""
from sqlalchemy.orm import aliased
from app.models.crm import CrmCustomer
from app.models.sys import SysUser
conditions = [SalesLog.is_deleted.is_(False)]
if company_id:
# ARRAY contains: 过滤涉及当前公司的日志
conditions.append(SalesLog.involved_company_ids.any(company_id))
# 数据权限
if user.data_scope == "self":
@@ -69,24 +79,107 @@ async def list_logs(
count_stmt = select(func.count()).select_from(SalesLog).where(where)
total = (await db.execute(count_stmt)).scalar() or 0
# data
# data — LEFT JOIN customer + user to get names
Author = aliased(SysUser)
stmt = (
select(SalesLog)
select(
SalesLog,
CrmCustomer.name.label("customer_name"),
Author.real_name.label("author_name"),
)
.outerjoin(CrmCustomer, SalesLog.customer_id == CrmCustomer.id)
.outerjoin(Author, SalesLog.salesperson_id == Author.id)
.where(where)
.order_by(desc(SalesLog.created_at))
.offset((page - 1) * size)
.limit(size)
)
rows = (await db.execute(stmt)).scalars().all()
rows = (await db.execute(stmt)).all()
items = []
for log, cust_name, auth_name in rows:
d = _to_dict(log)
d["customer_name"] = cust_name
d["author_name"] = auth_name
items.append(d)
return {
"total": total,
"page": page,
"size": size,
"items": [_to_dict(r) for r in rows],
"items": items,
}
async def update_log(
db: AsyncSession,
user: CurrentUserPayload,
log_id: uuid.UUID,
content: str | None = None,
customer_id: str | None = None,
contact_ids: list[str] | None = None,
log_date: str | None = None,
company_id: uuid.UUID | None = None,
) -> dict:
"""编辑销售日志 — 员工只能改自己的,管理员可改所有"""
from app.models.crm import CrmCustomer
from app.models.sys import SysUserCompany
log = await db.get(SalesLog, log_id)
if not log or log.is_deleted:
raise Exception("日志不存在")
# 权限检查
if user.data_scope != "all" and log.salesperson_id != user.user_id:
raise Exception("您无权编辑此日志")
if content is not None:
log.content = content
if contact_ids is not None:
log.contact_ids = contact_ids
if log_date is not None:
log.log_date = date.fromisoformat(log_date)
# 更新客户关联 + 自动重算 involved_company_ids
if customer_id is not None:
log.customer_id = uuid.UUID(customer_id) if customer_id else None
# 重新关联公司
resolved = set(log.involved_company_ids or [])
if company_id:
resolved.add(company_id)
if customer_id:
cust = await db.get(CrmCustomer, uuid.UUID(customer_id))
if cust and cust.owner_id:
stmt = select(SysUserCompany.company_id).where(
SysUserCompany.user_id == cust.owner_id
)
rows = (await db.execute(stmt)).scalars().all()
for cid in rows:
resolved.add(cid)
log.involved_company_ids = list(resolved)
await db.commit()
await db.refresh(log)
return _to_dict(log)
async def delete_log(
db: AsyncSession,
user: CurrentUserPayload,
log_id: uuid.UUID,
) -> None:
"""软删除销售日志 — 员工只能删自己的,管理员可删所有"""
log = await db.get(SalesLog, log_id)
if not log or log.is_deleted:
raise Exception("日志不存在")
if user.data_scope != "all" and log.salesperson_id != user.user_id:
raise Exception("您无权删除此日志")
log.is_deleted = True
await db.commit()
async def trigger_persona_workflow(
log_id: uuid.UUID,
customer_id: uuid.UUID,
@@ -157,6 +250,7 @@ def _to_dict(log: SalesLog) -> dict:
"salesperson_id": str(log.salesperson_id),
"customer_id": str(log.customer_id) if log.customer_id else None,
"contact_ids": log.contact_ids or [],
"involved_company_ids": [str(c) for c in (log.involved_company_ids or [])],
"content": log.content,
"log_date": log.log_date.isoformat() if log.log_date else None,
"ai_processed": log.ai_processed,
+49 -15
View File
@@ -10,9 +10,11 @@ from typing import Any
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import BizException, ForbiddenException, NotFoundException
from app.models.erp import InventoryFlow, ProductSku
from app.models.erp import ErpSkuInventory, InventoryFlow, ProductSku
from app.models.order import ErpOrder, ErpOrderItem
from app.models.shipping import ErpShippingItem, ErpShippingRecord
from app.models.sys import SysUser
from app.models.crm import CrmCustomer
from app.schemas.auth import CurrentUserPayload
from app.schemas.shipping import (
ShippingBriefResponse, ShippingCreate, ShippingItemResponse,
@@ -75,10 +77,15 @@ def _check_shipping_access(order: ErpOrder, user: CurrentUserPayload) -> None:
async def create_shipping(
db: AsyncSession, user: CurrentUserPayload, body: ShippingCreate,
company_id: uuid.UUID,
) -> tuple[ShippingResponse, str]:
"""返回 (response, new_shipping_state)"""
"""返回 (response, new_shipping_state)。库存从 erp_sku_inventory 扣减"""
order = (await db.execute(
select(ErpOrder).where(ErpOrder.id == body.order_id, ErpOrder.is_deleted.is_(False))
select(ErpOrder).where(
ErpOrder.id == body.order_id,
ErpOrder.is_deleted.is_(False),
ErpOrder.company_id == company_id,
)
)).scalar_one_or_none()
if order is None:
raise NotFoundException("订单不存在")
@@ -114,6 +121,7 @@ async def create_shipping(
carrier=body.carrier, tracking_no=body.tracking_no,
status="transit", ship_date=body.ship_date or date.today(),
remark=body.remark, operator_id=user.user_id,
company_id=company_id,
)
db.add(record)
await db.flush()
@@ -125,22 +133,41 @@ async def create_shipping(
)
db.add(si)
result = await db.execute(
update(ProductSku).where(
ProductSku.id == item.sku_id,
ProductSku.stock_qty >= item.shipped_qty,
).values(
stock_qty=ProductSku.stock_qty - Decimal(str(item.shipped_qty)),
# ── 从 erp_sku_inventory 扣减库存(行锁) ──
inv = (
await db.execute(
select(ErpSkuInventory)
.where(
ErpSkuInventory.sku_id == item.sku_id,
ErpSkuInventory.company_id == company_id,
)
.with_for_update()
)
).scalar_one_or_none()
current_stock = float(inv.stock_qty) if inv else 0
if current_stock < item.shipped_qty:
raise BizException(
message=f"库存不足无法发货: SKU {item.sku_id}"
f"当前库存 {current_stock},请求出库 {item.shipped_qty}"
)
if inv is None:
# 不应出现此情况,但防御性处理
raise BizException(message=f"SKU {item.sku_id} 在当前公司无库存记录")
await db.execute(
update(ErpSkuInventory)
.where(ErpSkuInventory.id == inv.id)
.values(
stock_qty=ErpSkuInventory.stock_qty - Decimal(str(item.shipped_qty)),
updated_at=now,
)
)
if result.rowcount == 0:
sku = (await db.execute(select(ProductSku).where(ProductSku.id == item.sku_id))).scalar_one_or_none()
current_stock = float(sku.stock_qty) if sku else 0
raise BizException(message=f"库存不足无法发货: SKU {item.sku_id},当前库存 {current_stock},请求出库 {item.shipped_qty}")
db.add(InventoryFlow(
sku_id=item.sku_id, change_qty=-item.shipped_qty,
sku_id=item.sku_id, company_id=company_id,
change_qty=-item.shipped_qty,
reason="shipment", remark=f"订单发货出库 - 发货单 {shipping_no}",
operator_id=user.user_id,
))
@@ -178,8 +205,11 @@ async def list_shipping(
db: AsyncSession, user: CurrentUserPayload,
page: int = 1, size: int = 20,
order_no: str | None = None, tracking_no: str | None = None,
company_id: uuid.UUID | None = None,
) -> ShippingListResponse:
where: list[Any] = [ErpShippingRecord.is_deleted.is_(False)]
if company_id:
where.append(ErpShippingRecord.company_id == company_id)
if user.data_scope == "self":
my_orders = select(ErpOrder.id).where(ErpOrder.salesperson_id == user.user_id, ErpOrder.is_deleted.is_(False))
where.append(ErpShippingRecord.order_id.in_(my_orders))
@@ -203,9 +233,13 @@ async def list_shipping(
async def get_shipping_by_order(
db: AsyncSession, user: CurrentUserPayload, order_id: uuid.UUID,
company_id: uuid.UUID | None = None,
) -> dict[str, Any]:
where_clause = [ErpOrder.id == order_id, ErpOrder.is_deleted.is_(False)]
if company_id:
where_clause.append(ErpOrder.company_id == company_id)
order = (await db.execute(
select(ErpOrder).where(ErpOrder.id == order_id, ErpOrder.is_deleted.is_(False))
select(ErpOrder).where(*where_clause)
)).scalar_one_or_none()
if order is None:
raise NotFoundException("订单不存在")