v0.2.0: CRM/ERP 系统升级 - 清理 .gitignore 并移除误提交的 venv/env/db 文件
- 更新 .gitignore:全面覆盖环境变量、数据库、日志、缓存、上传文件 - 移除误跟踪的 server/venv/、crm_data.db、.env 文件 - 新增 server/.env.example 模板 - 新增合同管理、利润核算、AI教练等功能模块 - 新增 Playwright e2e 测试套件 - 前后端多项功能升级和 bug 修复
This commit is contained in:
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
OCR 后台 Worker — asyncio 协程,FastAPI lifespan 启动
|
||||
策略 C: 工作时间限流(1并发 + 60s间隔),17:00-20:00 BJT 全速
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import async_session_factory
|
||||
from app.models.finance import FinInvoicePool, FinOcrTask
|
||||
|
||||
|
||||
class OcrWorker:
|
||||
"""后台 OCR 任务处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.running = False
|
||||
self.current_task_id: uuid.UUID | None = None
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
def start(self):
|
||||
self.running = True
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
print("[OcrWorker] 启动 — 策略 C: 工作时间限流, 17-20 BJT 全速")
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
print("[OcrWorker] 已停止")
|
||||
|
||||
async def _run_loop(self):
|
||||
"""主循环:每 10 秒检查一次队列"""
|
||||
while self.running:
|
||||
try:
|
||||
task = await self._pick_next_task()
|
||||
if task:
|
||||
await self._process_task(task)
|
||||
# 限流:非高峰期间隔 60s
|
||||
if not self._is_peak_time():
|
||||
await asyncio.sleep(60)
|
||||
else:
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"[OcrWorker] 循环异常: {e}")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
def _is_peak_time(self) -> bool:
|
||||
"""17:00-20:00 BJT = 09:00-12:00 UTC"""
|
||||
utc_hour = datetime.utcnow().hour
|
||||
return 9 <= utc_hour < 12
|
||||
|
||||
async def _pick_next_task(self) -> dict | None:
|
||||
"""从 DB 获取优先级最高的 pending 任务"""
|
||||
async with async_session_factory() as db:
|
||||
stmt = (
|
||||
select(FinOcrTask)
|
||||
.where(
|
||||
FinOcrTask.status == "pending",
|
||||
FinOcrTask.is_deleted.is_(False),
|
||||
FinOcrTask.retry_count < FinOcrTask.max_retries,
|
||||
)
|
||||
.order_by(FinOcrTask.priority, FinOcrTask.created_at)
|
||||
.limit(1)
|
||||
)
|
||||
task = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
# 标记为 processing
|
||||
task.status = "processing"
|
||||
task.updated_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
self.current_task_id = task.id
|
||||
return {
|
||||
"id": task.id,
|
||||
"file_url": task.file_url,
|
||||
"file_ext": task.file_ext,
|
||||
"original_name": task.original_name,
|
||||
"uploader_id": task.uploader_id,
|
||||
"company_id": task.company_id,
|
||||
"inv_type": task.inv_type,
|
||||
"retry_count": task.retry_count,
|
||||
}
|
||||
|
||||
async def _process_task(self, task_info: dict):
|
||||
"""执行 OCR 并更新"""
|
||||
task_id = task_info["id"]
|
||||
file_url = task_info["file_url"]
|
||||
file_ext = task_info["file_ext"]
|
||||
print(f"[OcrWorker] 处理任务 {task_id} ({task_info['original_name']}, {file_ext})")
|
||||
|
||||
try:
|
||||
# 读取文件
|
||||
file_path = file_url.lstrip("/")
|
||||
if not os.path.exists(file_path):
|
||||
await self._mark_failed(task_id, f"文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
ocr_data = {}
|
||||
message = ""
|
||||
|
||||
# PDF 处理
|
||||
if file_ext == ".pdf":
|
||||
ocr_data, message = await self._process_pdf(file_bytes)
|
||||
# 图片处理
|
||||
elif file_ext in (".png", ".jpg", ".jpeg"):
|
||||
ocr_data, message = await self._process_image(file_bytes)
|
||||
else:
|
||||
await self._mark_failed(task_id, f"不支持的文件格式: {file_ext}")
|
||||
return
|
||||
|
||||
if ocr_data and (ocr_data.get("merchant") or ocr_data.get("amount")):
|
||||
# OCR 成功 → 自动入池
|
||||
await self._mark_success_and_pool(task_id, task_info, ocr_data)
|
||||
print(f"[OcrWorker] ✅ {task_info['original_name']} 入池成功")
|
||||
else:
|
||||
# OCR 完成但没提取到关键字段
|
||||
await self._mark_failed(
|
||||
task_id,
|
||||
message or "AI 未能提取发票关键字段(开票方/金额),请手动录入",
|
||||
ocr_data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[OcrWorker] ❌ 任务 {task_id} 异常: {e}")
|
||||
await self._mark_failed(task_id, str(e))
|
||||
|
||||
self.current_task_id = None
|
||||
|
||||
async def _process_pdf(self, file_bytes: bytes) -> tuple[dict, str]:
|
||||
"""PDF: 先尝试文本提取,失败降级 Vision OCR"""
|
||||
try:
|
||||
import fitz
|
||||
doc = fitz.open(stream=file_bytes, filetype="pdf")
|
||||
text = ""
|
||||
for page in doc:
|
||||
text += page.get_text() + "\n"
|
||||
doc.close()
|
||||
text = text.strip()
|
||||
|
||||
if len(text) > 50:
|
||||
from app.services.ocr_service import extract_invoice_from_text
|
||||
result = await extract_invoice_from_text(text, "invoice")
|
||||
if result.get("success") and result.get("data"):
|
||||
return result["data"], "PDF 文本解析成功"
|
||||
|
||||
# 降级: 扫描件 → Vision OCR
|
||||
doc2 = fitz.open(stream=file_bytes, filetype="pdf")
|
||||
pix = doc2[0].get_pixmap(dpi=150)
|
||||
ocr_bytes = pix.tobytes("png")
|
||||
doc2.close()
|
||||
return await self._vision_ocr(ocr_bytes)
|
||||
|
||||
except Exception as e:
|
||||
return {}, f"PDF 处理失败: {e}"
|
||||
|
||||
async def _process_image(self, file_bytes: bytes) -> tuple[dict, str]:
|
||||
"""图片: Vision OCR"""
|
||||
return await self._vision_ocr(file_bytes)
|
||||
|
||||
async def _vision_ocr(self, image_bytes: bytes) -> tuple[dict, str]:
|
||||
"""调用 3090 Vision OCR"""
|
||||
import base64
|
||||
from app.services.ocr_service import ocr_image
|
||||
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
result = await ocr_image(image_b64, "invoice")
|
||||
if result.get("success"):
|
||||
return result.get("data", {}), "Vision OCR 成功"
|
||||
return {}, result.get("error", "OCR 失败")
|
||||
|
||||
async def _mark_success_and_pool(self, task_id: uuid.UUID, task_info: dict, ocr_data: dict):
|
||||
"""标记成功 + 自动入池"""
|
||||
async with async_session_factory() as db:
|
||||
merchant = ocr_data.get("merchant") or ocr_data.get("merchant_name") or "(AI 提取)"
|
||||
amount = 0
|
||||
try:
|
||||
amount = float(ocr_data.get("amount", 0))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
invoice_date_str = ocr_data.get("date")
|
||||
invoice_date = None
|
||||
if invoice_date_str:
|
||||
try:
|
||||
from datetime import date as dt_date
|
||||
invoice_date = dt_date.fromisoformat(invoice_date_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
inv = FinInvoicePool(
|
||||
uploader_id=task_info["uploader_id"],
|
||||
company_id=task_info["company_id"],
|
||||
file_url=task_info["file_url"],
|
||||
merchant_name=merchant,
|
||||
amount=amount,
|
||||
invoice_date=invoice_date,
|
||||
type=task_info["inv_type"],
|
||||
ai_extracted_data=ocr_data,
|
||||
is_used=False,
|
||||
)
|
||||
db.add(inv)
|
||||
await db.flush()
|
||||
|
||||
await db.execute(
|
||||
update(FinOcrTask)
|
||||
.where(FinOcrTask.id == task_id)
|
||||
.values(
|
||||
status="success",
|
||||
ocr_result=ocr_data,
|
||||
invoice_pool_id=inv.id,
|
||||
error_message=None,
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def _mark_failed(self, task_id: uuid.UUID, error: str, partial_data: dict | None = None):
|
||||
"""标记失败 + retry_count+1"""
|
||||
async with async_session_factory() as db:
|
||||
task = (await db.execute(
|
||||
select(FinOcrTask).where(FinOcrTask.id == task_id)
|
||||
)).scalar_one_or_none()
|
||||
if not task:
|
||||
return
|
||||
|
||||
new_retry = task.retry_count + 1
|
||||
new_status = "failed" if new_retry >= task.max_retries else "pending"
|
||||
|
||||
await db.execute(
|
||||
update(FinOcrTask)
|
||||
.where(FinOcrTask.id == task_id)
|
||||
.values(
|
||||
status=new_status,
|
||||
retry_count=new_retry,
|
||||
error_message=error,
|
||||
ocr_result=partial_data or task.ocr_result,
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
if new_status == "pending":
|
||||
print(f"[OcrWorker] ⚠️ 任务 {task_id} 第 {new_retry} 次重试入队")
|
||||
else:
|
||||
print(f"[OcrWorker] ❌ 任务 {task_id} 已达最大重试次数,标记失败")
|
||||
|
||||
|
||||
# 单例
|
||||
ocr_worker = OcrWorker()
|
||||
Reference in New Issue
Block a user