""" 产品与库存 Service 层 REST API 路由 和 MCP 工具 共用此层函数 """ from __future__ import annotations import uuid from datetime import datetime from decimal import Decimal from typing import Any from sqlalchemy import func, select, update from sqlalchemy.ext.asyncio import AsyncSession from app.core.exceptions import BizException, NotFoundException from app.models.erp import ErpSkuInventory, InventoryFlow, ProductCategory, ProductSku from app.models.sys import SysUser from app.schemas.auth import CurrentUserPayload from app.schemas.erp import ( CategoryCreate, CategoryNode, CategoryUpdate, InventoryFlowCreate, InventoryFlowResponse, SkuCreate, SkuListResponse, SkuResponse, SkuUpdate, ) # ── ORM → Response ─────────────────────────────────────── def _sku_to_response( s: ProductSku, inv: ErpSkuInventory | None = None, ) -> SkuResponse: return SkuResponse( id=s.id, sku_code=s.sku_code, name=s.name, category_id=s.category_id, category_name=s.category.name if s.category else None, spec=s.spec, standard_price=float(s.standard_price or 0), stock_qty=float(inv.stock_qty) if inv else 0.0, warning_threshold=float(inv.warning_threshold) if inv else 0.0, unit=s.unit, status=s.status, created_at=s.created_at, updated_at=s.updated_at, ) def _flow_to_response(f: InventoryFlow) -> InventoryFlowResponse: return InventoryFlowResponse( id=f.id, sku_id=f.sku_id, sku_code=f.sku.sku_code if f.sku else None, sku_name=f.sku.name if f.sku else None, change_qty=float(f.change_qty), reason=f.reason, remark=f.remark, operator_id=f.operator_id, operator_name=f.operator.real_name if f.operator else None, created_at=f.created_at, ) def _build_tree( items: list[ProductCategory], parent_id: uuid.UUID | None = None, ) -> list[CategoryNode]: nodes: list[CategoryNode] = [] for item in items: if item.parent_id == parent_id: children = _build_tree(items, item.id) nodes.append( CategoryNode( id=item.id, parent_id=item.parent_id, name=item.name, sort_order=item.sort_order, children=children, ) ) nodes.sort(key=lambda n: n.sort_order) return nodes # ── Service Functions ──────────────────────────────────── async def get_category_tree(db: AsyncSession) -> list[dict[str, Any]]: stmt = ( select(ProductCategory) .where(ProductCategory.is_deleted.is_(False)) .order_by(ProductCategory.sort_order) ) categories = list((await db.execute(stmt)).scalars().all()) tree = _build_tree(categories, parent_id=None) return [n.model_dump(mode="json") for n in tree] async def create_category( db: AsyncSession, body: CategoryCreate, ) -> dict[str, Any]: if body.parent_id: parent = ( await db.execute( select(ProductCategory).where( ProductCategory.id == body.parent_id, ProductCategory.is_deleted.is_(False), ) ) ).scalar_one_or_none() if parent is None: raise NotFoundException("父级分类不存在") cat = ProductCategory( name=body.name, parent_id=body.parent_id, sort_order=body.sort_order, ) db.add(cat) await db.commit() await db.refresh(cat) return { "id": str(cat.id), "name": cat.name, "parent_id": str(cat.parent_id) if cat.parent_id else None, } async def update_category( db: AsyncSession, cat_id: uuid.UUID, body: CategoryUpdate, ) -> None: cat = ( await db.execute( select(ProductCategory).where( ProductCategory.id == cat_id, ProductCategory.is_deleted.is_(False), ) ) ).scalar_one_or_none() if cat is None: raise NotFoundException("分类不存在或已被删除") update_data = body.model_dump(exclude_unset=True) if not update_data: raise BizException(message="未提供任何需要更新的字段") update_data["updated_at"] = datetime.utcnow() await db.execute( update(ProductCategory).where(ProductCategory.id == cat_id).values(**update_data) ) await db.commit() async def delete_category(db: AsyncSession, cat_id: uuid.UUID) -> None: cat = ( await db.execute( select(ProductCategory).where( ProductCategory.id == cat_id, ProductCategory.is_deleted.is_(False), ) ) ).scalar_one_or_none() if cat is None: raise NotFoundException("分类不存在或已被删除") child_count = ( await db.execute( select(func.count()).select_from(ProductCategory).where( ProductCategory.parent_id == cat_id, ProductCategory.is_deleted.is_(False), ) ) ).scalar() or 0 if child_count > 0: raise BizException(message=f"该分类下有 {child_count} 个子分类,无法删除") sku_count = ( await db.execute( select(func.count()).select_from(ProductSku).where( ProductSku.category_id == cat_id, ProductSku.is_deleted.is_(False), ) ) ).scalar() or 0 if sku_count > 0: raise BizException(message=f"该分类下有 {sku_count} 个产品 SKU,无法删除") await db.execute( update(ProductCategory) .where(ProductCategory.id == cat_id) .values(is_deleted=True, updated_at=datetime.utcnow()) ) await db.commit() async def list_skus( db: AsyncSession, company_id: uuid.UUID, page: int = 1, size: int = 20, category_id: uuid.UUID | None = None, keyword: str | None = None, ) -> SkuListResponse: """LEFT JOIN erp_sku_inventory 获取当前公司库存,COALESCE 兜底为 0""" where: list[Any] = [ProductSku.is_deleted.is_(False)] if category_id: where.append(ProductSku.category_id == category_id) if keyword: where.append( ProductSku.name.ilike(f"%{keyword}%") | ProductSku.sku_code.ilike(f"%{keyword}%") ) total = ( await db.execute(select(func.count()).select_from(ProductSku).where(*where)) ).scalar() or 0 # LEFT JOIN erp_sku_inventory 带出当前公司库存 stmt = ( select(ProductSku, ErpSkuInventory) .outerjoin( ErpSkuInventory, (ErpSkuInventory.sku_id == ProductSku.id) & (ErpSkuInventory.company_id == company_id), ) .where(*where) .order_by(ProductSku.created_at.desc()) .offset((page - 1) * size) .limit(size) ) rows = (await db.execute(stmt)).all() return SkuListResponse( total=total, items=[_sku_to_response(sku, inv) for sku, inv in rows], page=page, size=size, ) async def create_sku(db: AsyncSession, body: SkuCreate) -> SkuResponse: """创建 SKU(不创建库存行,LEFT JOIN 查询自动兜底为 0)""" exists = ( await db.execute( select(ProductSku.id).where( ProductSku.sku_code == body.sku_code, ProductSku.is_deleted.is_(False), ) ) ).scalar_one_or_none() if exists: raise BizException(message=f"SKU 编码 '{body.sku_code}' 已存在") sku = ProductSku( sku_code=body.sku_code, name=body.name, category_id=body.category_id, spec=body.spec, standard_price=body.standard_price, unit=body.unit, status=body.status, ) db.add(sku) await db.commit() await db.refresh(sku) return _sku_to_response(sku) async def update_sku( db: AsyncSession, sku_id: uuid.UUID, body: SkuUpdate, ) -> SkuResponse: sku = ( await db.execute( select(ProductSku).where( ProductSku.id == sku_id, ProductSku.is_deleted.is_(False) ) ) ).scalar_one_or_none() if sku is None: raise NotFoundException("产品不存在或已被删除") update_data = body.model_dump(exclude_unset=True) if not update_data: raise BizException(message="未提供任何需要更新的字段") update_data["updated_at"] = datetime.utcnow() await db.execute( update(ProductSku).where(ProductSku.id == sku_id).values(**update_data) ) await db.commit() refreshed = ( await db.execute(select(ProductSku).where(ProductSku.id == sku_id)) ).scalar_one() return _sku_to_response(refreshed) async def create_inventory_flow( db: AsyncSession, user: CurrentUserPayload, body: InventoryFlowCreate, company_id: uuid.UUID, ) -> InventoryFlowResponse: """库存变更(upsert erp_sku_inventory + 写流水)""" sku = ( await db.execute( select(ProductSku).where( ProductSku.id == body.sku_id, ProductSku.is_deleted.is_(False) ) ) ).scalar_one_or_none() if sku is None: raise NotFoundException("产品 SKU 不存在") try: async with db.begin_nested(): # ── upsert: 查找或创建当前公司的库存行 ── inv = ( await db.execute( select(ErpSkuInventory) .where( ErpSkuInventory.sku_id == body.sku_id, ErpSkuInventory.company_id == company_id, ) .with_for_update() ) ).scalar_one_or_none() if inv is None: # 首次操作该 SKU:自动创建 0 库存行 inv = ErpSkuInventory( sku_id=body.sku_id, company_id=company_id, stock_qty=0, warning_threshold=0, ) db.add(inv) await db.flush() # 重新锁行 inv = ( await db.execute( select(ErpSkuInventory) .where(ErpSkuInventory.id == inv.id) .with_for_update() ) ).scalar_one() # ── 校验库存 ── current_stock = float(inv.stock_qty or 0) if body.change_qty < 0 and current_stock + body.change_qty < 0: raise BizException( message=f"库存不足:当前库存 {current_stock},请求出库 {abs(body.change_qty)}" ) # ── 更新库存 ── await db.execute( update(ErpSkuInventory) .where(ErpSkuInventory.id == inv.id) .values( stock_qty=ErpSkuInventory.stock_qty + Decimal(str(body.change_qty)), updated_at=datetime.utcnow(), ) ) # ── 写流水 ── flow = InventoryFlow( sku_id=body.sku_id, company_id=company_id, change_qty=body.change_qty, reason=body.reason, remark=body.remark, purchase_unit_price=body.purchase_unit_price if body.change_qty > 0 else 0, is_special_zero_cost=body.is_special_zero_cost if body.change_qty > 0 else False, operator_id=user.user_id, ) db.add(flow) await db.flush() await db.commit() except BizException: await db.rollback() raise except Exception as e: await db.rollback() raise BizException(code=500, message=f"库存变更事务失败: {e!s}") from e refreshed = ( await db.execute(select(InventoryFlow).where(InventoryFlow.id == flow.id)) ).scalar_one() return _flow_to_response(refreshed) async def get_inventory_flows( db: AsyncSession, sku_id: uuid.UUID, company_id: uuid.UUID, page: int = 1, size: int = 50, ) -> dict[str, Any]: """获取单个 SKU 在当前公司的库存流水""" sku = ( await db.execute( select(ProductSku).where( ProductSku.id == sku_id, ProductSku.is_deleted.is_(False) ) ) ).scalar_one_or_none() if sku is None: raise NotFoundException("产品 SKU 不存在") # 查当前公司库存 inv = ( await db.execute( select(ErpSkuInventory).where( ErpSkuInventory.sku_id == sku_id, ErpSkuInventory.company_id == company_id, ) ) ).scalar_one_or_none() where: list[Any] = [ InventoryFlow.sku_id == sku_id, InventoryFlow.company_id == company_id, InventoryFlow.is_deleted.is_(False), ] total = ( await db.execute( select(func.count()).select_from(InventoryFlow).where(*where) ) ).scalar() or 0 stmt = ( select(InventoryFlow) .where(*where) .order_by(InventoryFlow.created_at.desc()) .offset((page - 1) * size) .limit(size) ) flows = (await db.execute(stmt)).scalars().all() return { "total": total, "sku_code": sku.sku_code, "sku_name": sku.name, "current_stock": float(inv.stock_qty) if inv else 0.0, "items": [_flow_to_response(f).model_dump(mode="json") for f in flows], "page": page, "size": size, }