"""multi-tenant company isolation Revision ID: a1b2c3d4e5f6 Revises: 03d8dcc2d72a Create Date: 2026-03-18 08:45:00.000000 """ from typing import Sequence, Union from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = 'a1b2c3d4e5f6' down_revision: Union[str, Sequence[str], None] = '03d8dcc2d72a' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None # 默认公司的固定 UUID DEFAULT_COMPANY_ID = 'aaaaaaaa-bbbb-cccc-dddd-eeeeeeee0001' def upgrade() -> None: # ═══════════════════════════════════════════════════════════════ # Step 1: 创建 sys_companies 公司主体表 # ═══════════════════════════════════════════════════════════════ op.create_table( 'sys_companies', sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True), sa.Column('name', sa.String(200), nullable=False), sa.Column('code', sa.String(50), unique=True, nullable=False), sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), ) # Step 2: 插入默认公司 op.execute(f""" INSERT INTO sys_companies (id, name, code, is_active) VALUES ('{DEFAULT_COMPANY_ID}', '天津硕博霖', 'SHBL-TJ', true) """) # ═══════════════════════════════════════════════════════════════ # Step 3: 创建 sys_user_companies 用户-公司关联表 # ═══════════════════════════════════════════════════════════════ op.create_table( 'sys_user_companies', sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')), sa.Column('user_id', postgresql.UUID(as_uuid=True), sa.ForeignKey('sys_users.id'), nullable=False), sa.Column('company_id', postgresql.UUID(as_uuid=True), sa.ForeignKey('sys_companies.id'), nullable=False), sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), sa.UniqueConstraint('user_id', 'company_id', name='uq_user_company'), ) # Step 4: 为所有现有用户关联默认公司 op.execute(f""" INSERT INTO sys_user_companies (id, user_id, company_id, is_default) SELECT gen_random_uuid(), id, '{DEFAULT_COMPANY_ID}'::uuid, true FROM sys_users WHERE is_deleted = false """) # ═══════════════════════════════════════════════════════════════ # Step 5: 创建 erp_sku_inventory 分公司库存表 # ═══════════════════════════════════════════════════════════════ op.create_table( 'erp_sku_inventory', sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')), sa.Column('sku_id', postgresql.UUID(as_uuid=True), sa.ForeignKey('erp_product_skus.id'), nullable=False), sa.Column('company_id', postgresql.UUID(as_uuid=True), sa.ForeignKey('sys_companies.id'), nullable=False), sa.Column('stock_qty', sa.Numeric(12, 2), server_default=sa.text('0'), nullable=False), sa.Column('warning_threshold', sa.Numeric(12, 2), server_default=sa.text('0'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), sa.UniqueConstraint('sku_id', 'company_id', name='uq_sku_company'), ) op.create_index('ix_erp_sku_inventory_company_id', 'erp_sku_inventory', ['company_id']) # Step 6: 迁移 erp_product_skus 的库存数据到 erp_sku_inventory op.execute(f""" INSERT INTO erp_sku_inventory (id, sku_id, company_id, stock_qty, warning_threshold) SELECT gen_random_uuid(), id, '{DEFAULT_COMPANY_ID}'::uuid, stock_qty, warning_threshold FROM erp_product_skus WHERE is_deleted = false """) # ═══════════════════════════════════════════════════════════════ # Step 7: 为业务表追加 company_id 列(先 nullable → 填数据 → set NOT NULL) # ═══════════════════════════════════════════════════════════════ tables_with_company_id = [ 'erp_orders', 'erp_inventory_flows', 'erp_shipping_records', 'fin_invoice_pool', 'fin_expense_records', 'finance_sales_invoices', 'sales_logs', ] for table in tables_with_company_id: # 添加列(先允许 NULL) op.add_column(table, sa.Column('company_id', postgresql.UUID(as_uuid=True), nullable=True)) # 填入默认公司 ID op.execute(f""" UPDATE {table} SET company_id = '{DEFAULT_COMPANY_ID}'::uuid WHERE company_id IS NULL """) # 设 NOT NULL op.alter_column(table, 'company_id', nullable=False) # 创建外键 op.create_foreign_key( f'fk_{table}_company_id', table, 'sys_companies', ['company_id'], ['id'], ) # 创建索引 op.create_index(f'ix_{table}_company_id', table, ['company_id']) # ═══════════════════════════════════════════════════════════════ # Step 8: 从 erp_product_skus 删除已迁移的库存字段 # ═══════════════════════════════════════════════════════════════ op.drop_column('erp_product_skus', 'stock_qty') op.drop_column('erp_product_skus', 'warning_threshold') def downgrade() -> None: # 恢复 erp_product_skus 的库存字段 op.add_column('erp_product_skus', sa.Column('stock_qty', sa.Numeric(12, 2), server_default=sa.text('0'), nullable=False)) op.add_column('erp_product_skus', sa.Column('warning_threshold', sa.Numeric(12, 2), server_default=sa.text('0'), nullable=False)) # 从 erp_sku_inventory 回迁默认公司的库存数据 op.execute(f""" UPDATE erp_product_skus SET stock_qty = inv.stock_qty, warning_threshold = inv.warning_threshold FROM erp_sku_inventory inv WHERE erp_product_skus.id = inv.sku_id AND inv.company_id = '{DEFAULT_COMPANY_ID}'::uuid """) # 删除 company_id 列 tables_with_company_id = [ 'erp_orders', 'erp_inventory_flows', 'erp_shipping_records', 'fin_invoice_pool', 'fin_expense_records', 'finance_sales_invoices', 'sales_logs', ] for table in tables_with_company_id: op.drop_index(f'ix_{table}_company_id', table_name=table) op.drop_constraint(f'fk_{table}_company_id', table, type_='foreignkey') op.drop_column(table, 'company_id') # 删除新建的表 op.drop_index('ix_erp_sku_inventory_company_id', table_name='erp_sku_inventory') op.drop_table('erp_sku_inventory') op.drop_table('sys_user_companies') op.drop_table('sys_companies')