Files
ZLD_POC/segment_ai_regions.py
2026-04-15 17:18:49 +08:00

193 lines
6.0 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import json
import re
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
WORKDIR = Path("/Users/icemilk/Workspace/zld_POC")
TEXT_BLOCKS = WORKDIR / "【2026-04-09】端午-背标-天问.text_blocks.json"
IMAGE_PATH = WORKDIR / "1.jpg"
OUT_IMAGE = WORKDIR / "【2026-04-09】端午-背标-天问.region_overlay.png"
OUT_JSON = WORKDIR / "【2026-04-09】端午-背标-天问.regions.json"
PAGE_WIDTH_PT = 1363.4
PAGE_HEIGHT_PT = 942.06
def load_blocks() -> list[dict]:
return json.loads(TEXT_BLOCKS.read_text(encoding="utf-8"))
def overlaps(a: tuple[float, float, float, float], b: tuple[float, float, float, float]) -> bool:
ax0, ay0, ax1, ay1 = a
bx0, by0, bx1, by1 = b
return not (ax1 < bx0 or bx1 < ax0 or ay1 < by0 or by1 < ay0)
def expanded_box(block: dict, pad_x: float = 24.0, pad_y: float = 18.0) -> tuple[float, float, float, float]:
return (
block["x0_pt"] - pad_x,
block["top_pt"] - pad_y,
block["x1_pt"] + pad_x,
block["bottom_pt"] + pad_y,
)
def region_bbox(blocks: list[dict], margin_x: float = 20.0, margin_y: float = 14.0) -> dict:
x0 = min(b["x0_pt"] for b in blocks) - margin_x
y0 = min(b["top_pt"] for b in blocks) - margin_y
x1 = max(b["x1_pt"] for b in blocks) + margin_x
y1 = max(b["bottom_pt"] for b in blocks) + margin_y
return {"x0_pt": max(0, x0), "top_pt": max(0, y0), "x1_pt": x1, "bottom_pt": y1}
def classify(region: dict) -> str:
return region["label"]
def to_px(x_pt: float, y_pt: float, img_w: int, img_h: int) -> tuple[int, int]:
return (
round(x_pt / PAGE_WIDTH_PT * img_w),
round(y_pt / PAGE_HEIGHT_PT * img_h),
)
def match_any(text: str, patterns: list[str]) -> bool:
return any(p in text for p in patterns)
def semantic_groups(blocks: list[dict]) -> list[tuple[str, list[dict]]]:
groups: list[tuple[str, list[dict]]] = []
defs = [
(
"header_basic",
lambda b: b["top_pt"] < 140 and match_any(
b["text"], ["品名", "成品尺寸", "材质", "工艺", "盒型"]
),
),
(
"header_rules",
lambda b: b["top_pt"] < 140 and match_any(
b["text"], ["日期", "设计比例", "字体大小规范", "常规内容最小高度", "净含量最小高度", "条形码"]
),
),
(
"workflow_notes",
lambda b: b["x0_pt"] > 1180 or match_any(b["text"], ["签稿流程", "设计师", "品控", "安冬梅"]),
),
(
"version_info",
lambda b: "版本号" in b["text"],
),
(
"upper_main",
lambda b: 250 <= b["top_pt"] <= 540 and b["x0_pt"] < 820 and not match_any(b["text"], ["营养成分表"]),
),
(
"cooking_box",
lambda b: 560 <= b["top_pt"] <= 650 and 500 <= b["x0_pt"] <= 680,
),
(
"seal_mark",
lambda b: 560 <= b["top_pt"] <= 650 and 680 < b["x0_pt"] <= 760,
),
(
"nutrition_table",
lambda b: 520 <= b["top_pt"] <= 670 and b["x0_pt"] < 960,
),
(
"lower_left_details",
lambda b: 590 <= b["top_pt"] <= 705 and b["x0_pt"] < 520,
),
(
"date_box",
lambda b: match_any(b["text"], ["生产日期", "保质期到期日"]) and b["x0_pt"] > 650,
),
(
"bottom_title",
lambda b: b["top_pt"] > 705 and b["x0_pt"] < 980,
),
]
remaining = blocks[:]
for label, predicate in defs:
matched = [b for b in remaining if predicate(b)]
if matched:
groups.append((label, matched))
ids = {id(b) for b in matched}
remaining = [b for b in remaining if id(b) not in ids]
if remaining:
# Keep any leftovers visible so we can inspect missed areas.
leftovers = [b for b in remaining if re.search(r"\S", b["text"])]
if leftovers:
groups.append(("unassigned", leftovers))
return groups
def build_regions(blocks: list[dict]) -> list[dict]:
regions = []
for idx, (label, group) in enumerate(semantic_groups(blocks), start=1):
bbox = region_bbox(group)
sample = " ".join(b["text"] for b in sorted(group, key=lambda b: (b["top_pt"], b["x0_pt"]))[:4])
region = {
"region_id": idx,
"label": label,
"bbox": bbox,
"block_count": len(group),
"sample_text": sample[:120],
}
regions.append(region)
return regions
def draw_regions(regions: list[dict]) -> None:
image = Image.open(IMAGE_PATH).convert("RGBA")
draw = ImageDraw.Draw(image, "RGBA")
colors = [
(255, 99, 71, 255),
(65, 105, 225, 255),
(50, 205, 50, 255),
(255, 165, 0, 255),
(148, 0, 211, 255),
(0, 191, 255, 255),
(220, 20, 60, 255),
(46, 139, 87, 255),
]
font = ImageFont.load_default()
for i, region in enumerate(regions):
color = colors[i % len(colors)]
bbox = region["bbox"]
x0, y0 = to_px(bbox["x0_pt"], bbox["top_pt"], image.width, image.height)
x1, y1 = to_px(bbox["x1_pt"], bbox["bottom_pt"], image.width, image.height)
draw.rectangle([x0, y0, x1, y1], outline=color[:3], width=5)
tag = f"R{region['region_id']} {region['label']}"
tx0 = max(8, x0 + 8)
ty0 = max(8, y0 + 8)
tw, th = draw.textbbox((tx0, ty0), tag, font=font)[2:]
draw.rectangle([tx0 - 4, ty0 - 2, tx0 + tw + 4, ty0 + th + 2], fill=(255, 255, 255, 220))
draw.text((tx0, ty0), tag, fill=(0, 0, 0, 255), font=font)
image.save(OUT_IMAGE)
def main() -> None:
blocks = load_blocks()
regions = build_regions(blocks)
OUT_JSON.write_text(json.dumps(regions, ensure_ascii=False, indent=2), encoding="utf-8")
draw_regions(regions)
print(OUT_IMAGE)
print(OUT_JSON)
print(f"regions={len(regions)}")
if __name__ == "__main__":
main()