Files
ZLD_POC/backend/app/region_detector.py
2026-04-15 17:18:49 +08:00

373 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Detect main regions in a label image via Qwen2.5-VL (DashScope).
Workflow
--------
1. Read the original image; record its exact dimensions (orig_w × orig_h).
2. Downscale a copy to fit within ``api_max_side`` for the API call
(faster upload, lower token cost). Record api_w × api_h.
3. Send the downscaled image to Qwen VL.
4. Parse the response coordinates (which are relative to the api image):
a. Qwen2.5-VL grounding tokens <|box_start|>(x1,y1),(x2,y2)<|box_end|>
normalised to [0, 1000] of the *api* image.
b. Fallback: JSON array ``[{"label": "...", "bbox": [x1,y1,x2,y2]}, ...]``
pixel values in the *api* image coordinate space.
5. Scale coordinates back to the original image space:
x_orig = round(x_api * orig_w / api_w)
6. Crop from the **original** high-resolution file → full-quality output.
"""
from __future__ import annotations
import base64
import json
import logging
import os
import re
from pathlib import Path
from typing import NamedTuple
logger = logging.getLogger(__name__)
_DASHSCOPE_BASE_URL_DEFAULT = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# Model alias: the 7B variant is faster ("flash"); swap for 72B for higher accuracy
DEFAULT_MODEL = "qwen2.5-vl-7b-instruct"
_GROUNDING_RE = re.compile(
r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|>",
re.DOTALL,
)
_DEFAULT_PROMPT = (
"请检测图像中食品包装标签的所有主要内容区域(如:主产品信息表格、"
"营养成分表、标题、配料表、厂商信息、条码区等)。"
"以JSON列表输出格式为\n"
'[{"label": "区域名称", "bbox": [x1, y1, x2, y2]}, ...]'
"\n坐标为实际像素值(整数),原点在左上角。"
)
class Region(NamedTuple):
label: str
x1: int
y1: int
x2: int
y2: int
@property
def width(self) -> int:
return self.x2 - self.x1
@property
def height(self) -> int:
return self.y2 - self.y1
def _read_dotenv(path: Path) -> dict[str, str]:
"""Parse a simple KEY=VALUE .env file into a dict."""
result: dict[str, str] = {}
if not path.exists():
return result
for raw in path.read_text(encoding="utf-8").splitlines():
line = raw.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, v = line.split("=", 1)
result[k.strip()] = v.strip().strip('"').strip("'")
return result
def _load_env() -> dict[str, str]:
"""Merge .env files (project root → parent → home) into a single dict."""
merged: dict[str, str] = {}
for p in [
Path(__file__).resolve().parents[2] / ".env",
Path(__file__).resolve().parents[3] / ".env",
Path.home() / ".env",
]:
merged.update(_read_dotenv(p))
return merged
def _get_api_key() -> str:
"""Read DASHSCOPE_API_KEY from env vars then .env files."""
val = os.environ.get("DASHSCOPE_API_KEY", "").strip()
if val:
return val
return _load_env().get("DASHSCOPE_API_KEY", "")
def _get_base_url() -> str:
"""Read DASHSCOPE_BASE_URL from env vars then .env files."""
val = os.environ.get("DASHSCOPE_BASE_URL", "").strip()
if val:
return val
return _load_env().get("DASHSCOPE_BASE_URL", _DASHSCOPE_BASE_URL_DEFAULT)
def _encode_image_for_api(
image_path: Path,
max_side: int = 1024,
) -> tuple[str, int, int]:
"""Downscale image to fit within *max_side* × *max_side*, encode as PNG base64.
Returns
-------
b64 : str
Base64-encoded PNG of the (possibly resized) image.
api_w : int
Width of the image that was actually sent to the API.
api_h : int
Height of the image that was actually sent to the API.
"""
import io
from PIL import Image
with Image.open(image_path) as img:
# Convert to RGB so PNG encoding always works
img = img.convert("RGB")
orig_w, orig_h = img.size
if max(orig_w, orig_h) > max_side:
scale = max_side / max(orig_w, orig_h)
api_w = max(1, round(orig_w * scale))
api_h = max(1, round(orig_h * scale))
api_img = img.resize((api_w, api_h), Image.LANCZOS)
else:
api_w, api_h = orig_w, orig_h
api_img = img
buf = io.BytesIO()
api_img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return b64, api_w, api_h
def _parse_grounding_tokens(text: str, api_w: int, api_h: int) -> list[Region]:
"""Parse <|box_start|>(x1,y1),(x2,y2)<|box_end|> tokens.
Qwen2.5-VL normalises coordinates to [0, 1000] of the *api* image.
Returns pixel coordinates in the api image space.
"""
regions: list[Region] = []
for m in _GROUNDING_RE.finditer(text):
label = m.group(1).strip()
x1 = round(int(m.group(2)) * api_w / 1000)
y1 = round(int(m.group(3)) * api_h / 1000)
x2 = round(int(m.group(4)) * api_w / 1000)
y2 = round(int(m.group(5)) * api_h / 1000)
regions.append(Region(label, x1, y1, x2, y2))
return regions
def _parse_json_regions(text: str) -> list[Region]:
"""Fallback: extract bbox from a JSON object or array in the response."""
clean = re.sub(r"<\|[^|]+\|>", "", text)
clean = re.sub(r"```[a-z]*", "", clean).strip("`").strip()
def _extract_bbox(item: dict) -> list | None:
"""Try multiple known bbox key names, including nested dicts."""
for key in ("bbox", "bbox_2d", "box", "coordinates", "bounding_box"):
v = item.get(key)
if isinstance(v, (list, tuple)) and len(v) >= 4:
return list(v)
# e.g. {"label": {"bbox_2d": [...]}}
if isinstance(v, dict):
inner = _extract_bbox(v)
if inner:
return inner
# Recurse into all dict values
for v in item.values():
if isinstance(v, dict):
inner = _extract_bbox(v)
if inner:
return inner
return None
def _region_from_dict(item: dict) -> Region | None:
bbox = _extract_bbox(item)
if not bbox or len(bbox) < 4:
return None
# Label: try common keys; skip if value is a nested dict
raw_label = (item.get("label") or item.get("name") or item.get("type") or "主内容区")
label = raw_label if isinstance(raw_label, str) else "主内容区"
x1, y1, x2, y2 = (int(v) for v in bbox[:4])
return Region(label, x1, y1, x2, y2)
# Try single JSON object first (default prompt returns one object)
obj_start, obj_end = clean.find("{"), clean.rfind("}")
if obj_start != -1 and obj_end > obj_start:
try:
obj = json.loads(clean[obj_start : obj_end + 1])
if isinstance(obj, dict):
r = _region_from_dict(obj)
if r:
return [r]
except json.JSONDecodeError:
pass
# Fallback to JSON array
arr_start, arr_end = clean.find("["), clean.rfind("]")
if arr_start == -1 or arr_end <= arr_start:
return []
try:
items = json.loads(clean[arr_start : arr_end + 1])
except json.JSONDecodeError:
return []
regions: list[Region] = []
for item in items:
if isinstance(item, dict):
r = _region_from_dict(item)
if r:
regions.append(r)
return regions
def detect_regions(
image_path: Path,
api_key: str | None = None,
model: str = DEFAULT_MODEL,
prompt: str = _DEFAULT_PROMPT,
api_max_side: int = 1024,
) -> tuple[list[Region], str]:
"""Call Qwen VL to detect main regions.
The image is downscaled to *api_max_side* before the API call for speed
and cost efficiency. Returned ``Region`` coordinates are always mapped
back to the **original** image pixel space, so ``crop_and_save`` will
produce full-resolution output.
Parameters
----------
api_max_side:
Maximum side length (px) of the image sent to the API.
Increase for very large originals where detection needs more detail.
Returns
-------
regions : list[Region]
Bounding boxes in **original** image coordinates.
raw_response : str
Full model text (for debugging).
"""
from openai import OpenAI
from PIL import Image
key = api_key or _get_api_key()
if not key:
raise RuntimeError(
"DASHSCOPE_API_KEY not set. "
"Add it to the project .env or set the environment variable."
)
# ── 1. Original dimensions ────────────────────────────────────────────
with Image.open(image_path) as img:
orig_w, orig_h = img.size
# ── 2. Downscale for API; remember api dims for coordinate mapping ─────
b64, api_w, api_h = _encode_image_for_api(image_path, max_side=api_max_side)
logger.info(
"Calling %s on %s orig=%dx%d → api=%dx%d (scale=%.3f)",
model, image_path.name, orig_w, orig_h, api_w, api_h, api_w / orig_w,
)
# ── 3. API call ───────────────────────────────────────────────────────
client = OpenAI(api_key=key, base_url=_get_base_url())
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{b64}"},
},
{"type": "text", "text": prompt},
],
}
],
)
raw = response.choices[0].message.content or ""
logger.debug("Qwen VL raw response:\n%s", raw)
# ── 4. Parse — coordinates are in api-image space ─────────────────────
regions = _parse_grounding_tokens(raw, api_w, api_h)
if regions:
logger.info("Parsed %d region(s) from grounding tokens", len(regions))
else:
regions = _parse_json_regions(raw)
if regions:
logger.info("Parsed %d region(s) from JSON fallback", len(regions))
if not regions:
logger.warning("No regions parsed from response:\n%s", raw[:400])
return [], raw
# ── 5. Scale coordinates back to original image space ─────────────────
sx, sy = orig_w / api_w, orig_h / api_h
original_regions = [
Region(r.label,
round(r.x1 * sx), round(r.y1 * sy),
round(r.x2 * sx), round(r.y2 * sy))
for r in regions
]
logger.info(
"Coordinates remapped api(%dx%d) → orig(%dx%d)",
api_w, api_h, orig_w, orig_h,
)
return original_regions, raw
def merge_regions(regions: list[Region], label: str = "主内容区") -> Region:
"""Return the union bounding box of all regions as a single Region."""
if not regions:
raise ValueError("Cannot merge empty region list")
x1 = min(r.x1 for r in regions)
y1 = min(r.y1 for r in regions)
x2 = max(r.x2 for r in regions)
y2 = max(r.y2 for r in regions)
return Region(label, x1, y1, x2, y2)
def crop_and_save(
image_path: Path,
regions: list[Region],
output_dir: Path,
) -> list[dict]:
"""Crop each region and save as PNG.
Returns
-------
list of dicts with keys: label, bbox, path
"""
from PIL import Image
output_dir.mkdir(parents=True, exist_ok=True)
results: list[dict] = []
with Image.open(image_path) as img:
img_w, img_h = img.size
for i, region in enumerate(regions, start=1):
# Clamp to image bounds
x1 = max(0, region.x1)
y1 = max(0, region.y1)
x2 = min(img_w, region.x2)
y2 = min(img_h, region.y2)
if x2 <= x1 or y2 <= y1:
logger.warning("Skipping zero-area region: %s", region.label)
continue
cropped = img.crop((x1, y1, x2, y2))
safe_name = re.sub(r"[^\w\u4e00-\u9fff-]", "_", region.label)[:40]
out_path = output_dir / f"{i:02d}_{safe_name}.png"
cropped.save(out_path)
logger.info("Saved region [%s] → %s", region.label, out_path.name)
results.append({
"label": region.label,
"bbox": [x1, y1, x2, y2],
"path": str(out_path),
})
return results