Files
HSAP/platform/as_platform/audit/preview.py

383 lines
13 KiB
Python
Raw Normal View History

"""审核单关联的送标/回传数据:解析范围、列举图像、渲染 GT 叠加。"""
from __future__ import annotations
import hashlib
import io
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Sequence
import yaml
from PIL import Image, ImageDraw, ImageFont
from as_platform.data.batch import IMG_EXTS
from as_platform.data.core import load_wf, proj_root, resolve_pack_dir
IMAGE_EXTS = tuple(ext.lower() for ext in IMG_EXTS) + tuple(ext.upper() for ext in IMG_EXTS if ext.islower())
@dataclass(frozen=True)
class ImageRef:
image_path: Path
label_path: Path | None
batch: str
location: str
split: str
@property
def id(self) -> str:
key = f"{self.image_path}|{self.label_path or ''}"
return hashlib.sha256(key.encode()).hexdigest()[:16]
def _find_image(images_dir: Path, stem: str) -> Path | None:
for ext in IMAGE_EXTS:
p = images_dir / f"{stem}{ext}"
if p.is_file():
return p
return None
def _parse_yolo_line(line: str) -> dict[str, Any] | None:
parts = line.strip().split()
if len(parts) < 5:
return None
try:
class_id = int(float(parts[0]))
cx, cy, w, h = map(float, parts[1:5])
except Exception:
return None
keypoints: list[tuple[float, float, float]] = []
rest = parts[5:]
if len(rest) >= 3:
n = len(rest) // 3
for i in range(n):
keypoints.append((float(rest[i * 3]), float(rest[i * 3 + 1]), float(rest[i * 3 + 2])))
return {"class_id": class_id, "bbox": (cx, cy, w, h), "keypoints": keypoints}
def parse_label_file(label_path: Path) -> list[dict[str, Any]]:
if not label_path.is_file():
return []
out: list[dict[str, Any]] = []
for raw in label_path.read_text(encoding="utf-8", errors="ignore").splitlines():
parsed = _parse_yolo_line(raw)
if parsed is not None:
out.append(parsed)
return out
def _load_font(size: int) -> ImageFont.ImageFont:
for p in (
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
):
if Path(p).exists():
try:
return ImageFont.truetype(p, size)
except Exception:
continue
return ImageFont.load_default()
def _yolo_bbox_to_xyxy(bbox: tuple[float, float, float, float], width: int, height: int) -> tuple[int, int, int, int]:
cx, cy, bw, bh = bbox
x1 = int((cx - bw / 2.0) * width)
y1 = int((cy - bh / 2.0) * height)
x2 = int((cx + bw / 2.0) * width)
y2 = int((cy + bh / 2.0) * height)
return (
max(0, min(width - 1, x1)),
max(0, min(height - 1, y1)),
max(0, min(width - 1, x2)),
max(0, min(height - 1, y2)),
)
def render_overlay(
image_path: Path,
label_path: Path | None,
class_names: dict[int, str],
*,
max_size: int | None = None,
) -> bytes:
with Image.open(image_path) as im:
base = im.convert("RGB")
if max_size and max(base.size) > max_size:
base.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
over = base.copy()
draw = ImageDraw.Draw(over)
w, h = over.size
font = _load_font(max(12, min(18, w // 40)))
anns = parse_label_file(label_path) if label_path else []
palette = [
(220, 20, 60),
(30, 144, 255),
(50, 205, 50),
(255, 165, 0),
(186, 85, 211),
(0, 206, 209),
]
for ann in anns:
cid = ann["class_id"]
color = palette[cid % len(palette)]
x1, y1, x2, y2 = _yolo_bbox_to_xyxy(ann["bbox"], w, h)
draw.rectangle((x1, y1, x2, y2), outline=color, width=max(2, w // 320))
label = class_names.get(cid, f"class_{cid}")
draw.text((x1 + 2, max(0, y1 - 16)), label, fill=color, font=font)
for kx, ky, kv in ann.get("keypoints") or []:
if kv <= 0:
continue
px, py = int(kx * w), int(ky * h)
r = max(2, w // 400)
draw.ellipse((px - r, py - r, px + r, py + r), outline=color, fill=color)
buf = io.BytesIO()
over.save(buf, format="JPEG", quality=88)
return buf.getvalue()
def _collect_from_split(batch_dir: Path, split: str, *, batch: str, location: str) -> list[ImageRef]:
if split:
images_dir = batch_dir / "images" / split
labels_dir = batch_dir / "labels" / split
else:
images_dir = batch_dir / "images"
labels_dir = batch_dir / "labels"
out: list[ImageRef] = []
if not images_dir.is_dir():
return out
label_stems: dict[str, Path] = {}
if labels_dir.is_dir():
for lp in labels_dir.glob("*.txt"):
label_stems[lp.stem] = lp
seen: set[str] = set()
for p in sorted(images_dir.iterdir()):
if not p.is_file() or p.suffix.lower() not in {e.lower() for e in IMG_EXTS}:
continue
stem = p.stem
seen.add(stem)
out.append(
ImageRef(
image_path=p.resolve(),
label_path=label_stems.get(stem),
batch=batch,
location=location,
split=split or "root",
)
)
for stem, lp in sorted(label_stems.items()):
if stem in seen:
continue
img = _find_image(images_dir, stem)
if img:
out.append(
ImageRef(
image_path=img.resolve(),
label_path=lp.resolve(),
batch=batch,
location=location,
split=split or "root",
)
)
return out
def collect_batch_images(batch_dir: Path, *, batch: str, location: str) -> list[ImageRef]:
if not batch_dir.is_dir():
return []
refs: list[ImageRef] = []
for split in ("train", "val", "test", ""):
refs.extend(_collect_from_split(batch_dir, split, batch=batch, location=location))
# 去重flat 与 train 可能重叠)
dedup: dict[str, ImageRef] = {}
for ref in refs:
dedup[str(ref.image_path)] = ref
return sorted(dedup.values(), key=lambda r: (r.batch, r.split, r.image_path.name))
def _dms_task_cfg(root: Path, wf: dict, task: str) -> tuple[dict, dict]:
reg_path = root / wf["projects"]["dms"]["registry"]
reg = yaml.safe_load(reg_path.read_text(encoding="utf-8"))
if task not in reg.get("tasks", {}):
raise ValueError(f"未知 task: {task}")
return reg, reg["tasks"][task]
def resolve_approval_scope(action: str, params: dict[str, Any]) -> dict[str, Any]:
"""解析审核单对应的数据目录与类别名。"""
p = params or {}
wf = load_wf()
if action in ("build_dms", "register_batch"):
task = p.get("task")
if not task:
raise ValueError("缺少 task 参数")
root = proj_root(wf, "dms")
reg, tcfg = _dms_task_cfg(root, wf, task)
src_sub = (reg.get("ingest") or {}).get("sources_subdir", "sources")
pack = p.get("pack") or "dms_v2"
batches: list[dict[str, Any]] = []
location = p.get("location", "inbox")
if action == "build_dms" and p.get("all_sources"):
location = "sources"
if action == "build_dms" and p.get("batch") and not p.get("all_sources"):
location = "inbox"
if location == "inbox":
batch_name = p.get("batch")
if batch_name:
batches.append({"path": root / "inbox" / task / batch_name, "batch": batch_name, "location": "inbox"})
else:
ib = root / "inbox" / task
if ib.is_dir():
for d in sorted(ib.iterdir()):
if d.is_dir() and not d.name.startswith("."):
batches.append({"path": d, "batch": d.name, "location": "inbox"})
else:
pack_dir = resolve_pack_dir("dms", root, wf, pack)
src_root = pack_dir / tcfg.get("task_dir", task) / src_sub
batch_name = p.get("batch")
if batch_name:
batches.append({"path": src_root / batch_name, "batch": batch_name, "location": "sources"})
elif src_root.is_dir():
for d in sorted(src_root.iterdir()):
if d.is_dir() and d.name not in ("_ingested", "_merged") and not d.name.startswith("."):
batches.append({"path": d, "batch": d.name, "location": "sources"})
names = tcfg.get("names") or {}
class_names = {int(k): v for k, v in names.items()} if isinstance(names, dict) else {i: n for i, n in enumerate(names)}
return {
"project": "dms",
"task": task,
"pack": pack,
"scope_label": f"DMS · {task} · {pack}" + (" · 全部 sources" if location == "sources" and not p.get("batch") else ""),
"class_names": class_names,
"batches": batches,
}
if action == "delivery_ingest":
data_path = (p.get("data_path") or "").strip()
if not data_path:
raise ValueError("缺少 data_path 参数")
src = Path(data_path)
project = p.get("project") or "dms"
task = p.get("task") or ""
batch_name = p.get("batch_name") or src.name
scope_label = f"数据送标入湖 · {project}"
if task:
scope_label += f" · {task}"
scope_label += f" · {batch_name}"
return {
"project": project,
"task": task or None,
"pack": None,
"scope_label": scope_label,
"class_names": {},
"batches": [
{
"path": src,
"batch": batch_name,
"location": "delivery",
}
],
}
if action in ("train_dms", "promote_dms", "eval_dms"):
task = p.get("task")
if not task:
raise ValueError("缺少 task 参数")
root = proj_root(wf, "dms")
_, tcfg = _dms_task_cfg(root, wf, task)
pack = p.get("pack") or "dms_v2"
pack_dir = resolve_pack_dir("dms", root, wf, pack)
task_dir = pack_dir / tcfg.get("task_dir", task)
batches = [{"path": task_dir, "batch": f"{pack}/{task}", "location": "pack"}]
names = tcfg.get("names") or {}
class_names = {int(k): v for k, v in names.items()} if isinstance(names, dict) else {i: n for i, n in enumerate(names)}
label = "模型晋级" if action == "promote_dms" else ("评估" if action == "eval_dms" else "训练")
return {
"project": "dms",
"task": task,
"pack": pack,
"scope_label": f"DMS · {task} · {pack} · pack 数据({label}",
"class_names": class_names,
"batches": batches,
}
raise ValueError(f"暂不支持预览的动作: {action}")
def list_scope_images(scope: dict[str, Any], *, offset: int = 0, limit: int = 60) -> dict[str, Any]:
all_refs: list[ImageRef] = []
for b in scope.get("batches") or []:
batch_dir = Path(b["path"])
all_refs.extend(
collect_batch_images(batch_dir, batch=b.get("batch", batch_dir.name), location=b.get("location", ""))
)
dedup: dict[str, ImageRef] = {str(r.image_path): r for r in all_refs}
ordered = sorted(dedup.values(), key=lambda r: (r.batch, r.split, r.image_path.name))
total = len(ordered)
page = ordered[offset : offset + limit]
items = []
for ref in page:
anns = parse_label_file(ref.label_path) if ref.label_path else []
items.append(
{
"id": ref.id,
"batch": ref.batch,
"location": ref.location,
"split": ref.split,
"filename": ref.image_path.name,
"has_label": ref.label_path is not None and ref.label_path.is_file(),
"box_count": len(anns),
"missing_label": ref.label_path is None or not ref.label_path.is_file(),
}
)
return {"total": total, "offset": offset, "limit": limit, "items": items}
def find_image_ref(scope: dict[str, Any], image_id: str) -> ImageRef | None:
"""线性查找;审核场景批次有限,可接受。"""
batches = scope.get("batches") or []
for b in batches:
batch_dir = Path(b["path"])
refs = collect_batch_images(batch_dir, batch=b.get("batch", batch_dir.name), location=b.get("location", ""))
for ref in refs:
if ref.id == image_id:
return ref
return None
def image_to_item(ref: ImageRef) -> dict[str, Any]:
anns = parse_label_file(ref.label_path) if ref.label_path else []
return {
"id": ref.id,
"batch": ref.batch,
"location": ref.location,
"split": ref.split,
"filename": ref.image_path.name,
"has_label": ref.label_path is not None and ref.label_path.is_file(),
"box_count": len(anns),
"missing_label": ref.label_path is None or not ref.label_path.is_file(),
"annotations": [
{
"class_id": a["class_id"],
"class_name": None,
"bbox": a["bbox"],
"keypoints": a.get("keypoints") or [],
}
for a in anns
],
}