815cbf9d8c
- 更新 .gitignore:全面覆盖环境变量、数据库、日志、缓存、上传文件 - 移除误跟踪的 server/venv/、crm_data.db、.env 文件 - 新增 server/.env.example 模板 - 新增合同管理、利润核算、AI教练等功能模块 - 新增 Playwright e2e 测试套件 - 前后端多项功能升级和 bug 修复
351 lines
14 KiB
Python
351 lines
14 KiB
Python
"""
|
|
SHBL-ERP CRM 测试基础设施
|
|
=========================
|
|
策略: 在 import app 之前, 先修改 settings 实例的 DATABASE_URL,
|
|
然后 monkey-patch app.db.database 模块, 再安全 import app。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import sys
|
|
import uuid
|
|
from datetime import date, datetime
|
|
from typing import AsyncGenerator
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# Step 1: 设置环境变量 (在任何 app 代码被 import 之前)
|
|
# ═══════════════════════════════════════════════════════════
|
|
import os
|
|
os.environ["DATABASE_URL"] = "sqlite+aiosqlite://"
|
|
os.environ["JWT_SECRET_KEY"] = "test-secret-key-for-unit-tests"
|
|
os.environ["DEBUG"] = "false"
|
|
os.environ["DIFY_API_BASE_URL"] = ""
|
|
os.environ["DIFY_API_KEY"] = ""
|
|
os.environ["DIFY_WORKFLOW_PERSONA_KEY"] = ""
|
|
os.environ["DIFY_WORKFLOW_REPORT_KEY"] = ""
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# Step 2: 先 import config (轻量级, 不会触发 DB 连接)
|
|
# 然后 patch settings.DATABASE_URL
|
|
# ═══════════════════════════════════════════════════════════
|
|
from app.core.config import settings # noqa: E402
|
|
# 确保 settings 指向 sqlite
|
|
settings.DATABASE_URL = "sqlite+aiosqlite://"
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# Step 3: 预先构造 database 模块并注入 sys.modules,
|
|
# 绕开 database.py 模块级代码中的 PG 参数
|
|
# ═══════════════════════════════════════════════════════════
|
|
from sqlalchemy import event as sa_event
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
from types import ModuleType
|
|
from collections.abc import AsyncGenerator as AsyncGenType
|
|
|
|
# ── PG → SQLite 类型编译适配 + bind/result processor ─────
|
|
from sqlalchemy.ext.compiler import compiles
|
|
from sqlalchemy.dialects.postgresql import JSONB, UUID as PG_UUID, ARRAY as PG_ARRAY
|
|
import json as _json
|
|
|
|
@compiles(JSONB, "sqlite")
|
|
def _compile_jsonb_sqlite(element, compiler, **kw):
|
|
return "TEXT"
|
|
|
|
@compiles(PG_UUID, "sqlite")
|
|
def _compile_uuid_sqlite(element, compiler, **kw):
|
|
return "CHAR(36)"
|
|
|
|
@compiles(PG_ARRAY, "sqlite")
|
|
def _compile_array_sqlite(element, compiler, **kw):
|
|
return "TEXT"
|
|
|
|
# 让 UUID(as_uuid=True) 在 SQLite 下正确序列化/反序列化
|
|
_orig_uuid_bind = PG_UUID.bind_processor
|
|
_orig_uuid_result = PG_UUID.result_processor
|
|
|
|
def _uuid_bind_processor(self, dialect):
|
|
if dialect.name == "sqlite":
|
|
def process(value):
|
|
if value is not None:
|
|
return str(value) if not isinstance(value, str) else value
|
|
return value
|
|
return process
|
|
return _orig_uuid_bind(self, dialect)
|
|
|
|
def _uuid_result_processor(self, dialect, coltype):
|
|
if dialect.name == "sqlite":
|
|
def process(value):
|
|
if value is not None and not isinstance(value, uuid.UUID):
|
|
return uuid.UUID(str(value))
|
|
return value
|
|
return process
|
|
return _orig_uuid_result(self, dialect, coltype)
|
|
|
|
PG_UUID.bind_processor = _uuid_bind_processor
|
|
PG_UUID.result_processor = _uuid_result_processor
|
|
|
|
# JSONB 在 SQLite 下序列化为 JSON 字符串
|
|
_orig_jsonb_bind = JSONB.bind_processor
|
|
_orig_jsonb_result = JSONB.result_processor
|
|
|
|
def _jsonb_bind_processor(self, dialect):
|
|
if dialect.name == "sqlite":
|
|
def process(value):
|
|
if value is not None:
|
|
return _json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
|
|
return value
|
|
return process
|
|
return _orig_jsonb_bind(self, dialect)
|
|
|
|
def _jsonb_result_processor(self, dialect, coltype):
|
|
if dialect.name == "sqlite":
|
|
def process(value):
|
|
if value is not None and isinstance(value, str):
|
|
try:
|
|
return _json.loads(value)
|
|
except _json.JSONDecodeError:
|
|
return value
|
|
return value
|
|
return process
|
|
return _orig_jsonb_result(self, dialect, coltype)
|
|
|
|
JSONB.bind_processor = _jsonb_bind_processor
|
|
JSONB.result_processor = _jsonb_result_processor
|
|
|
|
# ARRAY 在 SQLite 下序列化为 JSON 字符串
|
|
_orig_array_bind = PG_ARRAY.bind_processor
|
|
_orig_array_result = PG_ARRAY.result_processor
|
|
|
|
def _array_bind_processor(self, dialect):
|
|
if dialect.name == "sqlite":
|
|
def process(value):
|
|
if value is not None:
|
|
return _json.dumps([str(v) for v in value], ensure_ascii=False) if not isinstance(value, str) else value
|
|
return value
|
|
return process
|
|
if hasattr(_orig_array_bind, '__func__'):
|
|
return _orig_array_bind(self, dialect)
|
|
return None
|
|
|
|
def _array_result_processor(self, dialect, coltype):
|
|
if dialect.name == "sqlite":
|
|
def process(value):
|
|
if value is not None and isinstance(value, str):
|
|
try:
|
|
return _json.loads(value)
|
|
except _json.JSONDecodeError:
|
|
return value
|
|
return value
|
|
return process
|
|
if hasattr(_orig_array_result, '__func__'):
|
|
return _orig_array_result(self, dialect, coltype)
|
|
return None
|
|
|
|
PG_ARRAY.bind_processor = _array_bind_processor
|
|
PG_ARRAY.result_processor = _array_result_processor
|
|
|
|
|
|
# 创建测试 SQLite 引擎
|
|
_test_engine = create_async_engine("sqlite+aiosqlite://", echo=False)
|
|
|
|
@sa_event.listens_for(_test_engine.sync_engine, "connect")
|
|
def _set_sqlite_pragma(dbapi_conn, connection_record):
|
|
cursor = dbapi_conn.cursor()
|
|
cursor.execute("PRAGMA foreign_keys=OFF")
|
|
cursor.close()
|
|
|
|
_test_session_factory = async_sessionmaker(
|
|
_test_engine, class_=AsyncSession, expire_on_commit=False
|
|
)
|
|
|
|
async def _test_get_db() -> AsyncGenType:
|
|
async with _test_session_factory() as session:
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
|
|
# 创建 fake database 模块
|
|
_fake_db_mod = ModuleType("app.db.database")
|
|
_fake_db_mod.engine = _test_engine
|
|
_fake_db_mod.async_session_factory = _test_session_factory
|
|
_fake_db_mod.get_db = _test_get_db
|
|
|
|
# 注入到 sys.modules,后续所有 `from app.db.database import ...` 都会用这个
|
|
sys.modules["app.db.database"] = _fake_db_mod
|
|
# 确保 parent 包也存在
|
|
if "app.db" not in sys.modules:
|
|
_fake_db_pkg = ModuleType("app.db")
|
|
_fake_db_pkg.database = _fake_db_mod
|
|
sys.modules["app.db"] = _fake_db_pkg
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# Step 4: 现在安全地 import app 及所有模型
|
|
# ═══════════════════════════════════════════════════════════
|
|
from app.core.security import create_access_token, hash_password # noqa: E402
|
|
from app.models.base import Base # noqa: E402
|
|
from app.main import app # noqa: E402
|
|
|
|
from app.models.sys import SysDepartment, SysRole, SysUser, SysCompany, SysUserCompany # noqa: F401
|
|
from app.models.crm import CrmCustomer, CrmContact # noqa: F401
|
|
from app.models.erp import ProductCategory, ProductSku, InventoryFlow, ErpSkuInventory # noqa: F401
|
|
from app.models.order import ErpOrder, ErpOrderItem # noqa: F401
|
|
from app.models.contract import ErpContract, ErpContractItem, ErpContractAttachment # noqa: F401
|
|
from app.models.finance import FinInvoicePool, FinExpenseRecord, FinExpenseDetail, FinSalesInvoice # noqa: F401
|
|
from app.models.shipping import ErpShippingRecord, ErpShippingItem # noqa: F401
|
|
from app.models.ai import AiChatSession, SalesLog, AiReportDraft # noqa: F401
|
|
from app.models.cost import ErpOrderItemCost # noqa: F401
|
|
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# 固定 UUID
|
|
# ═══════════════════════════════════════════════════════════
|
|
ADMIN_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
|
SALES_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000000002")
|
|
COMPANY_ID = uuid.UUID("00000000-0000-0000-0000-000000000010")
|
|
DEPT_ID = uuid.UUID("00000000-0000-0000-0000-000000000020")
|
|
ADMIN_ROLE_ID = uuid.UUID("00000000-0000-0000-0000-000000000030")
|
|
SALES_ROLE_ID = uuid.UUID("00000000-0000-0000-0000-000000000031")
|
|
CUSTOMER_ID = uuid.UUID("00000000-0000-0000-0000-000000000040")
|
|
SKU_ID = uuid.UUID("00000000-0000-0000-0000-000000000050")
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# Fixtures
|
|
# ═══════════════════════════════════════════════════════════
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop():
|
|
loop = asyncio.new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
|
async def _create_tables():
|
|
"""session 级: 只建表一次"""
|
|
async with _test_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield
|
|
async with _test_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await _test_engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture()
|
|
async def db_session(_create_tables) -> AsyncGenerator[AsyncSession, None]:
|
|
"""
|
|
每个测试用例用独立的事务包裹:
|
|
- 开启一个连接级事务 (begin)
|
|
- 在该事务上创建 session
|
|
- 测试结束后 rollback 连接级事务,所有改动都会被撤销
|
|
"""
|
|
async with _test_engine.connect() as conn:
|
|
trans = await conn.begin()
|
|
session = AsyncSession(bind=conn, expire_on_commit=False)
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
await trans.rollback()
|
|
|
|
|
|
@pytest_asyncio.fixture()
|
|
async def client(_create_tables, db_session) -> AsyncGenerator[AsyncClient, None]:
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[_test_get_db] = override_get_db
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
|
yield ac
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest_asyncio.fixture()
|
|
async def seed_data(db_session: AsyncSession):
|
|
"""种子数据: 部门/角色/公司/用户/客户/SKU"""
|
|
dept = SysDepartment(id=DEPT_ID, name="销售部", sort_order=1)
|
|
db_session.add(dept)
|
|
|
|
admin_role = SysRole(
|
|
id=ADMIN_ROLE_ID, role_name="管理员", data_scope="all",
|
|
menu_keys=["dashboard", "customers", "orders", "contracts", "products",
|
|
"shipping", "finance", "settings", "logs", "reports"]
|
|
)
|
|
sales_role = SysRole(
|
|
id=SALES_ROLE_ID, role_name="销售", data_scope="self",
|
|
menu_keys=["dashboard", "customers", "orders", "logs"]
|
|
)
|
|
db_session.add_all([admin_role, sales_role])
|
|
|
|
company = SysCompany(
|
|
id=COMPANY_ID, name="测试润滑油有限公司", code="TEST-CO",
|
|
full_info={"full_name": "天津测试润滑油有限公司", "tax_id": "91120000XXXX"}
|
|
)
|
|
db_session.add(company)
|
|
|
|
admin_user = SysUser(
|
|
id=ADMIN_USER_ID, username="admin", password_hash=hash_password("admin123"),
|
|
real_name="管理员", dept_id=DEPT_ID, role_id=ADMIN_ROLE_ID, status=1
|
|
)
|
|
sales_user = SysUser(
|
|
id=SALES_USER_ID, username="sales01", password_hash=hash_password("sales123"),
|
|
real_name="销售01", dept_id=DEPT_ID, role_id=SALES_ROLE_ID, status=1
|
|
)
|
|
db_session.add_all([admin_user, sales_user])
|
|
await db_session.flush()
|
|
|
|
db_session.add(SysUserCompany(user_id=ADMIN_USER_ID, company_id=COMPANY_ID, is_default=True))
|
|
db_session.add(SysUserCompany(user_id=SALES_USER_ID, company_id=COMPANY_ID, is_default=True))
|
|
|
|
customer = CrmCustomer(
|
|
id=CUSTOMER_ID, name="中石化天津分公司", level="A",
|
|
industry="石油化工", contact="张经理", phone="13800138000",
|
|
owner_id=SALES_USER_ID
|
|
)
|
|
db_session.add(customer)
|
|
|
|
sku = ProductSku(
|
|
id=SKU_ID, sku_code="LUB-001", name="壳牌劲霸R4 15W-40",
|
|
spec="18L/桶", standard_price=280.00, unit="桶"
|
|
)
|
|
db_session.add(sku)
|
|
|
|
await db_session.commit()
|
|
return {
|
|
"admin_user_id": ADMIN_USER_ID,
|
|
"sales_user_id": SALES_USER_ID,
|
|
"company_id": COMPANY_ID,
|
|
"customer_id": CUSTOMER_ID,
|
|
"sku_id": SKU_ID,
|
|
}
|
|
|
|
|
|
def make_auth_headers(user_id: uuid.UUID, company_id: uuid.UUID = COMPANY_ID) -> dict:
|
|
token = create_access_token(data={"sub": str(user_id)})
|
|
return {
|
|
"Authorization": f"Bearer {token}",
|
|
"X-Company-Id": str(company_id),
|
|
}
|
|
|
|
|
|
@pytest_asyncio.fixture()
|
|
def admin_headers(seed_data) -> dict:
|
|
return make_auth_headers(ADMIN_USER_ID)
|
|
|
|
|
|
@pytest_asyncio.fixture()
|
|
def sales_headers(seed_data) -> dict:
|
|
return make_auth_headers(SALES_USER_ID)
|