Initial commit: 包装审核 POC、Docker 与前后端
Made-with: Cursor
This commit is contained in:
1
backend/__init__.py
Normal file
1
backend/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Backend package for the AI field validation preview tool."""
|
||||
1
backend/app/__init__.py
Normal file
1
backend/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Backend application package."""
|
||||
103
backend/app/barcode_detector.py
Normal file
103
backend/app/barcode_detector.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Detect and decode barcodes / QR codes from an image file using zxing-cpp."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Human-readable names for zxing BarcodeFormat values
|
||||
_FORMAT_NAMES: dict[str, str] = {
|
||||
"QRCode": "二维码 (QR Code)",
|
||||
"DataMatrix": "DataMatrix",
|
||||
"Aztec": "Aztec",
|
||||
"PDF417": "PDF417",
|
||||
"MicroQRCode": "微型二维码",
|
||||
"RMQRCode": "R型QR码",
|
||||
"EAN8": "EAN-8",
|
||||
"EAN13": "EAN-13",
|
||||
"UPCE": "UPC-E",
|
||||
"UPCA": "UPC-A",
|
||||
"Code39": "Code 39",
|
||||
"Code93": "Code 93",
|
||||
"Code128": "Code 128",
|
||||
"ITF": "ITF(交叉二五码)",
|
||||
"Codabar": "Codabar",
|
||||
"DataBar": "DataBar",
|
||||
"DataBarExpanded": "DataBar Expanded",
|
||||
"MaxiCode": "MaxiCode",
|
||||
"DXFilmEdge": "DXFilmEdge",
|
||||
"LinearCodes": "一维码",
|
||||
"MatrixCodes": "矩阵码",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BarcodeResult:
|
||||
format: str # zxing format string, e.g. "EAN13"
|
||||
format_label: str # Chinese-friendly label
|
||||
text: str # decoded text / number
|
||||
# bounding box in image pixels (top-left origin)
|
||||
x0: int
|
||||
y0: int
|
||||
x1: int
|
||||
y1: int
|
||||
valid: bool # zxing isValid
|
||||
|
||||
|
||||
def detect_barcodes(image_path: Path) -> list[BarcodeResult]:
|
||||
"""Scan *image_path* for all barcodes and QR codes.
|
||||
|
||||
Returns a list of :class:`BarcodeResult`, one entry per detected code.
|
||||
Returns an empty list when nothing is found or on error.
|
||||
"""
|
||||
try:
|
||||
import zxingcpp
|
||||
except ImportError:
|
||||
logger.warning("zxing-cpp not installed; barcode detection skipped")
|
||||
return []
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
except Exception as exc:
|
||||
logger.warning("barcode_detector: cannot open image %s: %s", image_path, exc)
|
||||
return []
|
||||
|
||||
try:
|
||||
results = zxingcpp.read_barcodes(img)
|
||||
except Exception as exc:
|
||||
logger.warning("barcode_detector: zxing scan failed: %s", exc)
|
||||
return []
|
||||
|
||||
output: list[BarcodeResult] = []
|
||||
for r in results:
|
||||
fmt_str = str(r.format).replace("BarcodeFormat.", "")
|
||||
label = _FORMAT_NAMES.get(fmt_str, fmt_str)
|
||||
|
||||
# zxing-cpp position: r.position is a quadrilateral with four points
|
||||
try:
|
||||
pts = r.position
|
||||
xs = [pts.top_left.x, pts.top_right.x, pts.bottom_right.x, pts.bottom_left.x]
|
||||
ys = [pts.top_left.y, pts.top_right.y, pts.bottom_right.y, pts.bottom_left.y]
|
||||
x0, y0, x1, y1 = int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))
|
||||
except Exception:
|
||||
x0 = y0 = x1 = y1 = 0
|
||||
|
||||
output.append(BarcodeResult(
|
||||
format=fmt_str,
|
||||
format_label=label,
|
||||
text=r.text,
|
||||
x0=x0, y0=y0, x1=x1, y1=y1,
|
||||
valid=r.valid,
|
||||
))
|
||||
logger.info(
|
||||
"barcode_detector: found %s text=%r bbox=(%d,%d,%d,%d)",
|
||||
fmt_str, r.text, x0, y0, x1, y1,
|
||||
)
|
||||
|
||||
if not output:
|
||||
logger.info("barcode_detector: no barcode/QR found in %s", image_path.name)
|
||||
|
||||
return output
|
||||
132
backend/app/image_classifier.py
Normal file
132
backend/app/image_classifier.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""用 Qwen VL 对图片内容做语义分类,判断是否为二维码/条码。
|
||||
|
||||
调用方式
|
||||
--------
|
||||
::
|
||||
|
||||
from backend.app.image_classifier import is_qr_code
|
||||
|
||||
result = is_qr_code(Path("crop.png"), api_key="sk-...")
|
||||
if result:
|
||||
# 再交给条码识别模块处理
|
||||
...
|
||||
|
||||
设计原则
|
||||
--------
|
||||
* 只做"是/否"的单一判断,不解码内容(解码交给 barcode_detector)。
|
||||
* 复用 region_detector 中已有的 API key / base_url 读取逻辑。
|
||||
* 网络或模型调用失败时返回 False,保证 pipeline 可降级运行。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 使用轻量级的 7B 视觉模型,速度快、成本低
|
||||
_DEFAULT_MODEL = "qwen2.5-vl-7b-instruct"
|
||||
|
||||
_CLASSIFY_PROMPT = (
|
||||
"请仔细观察这张图片。\n"
|
||||
"问题:图片中是否包含二维码(QR Code)或任何类型的条形码?\n"
|
||||
'请只回答"是"或"否",不要输出其他任何内容。'
|
||||
)
|
||||
|
||||
|
||||
def _encode_image(image_path: Path, max_side: int = 512) -> str:
|
||||
"""将图片缩放后编码为 base64 PNG 字符串。
|
||||
|
||||
对小图(如 MinerU 裁出的图片块)保持原尺寸;
|
||||
对大图做等比缩放以减少 token 消耗。
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
img = img.convert("RGB")
|
||||
w, h = img.size
|
||||
if max(w, h) > max_side:
|
||||
scale = max_side / max(w, h)
|
||||
img = img.resize((max(1, round(w * scale)), max(1, round(h * scale))), Image.LANCZOS)
|
||||
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def is_qr_code(
|
||||
image_path: Path,
|
||||
api_key: str | None = None,
|
||||
model: str = _DEFAULT_MODEL,
|
||||
) -> bool:
|
||||
"""调用 Qwen VL 判断图片是否包含二维码或条形码。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_path:
|
||||
待分类的图片路径。
|
||||
api_key:
|
||||
DashScope API Key;若为 None 则从环境变量 / .env 文件自动读取。
|
||||
model:
|
||||
使用的模型名称,默认为 qwen2.5-vl-7b-instruct。
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True → 大模型认为图片中存在二维码/条形码
|
||||
False → 不存在,或调用失败(降级返回 False)
|
||||
"""
|
||||
# 延迟导入,避免在未配置环境时影响模块加载
|
||||
from backend.app.region_detector import _get_api_key, _get_base_url
|
||||
from openai import OpenAI
|
||||
|
||||
key = api_key or _get_api_key()
|
||||
if not key:
|
||||
logger.warning("image_classifier: DASHSCOPE_API_KEY 未配置,跳过 QR 语义判断")
|
||||
return False
|
||||
|
||||
try:
|
||||
b64 = _encode_image(image_path)
|
||||
except Exception as exc:
|
||||
logger.warning("image_classifier: 图片编码失败 (%s),跳过分类", exc)
|
||||
return False
|
||||
|
||||
client = OpenAI(api_key=key, base_url=_get_base_url())
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
||||
},
|
||||
{"type": "text", "text": _CLASSIFY_PROMPT},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("image_classifier: Qwen VL 调用失败 (%s),跳过分类", exc)
|
||||
return False
|
||||
|
||||
raw = (response.choices[0].message.content or "").strip()
|
||||
logger.debug("image_classifier: 模型原始回复 = %r", raw)
|
||||
|
||||
# 兼容"是"/"否"以及"Yes"/"No"等输出
|
||||
answer = raw.lower()
|
||||
result = answer.startswith("是") or answer.startswith("yes")
|
||||
logger.info(
|
||||
"image_classifier: %s → %s(原始回复:%r)",
|
||||
image_path.name,
|
||||
"二维码/条码" if result else "非二维码",
|
||||
raw,
|
||||
)
|
||||
return result
|
||||
255
backend/app/main.py
Normal file
255
backend/app/main.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""FastAPI application entry point."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import queue as thread_queue
|
||||
import shutil
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, File, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from backend.app import pipeline
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Logging + SSE broadcast #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_log_buffer: list[dict] = [] # 最近 200 条,供新连接回放
|
||||
_log_queues: list[thread_queue.Queue] = []
|
||||
_log_lock = threading.Lock()
|
||||
|
||||
|
||||
class _BroadcastHandler(logging.Handler):
|
||||
"""把日志记录广播给所有 SSE 客户端。"""
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
entry = {
|
||||
"time": datetime.fromtimestamp(record.created).strftime("%H:%M:%S"),
|
||||
"level": record.levelname,
|
||||
"name": record.name.replace("backend.app.", ""),
|
||||
"msg": record.getMessage(),
|
||||
}
|
||||
with _log_lock:
|
||||
_log_buffer.append(entry)
|
||||
if len(_log_buffer) > 200:
|
||||
_log_buffer.pop(0)
|
||||
for q in _log_queues:
|
||||
try:
|
||||
q.put_nowait(entry)
|
||||
except thread_queue.Full:
|
||||
pass
|
||||
|
||||
|
||||
# 挂到根 logger,覆盖所有模块日志
|
||||
_broadcast_handler = _BroadcastHandler()
|
||||
logging.getLogger().addHandler(_broadcast_handler)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Paths & constants #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
_ROOT = Path(__file__).resolve().parents[2]
|
||||
UPLOADS_DIR = _ROOT / "data" / "uploads"
|
||||
OUTPUTS_DIR = _ROOT / "data" / "outputs"
|
||||
|
||||
_DEFAULT_AI_NAME = "【2026-04-09】端午 - 背标 - 天问.ai"
|
||||
_DEFAULT_WORD_NAME = "天问礼品粽【260331】.docx"
|
||||
_DEFAULT_AI = _ROOT / _DEFAULT_AI_NAME
|
||||
_DEFAULT_WORD = _ROOT / _DEFAULT_WORD_NAME
|
||||
|
||||
ALLOWED_AI_EXT = {".ai", ".pdf"}
|
||||
ALLOWED_WORD_EXT = {".docx"}
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# App #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
app = FastAPI(
|
||||
title="诸老大包装审核 API",
|
||||
description="Upload an Illustrator file and a Word document to validate packaging copy.",
|
||||
version="2.0.0",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def _save_upload(upload: UploadFile, dest: Path) -> None:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
with dest.open("wb") as fh:
|
||||
fh.write(upload.file.read())
|
||||
|
||||
|
||||
def _copy_default(src: Optional[Path], dest: Path, label: str) -> None:
|
||||
if src is None or not src.exists():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"未上传{label}且找不到默认样例文件,请上传文件后重试。",
|
||||
)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src, dest)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Endpoints #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
@app.get("/api/logs/stream")
|
||||
async def log_stream() -> StreamingResponse:
|
||||
"""SSE 端点:实时推送后端日志给前端侧边栏。"""
|
||||
q: thread_queue.Queue = thread_queue.Queue(maxsize=500)
|
||||
with _log_lock:
|
||||
_log_queues.append(q)
|
||||
recent = list(_log_buffer)
|
||||
|
||||
async def generate():
|
||||
try:
|
||||
# 先把缓冲区里的历史日志推过去
|
||||
for entry in recent:
|
||||
yield f"data: {json.dumps(entry, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 再持续推新日志
|
||||
while True:
|
||||
batch: list[dict] = []
|
||||
try:
|
||||
while True:
|
||||
batch.append(q.get_nowait())
|
||||
except thread_queue.Empty:
|
||||
pass
|
||||
|
||||
for entry in batch:
|
||||
yield f"data: {json.dumps(entry, ensure_ascii=False)}\n\n"
|
||||
|
||||
if not batch:
|
||||
yield ": keepalive\n\n"
|
||||
|
||||
await asyncio.sleep(0.25)
|
||||
finally:
|
||||
with _log_lock:
|
||||
if q in _log_queues:
|
||||
_log_queues.remove(q)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/process")
|
||||
async def process_endpoint(
|
||||
ai_file: Optional[UploadFile] = File(None),
|
||||
word_file: Optional[UploadFile] = File(None),
|
||||
) -> dict:
|
||||
"""运行完整 pipeline:AI → PDF → MinerU → Word 校验。"""
|
||||
job_id = uuid.uuid4().hex
|
||||
upload_dir = UPLOADS_DIR / job_id
|
||||
output_dir = OUTPUTS_DIR / job_id
|
||||
|
||||
logger.info("POST /api/process job_id=%s", job_id)
|
||||
|
||||
# ── Resolve AI file ──────────────────────────────────────────────────── #
|
||||
if ai_file is not None:
|
||||
original_name = Path(ai_file.filename or "source.ai").name
|
||||
suffix = Path(original_name).suffix.lower()
|
||||
if suffix not in ALLOWED_AI_EXT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的 AI 文件格式 '{suffix}',请上传 .ai 或 PDF。",
|
||||
)
|
||||
ai_path = upload_dir / original_name
|
||||
_save_upload(ai_file, ai_path)
|
||||
else:
|
||||
ai_path = upload_dir / (_DEFAULT_AI.name if _DEFAULT_AI else "source.ai")
|
||||
_copy_default(_DEFAULT_AI, ai_path, "AI 设计文件")
|
||||
|
||||
# ── Resolve Word file ────────────────────────────────────────────────── #
|
||||
if word_file is not None:
|
||||
suffix = Path(word_file.filename or "").suffix.lower()
|
||||
if suffix not in ALLOWED_WORD_EXT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的 Word 文件格式 '{suffix}',请上传 .docx。",
|
||||
)
|
||||
word_path = upload_dir / f"reference{suffix}"
|
||||
_save_upload(word_file, word_path)
|
||||
else:
|
||||
word_path = upload_dir / (_DEFAULT_WORD.name if _DEFAULT_WORD else "reference.docx")
|
||||
_copy_default(_DEFAULT_WORD, word_path, "Word 校对稿")
|
||||
|
||||
# ── Run pipeline in thread pool(不阻塞事件循环,SSE 可正常推日志) ─── #
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
pipeline.process_document,
|
||||
ai_path,
|
||||
word_path,
|
||||
output_dir,
|
||||
job_id,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
logger.exception("Pipeline error (not found): job_id=%s", job_id)
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except RuntimeError as exc:
|
||||
logger.exception("Pipeline error (runtime): job_id=%s", job_id)
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Pipeline error (unexpected): job_id=%s", job_id)
|
||||
raise HTTPException(status_code=500, detail=f"处理失败:{exc}") from exc
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/files/{job_id}/{file_path:path}")
|
||||
async def serve_file(job_id: str, file_path: str) -> FileResponse:
|
||||
"""提供 job 产物文件(预览 PDF、JSON 等)。"""
|
||||
target = OUTPUTS_DIR / job_id / file_path
|
||||
if not target.exists() or not target.is_file():
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
suffix = target.suffix.lower()
|
||||
media_type = {
|
||||
".pdf": "application/pdf",
|
||||
".json": "application/json",
|
||||
".md": "text/markdown",
|
||||
}.get(suffix, "application/octet-stream")
|
||||
|
||||
return FileResponse(target, media_type=media_type)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health() -> dict:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# 生产镜像:Vite 构建产物与 API 同源,无需配置 VITE_API_BASE_URL
|
||||
_dist = _ROOT / "frontend" / "dist"
|
||||
if _dist.is_dir():
|
||||
app.mount("/", StaticFiles(directory=str(_dist), html=True), name="frontend")
|
||||
248
backend/app/mineru_client.py
Normal file
248
backend/app/mineru_client.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MINERU_BASE = "https://mineru.net/api/v4"
|
||||
|
||||
TERMINAL_STATES = {"done", "failed"}
|
||||
IN_PROGRESS_STATE_LABELS = {
|
||||
"waiting-file": "等待文件上传",
|
||||
"pending": "排队中",
|
||||
"running": "解析中",
|
||||
"converting": "格式转换中",
|
||||
}
|
||||
|
||||
|
||||
class MineruClientError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class MineruClient:
|
||||
"""MinerU 精准解析 API 客户端(需要 Token)。
|
||||
|
||||
针对本地图片文件的完整调用流程:
|
||||
1. POST /file-urls/batch → 获取 batch_id + OSS 签名上传 URL
|
||||
2. PUT 上传图片到 OSS → 系统自动感知并提交解析任务
|
||||
3. GET /extract-results/batch/{batch_id} 轮询直到 state=done
|
||||
4. 下载 full_zip_url,解压提取结构化 JSON
|
||||
|
||||
文件限制:≤ 200MB,≤ 600 页
|
||||
支持格式:PDF、图片(png/jpg/jpeg/jp2/webp/gif/bmp)、Doc、Docx、Ppt、PPTx
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_version: str = "vlm",
|
||||
language: str = "ch",
|
||||
enable_table: bool = True,
|
||||
is_ocr: bool = True,
|
||||
enable_formula: bool = True,
|
||||
poll_interval: float = 3.0,
|
||||
timeout: float = 300.0,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.model_version = model_version
|
||||
self.language = language
|
||||
self.enable_table = enable_table
|
||||
self.is_ocr = is_ocr
|
||||
self.enable_formula = enable_formula
|
||||
self.poll_interval = poll_interval
|
||||
self.timeout = timeout
|
||||
|
||||
def parse_image(self, image_path: Path, output_dir: Path) -> dict:
|
||||
"""解析本地图片文件,返回结构化 JSON 数据。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_path:
|
||||
本地图片路径(png/jpg/jpeg/jp2/webp/gif/bmp)
|
||||
output_dir:
|
||||
中间产物(zip、解压目录)的存放目录
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
包含 pdf_info 的结构化 JSON(layout.json 或 content_list.json)
|
||||
"""
|
||||
image_path = Path(image_path)
|
||||
if not image_path.exists():
|
||||
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("MinerU 精准解析开始: %s", image_path.name)
|
||||
batch_id, upload_url = self._request_upload_url(image_path.name)
|
||||
logger.info("MinerU 批次已创建: batch_id=%s", batch_id)
|
||||
|
||||
self._upload_file(upload_url, image_path)
|
||||
logger.info("MinerU 文件上传完成: %s(系统自动提交解析)", image_path.name)
|
||||
|
||||
zip_url = self._poll_batch_until_done(batch_id)
|
||||
logger.info("MinerU 解析完成: batch_id=%s", batch_id)
|
||||
|
||||
zip_path = self._download_zip(zip_url, output_dir)
|
||||
extract_dir = output_dir / "result"
|
||||
self._extract_zip(zip_path, extract_dir)
|
||||
|
||||
result = self._load_structured_json(extract_dir)
|
||||
logger.info("MinerU 结构化 JSON 加载完毕")
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内部方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _auth_headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _request_upload_url(self, file_name: str) -> tuple[str, str]:
|
||||
"""申请批量上传链接,返回 (batch_id, oss_upload_url)。"""
|
||||
payload = {
|
||||
"files": [{"name": file_name, "is_ocr": self.is_ocr}],
|
||||
"model_version": self.model_version,
|
||||
"language": self.language,
|
||||
"enable_table": self.enable_table,
|
||||
"enable_formula": self.enable_formula,
|
||||
}
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{MINERU_BASE}/file-urls/batch",
|
||||
headers=self._auth_headers(),
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
raise MineruClientError(f"MinerU 申请上传 URL 失败: {exc}") from exc
|
||||
|
||||
body = resp.json()
|
||||
if body.get("code") != 0:
|
||||
raise MineruClientError(f"MinerU 申请上传 URL 失败: {body.get('msg')}")
|
||||
|
||||
data = body.get("data", {})
|
||||
batch_id = data.get("batch_id")
|
||||
file_urls = data.get("file_urls", [])
|
||||
if not batch_id or not file_urls:
|
||||
raise MineruClientError("MinerU 返回的 batch_id 或 file_urls 为空")
|
||||
|
||||
return batch_id, file_urls[0]
|
||||
|
||||
def _upload_file(self, upload_url: str, image_path: Path) -> None:
|
||||
"""将图片 PUT 上传到 OSS。上传时无需设置 Content-Type。"""
|
||||
try:
|
||||
with image_path.open("rb") as f:
|
||||
resp = requests.put(upload_url, data=f, timeout=120)
|
||||
except requests.RequestException as exc:
|
||||
raise MineruClientError(f"MinerU 文件上传网络错误: {exc}") from exc
|
||||
|
||||
if resp.status_code not in (200, 201):
|
||||
raise MineruClientError(
|
||||
f"MinerU 文件上传失败: HTTP {resp.status_code} {resp.text[:200]}"
|
||||
)
|
||||
|
||||
def _poll_batch_until_done(self, batch_id: str) -> str:
|
||||
"""轮询批次结果,返回 full_zip_url。"""
|
||||
url = f"{MINERU_BASE}/extract-results/batch/{batch_id}"
|
||||
deadline = time.monotonic() + self.timeout
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
resp = requests.get(url, headers=self._auth_headers(), timeout=30)
|
||||
resp.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
raise MineruClientError(f"MinerU 查询批次状态失败: {exc}") from exc
|
||||
|
||||
body = resp.json()
|
||||
if body.get("code") != 0:
|
||||
raise MineruClientError(f"MinerU 查询批次失败: {body.get('msg')}")
|
||||
|
||||
results: list[dict] = body.get("data", {}).get("extract_result", [])
|
||||
if not results:
|
||||
time.sleep(self.poll_interval)
|
||||
continue
|
||||
|
||||
item = results[0]
|
||||
state = item.get("state", "")
|
||||
label = IN_PROGRESS_STATE_LABELS.get(state, state)
|
||||
logger.info("MinerU 批次状态: batch_id=%s state=%s (%s)", batch_id, state, label)
|
||||
|
||||
if state == "done":
|
||||
zip_url = item.get("full_zip_url")
|
||||
if not zip_url:
|
||||
raise MineruClientError("MinerU 完成但未返回 full_zip_url")
|
||||
return zip_url
|
||||
|
||||
if state == "failed":
|
||||
err_msg = item.get("err_msg") or "未知错误"
|
||||
raise MineruClientError(f"MinerU 解析失败: {err_msg}")
|
||||
|
||||
time.sleep(self.poll_interval)
|
||||
|
||||
raise MineruClientError(
|
||||
f"MinerU 轮询超时 ({self.timeout:.0f}s): batch_id={batch_id}"
|
||||
)
|
||||
|
||||
def _download_zip(self, zip_url: str, output_dir: Path) -> Path:
|
||||
"""下载结果 zip 包到本地。"""
|
||||
target = output_dir / "mineru_result.zip"
|
||||
try:
|
||||
resp = requests.get(zip_url, timeout=120, stream=True)
|
||||
resp.raise_for_status()
|
||||
with target.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
except requests.RequestException as exc:
|
||||
raise MineruClientError(f"MinerU zip 下载失败: {exc}") from exc
|
||||
logger.info("MinerU zip 下载完毕: %s", target)
|
||||
return target
|
||||
|
||||
def _extract_zip(self, zip_path: Path, extract_dir: Path) -> None:
|
||||
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||
with zipfile.ZipFile(zip_path) as archive:
|
||||
archive.extractall(extract_dir)
|
||||
logger.info("MinerU zip 解压完毕: %s", extract_dir)
|
||||
|
||||
def _load_structured_json(self, extract_dir: Path) -> dict:
|
||||
"""从解压目录中找到并加载包含 pdf_info 的结构化 JSON。
|
||||
|
||||
MinerU zip 结构说明:
|
||||
layout.json → 中间处理结果(对应 middle.json)
|
||||
*_content_list.json → 内容列表
|
||||
*_model.json → 模型推理结果
|
||||
full.md → Markdown 解析结果
|
||||
"""
|
||||
candidates = [
|
||||
*sorted(extract_dir.rglob("layout.json")),
|
||||
*sorted(extract_dir.rglob("*layout*.json")),
|
||||
*sorted(extract_dir.rglob("*_content_list*.json")),
|
||||
*sorted(extract_dir.rglob("*.json")),
|
||||
]
|
||||
seen: set[Path] = set()
|
||||
for candidate in candidates:
|
||||
if candidate in seen:
|
||||
continue
|
||||
seen.add(candidate)
|
||||
try:
|
||||
parsed = json.loads(candidate.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
if isinstance(parsed, dict) and isinstance(parsed.get("pdf_info"), list):
|
||||
logger.info("MinerU 结构化 JSON 选用: %s", candidate.name)
|
||||
return parsed
|
||||
|
||||
raise MineruClientError(
|
||||
"MinerU 结果 zip 中未找到包含 pdf_info 的结构化 JSON"
|
||||
)
|
||||
299
backend/app/mineru_parser.py
Normal file
299
backend/app/mineru_parser.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""Parse MinerU structured JSON (layout.json / middle.json) into field records."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_table_text(html: str) -> str:
|
||||
"""将表格 HTML 转为可供文本匹配的多行字符串。
|
||||
|
||||
每行格式:单元格1|单元格2|单元格3
|
||||
同一行内的单元格用 | 连接,行与行之间用换行分隔。
|
||||
"""
|
||||
try:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
rows = []
|
||||
for tr in soup.find_all("tr"):
|
||||
cells = [td.get_text(strip=True) for td in tr.find_all(["td", "th"])]
|
||||
if any(cells):
|
||||
rows.append("|".join(cells))
|
||||
return "\n".join(rows)
|
||||
except Exception:
|
||||
# 解析失败时退回正则粗提取
|
||||
return re.sub(r"<[^>]+>", " ", html).strip()
|
||||
|
||||
# 1 pt = 0.352778 mm
|
||||
PT_TO_MM = 0.352778
|
||||
|
||||
# LaTeX inline-equation → Unicode 映射(仅处理标签文件中常见的符号)
|
||||
_LATEX_TO_UNICODE: dict[str, str] = {
|
||||
r"\times": "×",
|
||||
r"\div": "÷",
|
||||
r"\pm": "±",
|
||||
r"\mp": "∓",
|
||||
r"\cdot": "·",
|
||||
r"\leq": "≤",
|
||||
r"\geq": "≥",
|
||||
r"\neq": "≠",
|
||||
r"\approx": "≈",
|
||||
r"\infty": "∞",
|
||||
r"\circ": "°",
|
||||
r"\degree": "°",
|
||||
r"\alpha": "α",
|
||||
r"\beta": "β",
|
||||
r"\gamma": "γ",
|
||||
r"\delta": "δ",
|
||||
r"\mu": "μ",
|
||||
r"\%": "%",
|
||||
}
|
||||
|
||||
# MinerU 有时将 ^{\circ} 输出为 ^{circ}(缺少反斜杠)
|
||||
# 用正则统一匹配两种写法
|
||||
_SUPERSCRIPT_DEGREE_RE = re.compile(r"\^\{\\?circ\}", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MineruDocument:
|
||||
page_width: float # points
|
||||
page_height: float # points
|
||||
fields: list[dict] # list of field dicts ready for the API response
|
||||
|
||||
|
||||
def _page_size(page: dict) -> tuple[float, float]:
|
||||
"""Return (width, height) in points for a MinerU page entry."""
|
||||
# MinerU stores page size as [width, height] in `page_size`
|
||||
size = page.get("page_size") or page.get("page_size_pt") or []
|
||||
if isinstance(size, (list, tuple)) and len(size) >= 2:
|
||||
return float(size[0]), float(size[1])
|
||||
# Fallback: inspect block bboxes
|
||||
return 595.0, 842.0 # A4 default
|
||||
|
||||
|
||||
def _latex_to_text(expr: str) -> str:
|
||||
"""将简单的 LaTeX 表达式转换为可读文本(逐一替换已知符号)。"""
|
||||
result = expr.strip()
|
||||
# 优先处理上标度数:^{circ} 或 ^{\circ} → °
|
||||
result = _SUPERSCRIPT_DEGREE_RE.sub("°", result)
|
||||
# 其他上标 ^{...} / 下标 _{...}:去掉包装,只保留内容
|
||||
result = re.sub(r"[\^_]\{([^}]*)\}", r"\1", result)
|
||||
for latex, uni in _LATEX_TO_UNICODE.items():
|
||||
result = result.replace(latex, uni)
|
||||
# 剩余未识别的命令(如 \foo)直接去掉反斜杠,降级为原始字母
|
||||
result = re.sub(r"\\([A-Za-z]+)", r"\1", result)
|
||||
return result
|
||||
|
||||
|
||||
def _span_content(span: dict) -> str:
|
||||
"""从 span 中提取可供匹配的文本内容。
|
||||
|
||||
- type == "table":解析 html 字段,转为行列文本
|
||||
- type == "inline_equation":LaTeX → Unicode 文本
|
||||
- 其他类型:取 content 字段,并修复常见 LaTeX 上标残留(如 ^{circ})
|
||||
"""
|
||||
span_type = span.get("type") or ""
|
||||
if span_type == "table":
|
||||
html = span.get("html") or ""
|
||||
return _extract_table_text(html) if html else ""
|
||||
if span_type == "inline_equation":
|
||||
return _latex_to_text((span.get("content") or "").strip())
|
||||
# 普通文本 span:MinerU 有时在 content 中直接嵌入 LaTeX 上标(如 ^{circ})
|
||||
raw = (span.get("content") or "").strip()
|
||||
return _SUPERSCRIPT_DEGREE_RE.sub("°", raw)
|
||||
|
||||
|
||||
def _iter_lines(block: dict):
|
||||
"""Yield (line, block) tuples for all lines in a block.
|
||||
|
||||
Handles two MinerU structures:
|
||||
- Flat: block → lines → spans (text/title/etc.)
|
||||
- Nested: block → blocks → lines → spans (table blocks)
|
||||
"""
|
||||
lines = block.get("lines")
|
||||
if lines:
|
||||
for line in lines:
|
||||
yield line, block
|
||||
else:
|
||||
# Table blocks (and some other types) have a nested `blocks` layer
|
||||
for inner in block.get("blocks", []):
|
||||
for line in inner.get("lines", []):
|
||||
yield line, block
|
||||
|
||||
|
||||
def _iter_line_fields(page: dict):
|
||||
"""Yield one record per non-empty *line* across the whole page.
|
||||
|
||||
Each yielded tuple is ``(merged_text, line, first_text_span, block)`` where:
|
||||
- ``merged_text`` – all span contents concatenated (LaTeX already converted)
|
||||
- ``line`` – the MinerU line dict (carries the authoritative bbox)
|
||||
- ``first_text_span`` – first span that has font metadata, or ``None``
|
||||
- ``block`` – the containing block (carries ``type``)
|
||||
|
||||
Merging at the line level correctly handles footer / title blocks where a
|
||||
single printed sentence is split across many spans (e.g. text + inline_equation
|
||||
+ text …). Table blocks still produce one record per table because they have
|
||||
exactly one span (type="table") per line.
|
||||
"""
|
||||
def _process_block_set(blocks_iter):
|
||||
for block in blocks_iter:
|
||||
for line, src_block in _iter_lines(block):
|
||||
spans = line.get("spans", [])
|
||||
if not spans:
|
||||
continue
|
||||
|
||||
parts: list[str] = []
|
||||
first_text_span: dict | None = None
|
||||
table_html: str | None = None
|
||||
for span in spans:
|
||||
content = _span_content(span)
|
||||
if content:
|
||||
parts.append(content)
|
||||
if span.get("type") == "table":
|
||||
# 保留原始 HTML,前端可用于渲染含 colspan/rowspan 的复杂表格
|
||||
table_html = span.get("html") or None
|
||||
elif first_text_span is None:
|
||||
first_text_span = span
|
||||
|
||||
merged = "".join(parts)
|
||||
if merged:
|
||||
yield merged, line, first_text_span, src_block, table_html
|
||||
|
||||
yield from _process_block_set(page.get("para_blocks", []))
|
||||
yield from _process_block_set(page.get("blocks", []))
|
||||
|
||||
|
||||
def _bbox(obj: dict) -> tuple[float, float, float, float]:
|
||||
"""Return (x0, y0, x1, y1) from an object's bbox field."""
|
||||
bbox = obj.get("bbox") or [0, 0, 0, 0]
|
||||
if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
|
||||
return float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])
|
||||
return 0.0, 0.0, 0.0, 0.0
|
||||
|
||||
|
||||
def parse_mineru_fields(data: dict) -> MineruDocument:
|
||||
"""Convert raw MinerU structured JSON into a :class:`MineruDocument`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data:
|
||||
The parsed JSON dict returned by :class:`~backend.app.mineru_client.MineruClient`.
|
||||
Must contain a ``pdf_info`` list with one entry per page.
|
||||
|
||||
Returns
|
||||
-------
|
||||
MineruDocument
|
||||
Holds page dimensions and a flat list of text field dicts.
|
||||
"""
|
||||
pdf_info: list[dict] = data.get("pdf_info", [])
|
||||
if not pdf_info:
|
||||
logger.warning("MinerU JSON contains empty pdf_info")
|
||||
return MineruDocument(page_width=595.0, page_height=842.0, fields=[])
|
||||
|
||||
# Use the first page's dimensions for the preview
|
||||
first_page = pdf_info[0]
|
||||
page_width, page_height = _page_size(first_page)
|
||||
|
||||
fields: list[dict] = []
|
||||
for page in pdf_info:
|
||||
page_idx = int(page.get("page_idx", 0))
|
||||
page_num = page_idx + 1
|
||||
pw, ph = _page_size(page)
|
||||
|
||||
for content, line, font_span, _block, table_html in _iter_line_fields(page):
|
||||
# bbox comes from the line (covers all spans in one visual row)
|
||||
x0, y0, x1, y1 = _bbox(line)
|
||||
|
||||
font_size_pt: float | None = None
|
||||
font_name: str | None = None
|
||||
if font_span is not None:
|
||||
raw_size = font_span.get("size") or font_span.get("font_size")
|
||||
if raw_size is not None:
|
||||
try:
|
||||
font_size_pt = float(raw_size)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
font_name = font_span.get("font") or font_span.get("font_name") or None
|
||||
|
||||
font_height_mm: float | None = (
|
||||
round(font_size_pt * PT_TO_MM, 2) if font_size_pt else None
|
||||
)
|
||||
|
||||
block_type = (_block.get("type") or "text").strip() or "text"
|
||||
|
||||
fields.append(
|
||||
{
|
||||
"page": page_num,
|
||||
"block_type": block_type,
|
||||
"text": content,
|
||||
"table_html": table_html,
|
||||
"font_name": font_name,
|
||||
"font_size_pt": round(font_size_pt, 2) if font_size_pt else None,
|
||||
"font_height_mm": font_height_mm,
|
||||
"x0_pt": round(x0, 2),
|
||||
"top_pt": round(y0, 2),
|
||||
"x1_pt": round(x1, 2),
|
||||
"bottom_pt": round(y1, 2),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"MinerU parser extracted %d fields across %d page(s)",
|
||||
len(fields),
|
||||
len(pdf_info),
|
||||
)
|
||||
return MineruDocument(
|
||||
page_width=page_width,
|
||||
page_height=page_height,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
|
||||
def parse_mineru_image_blocks(data: dict) -> list[dict]:
|
||||
"""从 MinerU 结构化 JSON 中提取所有 image 类型的 block。
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of dict,每项包含:
|
||||
- page : 页码(从 1 起)
|
||||
- block_type : "image"
|
||||
- img_path : MinerU 在 zip 包内记录的相对路径(可能为 None)
|
||||
- x0_pt, top_pt, x1_pt, bottom_pt : block 边界框(与文本字段坐标系相同)
|
||||
"""
|
||||
pdf_info: list[dict] = data.get("pdf_info", [])
|
||||
images: list[dict] = []
|
||||
|
||||
for page in pdf_info:
|
||||
page_idx = int(page.get("page_idx", 0))
|
||||
page_num = page_idx + 1
|
||||
|
||||
for blocks_key in ("para_blocks", "blocks"):
|
||||
for block in page.get(blocks_key, []):
|
||||
if (block.get("type") or "").strip().lower() != "image":
|
||||
continue
|
||||
x0, y0, x1, y1 = _bbox(block)
|
||||
# MinerU 有时把图片路径放在这几个字段中
|
||||
img_path = (
|
||||
block.get("img_path")
|
||||
or block.get("image_path")
|
||||
or block.get("path")
|
||||
or None
|
||||
)
|
||||
images.append(
|
||||
{
|
||||
"page": page_num,
|
||||
"block_type": "image",
|
||||
"img_path": img_path,
|
||||
"x0_pt": round(x0, 2),
|
||||
"top_pt": round(y0, 2),
|
||||
"x1_pt": round(x1, 2),
|
||||
"bottom_pt": round(y1, 2),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("MinerU parser found %d image block(s)", len(images))
|
||||
return images
|
||||
507
backend/app/pipeline.py
Normal file
507
backend/app/pipeline.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""Core processing pipeline: AI → PDF → PNG → Qwen crop → MinerU → validate."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from backend.app.barcode_detector import detect_barcodes
|
||||
from backend.app.image_classifier import is_qr_code
|
||||
from backend.app.mineru_client import MineruClient, MineruClientError
|
||||
from backend.app.mineru_parser import parse_mineru_fields, parse_mineru_image_blocks
|
||||
from backend.app.text_validation import validate_field_against_word
|
||||
from backend.app.word_parser import extract_word_html, extract_word_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Environment helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def _get_mineru_api_key() -> str:
|
||||
"""Read MINERU_API_KEY from the process environment or the project .env file."""
|
||||
value = os.environ.get("MINERU_API_KEY", "").strip()
|
||||
if value:
|
||||
return value
|
||||
|
||||
for candidate in (
|
||||
Path(__file__).resolve().parents[2] / ".env",
|
||||
Path(__file__).resolve().parents[3] / ".env",
|
||||
):
|
||||
if not candidate.exists():
|
||||
continue
|
||||
for raw in candidate.read_text(encoding="utf-8").splitlines():
|
||||
line = raw.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, val = line.split("=", 1)
|
||||
if key.strip() == "MINERU_API_KEY":
|
||||
cleaned = val.strip().strip('"').strip("'")
|
||||
if cleaned:
|
||||
logger.info("Loaded MINERU_API_KEY from %s", candidate)
|
||||
return cleaned
|
||||
return ""
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# AI → PDF conversion #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def _ai_to_pdf(ai_path: Path, output_dir: Path) -> Path:
|
||||
"""Convert an Adobe Illustrator file to PDF, keeping the original filename stem.
|
||||
|
||||
Modern .ai files (CS and later) are internally PDF-based; pypdf can copy
|
||||
them directly. Legacy EPS-based .ai files require Ghostscript.
|
||||
If the uploaded file is already a PDF it is copied as-is.
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
pdf_path = output_dir / f"{ai_path.stem}.pdf"
|
||||
|
||||
with ai_path.open("rb") as fh:
|
||||
header = fh.read(8)
|
||||
|
||||
if header.startswith(b"%PDF-"):
|
||||
# PDF-based .ai or an actual PDF – re-write with pypdf for cleanliness
|
||||
try:
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
|
||||
reader = PdfReader(str(ai_path))
|
||||
writer = PdfWriter()
|
||||
for page in reader.pages:
|
||||
writer.add_page(page)
|
||||
with pdf_path.open("wb") as fh:
|
||||
writer.write(fh)
|
||||
logger.info("Converted PDF-based .ai via pypdf: %s", ai_path.name)
|
||||
except Exception as exc:
|
||||
logger.warning("pypdf failed (%s), falling back to direct copy", exc)
|
||||
shutil.copy2(ai_path, pdf_path)
|
||||
else:
|
||||
# Legacy EPS-based .ai → Ghostscript
|
||||
gs = shutil.which("/opt/homebrew/bin/gs") or shutil.which("gs") or shutil.which("ghostscript")
|
||||
if gs is None:
|
||||
raise RuntimeError(
|
||||
"Cannot convert legacy .ai file: Ghostscript is not installed. "
|
||||
"Run: brew install ghostscript"
|
||||
)
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(
|
||||
[gs, "-dNOPAUSE", "-dBATCH", "-dSAFER",
|
||||
"-sDEVICE=pdfwrite", f"-sOutputFile={pdf_path}", str(ai_path)],
|
||||
capture_output=True, text=True, timeout=120,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Ghostscript failed (exit {result.returncode}):\n{result.stderr.strip()}"
|
||||
)
|
||||
logger.info("Converted legacy .ai via Ghostscript: %s", ai_path.name)
|
||||
|
||||
return pdf_path
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# PDF → PNG rasterisation #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def _pdf_to_png(pdf_path: Path, output_dir: Path, dpi: int = 150) -> Path:
|
||||
"""Rasterise the first page of a PDF to a PNG.
|
||||
|
||||
Tries, in order:
|
||||
1. Ghostscript (if installed)
|
||||
2. PyMuPDF (pip install pymupdf)
|
||||
|
||||
Uses a safe output filename ``page1.png`` to avoid issues with special
|
||||
characters in the source PDF name.
|
||||
Returns the path of the generated PNG.
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Use a safe filename – special chars / spaces in the PDF stem can cause
|
||||
# Ghostscript to silently produce no output.
|
||||
png_path = output_dir / "page1.png"
|
||||
|
||||
# ── 1. Ghostscript ────────────────────────────────────────────────────── #
|
||||
gs = (
|
||||
shutil.which("/opt/homebrew/bin/gs")
|
||||
or shutil.which("/usr/local/bin/gs")
|
||||
or shutil.which("ghostscript")
|
||||
)
|
||||
if gs:
|
||||
result = subprocess.run(
|
||||
[
|
||||
gs, "-dNOPAUSE", "-dBATCH", "-dSAFER",
|
||||
"-sDEVICE=png16m", f"-r{dpi}",
|
||||
"-dFirstPage=1", "-dLastPage=1",
|
||||
f"-sOutputFile={png_path}", str(pdf_path),
|
||||
],
|
||||
capture_output=True, text=True, timeout=60,
|
||||
)
|
||||
if result.returncode == 0 and png_path.exists():
|
||||
w, h = _png_size(png_path)
|
||||
logger.info(
|
||||
"Rasterised PDF → PNG via Ghostscript at %d DPI: %dx%d px (%d KB)",
|
||||
dpi, w, h, png_path.stat().st_size // 1024,
|
||||
)
|
||||
return png_path
|
||||
logger.warning("Ghostscript rasterisation failed (exit %d): %s",
|
||||
result.returncode, result.stderr[:300])
|
||||
|
||||
# ── 2. PyMuPDF fallback ───────────────────────────────────────────────── #
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
|
||||
doc = fitz.open(str(pdf_path))
|
||||
page = doc[0]
|
||||
zoom = dpi / 72.0
|
||||
mat = fitz.Matrix(zoom, zoom)
|
||||
pix = page.get_pixmap(matrix=mat, alpha=False)
|
||||
pix.save(str(png_path))
|
||||
doc.close()
|
||||
w, h = _png_size(png_path)
|
||||
logger.info(
|
||||
"Rasterised PDF → PNG via PyMuPDF at %d DPI: %dx%d px (%d KB)",
|
||||
dpi, w, h, png_path.stat().st_size // 1024,
|
||||
)
|
||||
return png_path
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"Cannot rasterise PDF to PNG: neither Ghostscript nor PyMuPDF is "
|
||||
"available. Run: pip install pymupdf OR brew install ghostscript"
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Cannot rasterise PDF to PNG: {exc}") from exc
|
||||
|
||||
|
||||
def _png_size(png_path: Path) -> tuple[int, int]:
|
||||
"""Return (width, height) in pixels of a PNG file."""
|
||||
from PIL import Image
|
||||
with Image.open(png_path) as img:
|
||||
return img.size # (width, height)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Qwen VL region crop #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def _crop_label_region(png_path: Path, output_dir: Path) -> Path:
|
||||
"""Detect the main label area with Qwen VL and crop to it.
|
||||
|
||||
If DASHSCOPE_API_KEY is missing or detection fails, returns the original
|
||||
PNG unchanged so the pipeline continues without interruption.
|
||||
"""
|
||||
from backend.app.region_detector import (
|
||||
_get_api_key,
|
||||
crop_and_save,
|
||||
detect_regions,
|
||||
merge_regions,
|
||||
)
|
||||
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
logger.info("DASHSCOPE_API_KEY not configured – skipping AI crop, using full image")
|
||||
return png_path
|
||||
|
||||
try:
|
||||
regions, _ = detect_regions(png_path, api_key=api_key, api_max_side=1024)
|
||||
except Exception as exc:
|
||||
logger.warning("Qwen region detection failed (%s) – using full image", exc)
|
||||
return png_path
|
||||
|
||||
if not regions:
|
||||
logger.warning("No regions detected by Qwen – using full image")
|
||||
return png_path
|
||||
|
||||
merged = merge_regions(regions)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
cropped_png = output_dir / "cropped_label.png"
|
||||
|
||||
# crop_and_save writes to numbered files; rename for predictability
|
||||
results = crop_and_save(png_path, [merged], output_dir / "_tmp")
|
||||
if not results:
|
||||
return png_path
|
||||
|
||||
import shutil as _sh
|
||||
_sh.move(results[0]["path"], str(cropped_png))
|
||||
|
||||
w, h = _png_size(cropped_png)
|
||||
logger.info(
|
||||
"Qwen crop: bbox=(%d,%d)-(%d,%d) → %s (%dx%d px)",
|
||||
merged.x1, merged.y1, merged.x2, merged.y2,
|
||||
cropped_png.name, w, h,
|
||||
)
|
||||
return cropped_png
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# MinerU image-block QR processing #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def _process_image_blocks(
|
||||
mineru_data: dict,
|
||||
source_image: Path,
|
||||
output_dir: Path,
|
||||
) -> list[dict]:
|
||||
"""对 MinerU 解析出的每个 image 类型 block 执行二维码识别流程。
|
||||
|
||||
流程
|
||||
----
|
||||
1. 从 mineru_data 中提取所有 image block(含 bbox 坐标)。
|
||||
2. 按 bbox 从 source_image(高清裁剪图)中裁出对应区域,保存为临时 PNG。
|
||||
3. 调用 Qwen VL 判断裁出的图片是否为二维码/条形码。
|
||||
4. 如果判断为"是",再调用 zxing 条码模块进行精确解码。
|
||||
5. 返回每个 image block 的处理结果列表。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mineru_data:
|
||||
MinerU 结构化 JSON(包含 pdf_info)。
|
||||
source_image:
|
||||
用于裁剪的高清源图(即发送给 MinerU 的那张 PNG)。
|
||||
output_dir:
|
||||
裁剪图临时存放目录。
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of dict
|
||||
每项对应一个 image block,包含:
|
||||
- page, block_type, x0_pt, top_pt, x1_pt, bottom_pt
|
||||
- is_qr_code : bool — 大模型语义判断结果
|
||||
- barcodes : list — zxing 解码结果(is_qr_code=False 时为空列表)
|
||||
- crop_path : str — 裁剪图相对路径(调试用)
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
image_blocks = parse_mineru_image_blocks(mineru_data)
|
||||
if not image_blocks:
|
||||
return []
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
results: list[dict] = []
|
||||
|
||||
with Image.open(source_image) as src_img:
|
||||
img_w, img_h = src_img.size
|
||||
|
||||
for idx, block in enumerate(image_blocks, start=1):
|
||||
# ── 裁剪 ──────────────────────────────────────────────────────── #
|
||||
x0 = max(0, int(block["x0_pt"]))
|
||||
y0 = max(0, int(block["top_pt"]))
|
||||
x1 = min(img_w, int(block["x1_pt"]))
|
||||
y1 = min(img_h, int(block["bottom_pt"]))
|
||||
|
||||
if x1 <= x0 or y1 <= y0:
|
||||
logger.warning(
|
||||
"_process_image_blocks: block %d 边界框无效 (%d,%d)-(%d,%d),跳过",
|
||||
idx, x0, y0, x1, y1,
|
||||
)
|
||||
results.append({**block, "is_qr_code": False, "barcodes": [], "crop_path": None})
|
||||
continue
|
||||
|
||||
crop = src_img.crop((x0, y0, x1, y1))
|
||||
crop_file = output_dir / f"block_{idx:03d}_p{block['page']}.png"
|
||||
crop.save(crop_file)
|
||||
logger.info(
|
||||
"_process_image_blocks: block %d saved crop %s (%dx%d px)",
|
||||
idx, crop_file.name, x1 - x0, y1 - y0,
|
||||
)
|
||||
|
||||
# ── Qwen VL 语义判断 ──────────────────────────────────────────── #
|
||||
qr_detected = is_qr_code(crop_file)
|
||||
|
||||
# ── 条码解码(仅在语义判断为二维码时执行)────────────────────── #
|
||||
barcodes: list[dict] = []
|
||||
if qr_detected:
|
||||
logger.info(
|
||||
"_process_image_blocks: block %d 被识别为二维码,启动条码解码",
|
||||
idx,
|
||||
)
|
||||
raw_barcodes = detect_barcodes(crop_file)
|
||||
barcodes = [
|
||||
{
|
||||
"format": b.format,
|
||||
"format_label": b.format_label,
|
||||
"text": b.text,
|
||||
"x0": b.x0,
|
||||
"y0": b.y0,
|
||||
"x1": b.x1,
|
||||
"y1": b.y1,
|
||||
"valid": b.valid,
|
||||
}
|
||||
for b in raw_barcodes
|
||||
]
|
||||
if barcodes:
|
||||
logger.info(
|
||||
"_process_image_blocks: block %d 条码解码成功,共 %d 条",
|
||||
idx, len(barcodes),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"_process_image_blocks: block %d 语义判断为二维码,但 zxing 未能解码",
|
||||
idx,
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
**block,
|
||||
"is_qr_code": qr_detected,
|
||||
"barcodes": barcodes,
|
||||
"crop_path": str(crop_file),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Public API #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def process_document(
|
||||
ai_path: Path,
|
||||
word_path: Path,
|
||||
output_dir: Path,
|
||||
job_id: str,
|
||||
) -> dict:
|
||||
"""Full pipeline: AI → PDF → PNG → Qwen crop → MinerU → validate.
|
||||
|
||||
Steps
|
||||
-----
|
||||
1. AI / PDF file → clean PDF
|
||||
2. PDF → high-res PNG (Ghostscript, 150 DPI)
|
||||
3. PNG → Qwen VL detects main label area → cropped PNG
|
||||
(graceful fallback to full PNG when key is absent)
|
||||
4. Cropped PNG → MinerU structured-JSON extraction
|
||||
5. MinerU fields → validate against Word reference document
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
``{ preview: {...}, fields: [...] }`` matching the frontend
|
||||
``ProcessResponse`` type. ``preview.type`` is ``"png"`` and
|
||||
``pageWidthPt`` / ``pageHeightPt`` hold the cropped image dimensions
|
||||
in pixels (coord system is pixel-aligned for the PNG overlay).
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── 1. AI → PDF ──────────────────────────────────────────────────────── #
|
||||
logger.info("Step 1/5 – Converting AI to PDF: %s", ai_path.name)
|
||||
pdf_path = _ai_to_pdf(ai_path, output_dir)
|
||||
|
||||
# ── 2. PDF → PNG ─────────────────────────────────────────────────────── #
|
||||
logger.info("Step 2/5 – Rasterising PDF to PNG (150 DPI)")
|
||||
png_path = _pdf_to_png(pdf_path, output_dir / "raster", dpi=150)
|
||||
|
||||
# ── 3. Qwen VL crop ───────────────────────────────────────────────────── #
|
||||
logger.info("Step 3/5 – AI region detection & crop")
|
||||
cropped_path = _crop_label_region(png_path, output_dir / "crop")
|
||||
|
||||
# Relative URL fragment understood by /api/files/{job_id}/{file_path}
|
||||
cropped_rel = cropped_path.relative_to(output_dir).as_posix()
|
||||
img_w, img_h = _png_size(cropped_path)
|
||||
|
||||
# ── 3b. Barcode detection ─────────────────────────────────────────────── #
|
||||
logger.info("Step 3b – Scanning for barcodes / QR codes")
|
||||
barcodes = detect_barcodes(cropped_path)
|
||||
|
||||
# Crop each barcode region for frontend display
|
||||
barcode_crops_dir = output_dir / "barcode_crops"
|
||||
barcode_crops_dir.mkdir(parents=True, exist_ok=True)
|
||||
from PIL import Image as _PILImage # noqa: PLC0415
|
||||
with _PILImage.open(cropped_path) as _src_img:
|
||||
_src_w, _src_h = _src_img.size
|
||||
for _bi, _b in enumerate(barcodes):
|
||||
_pad = 12
|
||||
_cx0 = max(0, _b.x0 - _pad)
|
||||
_cy0 = max(0, _b.y0 - _pad)
|
||||
_cx1 = min(_src_w, _b.x1 + _pad)
|
||||
_cy1 = min(_src_h, _b.y1 + _pad)
|
||||
_crop = _src_img.crop((_cx0, _cy0, _cx1, _cy1))
|
||||
_crop.save(barcode_crops_dir / f"barcode_{_bi}.png")
|
||||
|
||||
barcode_results = [
|
||||
{
|
||||
"format": b.format,
|
||||
"format_label": b.format_label,
|
||||
"text": b.text,
|
||||
"x0": b.x0,
|
||||
"y0": b.y0,
|
||||
"x1": b.x1,
|
||||
"y1": b.y1,
|
||||
"valid": b.valid,
|
||||
"crop_url": f"/api/files/{job_id}/barcode_crops/barcode_{i}.png",
|
||||
}
|
||||
for i, b in enumerate(barcodes)
|
||||
]
|
||||
logger.info("Step 3b – Found %d barcode(s)", len(barcode_results))
|
||||
|
||||
# ── 4. MinerU parsing ────────────────────────────────────────────────── #
|
||||
logger.info("Step 4/5 – Sending cropped PNG to MinerU: %s", cropped_path.name)
|
||||
mineru_api_key = _get_mineru_api_key()
|
||||
if not mineru_api_key:
|
||||
raise RuntimeError("MINERU_API_KEY is not configured")
|
||||
|
||||
mineru_dir = output_dir / "mineru"
|
||||
client = MineruClient(api_key=mineru_api_key)
|
||||
mineru_data = client.parse_image(cropped_path, mineru_dir)
|
||||
|
||||
# ── 5. Parse + validate ───────────────────────────────────────────────── #
|
||||
logger.info("Step 5/5 – Parsing MinerU result and validating against Word")
|
||||
doc = parse_mineru_fields(mineru_data)
|
||||
word_text = extract_word_text(word_path)
|
||||
word_html = extract_word_html(word_path)
|
||||
|
||||
fields: list[dict] = []
|
||||
for idx, field in enumerate(doc.fields, start=1):
|
||||
validation = validate_field_against_word(field["text"], word_text)
|
||||
fields.append(
|
||||
{
|
||||
"id": f"field-{idx}",
|
||||
**field,
|
||||
"normalized_text": validation.normalized_text,
|
||||
"validation_status": validation.status,
|
||||
"validation_reason": validation.reason,
|
||||
"matched_excerpt": validation.matched_excerpt,
|
||||
}
|
||||
)
|
||||
|
||||
_STATUS_RANK = {"matched": 0, "unmatched": 1, "empty_or_garbled": 2}
|
||||
fields.sort(key=lambda f: (
|
||||
_STATUS_RANK.get(f["validation_status"], 9),
|
||||
f["page"],
|
||||
f["top_pt"],
|
||||
f["x0_pt"],
|
||||
))
|
||||
|
||||
logger.info(
|
||||
"Pipeline done: job_id=%s fields=%d matched=%d unmatched=%d garbled=%d",
|
||||
job_id,
|
||||
len(fields),
|
||||
sum(1 for f in fields if f["validation_status"] == "matched"),
|
||||
sum(1 for f in fields if f["validation_status"] == "unmatched"),
|
||||
sum(1 for f in fields if f["validation_status"] == "empty_or_garbled"),
|
||||
)
|
||||
|
||||
# ── 5b. Image blocks: QR semantic check → barcode decode ─────────────── #
|
||||
image_block_results = _process_image_blocks(
|
||||
mineru_data=mineru_data,
|
||||
source_image=cropped_path,
|
||||
output_dir=output_dir / "image_blocks",
|
||||
)
|
||||
logger.info("Step 5b – Processed %d image block(s) from MinerU", len(image_block_results))
|
||||
|
||||
return {
|
||||
"preview": {
|
||||
# type='png': frontend renders <img> + overlay (not PDF canvas)
|
||||
"type": "png",
|
||||
"url": f"/api/files/{job_id}/{cropped_rel}",
|
||||
# For PNG the "pt" fields carry pixel dimensions so overlay
|
||||
# scale factors remain 1:1 at 100% zoom.
|
||||
"pageWidthPt": img_w,
|
||||
"pageHeightPt": img_h,
|
||||
},
|
||||
"fields": fields,
|
||||
"word_text": word_text,
|
||||
"word_html": word_html,
|
||||
"barcodes": barcode_results,
|
||||
"image_blocks": image_block_results,
|
||||
}
|
||||
372
backend/app/region_detector.py
Normal file
372
backend/app/region_detector.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""Detect main regions in a label image via Qwen2.5-VL (DashScope).
|
||||
|
||||
Workflow
|
||||
--------
|
||||
1. Read the original image; record its exact dimensions (orig_w × orig_h).
|
||||
2. Downscale a copy to fit within ``api_max_side`` for the API call
|
||||
(faster upload, lower token cost). Record api_w × api_h.
|
||||
3. Send the downscaled image to Qwen VL.
|
||||
4. Parse the response coordinates (which are relative to the api image):
|
||||
a. Qwen2.5-VL grounding tokens <|box_start|>(x1,y1),(x2,y2)<|box_end|>
|
||||
– normalised to [0, 1000] of the *api* image.
|
||||
b. Fallback: JSON array ``[{"label": "...", "bbox": [x1,y1,x2,y2]}, ...]``
|
||||
– pixel values in the *api* image coordinate space.
|
||||
5. Scale coordinates back to the original image space:
|
||||
x_orig = round(x_api * orig_w / api_w)
|
||||
6. Crop from the **original** high-resolution file → full-quality output.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DASHSCOPE_BASE_URL_DEFAULT = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
# Model alias: the 7B variant is faster ("flash"); swap for 72B for higher accuracy
|
||||
DEFAULT_MODEL = "qwen2.5-vl-7b-instruct"
|
||||
|
||||
_GROUNDING_RE = re.compile(
|
||||
r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
|
||||
r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
_DEFAULT_PROMPT = (
|
||||
"请检测图像中食品包装标签的所有主要内容区域(如:主产品信息表格、"
|
||||
"营养成分表、标题、配料表、厂商信息、条码区等)。"
|
||||
"以JSON列表输出,格式为:\n"
|
||||
'[{"label": "区域名称", "bbox": [x1, y1, x2, y2]}, ...]'
|
||||
"\n坐标为实际像素值(整数),原点在左上角。"
|
||||
)
|
||||
|
||||
|
||||
class Region(NamedTuple):
|
||||
label: str
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
return self.x2 - self.x1
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return self.y2 - self.y1
|
||||
|
||||
|
||||
def _read_dotenv(path: Path) -> dict[str, str]:
|
||||
"""Parse a simple KEY=VALUE .env file into a dict."""
|
||||
result: dict[str, str] = {}
|
||||
if not path.exists():
|
||||
return result
|
||||
for raw in path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
k, v = line.split("=", 1)
|
||||
result[k.strip()] = v.strip().strip('"').strip("'")
|
||||
return result
|
||||
|
||||
|
||||
def _load_env() -> dict[str, str]:
|
||||
"""Merge .env files (project root → parent → home) into a single dict."""
|
||||
merged: dict[str, str] = {}
|
||||
for p in [
|
||||
Path(__file__).resolve().parents[2] / ".env",
|
||||
Path(__file__).resolve().parents[3] / ".env",
|
||||
Path.home() / ".env",
|
||||
]:
|
||||
merged.update(_read_dotenv(p))
|
||||
return merged
|
||||
|
||||
|
||||
def _get_api_key() -> str:
|
||||
"""Read DASHSCOPE_API_KEY from env vars then .env files."""
|
||||
val = os.environ.get("DASHSCOPE_API_KEY", "").strip()
|
||||
if val:
|
||||
return val
|
||||
return _load_env().get("DASHSCOPE_API_KEY", "")
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Read DASHSCOPE_BASE_URL from env vars then .env files."""
|
||||
val = os.environ.get("DASHSCOPE_BASE_URL", "").strip()
|
||||
if val:
|
||||
return val
|
||||
return _load_env().get("DASHSCOPE_BASE_URL", _DASHSCOPE_BASE_URL_DEFAULT)
|
||||
|
||||
|
||||
def _encode_image_for_api(
|
||||
image_path: Path,
|
||||
max_side: int = 1024,
|
||||
) -> tuple[str, int, int]:
|
||||
"""Downscale image to fit within *max_side* × *max_side*, encode as PNG base64.
|
||||
|
||||
Returns
|
||||
-------
|
||||
b64 : str
|
||||
Base64-encoded PNG of the (possibly resized) image.
|
||||
api_w : int
|
||||
Width of the image that was actually sent to the API.
|
||||
api_h : int
|
||||
Height of the image that was actually sent to the API.
|
||||
"""
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
# Convert to RGB so PNG encoding always works
|
||||
img = img.convert("RGB")
|
||||
orig_w, orig_h = img.size
|
||||
|
||||
if max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / max(orig_w, orig_h)
|
||||
api_w = max(1, round(orig_w * scale))
|
||||
api_h = max(1, round(orig_h * scale))
|
||||
api_img = img.resize((api_w, api_h), Image.LANCZOS)
|
||||
else:
|
||||
api_w, api_h = orig_w, orig_h
|
||||
api_img = img
|
||||
|
||||
buf = io.BytesIO()
|
||||
api_img.save(buf, format="PNG")
|
||||
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return b64, api_w, api_h
|
||||
|
||||
|
||||
def _parse_grounding_tokens(text: str, api_w: int, api_h: int) -> list[Region]:
|
||||
"""Parse <|box_start|>(x1,y1),(x2,y2)<|box_end|> tokens.
|
||||
|
||||
Qwen2.5-VL normalises coordinates to [0, 1000] of the *api* image.
|
||||
Returns pixel coordinates in the api image space.
|
||||
"""
|
||||
regions: list[Region] = []
|
||||
for m in _GROUNDING_RE.finditer(text):
|
||||
label = m.group(1).strip()
|
||||
x1 = round(int(m.group(2)) * api_w / 1000)
|
||||
y1 = round(int(m.group(3)) * api_h / 1000)
|
||||
x2 = round(int(m.group(4)) * api_w / 1000)
|
||||
y2 = round(int(m.group(5)) * api_h / 1000)
|
||||
regions.append(Region(label, x1, y1, x2, y2))
|
||||
return regions
|
||||
|
||||
|
||||
def _parse_json_regions(text: str) -> list[Region]:
|
||||
"""Fallback: extract bbox from a JSON object or array in the response."""
|
||||
clean = re.sub(r"<\|[^|]+\|>", "", text)
|
||||
clean = re.sub(r"```[a-z]*", "", clean).strip("`").strip()
|
||||
|
||||
def _extract_bbox(item: dict) -> list | None:
|
||||
"""Try multiple known bbox key names, including nested dicts."""
|
||||
for key in ("bbox", "bbox_2d", "box", "coordinates", "bounding_box"):
|
||||
v = item.get(key)
|
||||
if isinstance(v, (list, tuple)) and len(v) >= 4:
|
||||
return list(v)
|
||||
# e.g. {"label": {"bbox_2d": [...]}}
|
||||
if isinstance(v, dict):
|
||||
inner = _extract_bbox(v)
|
||||
if inner:
|
||||
return inner
|
||||
# Recurse into all dict values
|
||||
for v in item.values():
|
||||
if isinstance(v, dict):
|
||||
inner = _extract_bbox(v)
|
||||
if inner:
|
||||
return inner
|
||||
return None
|
||||
|
||||
def _region_from_dict(item: dict) -> Region | None:
|
||||
bbox = _extract_bbox(item)
|
||||
if not bbox or len(bbox) < 4:
|
||||
return None
|
||||
# Label: try common keys; skip if value is a nested dict
|
||||
raw_label = (item.get("label") or item.get("name") or item.get("type") or "主内容区")
|
||||
label = raw_label if isinstance(raw_label, str) else "主内容区"
|
||||
x1, y1, x2, y2 = (int(v) for v in bbox[:4])
|
||||
return Region(label, x1, y1, x2, y2)
|
||||
|
||||
# Try single JSON object first (default prompt returns one object)
|
||||
obj_start, obj_end = clean.find("{"), clean.rfind("}")
|
||||
if obj_start != -1 and obj_end > obj_start:
|
||||
try:
|
||||
obj = json.loads(clean[obj_start : obj_end + 1])
|
||||
if isinstance(obj, dict):
|
||||
r = _region_from_dict(obj)
|
||||
if r:
|
||||
return [r]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Fallback to JSON array
|
||||
arr_start, arr_end = clean.find("["), clean.rfind("]")
|
||||
if arr_start == -1 or arr_end <= arr_start:
|
||||
return []
|
||||
try:
|
||||
items = json.loads(clean[arr_start : arr_end + 1])
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
regions: list[Region] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
r = _region_from_dict(item)
|
||||
if r:
|
||||
regions.append(r)
|
||||
return regions
|
||||
|
||||
|
||||
def detect_regions(
|
||||
image_path: Path,
|
||||
api_key: str | None = None,
|
||||
model: str = DEFAULT_MODEL,
|
||||
prompt: str = _DEFAULT_PROMPT,
|
||||
api_max_side: int = 1024,
|
||||
) -> tuple[list[Region], str]:
|
||||
"""Call Qwen VL to detect main regions.
|
||||
|
||||
The image is downscaled to *api_max_side* before the API call for speed
|
||||
and cost efficiency. Returned ``Region`` coordinates are always mapped
|
||||
back to the **original** image pixel space, so ``crop_and_save`` will
|
||||
produce full-resolution output.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
api_max_side:
|
||||
Maximum side length (px) of the image sent to the API.
|
||||
Increase for very large originals where detection needs more detail.
|
||||
|
||||
Returns
|
||||
-------
|
||||
regions : list[Region]
|
||||
Bounding boxes in **original** image coordinates.
|
||||
raw_response : str
|
||||
Full model text (for debugging).
|
||||
"""
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
|
||||
key = api_key or _get_api_key()
|
||||
if not key:
|
||||
raise RuntimeError(
|
||||
"DASHSCOPE_API_KEY not set. "
|
||||
"Add it to the project .env or set the environment variable."
|
||||
)
|
||||
|
||||
# ── 1. Original dimensions ────────────────────────────────────────────
|
||||
with Image.open(image_path) as img:
|
||||
orig_w, orig_h = img.size
|
||||
|
||||
# ── 2. Downscale for API; remember api dims for coordinate mapping ─────
|
||||
b64, api_w, api_h = _encode_image_for_api(image_path, max_side=api_max_side)
|
||||
logger.info(
|
||||
"Calling %s on %s orig=%dx%d → api=%dx%d (scale=%.3f)",
|
||||
model, image_path.name, orig_w, orig_h, api_w, api_h, api_w / orig_w,
|
||||
)
|
||||
|
||||
# ── 3. API call ───────────────────────────────────────────────────────
|
||||
client = OpenAI(api_key=key, base_url=_get_base_url())
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
raw = response.choices[0].message.content or ""
|
||||
logger.debug("Qwen VL raw response:\n%s", raw)
|
||||
|
||||
# ── 4. Parse — coordinates are in api-image space ─────────────────────
|
||||
regions = _parse_grounding_tokens(raw, api_w, api_h)
|
||||
if regions:
|
||||
logger.info("Parsed %d region(s) from grounding tokens", len(regions))
|
||||
else:
|
||||
regions = _parse_json_regions(raw)
|
||||
if regions:
|
||||
logger.info("Parsed %d region(s) from JSON fallback", len(regions))
|
||||
|
||||
if not regions:
|
||||
logger.warning("No regions parsed from response:\n%s", raw[:400])
|
||||
return [], raw
|
||||
|
||||
# ── 5. Scale coordinates back to original image space ─────────────────
|
||||
sx, sy = orig_w / api_w, orig_h / api_h
|
||||
original_regions = [
|
||||
Region(r.label,
|
||||
round(r.x1 * sx), round(r.y1 * sy),
|
||||
round(r.x2 * sx), round(r.y2 * sy))
|
||||
for r in regions
|
||||
]
|
||||
logger.info(
|
||||
"Coordinates remapped api(%dx%d) → orig(%dx%d)",
|
||||
api_w, api_h, orig_w, orig_h,
|
||||
)
|
||||
return original_regions, raw
|
||||
|
||||
|
||||
def merge_regions(regions: list[Region], label: str = "主内容区") -> Region:
|
||||
"""Return the union bounding box of all regions as a single Region."""
|
||||
if not regions:
|
||||
raise ValueError("Cannot merge empty region list")
|
||||
x1 = min(r.x1 for r in regions)
|
||||
y1 = min(r.y1 for r in regions)
|
||||
x2 = max(r.x2 for r in regions)
|
||||
y2 = max(r.y2 for r in regions)
|
||||
return Region(label, x1, y1, x2, y2)
|
||||
|
||||
|
||||
def crop_and_save(
|
||||
image_path: Path,
|
||||
regions: list[Region],
|
||||
output_dir: Path,
|
||||
) -> list[dict]:
|
||||
"""Crop each region and save as PNG.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of dicts with keys: label, bbox, path
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
results: list[dict] = []
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
img_w, img_h = img.size
|
||||
for i, region in enumerate(regions, start=1):
|
||||
# Clamp to image bounds
|
||||
x1 = max(0, region.x1)
|
||||
y1 = max(0, region.y1)
|
||||
x2 = min(img_w, region.x2)
|
||||
y2 = min(img_h, region.y2)
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
logger.warning("Skipping zero-area region: %s", region.label)
|
||||
continue
|
||||
cropped = img.crop((x1, y1, x2, y2))
|
||||
safe_name = re.sub(r"[^\w\u4e00-\u9fff-]", "_", region.label)[:40]
|
||||
out_path = output_dir / f"{i:02d}_{safe_name}.png"
|
||||
cropped.save(out_path)
|
||||
logger.info("Saved region [%s] → %s", region.label, out_path.name)
|
||||
results.append({
|
||||
"label": region.label,
|
||||
"bbox": [x1, y1, x2, y2],
|
||||
"path": str(out_path),
|
||||
})
|
||||
|
||||
return results
|
||||
266
backend/app/text_validation.py
Normal file
266
backend/app/text_validation.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Validate extracted text blocks against a Word document's content."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
# Minimum SequenceMatcher ratio to count as a match (strict: content must be nearly identical)
|
||||
MATCH_THRESHOLD = 0.95
|
||||
# For multi-row tables: individual row match threshold
|
||||
TABLE_ROW_SINGLE_THRESHOLD = 0.95
|
||||
# For multi-row tables: fraction of valid rows that must match
|
||||
TABLE_ROW_MATCH_THRESHOLD = 0.5
|
||||
# Characters below this length are treated as too short to validate
|
||||
MIN_TEXT_LENGTH = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
status: str # "matched" | "unmatched" | "empty_or_garbled"
|
||||
reason: str
|
||||
normalized_text: str
|
||||
matched_excerpt: str | None
|
||||
|
||||
|
||||
# 圆圈序号 ①②③...⑳(NFKC 之前处理,避免转为数字后难以区分)
|
||||
_CIRCLED_NUM_RE = re.compile(r"^[①-⑳]")
|
||||
# 数字列表前缀:"1. " "2." "3. " 等(NFKC 之后处理)
|
||||
_LIST_NUM_RE = re.compile(r"^\d{1,2}[.\s]+")
|
||||
# 句末/列表标点(中英文等价符,忽略差异;保留小数点和冒号)
|
||||
_PUNCT_RE = re.compile(r"[,。;!?、…,;!?]")
|
||||
|
||||
|
||||
def _normalize(text: str) -> str:
|
||||
"""Collapse whitespace and normalise unicode for comparison.
|
||||
|
||||
额外处理:
|
||||
- 去掉首部圆圈序号(①②③)和数字列表前缀(1. 2.)
|
||||
- 忽略中英文标点差异(,。;vs ,.)
|
||||
- 统一 dash 并去掉 dash 两侧空格(50 – 60 → 50-60)
|
||||
"""
|
||||
text = text.lstrip()
|
||||
# 先去圆圈序号(在 NFKC 前,避免 ③→3 后与普通数字混淆)
|
||||
text = _CIRCLED_NUM_RE.sub("", text).lstrip()
|
||||
# Unicode 归一化(全角→半角、① → 1、:→ :、(→ ( 等)
|
||||
text = unicodedata.normalize("NFKC", text)
|
||||
# Strip markdown bold/italic markers
|
||||
text = re.sub(r"\*+", "", text)
|
||||
# 破折号变体归一化:en-dash / em-dash / minus sign → hyphen
|
||||
text = re.sub(r"[–—−]", "-", text)
|
||||
# 去掉 dash 两侧空格:"50 - 60" → "50-60"
|
||||
text = re.sub(r"\s*-\s*", "-", text)
|
||||
# 去掉数字列表前缀(NFKC 后,如 "3. " "4. ")
|
||||
text = _LIST_NUM_RE.sub("", text.lstrip())
|
||||
# 忽略句末/列表标点差异
|
||||
text = _PUNCT_RE.sub("", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def _is_garbled(text: str) -> bool:
|
||||
"""Return True when text is empty, too short, or mostly non-printable."""
|
||||
if not text or len(text) < MIN_TEXT_LENGTH:
|
||||
return True
|
||||
printable = sum(1 for c in text if not unicodedata.category(c).startswith("C"))
|
||||
return printable / len(text) < 0.5
|
||||
|
||||
|
||||
def _word_lines(word_text: str) -> list[str]:
|
||||
"""Split Word Markdown into non-empty normalised lines for matching.
|
||||
|
||||
Grid-table separator rows (e.g. ``+-----+-----+``) are filtered out
|
||||
because they carry no semantic content and would skew similarity scores.
|
||||
"""
|
||||
_SEP_RE = re.compile(r"^[+\-=| ]+$")
|
||||
lines = []
|
||||
for raw in word_text.splitlines():
|
||||
norm = _normalize(raw)
|
||||
if not norm:
|
||||
continue
|
||||
# Skip pandoc grid-table separator rows
|
||||
if _SEP_RE.match(norm.replace(" ", "")):
|
||||
continue
|
||||
lines.append(norm)
|
||||
return lines
|
||||
|
||||
|
||||
def _match_against_line(needle: str, line: str) -> tuple[float, str]:
|
||||
"""Return (ratio, excerpt) for needle vs a single Word line.
|
||||
|
||||
When the needle (MinerU row) is significantly shorter than the Word line
|
||||
(because the Word table has more product columns), a plain
|
||||
SequenceMatcher ratio under-counts matching content. We also compute
|
||||
*needle coverage* — the fraction of the needle's characters that appear
|
||||
in the line — and take the higher of the two scores.
|
||||
"""
|
||||
# Exact substring
|
||||
if needle in line:
|
||||
idx = line.index(needle)
|
||||
return 1.0, line[idx: idx + len(needle) + 20].strip()
|
||||
|
||||
matcher = SequenceMatcher(None, needle, line, autojunk=False)
|
||||
ratio = matcher.ratio()
|
||||
|
||||
# Coverage ratio: useful when MinerU row is a partial view of a wider table
|
||||
if len(needle) > 0 and len(needle) < len(line):
|
||||
match_chars = sum(t for _, _, t in matcher.get_matching_blocks())
|
||||
coverage = match_chars / len(needle)
|
||||
# Apply a small discount to avoid false positives on very short needles
|
||||
ratio = max(ratio, coverage * 0.95)
|
||||
|
||||
# 表格行(含 | 分隔符)可能很长,给更多上下文以便前端完整渲染
|
||||
max_len = 400 if line.lstrip().startswith("|") else 120
|
||||
return ratio, line[:max_len].strip()
|
||||
|
||||
|
||||
def _match_single_line(norm: str, word_lines: list[str]) -> tuple[float, str]:
|
||||
"""在 word_lines 中找与 norm 最相似的行,返回 (best_ratio, best_excerpt)。"""
|
||||
best_ratio = 0.0
|
||||
best_excerpt = ""
|
||||
for line in word_lines:
|
||||
ratio, excerpt = _match_against_line(norm, line)
|
||||
if ratio > best_ratio:
|
||||
best_ratio = ratio
|
||||
best_excerpt = excerpt
|
||||
if best_ratio == 1.0:
|
||||
break
|
||||
return best_ratio, best_excerpt
|
||||
|
||||
|
||||
def _validate_table_against_word(raw_rows: list[str], word_text: str) -> ValidationResult:
|
||||
"""多行表格逐行匹配,聚合命中率。
|
||||
|
||||
策略
|
||||
----
|
||||
- 对每一行分别调用单行匹配,达到阈值则计为命中。
|
||||
- 命中率 ≥ TABLE_ROW_MATCH_THRESHOLD(50%)即视为整体匹配。
|
||||
- matched_excerpt 收集命中行的 Word 摘录,前端可渲染为表格。
|
||||
"""
|
||||
word_lines = _word_lines(word_text)
|
||||
if not word_lines:
|
||||
norm_full = _normalize(" ".join(raw_rows))
|
||||
return ValidationResult(
|
||||
status="unmatched",
|
||||
reason="Word 文档为空",
|
||||
normalized_text=norm_full,
|
||||
matched_excerpt=None,
|
||||
)
|
||||
|
||||
matched = 0
|
||||
skipped = 0
|
||||
excerpts: list[str] = []
|
||||
seen_excerpts: set[str] = set()
|
||||
|
||||
for row in raw_rows:
|
||||
norm_row = _normalize(row)
|
||||
if _is_garbled(norm_row):
|
||||
skipped += 1
|
||||
continue
|
||||
ratio, exc = _match_single_line(norm_row, word_lines)
|
||||
if ratio >= TABLE_ROW_SINGLE_THRESHOLD:
|
||||
matched += 1
|
||||
if exc and exc not in seen_excerpts:
|
||||
excerpts.append(exc)
|
||||
seen_excerpts.add(exc)
|
||||
|
||||
valid_count = len(raw_rows) - skipped
|
||||
norm_full = _normalize(" ".join(raw_rows))
|
||||
|
||||
if valid_count == 0:
|
||||
return ValidationResult(
|
||||
status="empty_or_garbled",
|
||||
reason="表格文本为空或全部为乱码",
|
||||
normalized_text=norm_full,
|
||||
matched_excerpt=None,
|
||||
)
|
||||
|
||||
match_rate = matched / valid_count
|
||||
excerpt_text = "\n".join(excerpts) if excerpts else None
|
||||
|
||||
if match_rate >= TABLE_ROW_MATCH_THRESHOLD:
|
||||
return ValidationResult(
|
||||
status="matched",
|
||||
reason=f"表格 {matched}/{valid_count} 行与 Word 匹配(命中率 {match_rate:.0%})",
|
||||
normalized_text=norm_full,
|
||||
matched_excerpt=excerpt_text,
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
status="unmatched",
|
||||
reason=f"表格仅 {matched}/{valid_count} 行与 Word 匹配(命中率 {match_rate:.0%},阈值 {TABLE_ROW_MATCH_THRESHOLD:.0%})",
|
||||
normalized_text=norm_full,
|
||||
matched_excerpt=excerpt_text,
|
||||
)
|
||||
|
||||
|
||||
def validate_field_against_word(text: str, word_text: str) -> ValidationResult:
|
||||
"""Check whether *text* matches any line of *word_text*.
|
||||
|
||||
- 单行文本:找 Word 中最相似的一行,相似度 ≥ 0.82 视为匹配。
|
||||
- 多行文本(表格):逐行匹配,命中率 ≥ 50% 视为整体匹配。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text:
|
||||
The OCR-extracted text block to validate.
|
||||
word_text:
|
||||
Full Markdown text extracted from the reference Word document.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ValidationResult
|
||||
Contains status, a human-readable reason, the normalised text,
|
||||
and the best-matching line from the Word document (if any).
|
||||
"""
|
||||
# 多行文本(表格):逐行匹配
|
||||
raw_rows = [r.strip() for r in text.splitlines() if r.strip()]
|
||||
if len(raw_rows) > 1:
|
||||
return _validate_table_against_word(raw_rows, word_text)
|
||||
|
||||
# 单行匹配
|
||||
norm = _normalize(text)
|
||||
|
||||
if _is_garbled(norm):
|
||||
return ValidationResult(
|
||||
status="empty_or_garbled",
|
||||
reason="文本为空或包含乱码",
|
||||
normalized_text=norm,
|
||||
matched_excerpt=None,
|
||||
)
|
||||
|
||||
word_lines = _word_lines(word_text)
|
||||
if not word_lines:
|
||||
return ValidationResult(
|
||||
status="unmatched",
|
||||
reason="Word 文档为空",
|
||||
normalized_text=norm,
|
||||
matched_excerpt=None,
|
||||
)
|
||||
|
||||
best_ratio, best_excerpt = _match_single_line(norm, word_lines)
|
||||
|
||||
if best_ratio == 1.0:
|
||||
return ValidationResult(
|
||||
status="matched",
|
||||
reason="与 Word 某行内容完全匹配",
|
||||
normalized_text=norm,
|
||||
matched_excerpt=best_excerpt,
|
||||
)
|
||||
|
||||
if best_ratio >= MATCH_THRESHOLD:
|
||||
return ValidationResult(
|
||||
status="matched",
|
||||
reason=f"与 Word 某行相似度 {best_ratio:.0%},判定为匹配",
|
||||
normalized_text=norm,
|
||||
matched_excerpt=best_excerpt,
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
status="unmatched",
|
||||
reason=f"在 Word 中未找到匹配行(最高相似度 {best_ratio:.0%})",
|
||||
normalized_text=norm,
|
||||
matched_excerpt=best_excerpt or None,
|
||||
)
|
||||
147
backend/app/word_parser.py
Normal file
147
backend/app/word_parser.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Extract text / HTML from a Word (.docx) document via pandoc."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def extract_word_text(path: Path) -> str:
|
||||
"""Convert *path* (.docx) to Markdown with pandoc and return the result.
|
||||
|
||||
Pandoc preserves tables, lists, bold/italic, and paragraph structure far
|
||||
better than python-docx plain-text extraction. The returned string is
|
||||
cleaned of pandoc-specific span attributes (e.g. ``{.mark}``) that are
|
||||
irrelevant for text matching.
|
||||
|
||||
Falls back to python-docx plain-text extraction if pandoc is not installed.
|
||||
"""
|
||||
pandoc = shutil.which("pandoc")
|
||||
if pandoc:
|
||||
return _extract_via_pandoc(path, pandoc)
|
||||
return _extract_via_docx(path)
|
||||
|
||||
|
||||
def extract_word_html(path: Path) -> str | None:
|
||||
"""Convert *path* (.docx) to an HTML fragment preserving merged table cells.
|
||||
|
||||
Uses pandoc (``-t html5``) which correctly maps Word's ``<w:gridSpan>`` /
|
||||
``<w:vMerge>`` to HTML ``colspan`` / ``rowspan`` attributes.
|
||||
|
||||
Returns ``None`` when pandoc is unavailable or conversion fails.
|
||||
The returned string is a ``<body>`` fragment (no ``<html>`` / ``<head>``),
|
||||
with inline ``style`` attributes and ``<colgroup>`` stripped so that the
|
||||
frontend can apply its own CSS.
|
||||
"""
|
||||
pandoc = shutil.which("pandoc")
|
||||
if not pandoc:
|
||||
return None
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[pandoc, str(path), "-f", "docx", "-t", "html5", "--wrap=none"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
return _clean_word_html(result.stdout)
|
||||
|
||||
|
||||
def _clean_word_html(html: str) -> str:
|
||||
"""Extract <body> content and strip noise added by pandoc."""
|
||||
# 取 <body> 内容
|
||||
m = re.search(r"<body[^>]*>(.*?)</body>", html, re.DOTALL | re.IGNORECASE)
|
||||
body = m.group(1).strip() if m else html
|
||||
|
||||
# 删除 <colgroup> 块(含列宽 inline style,由前端 CSS 接管)
|
||||
body = re.sub(r"<colgroup[^>]*>.*?</colgroup>", "", body, flags=re.DOTALL | re.IGNORECASE)
|
||||
# 删除所有 style="..." 属性
|
||||
body = re.sub(r'\s+style="[^"]*"', "", body)
|
||||
# 删除 pandoc 输出的空 <p></p>
|
||||
body = re.sub(r"<p>\s*</p>", "", body)
|
||||
|
||||
return body.strip()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Pandoc path #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def _extract_via_pandoc(path: Path, pandoc: str) -> str:
|
||||
result = subprocess.run(
|
||||
[
|
||||
pandoc,
|
||||
str(path),
|
||||
"-f", "docx",
|
||||
"-t", "markdown",
|
||||
"--wrap=none",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"pandoc failed (exit {result.returncode}):\n{result.stderr.strip()}"
|
||||
)
|
||||
pandoc_text = _clean_pandoc_markdown(result.stdout)
|
||||
|
||||
# pandoc 会丢弃包含浮动形状(AlternateContent / WPS 图形)的段落的文字内容。
|
||||
# 用 python-docx 补充:找出 pandoc 没有输出的段落文本,追加到末尾。
|
||||
# 对文本匹配无副作用(最坏情况是轻微重复,不影响 SequenceMatcher 结果)。
|
||||
try:
|
||||
from docx import Document # type: ignore
|
||||
doc = Document(str(path))
|
||||
missing: list[str] = []
|
||||
for para in doc.paragraphs:
|
||||
text = para.text.strip()
|
||||
if text and text not in pandoc_text:
|
||||
missing.append(text)
|
||||
if missing:
|
||||
pandoc_text = pandoc_text + "\n" + "\n".join(missing)
|
||||
except Exception:
|
||||
pass # python-docx 不可用时静默降级,pandoc 结果仍然有效
|
||||
|
||||
return pandoc_text
|
||||
|
||||
|
||||
def _clean_pandoc_markdown(text: str) -> str:
|
||||
"""Remove pandoc-specific inline attributes that noise up text matching."""
|
||||
# [text]{.mark} / [text]{#id .cls key=val} → text
|
||||
text = re.sub(r"\[([^\]]*)\]\{[^}]*\}", r"\1", text)
|
||||
# Leftover bare {…} attribute blocks on their own
|
||||
text = re.sub(r"\{[^}]*\}", "", text)
|
||||
return text
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# python-docx fallback #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def _extract_via_docx(path: Path) -> str:
|
||||
from docx import Document # type: ignore
|
||||
|
||||
doc = Document(str(path))
|
||||
lines = [para.text for para in doc.paragraphs if para.text.strip()]
|
||||
|
||||
seen_cells: set[int] = set()
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
cells: list[str] = []
|
||||
for cell in row.cells:
|
||||
if id(cell) in seen_cells:
|
||||
continue
|
||||
seen_cells.add(id(cell))
|
||||
text = cell.text.strip()
|
||||
if text:
|
||||
cells.append(text)
|
||||
if cells:
|
||||
lines.append("|".join(cells))
|
||||
|
||||
return "\n".join(lines)
|
||||
Reference in New Issue
Block a user