96 lines
3.2 KiB
Python
96 lines
3.2 KiB
Python
|
|
"""Batch-level cuboid 3D fit for quaternion_json."""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
||
|
|
|
||
|
|
|
||
|
|
def _task_id_for_image(image_path: Path, batch_dir: Path) -> str:
|
||
|
|
try:
|
||
|
|
rel = image_path.relative_to(batch_dir)
|
||
|
|
stem = rel.as_posix()
|
||
|
|
except ValueError:
|
||
|
|
stem = image_path.name
|
||
|
|
return hashlib.sha256(stem.encode()).hexdigest()[:16]
|
||
|
|
|
||
|
|
|
||
|
|
def _load_ls_cuboid_points(batch_dir: Path, stem: str) -> list[list[float]]:
|
||
|
|
ann_dir = batch_dir / "labels" / "ls_annotations"
|
||
|
|
if not ann_dir.is_dir():
|
||
|
|
return []
|
||
|
|
for p in ann_dir.glob("*.json"):
|
||
|
|
try:
|
||
|
|
data = json.loads(p.read_text(encoding="utf-8"))
|
||
|
|
except (OSError, json.JSONDecodeError):
|
||
|
|
continue
|
||
|
|
img = str(data.get("image") or "")
|
||
|
|
if stem in img or p.stem:
|
||
|
|
regions = data.get("result") or []
|
||
|
|
pts_list = []
|
||
|
|
for r in regions:
|
||
|
|
if r.get("type") != "cuboid":
|
||
|
|
continue
|
||
|
|
pts = list(r.get("points") or [])
|
||
|
|
if len(pts) >= 16:
|
||
|
|
pts_list.append(pts[:16])
|
||
|
|
if pts_list:
|
||
|
|
return pts_list
|
||
|
|
return []
|
||
|
|
|
||
|
|
|
||
|
|
def fit_batch(batch_dir: Path) -> dict[str, Any]:
|
||
|
|
from algorithms.adas_mono3d.fit_cuboid import fit_cuboid_detection
|
||
|
|
|
||
|
|
batch_dir = batch_dir.resolve()
|
||
|
|
qdir = batch_dir / "labels" / "quaternion_json"
|
||
|
|
if not qdir.is_dir():
|
||
|
|
raise ValueError(f"missing {qdir}")
|
||
|
|
|
||
|
|
updated = 0
|
||
|
|
fit_ok = 0
|
||
|
|
total = 0
|
||
|
|
for p in sorted(qdir.glob("*.json")):
|
||
|
|
data = json.loads(p.read_text(encoding="utf-8"))
|
||
|
|
K = data.get("K")
|
||
|
|
if not K:
|
||
|
|
continue
|
||
|
|
stem = data.get("image_stem") or p.stem
|
||
|
|
cuboid_pts_list = _load_ls_cuboid_points(batch_dir, stem)
|
||
|
|
new_dets = []
|
||
|
|
for i, det in enumerate(data.get("detections") or []):
|
||
|
|
det = dict(det)
|
||
|
|
if det.get("fit_ok"):
|
||
|
|
new_dets.append(det)
|
||
|
|
total += 1
|
||
|
|
fit_ok += 1
|
||
|
|
continue
|
||
|
|
class_name = str(det.get("class_name") or "car")
|
||
|
|
points = cuboid_pts_list[i] if i < len(cuboid_pts_list) else None
|
||
|
|
if not points:
|
||
|
|
box = det.get("box2d_xyxy") or []
|
||
|
|
if len(box) >= 4:
|
||
|
|
x1, y1, x2, y2 = box[:4]
|
||
|
|
points = [x1, y1, x2, y1, x1, y2, x2, y2, x1, y1, x2, y1, x1, y2, x2, y2]
|
||
|
|
if points:
|
||
|
|
fitted = fit_cuboid_detection(points, K, class_name)
|
||
|
|
det.update({k: v for k, v in fitted.items() if k != "box2d_xyxy" or "box2d_xyxy" not in det})
|
||
|
|
new_dets.append(det)
|
||
|
|
total += 1
|
||
|
|
if det.get("fit_ok"):
|
||
|
|
fit_ok += 1
|
||
|
|
data["detections"] = new_dets
|
||
|
|
data["num_detections"] = len(new_dets)
|
||
|
|
p.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
|
|
updated += 1
|
||
|
|
|
||
|
|
return {
|
||
|
|
"updated_files": updated,
|
||
|
|
"detections": total,
|
||
|
|
"fit_ok": fit_ok,
|
||
|
|
"fit_ok_ratio": fit_ok / max(total, 1),
|
||
|
|
}
|