""" OCR 后台 Worker — asyncio 协程,FastAPI lifespan 启动 策略 C: 工作时间限流(1并发 + 60s间隔),17:00-20:00 BJT 全速 """ from __future__ import annotations import asyncio import os import uuid from datetime import datetime, timedelta from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from app.db.database import async_session_factory from app.models.finance import FinInvoicePool, FinOcrTask class OcrWorker: """后台 OCR 任务处理器""" def __init__(self): self.running = False self.current_task_id: uuid.UUID | None = None self._task: asyncio.Task | None = None def start(self): self.running = True self._task = asyncio.create_task(self._run_loop()) print("[OcrWorker] 启动 — 策略 C: 工作时间限流, 17-20 BJT 全速") async def stop(self): self.running = False if self._task: self._task.cancel() try: await self._task except asyncio.CancelledError: pass print("[OcrWorker] 已停止") async def _run_loop(self): """主循环:每 10 秒检查一次队列""" while self.running: try: task = await self._pick_next_task() if task: await self._process_task(task) # 限流:非高峰期间隔 60s if not self._is_peak_time(): await asyncio.sleep(60) else: await asyncio.sleep(5) else: await asyncio.sleep(10) except asyncio.CancelledError: break except Exception as e: print(f"[OcrWorker] 循环异常: {e}") await asyncio.sleep(30) def _is_peak_time(self) -> bool: """17:00-20:00 BJT = 09:00-12:00 UTC""" utc_hour = datetime.utcnow().hour return 9 <= utc_hour < 12 async def _pick_next_task(self) -> dict | None: """从 DB 获取优先级最高的 pending 任务""" async with async_session_factory() as db: stmt = ( select(FinOcrTask) .where( FinOcrTask.status == "pending", FinOcrTask.is_deleted.is_(False), FinOcrTask.retry_count < FinOcrTask.max_retries, ) .order_by(FinOcrTask.priority, FinOcrTask.created_at) .limit(1) ) task = (await db.execute(stmt)).scalar_one_or_none() if not task: return None # 标记为 processing task.status = "processing" task.updated_at = datetime.utcnow() await db.commit() self.current_task_id = task.id return { "id": task.id, "file_url": task.file_url, "file_ext": task.file_ext, "original_name": task.original_name, "uploader_id": task.uploader_id, "company_id": task.company_id, "inv_type": task.inv_type, "retry_count": task.retry_count, } async def _process_task(self, task_info: dict): """执行 OCR 并更新""" task_id = task_info["id"] file_url = task_info["file_url"] file_ext = task_info["file_ext"] print(f"[OcrWorker] 处理任务 {task_id} ({task_info['original_name']}, {file_ext})") try: # 读取文件 file_path = file_url.lstrip("/") if not os.path.exists(file_path): await self._mark_failed(task_id, f"文件不存在: {file_path}") return with open(file_path, "rb") as f: file_bytes = f.read() ocr_data = {} message = "" # PDF 处理 if file_ext == ".pdf": ocr_data, message = await self._process_pdf(file_bytes) # 图片处理 elif file_ext in (".png", ".jpg", ".jpeg"): ocr_data, message = await self._process_image(file_bytes) else: await self._mark_failed(task_id, f"不支持的文件格式: {file_ext}") return if ocr_data and (ocr_data.get("merchant") or ocr_data.get("amount")): # OCR 成功 → 自动入池 await self._mark_success_and_pool(task_id, task_info, ocr_data) print(f"[OcrWorker] ✅ {task_info['original_name']} 入池成功") else: # OCR 完成但没提取到关键字段 await self._mark_failed( task_id, message or "AI 未能提取发票关键字段(开票方/金额),请手动录入", ocr_data, ) except Exception as e: print(f"[OcrWorker] ❌ 任务 {task_id} 异常: {e}") await self._mark_failed(task_id, str(e)) self.current_task_id = None async def _process_pdf(self, file_bytes: bytes) -> tuple[dict, str]: """PDF: 先尝试文本提取,失败降级 Vision OCR""" try: import fitz doc = fitz.open(stream=file_bytes, filetype="pdf") text = "" for page in doc: text += page.get_text() + "\n" doc.close() text = text.strip() if len(text) > 50: from app.services.ocr_service import extract_invoice_from_text result = await extract_invoice_from_text(text, "invoice") if result.get("success") and result.get("data"): return result["data"], "PDF 文本解析成功" # 降级: 扫描件 → Vision OCR doc2 = fitz.open(stream=file_bytes, filetype="pdf") pix = doc2[0].get_pixmap(dpi=150) ocr_bytes = pix.tobytes("png") doc2.close() return await self._vision_ocr(ocr_bytes) except Exception as e: return {}, f"PDF 处理失败: {e}" async def _process_image(self, file_bytes: bytes) -> tuple[dict, str]: """图片: Vision OCR""" return await self._vision_ocr(file_bytes) async def _vision_ocr(self, image_bytes: bytes) -> tuple[dict, str]: """调用 3090 Vision OCR""" import base64 from app.services.ocr_service import ocr_image image_b64 = base64.b64encode(image_bytes).decode("utf-8") result = await ocr_image(image_b64, "invoice") if result.get("success"): return result.get("data", {}), "Vision OCR 成功" return {}, result.get("error", "OCR 失败") async def _mark_success_and_pool(self, task_id: uuid.UUID, task_info: dict, ocr_data: dict): """标记成功 + 自动入池""" async with async_session_factory() as db: merchant = ocr_data.get("merchant") or ocr_data.get("merchant_name") or "(AI 提取)" amount = 0 try: amount = float(ocr_data.get("amount", 0)) except (ValueError, TypeError): pass invoice_date_str = ocr_data.get("date") invoice_date = None if invoice_date_str: try: from datetime import date as dt_date invoice_date = dt_date.fromisoformat(invoice_date_str) except ValueError: pass inv = FinInvoicePool( uploader_id=task_info["uploader_id"], company_id=task_info["company_id"], file_url=task_info["file_url"], merchant_name=merchant, amount=amount, invoice_date=invoice_date, type=task_info["inv_type"], ai_extracted_data=ocr_data, is_used=False, ) db.add(inv) await db.flush() await db.execute( update(FinOcrTask) .where(FinOcrTask.id == task_id) .values( status="success", ocr_result=ocr_data, invoice_pool_id=inv.id, error_message=None, updated_at=datetime.utcnow(), ) ) await db.commit() async def _mark_failed(self, task_id: uuid.UUID, error: str, partial_data: dict | None = None): """标记失败 + retry_count+1""" async with async_session_factory() as db: task = (await db.execute( select(FinOcrTask).where(FinOcrTask.id == task_id) )).scalar_one_or_none() if not task: return new_retry = task.retry_count + 1 new_status = "failed" if new_retry >= task.max_retries else "pending" await db.execute( update(FinOcrTask) .where(FinOcrTask.id == task_id) .values( status=new_status, retry_count=new_retry, error_message=error, ocr_result=partial_data or task.ocr_result, updated_at=datetime.utcnow(), ) ) await db.commit() if new_status == "pending": print(f"[OcrWorker] ⚠️ 任务 {task_id} 第 {new_retry} 次重试入队") else: print(f"[OcrWorker] ❌ 任务 {task_id} 已达最大重试次数,标记失败") # 单例 ocr_worker = OcrWorker()