815cbf9d8c
- 更新 .gitignore:全面覆盖环境变量、数据库、日志、缓存、上传文件 - 移除误跟踪的 server/venv/、crm_data.db、.env 文件 - 新增 server/.env.example 模板 - 新增合同管理、利润核算、AI教练等功能模块 - 新增 Playwright e2e 测试套件 - 前后端多项功能升级和 bug 修复
267 lines
9.6 KiB
Python
267 lines
9.6 KiB
Python
"""
|
||
OCR 后台 Worker — asyncio 协程,FastAPI lifespan 启动
|
||
策略 C: 工作时间限流(1并发 + 60s间隔),17:00-20:00 BJT 全速
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import os
|
||
import uuid
|
||
from datetime import datetime, timedelta
|
||
|
||
from sqlalchemy import select, update
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.db.database import async_session_factory
|
||
from app.models.finance import FinInvoicePool, FinOcrTask
|
||
|
||
|
||
class OcrWorker:
|
||
"""后台 OCR 任务处理器"""
|
||
|
||
def __init__(self):
|
||
self.running = False
|
||
self.current_task_id: uuid.UUID | None = None
|
||
self._task: asyncio.Task | None = None
|
||
|
||
def start(self):
|
||
self.running = True
|
||
self._task = asyncio.create_task(self._run_loop())
|
||
print("[OcrWorker] 启动 — 策略 C: 工作时间限流, 17-20 BJT 全速")
|
||
|
||
async def stop(self):
|
||
self.running = False
|
||
if self._task:
|
||
self._task.cancel()
|
||
try:
|
||
await self._task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
print("[OcrWorker] 已停止")
|
||
|
||
async def _run_loop(self):
|
||
"""主循环:每 10 秒检查一次队列"""
|
||
while self.running:
|
||
try:
|
||
task = await self._pick_next_task()
|
||
if task:
|
||
await self._process_task(task)
|
||
# 限流:非高峰期间隔 60s
|
||
if not self._is_peak_time():
|
||
await asyncio.sleep(60)
|
||
else:
|
||
await asyncio.sleep(5)
|
||
else:
|
||
await asyncio.sleep(10)
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
print(f"[OcrWorker] 循环异常: {e}")
|
||
await asyncio.sleep(30)
|
||
|
||
def _is_peak_time(self) -> bool:
|
||
"""17:00-20:00 BJT = 09:00-12:00 UTC"""
|
||
utc_hour = datetime.utcnow().hour
|
||
return 9 <= utc_hour < 12
|
||
|
||
async def _pick_next_task(self) -> dict | None:
|
||
"""从 DB 获取优先级最高的 pending 任务"""
|
||
async with async_session_factory() as db:
|
||
stmt = (
|
||
select(FinOcrTask)
|
||
.where(
|
||
FinOcrTask.status == "pending",
|
||
FinOcrTask.is_deleted.is_(False),
|
||
FinOcrTask.retry_count < FinOcrTask.max_retries,
|
||
)
|
||
.order_by(FinOcrTask.priority, FinOcrTask.created_at)
|
||
.limit(1)
|
||
)
|
||
task = (await db.execute(stmt)).scalar_one_or_none()
|
||
if not task:
|
||
return None
|
||
|
||
# 标记为 processing
|
||
task.status = "processing"
|
||
task.updated_at = datetime.utcnow()
|
||
await db.commit()
|
||
|
||
self.current_task_id = task.id
|
||
return {
|
||
"id": task.id,
|
||
"file_url": task.file_url,
|
||
"file_ext": task.file_ext,
|
||
"original_name": task.original_name,
|
||
"uploader_id": task.uploader_id,
|
||
"company_id": task.company_id,
|
||
"inv_type": task.inv_type,
|
||
"retry_count": task.retry_count,
|
||
}
|
||
|
||
async def _process_task(self, task_info: dict):
|
||
"""执行 OCR 并更新"""
|
||
task_id = task_info["id"]
|
||
file_url = task_info["file_url"]
|
||
file_ext = task_info["file_ext"]
|
||
print(f"[OcrWorker] 处理任务 {task_id} ({task_info['original_name']}, {file_ext})")
|
||
|
||
try:
|
||
# 读取文件
|
||
file_path = file_url.lstrip("/")
|
||
if not os.path.exists(file_path):
|
||
await self._mark_failed(task_id, f"文件不存在: {file_path}")
|
||
return
|
||
|
||
with open(file_path, "rb") as f:
|
||
file_bytes = f.read()
|
||
|
||
ocr_data = {}
|
||
message = ""
|
||
|
||
# PDF 处理
|
||
if file_ext == ".pdf":
|
||
ocr_data, message = await self._process_pdf(file_bytes)
|
||
# 图片处理
|
||
elif file_ext in (".png", ".jpg", ".jpeg"):
|
||
ocr_data, message = await self._process_image(file_bytes)
|
||
else:
|
||
await self._mark_failed(task_id, f"不支持的文件格式: {file_ext}")
|
||
return
|
||
|
||
if ocr_data and (ocr_data.get("merchant") or ocr_data.get("amount")):
|
||
# OCR 成功 → 自动入池
|
||
await self._mark_success_and_pool(task_id, task_info, ocr_data)
|
||
print(f"[OcrWorker] ✅ {task_info['original_name']} 入池成功")
|
||
else:
|
||
# OCR 完成但没提取到关键字段
|
||
await self._mark_failed(
|
||
task_id,
|
||
message or "AI 未能提取发票关键字段(开票方/金额),请手动录入",
|
||
ocr_data,
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"[OcrWorker] ❌ 任务 {task_id} 异常: {e}")
|
||
await self._mark_failed(task_id, str(e))
|
||
|
||
self.current_task_id = None
|
||
|
||
async def _process_pdf(self, file_bytes: bytes) -> tuple[dict, str]:
|
||
"""PDF: 先尝试文本提取,失败降级 Vision OCR"""
|
||
try:
|
||
import fitz
|
||
doc = fitz.open(stream=file_bytes, filetype="pdf")
|
||
text = ""
|
||
for page in doc:
|
||
text += page.get_text() + "\n"
|
||
doc.close()
|
||
text = text.strip()
|
||
|
||
if len(text) > 50:
|
||
from app.services.ocr_service import extract_invoice_from_text
|
||
result = await extract_invoice_from_text(text, "invoice")
|
||
if result.get("success") and result.get("data"):
|
||
return result["data"], "PDF 文本解析成功"
|
||
|
||
# 降级: 扫描件 → Vision OCR
|
||
doc2 = fitz.open(stream=file_bytes, filetype="pdf")
|
||
pix = doc2[0].get_pixmap(dpi=150)
|
||
ocr_bytes = pix.tobytes("png")
|
||
doc2.close()
|
||
return await self._vision_ocr(ocr_bytes)
|
||
|
||
except Exception as e:
|
||
return {}, f"PDF 处理失败: {e}"
|
||
|
||
async def _process_image(self, file_bytes: bytes) -> tuple[dict, str]:
|
||
"""图片: Vision OCR"""
|
||
return await self._vision_ocr(file_bytes)
|
||
|
||
async def _vision_ocr(self, image_bytes: bytes) -> tuple[dict, str]:
|
||
"""调用 3090 Vision OCR"""
|
||
import base64
|
||
from app.services.ocr_service import ocr_image
|
||
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||
result = await ocr_image(image_b64, "invoice")
|
||
if result.get("success"):
|
||
return result.get("data", {}), "Vision OCR 成功"
|
||
return {}, result.get("error", "OCR 失败")
|
||
|
||
async def _mark_success_and_pool(self, task_id: uuid.UUID, task_info: dict, ocr_data: dict):
|
||
"""标记成功 + 自动入池"""
|
||
async with async_session_factory() as db:
|
||
merchant = ocr_data.get("merchant") or ocr_data.get("merchant_name") or "(AI 提取)"
|
||
amount = 0
|
||
try:
|
||
amount = float(ocr_data.get("amount", 0))
|
||
except (ValueError, TypeError):
|
||
pass
|
||
invoice_date_str = ocr_data.get("date")
|
||
invoice_date = None
|
||
if invoice_date_str:
|
||
try:
|
||
from datetime import date as dt_date
|
||
invoice_date = dt_date.fromisoformat(invoice_date_str)
|
||
except ValueError:
|
||
pass
|
||
|
||
inv = FinInvoicePool(
|
||
uploader_id=task_info["uploader_id"],
|
||
company_id=task_info["company_id"],
|
||
file_url=task_info["file_url"],
|
||
merchant_name=merchant,
|
||
amount=amount,
|
||
invoice_date=invoice_date,
|
||
type=task_info["inv_type"],
|
||
ai_extracted_data=ocr_data,
|
||
is_used=False,
|
||
)
|
||
db.add(inv)
|
||
await db.flush()
|
||
|
||
await db.execute(
|
||
update(FinOcrTask)
|
||
.where(FinOcrTask.id == task_id)
|
||
.values(
|
||
status="success",
|
||
ocr_result=ocr_data,
|
||
invoice_pool_id=inv.id,
|
||
error_message=None,
|
||
updated_at=datetime.utcnow(),
|
||
)
|
||
)
|
||
await db.commit()
|
||
|
||
async def _mark_failed(self, task_id: uuid.UUID, error: str, partial_data: dict | None = None):
|
||
"""标记失败 + retry_count+1"""
|
||
async with async_session_factory() as db:
|
||
task = (await db.execute(
|
||
select(FinOcrTask).where(FinOcrTask.id == task_id)
|
||
)).scalar_one_or_none()
|
||
if not task:
|
||
return
|
||
|
||
new_retry = task.retry_count + 1
|
||
new_status = "failed" if new_retry >= task.max_retries else "pending"
|
||
|
||
await db.execute(
|
||
update(FinOcrTask)
|
||
.where(FinOcrTask.id == task_id)
|
||
.values(
|
||
status=new_status,
|
||
retry_count=new_retry,
|
||
error_message=error,
|
||
ocr_result=partial_data or task.ocr_result,
|
||
updated_at=datetime.utcnow(),
|
||
)
|
||
)
|
||
await db.commit()
|
||
if new_status == "pending":
|
||
print(f"[OcrWorker] ⚠️ 任务 {task_id} 第 {new_retry} 次重试入队")
|
||
else:
|
||
print(f"[OcrWorker] ❌ 任务 {task_id} 已达最大重试次数,标记失败")
|
||
|
||
|
||
# 单例
|
||
ocr_worker = OcrWorker()
|