""" SHBL-ERP CRM 测试基础设施 ========================= 策略: 在 import app 之前, 先修改 settings 实例的 DATABASE_URL, 然后 monkey-patch app.db.database 模块, 再安全 import app。 """ from __future__ import annotations import asyncio import sys import uuid from datetime import date, datetime from typing import AsyncGenerator import pytest import pytest_asyncio # ═══════════════════════════════════════════════════════════ # Step 1: 设置环境变量 (在任何 app 代码被 import 之前) # ═══════════════════════════════════════════════════════════ import os os.environ["DATABASE_URL"] = "sqlite+aiosqlite://" os.environ["JWT_SECRET_KEY"] = "test-secret-key-for-unit-tests" os.environ["DEBUG"] = "false" os.environ["DIFY_API_BASE_URL"] = "" os.environ["DIFY_API_KEY"] = "" os.environ["DIFY_WORKFLOW_PERSONA_KEY"] = "" os.environ["DIFY_WORKFLOW_REPORT_KEY"] = "" # ═══════════════════════════════════════════════════════════ # Step 2: 先 import config (轻量级, 不会触发 DB 连接) # 然后 patch settings.DATABASE_URL # ═══════════════════════════════════════════════════════════ from app.core.config import settings # noqa: E402 # 确保 settings 指向 sqlite settings.DATABASE_URL = "sqlite+aiosqlite://" # ═══════════════════════════════════════════════════════════ # Step 3: 预先构造 database 模块并注入 sys.modules, # 绕开 database.py 模块级代码中的 PG 参数 # ═══════════════════════════════════════════════════════════ from sqlalchemy import event as sa_event from sqlalchemy.ext.asyncio import ( AsyncSession, async_sessionmaker, create_async_engine, ) from types import ModuleType from collections.abc import AsyncGenerator as AsyncGenType # ── PG → SQLite 类型编译适配 + bind/result processor ───── from sqlalchemy.ext.compiler import compiles from sqlalchemy.dialects.postgresql import JSONB, UUID as PG_UUID, ARRAY as PG_ARRAY import json as _json @compiles(JSONB, "sqlite") def _compile_jsonb_sqlite(element, compiler, **kw): return "TEXT" @compiles(PG_UUID, "sqlite") def _compile_uuid_sqlite(element, compiler, **kw): return "CHAR(36)" @compiles(PG_ARRAY, "sqlite") def _compile_array_sqlite(element, compiler, **kw): return "TEXT" # 让 UUID(as_uuid=True) 在 SQLite 下正确序列化/反序列化 _orig_uuid_bind = PG_UUID.bind_processor _orig_uuid_result = PG_UUID.result_processor def _uuid_bind_processor(self, dialect): if dialect.name == "sqlite": def process(value): if value is not None: return str(value) if not isinstance(value, str) else value return value return process return _orig_uuid_bind(self, dialect) def _uuid_result_processor(self, dialect, coltype): if dialect.name == "sqlite": def process(value): if value is not None and not isinstance(value, uuid.UUID): return uuid.UUID(str(value)) return value return process return _orig_uuid_result(self, dialect, coltype) PG_UUID.bind_processor = _uuid_bind_processor PG_UUID.result_processor = _uuid_result_processor # JSONB 在 SQLite 下序列化为 JSON 字符串 _orig_jsonb_bind = JSONB.bind_processor _orig_jsonb_result = JSONB.result_processor def _jsonb_bind_processor(self, dialect): if dialect.name == "sqlite": def process(value): if value is not None: return _json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value return value return process return _orig_jsonb_bind(self, dialect) def _jsonb_result_processor(self, dialect, coltype): if dialect.name == "sqlite": def process(value): if value is not None and isinstance(value, str): try: return _json.loads(value) except _json.JSONDecodeError: return value return value return process return _orig_jsonb_result(self, dialect, coltype) JSONB.bind_processor = _jsonb_bind_processor JSONB.result_processor = _jsonb_result_processor # ARRAY 在 SQLite 下序列化为 JSON 字符串 _orig_array_bind = PG_ARRAY.bind_processor _orig_array_result = PG_ARRAY.result_processor def _array_bind_processor(self, dialect): if dialect.name == "sqlite": def process(value): if value is not None: return _json.dumps([str(v) for v in value], ensure_ascii=False) if not isinstance(value, str) else value return value return process if hasattr(_orig_array_bind, '__func__'): return _orig_array_bind(self, dialect) return None def _array_result_processor(self, dialect, coltype): if dialect.name == "sqlite": def process(value): if value is not None and isinstance(value, str): try: return _json.loads(value) except _json.JSONDecodeError: return value return value return process if hasattr(_orig_array_result, '__func__'): return _orig_array_result(self, dialect, coltype) return None PG_ARRAY.bind_processor = _array_bind_processor PG_ARRAY.result_processor = _array_result_processor # 创建测试 SQLite 引擎 _test_engine = create_async_engine("sqlite+aiosqlite://", echo=False) @sa_event.listens_for(_test_engine.sync_engine, "connect") def _set_sqlite_pragma(dbapi_conn, connection_record): cursor = dbapi_conn.cursor() cursor.execute("PRAGMA foreign_keys=OFF") cursor.close() _test_session_factory = async_sessionmaker( _test_engine, class_=AsyncSession, expire_on_commit=False ) async def _test_get_db() -> AsyncGenType: async with _test_session_factory() as session: try: yield session finally: await session.close() # 创建 fake database 模块 _fake_db_mod = ModuleType("app.db.database") _fake_db_mod.engine = _test_engine _fake_db_mod.async_session_factory = _test_session_factory _fake_db_mod.get_db = _test_get_db # 注入到 sys.modules,后续所有 `from app.db.database import ...` 都会用这个 sys.modules["app.db.database"] = _fake_db_mod # 确保 parent 包也存在 if "app.db" not in sys.modules: _fake_db_pkg = ModuleType("app.db") _fake_db_pkg.database = _fake_db_mod sys.modules["app.db"] = _fake_db_pkg # ═══════════════════════════════════════════════════════════ # Step 4: 现在安全地 import app 及所有模型 # ═══════════════════════════════════════════════════════════ from app.core.security import create_access_token, hash_password # noqa: E402 from app.models.base import Base # noqa: E402 from app.main import app # noqa: E402 from app.models.sys import SysDepartment, SysRole, SysUser, SysCompany, SysUserCompany # noqa: F401 from app.models.crm import CrmCustomer, CrmContact # noqa: F401 from app.models.erp import ProductCategory, ProductSku, InventoryFlow, ErpSkuInventory # noqa: F401 from app.models.order import ErpOrder, ErpOrderItem # noqa: F401 from app.models.contract import ErpContract, ErpContractItem, ErpContractAttachment # noqa: F401 from app.models.finance import FinInvoicePool, FinExpenseRecord, FinExpenseDetail, FinSalesInvoice # noqa: F401 from app.models.shipping import ErpShippingRecord, ErpShippingItem # noqa: F401 from app.models.ai import AiChatSession, SalesLog, AiReportDraft # noqa: F401 from app.models.cost import ErpOrderItemCost # noqa: F401 from httpx import ASGITransport, AsyncClient # ═══════════════════════════════════════════════════════════ # 固定 UUID # ═══════════════════════════════════════════════════════════ ADMIN_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000000001") SALES_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000000002") COMPANY_ID = uuid.UUID("00000000-0000-0000-0000-000000000010") DEPT_ID = uuid.UUID("00000000-0000-0000-0000-000000000020") ADMIN_ROLE_ID = uuid.UUID("00000000-0000-0000-0000-000000000030") SALES_ROLE_ID = uuid.UUID("00000000-0000-0000-0000-000000000031") CUSTOMER_ID = uuid.UUID("00000000-0000-0000-0000-000000000040") SKU_ID = uuid.UUID("00000000-0000-0000-0000-000000000050") # ═══════════════════════════════════════════════════════════ # Fixtures # ═══════════════════════════════════════════════════════════ @pytest.fixture(scope="session") def event_loop(): loop = asyncio.new_event_loop() yield loop loop.close() @pytest_asyncio.fixture(scope="session", loop_scope="session") async def _create_tables(): """session 级: 只建表一次""" async with _test_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield async with _test_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await _test_engine.dispose() @pytest_asyncio.fixture() async def db_session(_create_tables) -> AsyncGenerator[AsyncSession, None]: """ 每个测试用例用独立的事务包裹: - 开启一个连接级事务 (begin) - 在该事务上创建 session - 测试结束后 rollback 连接级事务,所有改动都会被撤销 """ async with _test_engine.connect() as conn: trans = await conn.begin() session = AsyncSession(bind=conn, expire_on_commit=False) try: yield session finally: await session.close() await trans.rollback() @pytest_asyncio.fixture() async def client(_create_tables, db_session) -> AsyncGenerator[AsyncClient, None]: async def override_get_db(): yield db_session app.dependency_overrides[_test_get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as ac: yield ac app.dependency_overrides.clear() @pytest_asyncio.fixture() async def seed_data(db_session: AsyncSession): """种子数据: 部门/角色/公司/用户/客户/SKU""" dept = SysDepartment(id=DEPT_ID, name="销售部", sort_order=1) db_session.add(dept) admin_role = SysRole( id=ADMIN_ROLE_ID, role_name="管理员", data_scope="all", menu_keys=["dashboard", "customers", "orders", "contracts", "products", "shipping", "finance", "settings", "logs", "reports"] ) sales_role = SysRole( id=SALES_ROLE_ID, role_name="销售", data_scope="self", menu_keys=["dashboard", "customers", "orders", "logs"] ) db_session.add_all([admin_role, sales_role]) company = SysCompany( id=COMPANY_ID, name="测试润滑油有限公司", code="TEST-CO", full_info={"full_name": "天津测试润滑油有限公司", "tax_id": "91120000XXXX"} ) db_session.add(company) admin_user = SysUser( id=ADMIN_USER_ID, username="admin", password_hash=hash_password("admin123"), real_name="管理员", dept_id=DEPT_ID, role_id=ADMIN_ROLE_ID, status=1 ) sales_user = SysUser( id=SALES_USER_ID, username="sales01", password_hash=hash_password("sales123"), real_name="销售01", dept_id=DEPT_ID, role_id=SALES_ROLE_ID, status=1 ) db_session.add_all([admin_user, sales_user]) await db_session.flush() db_session.add(SysUserCompany(user_id=ADMIN_USER_ID, company_id=COMPANY_ID, is_default=True)) db_session.add(SysUserCompany(user_id=SALES_USER_ID, company_id=COMPANY_ID, is_default=True)) customer = CrmCustomer( id=CUSTOMER_ID, name="中石化天津分公司", level="A", industry="石油化工", contact="张经理", phone="13800138000", owner_id=SALES_USER_ID ) db_session.add(customer) sku = ProductSku( id=SKU_ID, sku_code="LUB-001", name="壳牌劲霸R4 15W-40", spec="18L/桶", standard_price=280.00, unit="桶" ) db_session.add(sku) await db_session.commit() return { "admin_user_id": ADMIN_USER_ID, "sales_user_id": SALES_USER_ID, "company_id": COMPANY_ID, "customer_id": CUSTOMER_ID, "sku_id": SKU_ID, } def make_auth_headers(user_id: uuid.UUID, company_id: uuid.UUID = COMPANY_ID) -> dict: token = create_access_token(data={"sub": str(user_id)}) return { "Authorization": f"Bearer {token}", "X-Company-Id": str(company_id), } @pytest_asyncio.fixture() def admin_headers(seed_data) -> dict: return make_auth_headers(ADMIN_USER_ID) @pytest_asyncio.fixture() def sales_headers(seed_data) -> dict: return make_auth_headers(SALES_USER_ID)