373 lines
12 KiB
Python
373 lines
12 KiB
Python
"""Detect main regions in a label image via Qwen2.5-VL (DashScope).
|
||
|
||
Workflow
|
||
--------
|
||
1. Read the original image; record its exact dimensions (orig_w × orig_h).
|
||
2. Downscale a copy to fit within ``api_max_side`` for the API call
|
||
(faster upload, lower token cost). Record api_w × api_h.
|
||
3. Send the downscaled image to Qwen VL.
|
||
4. Parse the response coordinates (which are relative to the api image):
|
||
a. Qwen2.5-VL grounding tokens <|box_start|>(x1,y1),(x2,y2)<|box_end|>
|
||
– normalised to [0, 1000] of the *api* image.
|
||
b. Fallback: JSON array ``[{"label": "...", "bbox": [x1,y1,x2,y2]}, ...]``
|
||
– pixel values in the *api* image coordinate space.
|
||
5. Scale coordinates back to the original image space:
|
||
x_orig = round(x_api * orig_w / api_w)
|
||
6. Crop from the **original** high-resolution file → full-quality output.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
from pathlib import Path
|
||
from typing import NamedTuple
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_DASHSCOPE_BASE_URL_DEFAULT = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
# Model alias: the 7B variant is faster ("flash"); swap for 72B for higher accuracy
|
||
DEFAULT_MODEL = "qwen2.5-vl-7b-instruct"
|
||
|
||
_GROUNDING_RE = re.compile(
|
||
r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
|
||
r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|>",
|
||
re.DOTALL,
|
||
)
|
||
|
||
_DEFAULT_PROMPT = (
|
||
"请检测图像中食品包装标签的所有主要内容区域(如:主产品信息表格、"
|
||
"营养成分表、标题、配料表、厂商信息、条码区等)。"
|
||
"以JSON列表输出,格式为:\n"
|
||
'[{"label": "区域名称", "bbox": [x1, y1, x2, y2]}, ...]'
|
||
"\n坐标为实际像素值(整数),原点在左上角。"
|
||
)
|
||
|
||
|
||
class Region(NamedTuple):
|
||
label: str
|
||
x1: int
|
||
y1: int
|
||
x2: int
|
||
y2: int
|
||
|
||
@property
|
||
def width(self) -> int:
|
||
return self.x2 - self.x1
|
||
|
||
@property
|
||
def height(self) -> int:
|
||
return self.y2 - self.y1
|
||
|
||
|
||
def _read_dotenv(path: Path) -> dict[str, str]:
|
||
"""Parse a simple KEY=VALUE .env file into a dict."""
|
||
result: dict[str, str] = {}
|
||
if not path.exists():
|
||
return result
|
||
for raw in path.read_text(encoding="utf-8").splitlines():
|
||
line = raw.strip()
|
||
if not line or line.startswith("#") or "=" not in line:
|
||
continue
|
||
k, v = line.split("=", 1)
|
||
result[k.strip()] = v.strip().strip('"').strip("'")
|
||
return result
|
||
|
||
|
||
def _load_env() -> dict[str, str]:
|
||
"""Merge .env files (project root → parent → home) into a single dict."""
|
||
merged: dict[str, str] = {}
|
||
for p in [
|
||
Path(__file__).resolve().parents[2] / ".env",
|
||
Path(__file__).resolve().parents[3] / ".env",
|
||
Path.home() / ".env",
|
||
]:
|
||
merged.update(_read_dotenv(p))
|
||
return merged
|
||
|
||
|
||
def _get_api_key() -> str:
|
||
"""Read DASHSCOPE_API_KEY from env vars then .env files."""
|
||
val = os.environ.get("DASHSCOPE_API_KEY", "").strip()
|
||
if val:
|
||
return val
|
||
return _load_env().get("DASHSCOPE_API_KEY", "")
|
||
|
||
|
||
def _get_base_url() -> str:
|
||
"""Read DASHSCOPE_BASE_URL from env vars then .env files."""
|
||
val = os.environ.get("DASHSCOPE_BASE_URL", "").strip()
|
||
if val:
|
||
return val
|
||
return _load_env().get("DASHSCOPE_BASE_URL", _DASHSCOPE_BASE_URL_DEFAULT)
|
||
|
||
|
||
def _encode_image_for_api(
|
||
image_path: Path,
|
||
max_side: int = 1024,
|
||
) -> tuple[str, int, int]:
|
||
"""Downscale image to fit within *max_side* × *max_side*, encode as PNG base64.
|
||
|
||
Returns
|
||
-------
|
||
b64 : str
|
||
Base64-encoded PNG of the (possibly resized) image.
|
||
api_w : int
|
||
Width of the image that was actually sent to the API.
|
||
api_h : int
|
||
Height of the image that was actually sent to the API.
|
||
"""
|
||
import io
|
||
from PIL import Image
|
||
|
||
with Image.open(image_path) as img:
|
||
# Convert to RGB so PNG encoding always works
|
||
img = img.convert("RGB")
|
||
orig_w, orig_h = img.size
|
||
|
||
if max(orig_w, orig_h) > max_side:
|
||
scale = max_side / max(orig_w, orig_h)
|
||
api_w = max(1, round(orig_w * scale))
|
||
api_h = max(1, round(orig_h * scale))
|
||
api_img = img.resize((api_w, api_h), Image.LANCZOS)
|
||
else:
|
||
api_w, api_h = orig_w, orig_h
|
||
api_img = img
|
||
|
||
buf = io.BytesIO()
|
||
api_img.save(buf, format="PNG")
|
||
|
||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||
return b64, api_w, api_h
|
||
|
||
|
||
def _parse_grounding_tokens(text: str, api_w: int, api_h: int) -> list[Region]:
|
||
"""Parse <|box_start|>(x1,y1),(x2,y2)<|box_end|> tokens.
|
||
|
||
Qwen2.5-VL normalises coordinates to [0, 1000] of the *api* image.
|
||
Returns pixel coordinates in the api image space.
|
||
"""
|
||
regions: list[Region] = []
|
||
for m in _GROUNDING_RE.finditer(text):
|
||
label = m.group(1).strip()
|
||
x1 = round(int(m.group(2)) * api_w / 1000)
|
||
y1 = round(int(m.group(3)) * api_h / 1000)
|
||
x2 = round(int(m.group(4)) * api_w / 1000)
|
||
y2 = round(int(m.group(5)) * api_h / 1000)
|
||
regions.append(Region(label, x1, y1, x2, y2))
|
||
return regions
|
||
|
||
|
||
def _parse_json_regions(text: str) -> list[Region]:
|
||
"""Fallback: extract bbox from a JSON object or array in the response."""
|
||
clean = re.sub(r"<\|[^|]+\|>", "", text)
|
||
clean = re.sub(r"```[a-z]*", "", clean).strip("`").strip()
|
||
|
||
def _extract_bbox(item: dict) -> list | None:
|
||
"""Try multiple known bbox key names, including nested dicts."""
|
||
for key in ("bbox", "bbox_2d", "box", "coordinates", "bounding_box"):
|
||
v = item.get(key)
|
||
if isinstance(v, (list, tuple)) and len(v) >= 4:
|
||
return list(v)
|
||
# e.g. {"label": {"bbox_2d": [...]}}
|
||
if isinstance(v, dict):
|
||
inner = _extract_bbox(v)
|
||
if inner:
|
||
return inner
|
||
# Recurse into all dict values
|
||
for v in item.values():
|
||
if isinstance(v, dict):
|
||
inner = _extract_bbox(v)
|
||
if inner:
|
||
return inner
|
||
return None
|
||
|
||
def _region_from_dict(item: dict) -> Region | None:
|
||
bbox = _extract_bbox(item)
|
||
if not bbox or len(bbox) < 4:
|
||
return None
|
||
# Label: try common keys; skip if value is a nested dict
|
||
raw_label = (item.get("label") or item.get("name") or item.get("type") or "主内容区")
|
||
label = raw_label if isinstance(raw_label, str) else "主内容区"
|
||
x1, y1, x2, y2 = (int(v) for v in bbox[:4])
|
||
return Region(label, x1, y1, x2, y2)
|
||
|
||
# Try single JSON object first (default prompt returns one object)
|
||
obj_start, obj_end = clean.find("{"), clean.rfind("}")
|
||
if obj_start != -1 and obj_end > obj_start:
|
||
try:
|
||
obj = json.loads(clean[obj_start : obj_end + 1])
|
||
if isinstance(obj, dict):
|
||
r = _region_from_dict(obj)
|
||
if r:
|
||
return [r]
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Fallback to JSON array
|
||
arr_start, arr_end = clean.find("["), clean.rfind("]")
|
||
if arr_start == -1 or arr_end <= arr_start:
|
||
return []
|
||
try:
|
||
items = json.loads(clean[arr_start : arr_end + 1])
|
||
except json.JSONDecodeError:
|
||
return []
|
||
regions: list[Region] = []
|
||
for item in items:
|
||
if isinstance(item, dict):
|
||
r = _region_from_dict(item)
|
||
if r:
|
||
regions.append(r)
|
||
return regions
|
||
|
||
|
||
def detect_regions(
|
||
image_path: Path,
|
||
api_key: str | None = None,
|
||
model: str = DEFAULT_MODEL,
|
||
prompt: str = _DEFAULT_PROMPT,
|
||
api_max_side: int = 1024,
|
||
) -> tuple[list[Region], str]:
|
||
"""Call Qwen VL to detect main regions.
|
||
|
||
The image is downscaled to *api_max_side* before the API call for speed
|
||
and cost efficiency. Returned ``Region`` coordinates are always mapped
|
||
back to the **original** image pixel space, so ``crop_and_save`` will
|
||
produce full-resolution output.
|
||
|
||
Parameters
|
||
----------
|
||
api_max_side:
|
||
Maximum side length (px) of the image sent to the API.
|
||
Increase for very large originals where detection needs more detail.
|
||
|
||
Returns
|
||
-------
|
||
regions : list[Region]
|
||
Bounding boxes in **original** image coordinates.
|
||
raw_response : str
|
||
Full model text (for debugging).
|
||
"""
|
||
from openai import OpenAI
|
||
from PIL import Image
|
||
|
||
key = api_key or _get_api_key()
|
||
if not key:
|
||
raise RuntimeError(
|
||
"DASHSCOPE_API_KEY not set. "
|
||
"Add it to the project .env or set the environment variable."
|
||
)
|
||
|
||
# ── 1. Original dimensions ────────────────────────────────────────────
|
||
with Image.open(image_path) as img:
|
||
orig_w, orig_h = img.size
|
||
|
||
# ── 2. Downscale for API; remember api dims for coordinate mapping ─────
|
||
b64, api_w, api_h = _encode_image_for_api(image_path, max_side=api_max_side)
|
||
logger.info(
|
||
"Calling %s on %s orig=%dx%d → api=%dx%d (scale=%.3f)",
|
||
model, image_path.name, orig_w, orig_h, api_w, api_h, api_w / orig_w,
|
||
)
|
||
|
||
# ── 3. API call ───────────────────────────────────────────────────────
|
||
client = OpenAI(api_key=key, base_url=_get_base_url())
|
||
response = client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
||
},
|
||
{"type": "text", "text": prompt},
|
||
],
|
||
}
|
||
],
|
||
)
|
||
|
||
raw = response.choices[0].message.content or ""
|
||
logger.debug("Qwen VL raw response:\n%s", raw)
|
||
|
||
# ── 4. Parse — coordinates are in api-image space ─────────────────────
|
||
regions = _parse_grounding_tokens(raw, api_w, api_h)
|
||
if regions:
|
||
logger.info("Parsed %d region(s) from grounding tokens", len(regions))
|
||
else:
|
||
regions = _parse_json_regions(raw)
|
||
if regions:
|
||
logger.info("Parsed %d region(s) from JSON fallback", len(regions))
|
||
|
||
if not regions:
|
||
logger.warning("No regions parsed from response:\n%s", raw[:400])
|
||
return [], raw
|
||
|
||
# ── 5. Scale coordinates back to original image space ─────────────────
|
||
sx, sy = orig_w / api_w, orig_h / api_h
|
||
original_regions = [
|
||
Region(r.label,
|
||
round(r.x1 * sx), round(r.y1 * sy),
|
||
round(r.x2 * sx), round(r.y2 * sy))
|
||
for r in regions
|
||
]
|
||
logger.info(
|
||
"Coordinates remapped api(%dx%d) → orig(%dx%d)",
|
||
api_w, api_h, orig_w, orig_h,
|
||
)
|
||
return original_regions, raw
|
||
|
||
|
||
def merge_regions(regions: list[Region], label: str = "主内容区") -> Region:
|
||
"""Return the union bounding box of all regions as a single Region."""
|
||
if not regions:
|
||
raise ValueError("Cannot merge empty region list")
|
||
x1 = min(r.x1 for r in regions)
|
||
y1 = min(r.y1 for r in regions)
|
||
x2 = max(r.x2 for r in regions)
|
||
y2 = max(r.y2 for r in regions)
|
||
return Region(label, x1, y1, x2, y2)
|
||
|
||
|
||
def crop_and_save(
|
||
image_path: Path,
|
||
regions: list[Region],
|
||
output_dir: Path,
|
||
) -> list[dict]:
|
||
"""Crop each region and save as PNG.
|
||
|
||
Returns
|
||
-------
|
||
list of dicts with keys: label, bbox, path
|
||
"""
|
||
from PIL import Image
|
||
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
results: list[dict] = []
|
||
|
||
with Image.open(image_path) as img:
|
||
img_w, img_h = img.size
|
||
for i, region in enumerate(regions, start=1):
|
||
# Clamp to image bounds
|
||
x1 = max(0, region.x1)
|
||
y1 = max(0, region.y1)
|
||
x2 = min(img_w, region.x2)
|
||
y2 = min(img_h, region.y2)
|
||
if x2 <= x1 or y2 <= y1:
|
||
logger.warning("Skipping zero-area region: %s", region.label)
|
||
continue
|
||
cropped = img.crop((x1, y1, x2, y2))
|
||
safe_name = re.sub(r"[^\w\u4e00-\u9fff-]", "_", region.label)[:40]
|
||
out_path = output_dir / f"{i:02d}_{safe_name}.png"
|
||
cropped.save(out_path)
|
||
logger.info("Saved region [%s] → %s", region.label, out_path.name)
|
||
results.append({
|
||
"label": region.label,
|
||
"bbox": [x1, y1, x2, y2],
|
||
"path": str(out_path),
|
||
})
|
||
|
||
return results
|