Initial commit: 包装审核 POC、Docker 与前后端
Made-with: Cursor
This commit is contained in:
372
backend/app/region_detector.py
Normal file
372
backend/app/region_detector.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""Detect main regions in a label image via Qwen2.5-VL (DashScope).
|
||||
|
||||
Workflow
|
||||
--------
|
||||
1. Read the original image; record its exact dimensions (orig_w × orig_h).
|
||||
2. Downscale a copy to fit within ``api_max_side`` for the API call
|
||||
(faster upload, lower token cost). Record api_w × api_h.
|
||||
3. Send the downscaled image to Qwen VL.
|
||||
4. Parse the response coordinates (which are relative to the api image):
|
||||
a. Qwen2.5-VL grounding tokens <|box_start|>(x1,y1),(x2,y2)<|box_end|>
|
||||
– normalised to [0, 1000] of the *api* image.
|
||||
b. Fallback: JSON array ``[{"label": "...", "bbox": [x1,y1,x2,y2]}, ...]``
|
||||
– pixel values in the *api* image coordinate space.
|
||||
5. Scale coordinates back to the original image space:
|
||||
x_orig = round(x_api * orig_w / api_w)
|
||||
6. Crop from the **original** high-resolution file → full-quality output.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DASHSCOPE_BASE_URL_DEFAULT = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
# Model alias: the 7B variant is faster ("flash"); swap for 72B for higher accuracy
|
||||
DEFAULT_MODEL = "qwen2.5-vl-7b-instruct"
|
||||
|
||||
_GROUNDING_RE = re.compile(
|
||||
r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
|
||||
r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
_DEFAULT_PROMPT = (
|
||||
"请检测图像中食品包装标签的所有主要内容区域(如:主产品信息表格、"
|
||||
"营养成分表、标题、配料表、厂商信息、条码区等)。"
|
||||
"以JSON列表输出,格式为:\n"
|
||||
'[{"label": "区域名称", "bbox": [x1, y1, x2, y2]}, ...]'
|
||||
"\n坐标为实际像素值(整数),原点在左上角。"
|
||||
)
|
||||
|
||||
|
||||
class Region(NamedTuple):
|
||||
label: str
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
return self.x2 - self.x1
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return self.y2 - self.y1
|
||||
|
||||
|
||||
def _read_dotenv(path: Path) -> dict[str, str]:
|
||||
"""Parse a simple KEY=VALUE .env file into a dict."""
|
||||
result: dict[str, str] = {}
|
||||
if not path.exists():
|
||||
return result
|
||||
for raw in path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
k, v = line.split("=", 1)
|
||||
result[k.strip()] = v.strip().strip('"').strip("'")
|
||||
return result
|
||||
|
||||
|
||||
def _load_env() -> dict[str, str]:
|
||||
"""Merge .env files (project root → parent → home) into a single dict."""
|
||||
merged: dict[str, str] = {}
|
||||
for p in [
|
||||
Path(__file__).resolve().parents[2] / ".env",
|
||||
Path(__file__).resolve().parents[3] / ".env",
|
||||
Path.home() / ".env",
|
||||
]:
|
||||
merged.update(_read_dotenv(p))
|
||||
return merged
|
||||
|
||||
|
||||
def _get_api_key() -> str:
|
||||
"""Read DASHSCOPE_API_KEY from env vars then .env files."""
|
||||
val = os.environ.get("DASHSCOPE_API_KEY", "").strip()
|
||||
if val:
|
||||
return val
|
||||
return _load_env().get("DASHSCOPE_API_KEY", "")
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Read DASHSCOPE_BASE_URL from env vars then .env files."""
|
||||
val = os.environ.get("DASHSCOPE_BASE_URL", "").strip()
|
||||
if val:
|
||||
return val
|
||||
return _load_env().get("DASHSCOPE_BASE_URL", _DASHSCOPE_BASE_URL_DEFAULT)
|
||||
|
||||
|
||||
def _encode_image_for_api(
|
||||
image_path: Path,
|
||||
max_side: int = 1024,
|
||||
) -> tuple[str, int, int]:
|
||||
"""Downscale image to fit within *max_side* × *max_side*, encode as PNG base64.
|
||||
|
||||
Returns
|
||||
-------
|
||||
b64 : str
|
||||
Base64-encoded PNG of the (possibly resized) image.
|
||||
api_w : int
|
||||
Width of the image that was actually sent to the API.
|
||||
api_h : int
|
||||
Height of the image that was actually sent to the API.
|
||||
"""
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
# Convert to RGB so PNG encoding always works
|
||||
img = img.convert("RGB")
|
||||
orig_w, orig_h = img.size
|
||||
|
||||
if max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / max(orig_w, orig_h)
|
||||
api_w = max(1, round(orig_w * scale))
|
||||
api_h = max(1, round(orig_h * scale))
|
||||
api_img = img.resize((api_w, api_h), Image.LANCZOS)
|
||||
else:
|
||||
api_w, api_h = orig_w, orig_h
|
||||
api_img = img
|
||||
|
||||
buf = io.BytesIO()
|
||||
api_img.save(buf, format="PNG")
|
||||
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return b64, api_w, api_h
|
||||
|
||||
|
||||
def _parse_grounding_tokens(text: str, api_w: int, api_h: int) -> list[Region]:
|
||||
"""Parse <|box_start|>(x1,y1),(x2,y2)<|box_end|> tokens.
|
||||
|
||||
Qwen2.5-VL normalises coordinates to [0, 1000] of the *api* image.
|
||||
Returns pixel coordinates in the api image space.
|
||||
"""
|
||||
regions: list[Region] = []
|
||||
for m in _GROUNDING_RE.finditer(text):
|
||||
label = m.group(1).strip()
|
||||
x1 = round(int(m.group(2)) * api_w / 1000)
|
||||
y1 = round(int(m.group(3)) * api_h / 1000)
|
||||
x2 = round(int(m.group(4)) * api_w / 1000)
|
||||
y2 = round(int(m.group(5)) * api_h / 1000)
|
||||
regions.append(Region(label, x1, y1, x2, y2))
|
||||
return regions
|
||||
|
||||
|
||||
def _parse_json_regions(text: str) -> list[Region]:
|
||||
"""Fallback: extract bbox from a JSON object or array in the response."""
|
||||
clean = re.sub(r"<\|[^|]+\|>", "", text)
|
||||
clean = re.sub(r"```[a-z]*", "", clean).strip("`").strip()
|
||||
|
||||
def _extract_bbox(item: dict) -> list | None:
|
||||
"""Try multiple known bbox key names, including nested dicts."""
|
||||
for key in ("bbox", "bbox_2d", "box", "coordinates", "bounding_box"):
|
||||
v = item.get(key)
|
||||
if isinstance(v, (list, tuple)) and len(v) >= 4:
|
||||
return list(v)
|
||||
# e.g. {"label": {"bbox_2d": [...]}}
|
||||
if isinstance(v, dict):
|
||||
inner = _extract_bbox(v)
|
||||
if inner:
|
||||
return inner
|
||||
# Recurse into all dict values
|
||||
for v in item.values():
|
||||
if isinstance(v, dict):
|
||||
inner = _extract_bbox(v)
|
||||
if inner:
|
||||
return inner
|
||||
return None
|
||||
|
||||
def _region_from_dict(item: dict) -> Region | None:
|
||||
bbox = _extract_bbox(item)
|
||||
if not bbox or len(bbox) < 4:
|
||||
return None
|
||||
# Label: try common keys; skip if value is a nested dict
|
||||
raw_label = (item.get("label") or item.get("name") or item.get("type") or "主内容区")
|
||||
label = raw_label if isinstance(raw_label, str) else "主内容区"
|
||||
x1, y1, x2, y2 = (int(v) for v in bbox[:4])
|
||||
return Region(label, x1, y1, x2, y2)
|
||||
|
||||
# Try single JSON object first (default prompt returns one object)
|
||||
obj_start, obj_end = clean.find("{"), clean.rfind("}")
|
||||
if obj_start != -1 and obj_end > obj_start:
|
||||
try:
|
||||
obj = json.loads(clean[obj_start : obj_end + 1])
|
||||
if isinstance(obj, dict):
|
||||
r = _region_from_dict(obj)
|
||||
if r:
|
||||
return [r]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Fallback to JSON array
|
||||
arr_start, arr_end = clean.find("["), clean.rfind("]")
|
||||
if arr_start == -1 or arr_end <= arr_start:
|
||||
return []
|
||||
try:
|
||||
items = json.loads(clean[arr_start : arr_end + 1])
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
regions: list[Region] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
r = _region_from_dict(item)
|
||||
if r:
|
||||
regions.append(r)
|
||||
return regions
|
||||
|
||||
|
||||
def detect_regions(
|
||||
image_path: Path,
|
||||
api_key: str | None = None,
|
||||
model: str = DEFAULT_MODEL,
|
||||
prompt: str = _DEFAULT_PROMPT,
|
||||
api_max_side: int = 1024,
|
||||
) -> tuple[list[Region], str]:
|
||||
"""Call Qwen VL to detect main regions.
|
||||
|
||||
The image is downscaled to *api_max_side* before the API call for speed
|
||||
and cost efficiency. Returned ``Region`` coordinates are always mapped
|
||||
back to the **original** image pixel space, so ``crop_and_save`` will
|
||||
produce full-resolution output.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
api_max_side:
|
||||
Maximum side length (px) of the image sent to the API.
|
||||
Increase for very large originals where detection needs more detail.
|
||||
|
||||
Returns
|
||||
-------
|
||||
regions : list[Region]
|
||||
Bounding boxes in **original** image coordinates.
|
||||
raw_response : str
|
||||
Full model text (for debugging).
|
||||
"""
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
|
||||
key = api_key or _get_api_key()
|
||||
if not key:
|
||||
raise RuntimeError(
|
||||
"DASHSCOPE_API_KEY not set. "
|
||||
"Add it to the project .env or set the environment variable."
|
||||
)
|
||||
|
||||
# ── 1. Original dimensions ────────────────────────────────────────────
|
||||
with Image.open(image_path) as img:
|
||||
orig_w, orig_h = img.size
|
||||
|
||||
# ── 2. Downscale for API; remember api dims for coordinate mapping ─────
|
||||
b64, api_w, api_h = _encode_image_for_api(image_path, max_side=api_max_side)
|
||||
logger.info(
|
||||
"Calling %s on %s orig=%dx%d → api=%dx%d (scale=%.3f)",
|
||||
model, image_path.name, orig_w, orig_h, api_w, api_h, api_w / orig_w,
|
||||
)
|
||||
|
||||
# ── 3. API call ───────────────────────────────────────────────────────
|
||||
client = OpenAI(api_key=key, base_url=_get_base_url())
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
raw = response.choices[0].message.content or ""
|
||||
logger.debug("Qwen VL raw response:\n%s", raw)
|
||||
|
||||
# ── 4. Parse — coordinates are in api-image space ─────────────────────
|
||||
regions = _parse_grounding_tokens(raw, api_w, api_h)
|
||||
if regions:
|
||||
logger.info("Parsed %d region(s) from grounding tokens", len(regions))
|
||||
else:
|
||||
regions = _parse_json_regions(raw)
|
||||
if regions:
|
||||
logger.info("Parsed %d region(s) from JSON fallback", len(regions))
|
||||
|
||||
if not regions:
|
||||
logger.warning("No regions parsed from response:\n%s", raw[:400])
|
||||
return [], raw
|
||||
|
||||
# ── 5. Scale coordinates back to original image space ─────────────────
|
||||
sx, sy = orig_w / api_w, orig_h / api_h
|
||||
original_regions = [
|
||||
Region(r.label,
|
||||
round(r.x1 * sx), round(r.y1 * sy),
|
||||
round(r.x2 * sx), round(r.y2 * sy))
|
||||
for r in regions
|
||||
]
|
||||
logger.info(
|
||||
"Coordinates remapped api(%dx%d) → orig(%dx%d)",
|
||||
api_w, api_h, orig_w, orig_h,
|
||||
)
|
||||
return original_regions, raw
|
||||
|
||||
|
||||
def merge_regions(regions: list[Region], label: str = "主内容区") -> Region:
|
||||
"""Return the union bounding box of all regions as a single Region."""
|
||||
if not regions:
|
||||
raise ValueError("Cannot merge empty region list")
|
||||
x1 = min(r.x1 for r in regions)
|
||||
y1 = min(r.y1 for r in regions)
|
||||
x2 = max(r.x2 for r in regions)
|
||||
y2 = max(r.y2 for r in regions)
|
||||
return Region(label, x1, y1, x2, y2)
|
||||
|
||||
|
||||
def crop_and_save(
|
||||
image_path: Path,
|
||||
regions: list[Region],
|
||||
output_dir: Path,
|
||||
) -> list[dict]:
|
||||
"""Crop each region and save as PNG.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of dicts with keys: label, bbox, path
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
results: list[dict] = []
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
img_w, img_h = img.size
|
||||
for i, region in enumerate(regions, start=1):
|
||||
# Clamp to image bounds
|
||||
x1 = max(0, region.x1)
|
||||
y1 = max(0, region.y1)
|
||||
x2 = min(img_w, region.x2)
|
||||
y2 = min(img_h, region.y2)
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
logger.warning("Skipping zero-area region: %s", region.label)
|
||||
continue
|
||||
cropped = img.crop((x1, y1, x2, y2))
|
||||
safe_name = re.sub(r"[^\w\u4e00-\u9fff-]", "_", region.label)[:40]
|
||||
out_path = output_dir / f"{i:02d}_{safe_name}.png"
|
||||
cropped.save(out_path)
|
||||
logger.info("Saved region [%s] → %s", region.label, out_path.name)
|
||||
results.append({
|
||||
"label": region.label,
|
||||
"bbox": [x1, y1, x2, y2],
|
||||
"path": str(out_path),
|
||||
})
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user