""" 产品与库存 Service 层 REST API 路由 和 MCP 工具 共用此层函数 """ from __future__ import annotations import uuid from datetime import datetime from decimal import Decimal from typing import Any from sqlalchemy import func, select, update from sqlalchemy.ext.asyncio import AsyncSession from app.core.exceptions import BizException, NotFoundException from app.models.erp import InventoryFlow, ProductCategory, ProductSku from app.schemas.auth import CurrentUserPayload from app.schemas.erp import ( CategoryCreate, CategoryNode, CategoryUpdate, InventoryFlowCreate, InventoryFlowResponse, SkuCreate, SkuListResponse, SkuResponse, SkuUpdate, ) # ── ORM → Response ─────────────────────────────────────── def _sku_to_response(s: ProductSku) -> SkuResponse: return SkuResponse( id=s.id, sku_code=s.sku_code, name=s.name, category_id=s.category_id, category_name=s.category.name if s.category else None, spec=s.spec, standard_price=float(s.standard_price or 0), stock_qty=float(s.stock_qty or 0), warning_threshold=float(s.warning_threshold or 0), unit=s.unit, status=s.status, created_at=s.created_at, updated_at=s.updated_at, ) def _flow_to_response(f: InventoryFlow) -> InventoryFlowResponse: return InventoryFlowResponse( id=f.id, sku_id=f.sku_id, sku_code=f.sku.sku_code if f.sku else None, sku_name=f.sku.name if f.sku else None, change_qty=float(f.change_qty), reason=f.reason, remark=f.remark, operator_id=f.operator_id, operator_name=f.operator.real_name if f.operator else None, created_at=f.created_at, ) def _build_tree( items: list[ProductCategory], parent_id: uuid.UUID | None = None, ) -> list[CategoryNode]: nodes: list[CategoryNode] = [] for item in items: if item.parent_id == parent_id: children = _build_tree(items, item.id) nodes.append( CategoryNode( id=item.id, parent_id=item.parent_id, name=item.name, sort_order=item.sort_order, children=children, ) ) nodes.sort(key=lambda n: n.sort_order) return nodes # ── Service Functions ──────────────────────────────────── async def get_category_tree(db: AsyncSession) -> list[dict[str, Any]]: stmt = ( select(ProductCategory) .where(ProductCategory.is_deleted.is_(False)) .order_by(ProductCategory.sort_order) ) categories = list((await db.execute(stmt)).scalars().all()) tree = _build_tree(categories, parent_id=None) return [n.model_dump(mode="json") for n in tree] async def create_category( db: AsyncSession, body: CategoryCreate, ) -> dict[str, Any]: if body.parent_id: parent = ( await db.execute( select(ProductCategory).where( ProductCategory.id == body.parent_id, ProductCategory.is_deleted.is_(False), ) ) ).scalar_one_or_none() if parent is None: raise NotFoundException("父级分类不存在") cat = ProductCategory( name=body.name, parent_id=body.parent_id, sort_order=body.sort_order, ) db.add(cat) await db.commit() await db.refresh(cat) return { "id": str(cat.id), "name": cat.name, "parent_id": str(cat.parent_id) if cat.parent_id else None, } async def update_category( db: AsyncSession, cat_id: uuid.UUID, body: CategoryUpdate, ) -> None: cat = ( await db.execute( select(ProductCategory).where( ProductCategory.id == cat_id, ProductCategory.is_deleted.is_(False), ) ) ).scalar_one_or_none() if cat is None: raise NotFoundException("分类不存在或已被删除") update_data = body.model_dump(exclude_unset=True) if not update_data: raise BizException(message="未提供任何需要更新的字段") update_data["updated_at"] = datetime.utcnow() await db.execute( update(ProductCategory).where(ProductCategory.id == cat_id).values(**update_data) ) await db.commit() async def delete_category(db: AsyncSession, cat_id: uuid.UUID) -> None: cat = ( await db.execute( select(ProductCategory).where( ProductCategory.id == cat_id, ProductCategory.is_deleted.is_(False), ) ) ).scalar_one_or_none() if cat is None: raise NotFoundException("分类不存在或已被删除") child_count = ( await db.execute( select(func.count()).select_from(ProductCategory).where( ProductCategory.parent_id == cat_id, ProductCategory.is_deleted.is_(False), ) ) ).scalar() or 0 if child_count > 0: raise BizException(message=f"该分类下有 {child_count} 个子分类,无法删除") sku_count = ( await db.execute( select(func.count()).select_from(ProductSku).where( ProductSku.category_id == cat_id, ProductSku.is_deleted.is_(False), ) ) ).scalar() or 0 if sku_count > 0: raise BizException(message=f"该分类下有 {sku_count} 个产品 SKU,无法删除") await db.execute( update(ProductCategory) .where(ProductCategory.id == cat_id) .values(is_deleted=True, updated_at=datetime.utcnow()) ) await db.commit() async def list_skus( db: AsyncSession, page: int = 1, size: int = 20, category_id: uuid.UUID | None = None, keyword: str | None = None, ) -> SkuListResponse: where: list[Any] = [ProductSku.is_deleted.is_(False)] if category_id: where.append(ProductSku.category_id == category_id) if keyword: where.append( ProductSku.name.ilike(f"%{keyword}%") | ProductSku.sku_code.ilike(f"%{keyword}%") ) total = ( await db.execute(select(func.count()).select_from(ProductSku).where(*where)) ).scalar() or 0 stmt = ( select(ProductSku) .where(*where) .order_by(ProductSku.created_at.desc()) .offset((page - 1) * size) .limit(size) ) rows = (await db.execute(stmt)).scalars().all() return SkuListResponse( total=total, items=[_sku_to_response(s) for s in rows], page=page, size=size, ) async def create_sku(db: AsyncSession, body: SkuCreate) -> SkuResponse: exists = ( await db.execute( select(ProductSku.id).where( ProductSku.sku_code == body.sku_code, ProductSku.is_deleted.is_(False), ) ) ).scalar_one_or_none() if exists: raise BizException(message=f"SKU 编码 '{body.sku_code}' 已存在") sku = ProductSku( sku_code=body.sku_code, name=body.name, category_id=body.category_id, spec=body.spec, standard_price=body.standard_price, stock_qty=body.stock_qty, warning_threshold=body.warning_threshold, unit=body.unit, status=body.status, ) db.add(sku) await db.commit() await db.refresh(sku) return _sku_to_response(sku) async def update_sku( db: AsyncSession, sku_id: uuid.UUID, body: SkuUpdate, ) -> SkuResponse: sku = ( await db.execute( select(ProductSku).where( ProductSku.id == sku_id, ProductSku.is_deleted.is_(False) ) ) ).scalar_one_or_none() if sku is None: raise NotFoundException("产品不存在或已被删除") update_data = body.model_dump(exclude_unset=True) if not update_data: raise BizException(message="未提供任何需要更新的字段") update_data["updated_at"] = datetime.utcnow() await db.execute( update(ProductSku).where(ProductSku.id == sku_id).values(**update_data) ) await db.commit() refreshed = ( await db.execute(select(ProductSku).where(ProductSku.id == sku_id)) ).scalar_one() return _sku_to_response(refreshed) async def create_inventory_flow( db: AsyncSession, user: CurrentUserPayload, body: InventoryFlowCreate, ) -> InventoryFlowResponse: sku = ( await db.execute( select(ProductSku).where( ProductSku.id == body.sku_id, ProductSku.is_deleted.is_(False) ) ) ).scalar_one_or_none() if sku is None: raise NotFoundException("产品 SKU 不存在") if body.change_qty < 0: current_stock = float(sku.stock_qty or 0) if current_stock + body.change_qty < 0: raise BizException( message=f"库存不足:当前库存 {current_stock},请求出库 {abs(body.change_qty)}" ) try: async with db.begin_nested(): flow = InventoryFlow( sku_id=body.sku_id, change_qty=body.change_qty, reason=body.reason, remark=body.remark, operator_id=user.user_id, ) db.add(flow) await db.flush() await db.execute( update(ProductSku) .where(ProductSku.id == body.sku_id) .values( stock_qty=ProductSku.stock_qty + Decimal(str(body.change_qty)), updated_at=datetime.utcnow(), ) ) await db.commit() except Exception as e: await db.rollback() raise BizException(code=500, message=f"库存变更事务失败: {e!s}") from e refreshed = ( await db.execute(select(InventoryFlow).where(InventoryFlow.id == flow.id)) ).scalar_one() return _flow_to_response(refreshed) async def get_inventory_flows( db: AsyncSession, sku_id: uuid.UUID, page: int = 1, size: int = 50, ) -> dict[str, Any]: sku = ( await db.execute( select(ProductSku).where( ProductSku.id == sku_id, ProductSku.is_deleted.is_(False) ) ) ).scalar_one_or_none() if sku is None: raise NotFoundException("产品 SKU 不存在") where: list[Any] = [ InventoryFlow.sku_id == sku_id, InventoryFlow.is_deleted.is_(False), ] total = ( await db.execute( select(func.count()).select_from(InventoryFlow).where(*where) ) ).scalar() or 0 stmt = ( select(InventoryFlow) .where(*where) .order_by(InventoryFlow.created_at.desc()) .offset((page - 1) * size) .limit(size) ) flows = (await db.execute(stmt)).scalars().all() return { "total": total, "sku_code": sku.sku_code, "sku_name": sku.name, "current_stock": float(sku.stock_qty or 0), "items": [_flow_to_response(f).model_dump(mode="json") for f in flows], "page": page, "size": size, }