Files
crm_project/server/app/services/ocr_worker.py
T
hankin 815cbf9d8c v0.2.0: CRM/ERP 系统升级 - 清理 .gitignore 并移除误提交的 venv/env/db 文件
- 更新 .gitignore:全面覆盖环境变量、数据库、日志、缓存、上传文件
- 移除误跟踪的 server/venv/、crm_data.db、.env 文件
- 新增 server/.env.example 模板
- 新增合同管理、利润核算、AI教练等功能模块
- 新增 Playwright e2e 测试套件
- 前后端多项功能升级和 bug 修复
2026-05-11 07:24:19 +00:00

267 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
OCR 后台 Worker — asyncio 协程,FastAPI lifespan 启动
策略 C: 工作时间限流(1并发 + 60s间隔)17:00-20:00 BJT 全速
"""
from __future__ import annotations
import asyncio
import os
import uuid
from datetime import datetime, timedelta
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import async_session_factory
from app.models.finance import FinInvoicePool, FinOcrTask
class OcrWorker:
"""后台 OCR 任务处理器"""
def __init__(self):
self.running = False
self.current_task_id: uuid.UUID | None = None
self._task: asyncio.Task | None = None
def start(self):
self.running = True
self._task = asyncio.create_task(self._run_loop())
print("[OcrWorker] 启动 — 策略 C: 工作时间限流, 17-20 BJT 全速")
async def stop(self):
self.running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
print("[OcrWorker] 已停止")
async def _run_loop(self):
"""主循环:每 10 秒检查一次队列"""
while self.running:
try:
task = await self._pick_next_task()
if task:
await self._process_task(task)
# 限流:非高峰期间隔 60s
if not self._is_peak_time():
await asyncio.sleep(60)
else:
await asyncio.sleep(5)
else:
await asyncio.sleep(10)
except asyncio.CancelledError:
break
except Exception as e:
print(f"[OcrWorker] 循环异常: {e}")
await asyncio.sleep(30)
def _is_peak_time(self) -> bool:
"""17:00-20:00 BJT = 09:00-12:00 UTC"""
utc_hour = datetime.utcnow().hour
return 9 <= utc_hour < 12
async def _pick_next_task(self) -> dict | None:
"""从 DB 获取优先级最高的 pending 任务"""
async with async_session_factory() as db:
stmt = (
select(FinOcrTask)
.where(
FinOcrTask.status == "pending",
FinOcrTask.is_deleted.is_(False),
FinOcrTask.retry_count < FinOcrTask.max_retries,
)
.order_by(FinOcrTask.priority, FinOcrTask.created_at)
.limit(1)
)
task = (await db.execute(stmt)).scalar_one_or_none()
if not task:
return None
# 标记为 processing
task.status = "processing"
task.updated_at = datetime.utcnow()
await db.commit()
self.current_task_id = task.id
return {
"id": task.id,
"file_url": task.file_url,
"file_ext": task.file_ext,
"original_name": task.original_name,
"uploader_id": task.uploader_id,
"company_id": task.company_id,
"inv_type": task.inv_type,
"retry_count": task.retry_count,
}
async def _process_task(self, task_info: dict):
"""执行 OCR 并更新"""
task_id = task_info["id"]
file_url = task_info["file_url"]
file_ext = task_info["file_ext"]
print(f"[OcrWorker] 处理任务 {task_id} ({task_info['original_name']}, {file_ext})")
try:
# 读取文件
file_path = file_url.lstrip("/")
if not os.path.exists(file_path):
await self._mark_failed(task_id, f"文件不存在: {file_path}")
return
with open(file_path, "rb") as f:
file_bytes = f.read()
ocr_data = {}
message = ""
# PDF 处理
if file_ext == ".pdf":
ocr_data, message = await self._process_pdf(file_bytes)
# 图片处理
elif file_ext in (".png", ".jpg", ".jpeg"):
ocr_data, message = await self._process_image(file_bytes)
else:
await self._mark_failed(task_id, f"不支持的文件格式: {file_ext}")
return
if ocr_data and (ocr_data.get("merchant") or ocr_data.get("amount")):
# OCR 成功 → 自动入池
await self._mark_success_and_pool(task_id, task_info, ocr_data)
print(f"[OcrWorker] ✅ {task_info['original_name']} 入池成功")
else:
# OCR 完成但没提取到关键字段
await self._mark_failed(
task_id,
message or "AI 未能提取发票关键字段(开票方/金额),请手动录入",
ocr_data,
)
except Exception as e:
print(f"[OcrWorker] ❌ 任务 {task_id} 异常: {e}")
await self._mark_failed(task_id, str(e))
self.current_task_id = None
async def _process_pdf(self, file_bytes: bytes) -> tuple[dict, str]:
"""PDF: 先尝试文本提取,失败降级 Vision OCR"""
try:
import fitz
doc = fitz.open(stream=file_bytes, filetype="pdf")
text = ""
for page in doc:
text += page.get_text() + "\n"
doc.close()
text = text.strip()
if len(text) > 50:
from app.services.ocr_service import extract_invoice_from_text
result = await extract_invoice_from_text(text, "invoice")
if result.get("success") and result.get("data"):
return result["data"], "PDF 文本解析成功"
# 降级: 扫描件 → Vision OCR
doc2 = fitz.open(stream=file_bytes, filetype="pdf")
pix = doc2[0].get_pixmap(dpi=150)
ocr_bytes = pix.tobytes("png")
doc2.close()
return await self._vision_ocr(ocr_bytes)
except Exception as e:
return {}, f"PDF 处理失败: {e}"
async def _process_image(self, file_bytes: bytes) -> tuple[dict, str]:
"""图片: Vision OCR"""
return await self._vision_ocr(file_bytes)
async def _vision_ocr(self, image_bytes: bytes) -> tuple[dict, str]:
"""调用 3090 Vision OCR"""
import base64
from app.services.ocr_service import ocr_image
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
result = await ocr_image(image_b64, "invoice")
if result.get("success"):
return result.get("data", {}), "Vision OCR 成功"
return {}, result.get("error", "OCR 失败")
async def _mark_success_and_pool(self, task_id: uuid.UUID, task_info: dict, ocr_data: dict):
"""标记成功 + 自动入池"""
async with async_session_factory() as db:
merchant = ocr_data.get("merchant") or ocr_data.get("merchant_name") or "AI 提取)"
amount = 0
try:
amount = float(ocr_data.get("amount", 0))
except (ValueError, TypeError):
pass
invoice_date_str = ocr_data.get("date")
invoice_date = None
if invoice_date_str:
try:
from datetime import date as dt_date
invoice_date = dt_date.fromisoformat(invoice_date_str)
except ValueError:
pass
inv = FinInvoicePool(
uploader_id=task_info["uploader_id"],
company_id=task_info["company_id"],
file_url=task_info["file_url"],
merchant_name=merchant,
amount=amount,
invoice_date=invoice_date,
type=task_info["inv_type"],
ai_extracted_data=ocr_data,
is_used=False,
)
db.add(inv)
await db.flush()
await db.execute(
update(FinOcrTask)
.where(FinOcrTask.id == task_id)
.values(
status="success",
ocr_result=ocr_data,
invoice_pool_id=inv.id,
error_message=None,
updated_at=datetime.utcnow(),
)
)
await db.commit()
async def _mark_failed(self, task_id: uuid.UUID, error: str, partial_data: dict | None = None):
"""标记失败 + retry_count+1"""
async with async_session_factory() as db:
task = (await db.execute(
select(FinOcrTask).where(FinOcrTask.id == task_id)
)).scalar_one_or_none()
if not task:
return
new_retry = task.retry_count + 1
new_status = "failed" if new_retry >= task.max_retries else "pending"
await db.execute(
update(FinOcrTask)
.where(FinOcrTask.id == task_id)
.values(
status=new_status,
retry_count=new_retry,
error_message=error,
ocr_result=partial_data or task.ocr_result,
updated_at=datetime.utcnow(),
)
)
await db.commit()
if new_status == "pending":
print(f"[OcrWorker] ⚠️ 任务 {task_id}{new_retry} 次重试入队")
else:
print(f"[OcrWorker] ❌ 任务 {task_id} 已达最大重试次数,标记失败")
# 单例
ocr_worker = OcrWorker()