Files
yolov26_3d/eval_tools/analysis/visualize_3d_badcases.py
2026-06-24 09:35:46 +08:00

724 lines
24 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Visualize 3D bad cases from analyze_3d_badcases.py results.
This script focuses on matched 3D samples and renders:
- full-frame overlays with GT / active detections / highlighted badcases
- per-example panels with crop, simple BEV, and a metrics sidebar
- an index.json for downstream browsing
"""
from __future__ import annotations
import argparse
import json
import math
import sys
from collections import Counter, defaultdict
from pathlib import Path
import cv2
import numpy as np
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from eval_tools.analysis.analyze_2d_fp_fn import build_config, class_name, parse_class_ids
from eval_tools.evaluator.evaluator import Evaluator
BOX_COLORS = {
"gt_all": (80, 220, 80),
"det_all": (150, 150, 150),
"target_gt": (40, 40, 255),
"target_det": (0, 215, 255),
"title_bg": (30, 30, 30),
"bev_gt": (40, 220, 80),
"bev_det": (0, 215, 255),
}
def parse_args():
parser = argparse.ArgumentParser(
description="Visualize 3D bad cases from analysis_report.json."
)
parser.add_argument(
"--analysis-report",
type=str,
required=True,
help="Path to analysis_report.json generated by analyze_3d_badcases.py",
)
parser.add_argument("--config", type=str, required=True, help="Path to YAML evaluation config")
parser.add_argument("--det-path", type=str, help="Detection results root directory")
parser.add_argument("--gt-path", type=str, help="Ground-truth labels root directory")
parser.add_argument("--path-depth", type=int, choices=[1, 2], help="Directory depth")
parser.add_argument(
"--det-format",
type=str,
choices=["auto", "json", "txt"],
help="Detection file format",
)
parser.add_argument(
"--gt-format",
type=str,
choices=["auto", "json", "txt"],
help="Ground-truth file format",
)
parser.add_argument("--img-width", type=int, help="Image width")
parser.add_argument("--img-height", type=int, help="Image height")
parser.add_argument(
"--coord-system",
type=str,
choices=["camera", "ego"],
help="Coordinate system used by the parser/evaluator",
)
parser.add_argument(
"--iou-threshold",
type=float,
help="IoU threshold used for evaluator loading",
)
parser.add_argument(
"--conf-threshold",
type=float,
help="Confidence threshold for active detections shown on overlays",
)
parser.add_argument(
"--metrics",
nargs="+",
default=["longitudinal_error"],
help="Metrics to visualize from badcase_examples.",
)
parser.add_argument(
"--classes",
nargs="+",
default=None,
help="Optional class filter, e.g. car suv pedestrian",
)
parser.add_argument(
"--min-confidence",
type=float,
default=None,
help="Minimum confidence for badcase examples.",
)
parser.add_argument(
"--max-confidence",
type=float,
default=None,
help="Maximum confidence for badcase examples.",
)
parser.add_argument(
"--min-iou",
type=float,
default=None,
help="Minimum IoU for badcase examples.",
)
parser.add_argument(
"--max-iou",
type=float,
default=None,
help="Maximum IoU for badcase examples.",
)
parser.add_argument(
"--top-k",
type=int,
default=200,
help="Maximum number of examples to visualize after filtering",
)
parser.add_argument(
"--top-k-per-distance-bin",
type=int,
default=0,
help="Optional cap per longitudinal distance bin before applying --top-k. 0 disables bin-wise capping.",
)
parser.add_argument(
"--dedup-frame",
action="store_true",
help="Keep at most one example per case/frame/class/metric combination",
)
parser.add_argument(
"--group-by-distance-bin",
action="store_true",
help="Render and save outputs separately for each longitudinal distance bin.",
)
parser.add_argument(
"--line-thickness",
type=int,
default=2,
help="Base line thickness for non-highlight boxes",
)
parser.add_argument(
"--crop-scale",
type=float,
default=1.8,
help="Expand crop window around GT/det union box by this factor",
)
parser.add_argument(
"--jpeg-quality",
type=int,
default=92,
help="JPEG quality for saved visualizations",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory. Defaults to sibling 3d_vis_<report_name>",
)
return parser.parse_args()
def normalize_token_set(values):
if not values:
return None
return {str(v).strip().lower() for v in values if str(v).strip()}
def filter_examples(report, args):
pools = []
for metric_name in args.metrics:
pools.extend(report.get("badcase_examples", {}).get(metric_name, []))
class_filter = normalize_token_set(args.classes)
metric_filter = normalize_token_set(args.metrics)
filtered = []
for item in pools:
if metric_filter and str(item.get("metric_name", "")).lower() not in metric_filter:
continue
if class_filter and str(item.get("class_name", "")).lower() not in class_filter:
continue
confidence = item.get("confidence")
iou = item.get("iou")
if args.min_confidence is not None and (confidence is None or float(confidence) < args.min_confidence):
continue
if args.max_confidence is not None and (confidence is None or float(confidence) > args.max_confidence):
continue
if args.min_iou is not None and (iou is None or float(iou) < args.min_iou):
continue
if args.max_iou is not None and (iou is None or float(iou) > args.max_iou):
continue
filtered.append(item)
filtered.sort(
key=lambda item: (
float(item.get("metric_value_display", 0.0)),
float(item.get("confidence", 0.0)),
float(item.get("iou", 0.0)),
),
reverse=True,
)
if args.dedup_frame:
deduped = []
seen = set()
for item in filtered:
key = (
item.get("case_name"),
item.get("frame_name"),
item.get("class_name"),
item.get("metric_name"),
)
if key in seen:
continue
seen.add(key)
deduped.append(item)
filtered = deduped
if args.top_k_per_distance_bin and args.top_k_per_distance_bin > 0:
kept = []
counts = Counter()
for item in filtered:
distance_bin = item.get("distance_bin") or "unbucketed"
if counts[distance_bin] >= args.top_k_per_distance_bin:
continue
kept.append(item)
counts[distance_bin] += 1
filtered = kept
if args.top_k is not None:
filtered = filtered[: args.top_k]
return filtered
def bbox_to_int(bbox):
return [int(round(float(v))) for v in bbox]
def draw_box(image, bbox, color, label=None, thickness=2):
x1, y1, x2, y2 = bbox_to_int(bbox)
cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness, cv2.LINE_AA)
if label:
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
y_text = max(0, y1 - th - 8)
cv2.rectangle(image, (x1, y_text), (x1 + tw + 8, y_text + th + 8), color, -1)
cv2.putText(
image,
label,
(x1 + 4, y_text + th + 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.55,
(255, 255, 255),
1,
cv2.LINE_AA,
)
def add_header(image, text):
h, w = image.shape[:2]
overlay = image.copy()
cv2.rectangle(overlay, (0, 0), (w, 42), BOX_COLORS["title_bg"], -1)
cv2.addWeighted(overlay, 0.55, image, 0.45, 0, image)
cv2.putText(
image,
text,
(10, 28),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 255, 255),
2,
cv2.LINE_AA,
)
def make_crop(image, boxes, scale=1.8):
h, w = image.shape[:2]
valid = [bbox for bbox in boxes if bbox]
if not valid:
return image.copy(), (0, 0)
x1 = min(float(b[0]) for b in valid)
y1 = min(float(b[1]) for b in valid)
x2 = max(float(b[2]) for b in valid)
y2 = max(float(b[3]) for b in valid)
cx = 0.5 * (x1 + x2)
cy = 0.5 * (y1 + y2)
bw = max(32.0, (x2 - x1) * scale)
bh = max(32.0, (y2 - y1) * scale)
crop_x1 = max(0, int(round(cx - bw / 2)))
crop_y1 = max(0, int(round(cy - bh / 2)))
crop_x2 = min(w, int(round(cx + bw / 2)))
crop_y2 = min(h, int(round(cy + bh / 2)))
return image[crop_y1:crop_y2, crop_x1:crop_x2].copy(), (crop_x1, crop_y1)
def shift_box(box, off_x, off_y):
if not box:
return None
return [
float(box[0]) - off_x,
float(box[1]) - off_y,
float(box[2]) - off_x,
float(box[3]) - off_y,
]
def draw_crop_panel(image, example, crop_scale):
gt_bbox = example.get("gt_bbox")
det_bbox = example.get("det_bbox")
crop, (off_x, off_y) = make_crop(image, [gt_bbox, det_bbox], scale=crop_scale)
gt_local = shift_box(gt_bbox, off_x, off_y)
det_local = shift_box(det_bbox, off_x, off_y)
if gt_local:
draw_box(
crop,
gt_local,
BOX_COLORS["target_gt"],
label=f"GT {example['class_name']}",
thickness=3,
)
if det_local:
draw_box(
crop,
det_local,
BOX_COLORS["target_det"],
label=f"Det {example['class_name']} {float(example.get('confidence', 0.0)):.2f}",
thickness=3,
)
add_header(
crop,
(
f"crop | {example['class_name']} | {example['metric_name']}="
f"{float(example.get('metric_value_display', 0.0)):.3f}{example.get('metric_unit', '')}"
),
)
return crop
def create_bev_panel(example, coord_system="camera", width=480, height=320, max_depth_m=100.0, max_lateral_m=30.0):
panel = np.full((height, width, 3), 245, dtype=np.uint8)
def project(point3d):
if not point3d or len(point3d) < 3:
return None
if coord_system == "camera":
x = float(point3d[0])
z = float(point3d[2])
else:
x = float(point3d[1])
z = float(point3d[0])
px = int(round((x + max_lateral_m) / (2.0 * max_lateral_m) * (width - 1)))
py = int(round((1.0 - max(0.0, min(z, max_depth_m)) / max_depth_m) * (height - 1)))
return px, py
for depth in range(0, int(max_depth_m) + 1, 10):
y = int(round((1.0 - depth / max_depth_m) * (height - 1)))
cv2.line(panel, (0, y), (width - 1, y), (225, 225, 225), 1, cv2.LINE_AA)
cv2.putText(panel, f"{depth}m", (6, max(14, y - 4)), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (90, 90, 90), 1, cv2.LINE_AA)
for lat in range(-int(max_lateral_m), int(max_lateral_m) + 1, 10):
x = int(round((lat + max_lateral_m) / (2.0 * max_lateral_m) * (width - 1)))
cv2.line(panel, (x, 0), (x, height - 1), (232, 232, 232), 1, cv2.LINE_AA)
center_x = int(round((0.0 + max_lateral_m) / (2.0 * max_lateral_m) * (width - 1)))
cv2.line(panel, (center_x, 0), (center_x, height - 1), (180, 180, 180), 2, cv2.LINE_AA)
gt_pt = project(example.get("gt_center_3d"))
det_pt = project(example.get("det_center_3d"))
if gt_pt:
cv2.circle(panel, gt_pt, 7, BOX_COLORS["bev_gt"], -1, cv2.LINE_AA)
draw_heading_arrow(panel, gt_pt, float(example.get("gt_rotation_rad", 0.0)), BOX_COLORS["bev_gt"])
if det_pt:
cv2.circle(panel, det_pt, 7, BOX_COLORS["bev_det"], -1, cv2.LINE_AA)
draw_heading_arrow(panel, det_pt, float(example.get("det_rotation_rad", 0.0)), BOX_COLORS["bev_det"])
if gt_pt and det_pt:
cv2.line(panel, gt_pt, det_pt, (80, 80, 80), 2, cv2.LINE_AA)
add_header(panel, "simple BEV | GT=green | Det=orange")
return panel
def draw_heading_arrow(canvas, anchor, rotation_rad, color, length_px=28):
dx = math.sin(rotation_rad) * length_px
dy = -math.cos(rotation_rad) * length_px
end_point = (int(round(anchor[0] + dx)), int(round(anchor[1] + dy)))
cv2.arrowedLine(canvas, anchor, end_point, color, 2, cv2.LINE_AA, tipLength=0.25)
def add_sidebar(panel, example):
h, _ = panel.shape[:2]
sidebar = np.full((h, 420, 3), 28, dtype=np.uint8)
lines = [
f"case: {example.get('case_name')}",
f"frame: {example.get('frame_name')}",
f"class: {example.get('class_name')}",
f"metric: {example.get('metric_name')}",
f"metric_display: {example.get('metric_value_display')} {example.get('metric_unit', '')}",
f"conf: {example.get('confidence')}",
f"iou: {example.get('iou')}",
f"distance_z_m: {example.get('distance_longitudinal_m')}",
f"distance_x_m: {example.get('distance_lateral_m')}",
f"distance_bin: {example.get('distance_bin')}",
f"lateral_bin: {example.get('lateral_bin')}",
f"lat_err_m: {example.get('lateral_error_m')}",
f"long_err_m: {example.get('longitudinal_error_m')}",
f"long_rel_err: {example.get('longitudinal_relative_error')}",
f"heading_deg: {example.get('heading_error_deg')}",
f"heading_relaxed_deg: {example.get('heading_error_relaxed_deg')}",
f"is_reversal: {example.get('is_reversal')}",
f"gt_id: {example.get('gt_id')}",
]
y = 36
for line in lines:
cv2.putText(
sidebar,
str(line),
(12, y),
cv2.FONT_HERSHEY_SIMPLEX,
0.56,
(235, 235, 235),
1,
cv2.LINE_AA,
)
y += 28
return np.hstack([panel, sidebar])
def resize_to_height(image, target_height):
h, w = image.shape[:2]
if h == target_height:
return image
scale = target_height / max(h, 1)
return cv2.resize(image, (max(1, int(round(w * scale))), target_height))
def combine_panels(full_image, crop_image, bev_image, example):
target_h = max(full_image.shape[0], crop_image.shape[0], bev_image.shape[0])
full_resized = resize_to_height(full_image, target_h)
crop_resized = resize_to_height(crop_image, target_h)
bev_resized = resize_to_height(bev_image, target_h)
panel = np.hstack([full_resized, crop_resized, bev_resized])
return add_sidebar(panel, example)
def find_pair_map(config):
evaluator = Evaluator(
config=config,
iou_threshold=float(config.get("matching", {}).get("iou_threshold", 0.5)),
num_workers=1,
save_detailed_matches=False,
)
dataset_cfg = config["dataset"]
image_cfg = config["image"]
evaluator.load_data_from_paths(
det_root=dataset_cfg["det_path"],
gt_root=dataset_cfg["gt_path"],
img_width=image_cfg.get("width", 1920),
img_height=image_cfg.get("height", 1080),
path_depth=dataset_cfg.get("path_depth", 1),
det_format=dataset_cfg.get("det_format", "auto"),
gt_format=dataset_cfg.get("gt_format", "auto"),
)
pair_map = {}
for pair in evaluator.image_pairs:
level1_name = pair.get("level1_name")
if level1_name:
case_key = f"{level1_name}/{pair['case']}"
else:
case_key = pair["case"]
pair_map[(case_key, pair["frame"])] = pair
return pair_map, evaluator
def find_image_path(pair):
gt_file = Path(pair["gt_file"])
case_dir = gt_file.parent.parent
images_dir = case_dir / "images"
stem = gt_file.stem
for suffix in (".png", ".jpg", ".jpeg", ".bmp"):
candidate = images_dir / f"{stem}{suffix}"
if candidate.exists():
return candidate
matches = list(images_dir.glob(f"{stem}.*"))
return matches[0] if matches else None
def render_frame_overlay(image, gts, active_dets, frame_examples, class_ids, line_thickness):
canvas = image.copy()
selected_class_ids = set(class_ids)
for gt in gts:
if gt["label"] not in selected_class_ids:
continue
draw_box(
canvas,
gt["bbox_2d"],
BOX_COLORS["gt_all"],
label=f"GT {class_name(gt['label'])}",
thickness=line_thickness,
)
for det in active_dets:
if det["label"] not in selected_class_ids:
continue
conf = float(det.get("confidence", 0.0))
draw_box(
canvas,
det["bbox_2d"],
BOX_COLORS["det_all"],
label=f"Det {class_name(det['label'])} {conf:.2f}",
thickness=line_thickness,
)
for idx, example in enumerate(frame_examples, 1):
if example.get("gt_bbox"):
draw_box(
canvas,
example["gt_bbox"],
BOX_COLORS["target_gt"],
label=f"GT#{idx} {example['class_name']}",
thickness=max(3, line_thickness + 1),
)
if example.get("det_bbox"):
label = (
f"Det#{idx} {example['class_name']} "
f"{float(example.get('confidence', 0.0)):.2f} "
f"{example['metric_name']}={float(example.get('metric_value_display', 0.0)):.2f}"
)
draw_box(
canvas,
example["det_bbox"],
BOX_COLORS["target_det"],
label=label,
thickness=max(3, line_thickness + 1),
)
headline = (
f"3D badcase visualization | examples={len(frame_examples)} | "
f"GT=green Det=gray TargetGT=red TargetDet=orange"
)
add_header(canvas, headline)
return canvas
def ensure_dir(path):
path.mkdir(parents=True, exist_ok=True)
return path
def resolve_output_dirs(output_dir, distance_bin=None, group_by_distance_bin=False):
if group_by_distance_bin:
safe_bin = sanitize_token(distance_bin or "unbucketed")
base_dir = output_dir / "distance_bins" / safe_bin
else:
base_dir = output_dir
frame_dir = ensure_dir(base_dir / "frames")
example_dir = ensure_dir(base_dir / "examples")
return base_dir, frame_dir, example_dir
def sanitize_token(value):
return str(value).replace("/", "__").replace("\\", "__").replace(" ", "_")
def default_output_dir(report_path):
report_path = Path(report_path)
return report_path.parent / f"3d_vis_{report_path.stem}"
def main():
args = parse_args()
with open(args.analysis_report, "r") as file:
report = json.load(file)
config = build_config(args)
class_ids = parse_class_ids(args.classes) if args.classes else parse_class_ids(report["metadata"]["classes"])
filtered_examples = filter_examples(report, args)
if not filtered_examples:
print("No examples matched the current filters.")
return
pair_map, evaluator = find_pair_map(config)
output_dir = Path(args.output_dir) if args.output_dir else default_output_dir(args.analysis_report)
by_frame = defaultdict(list)
for item in filtered_examples:
group_distance_bin = item.get("distance_bin") if args.group_by_distance_bin else None
by_frame[(item["case_name"], item["frame_name"], group_distance_bin)].append(item)
index = {
"analysis_report": str(Path(args.analysis_report).resolve()),
"num_examples": len(filtered_examples),
"num_frames": len(by_frame),
"metrics": args.metrics,
"classes": [class_name(cid) for cid in class_ids],
"group_by_distance_bin": bool(args.group_by_distance_bin),
"top_k_per_distance_bin": int(args.top_k_per_distance_bin),
"distance_bins": {},
"frames": [],
}
conf_threshold = float(
config.get("metrics_3d", {}).get(
"conf_threshold",
config.get("metrics_2d", {}).get("conf_threshold", 0.5),
)
)
saved_frame_dirs = set()
saved_example_dirs = set()
for frame_idx, ((case_name, frame_name, distance_bin), frame_examples) in enumerate(by_frame.items(), 1):
pair = pair_map.get((case_name, frame_name))
if pair is None:
print(f"Warning: failed to locate pair for {case_name}/{frame_name}, skipping")
continue
image_path = find_image_path(pair)
if image_path is None or not image_path.exists():
print(f"Warning: image not found for {case_name}/{frame_name}, skipping")
continue
image = cv2.imread(str(image_path))
if image is None:
print(f"Warning: failed to read image: {image_path}")
continue
gts = Evaluator._parse_ground_truths_for_pair(pair, evaluator.coord_system)
dets = Evaluator._parse_detections_for_pair(pair, evaluator.coord_system)
active_dets = [det for det in dets if float(det.get("confidence", 0.0)) >= conf_threshold]
frame_overlay = render_frame_overlay(
image,
gts,
active_dets,
frame_examples,
class_ids,
line_thickness=args.line_thickness,
)
base_dir, frame_dir, example_dir = resolve_output_dirs(
output_dir,
distance_bin=distance_bin,
group_by_distance_bin=args.group_by_distance_bin,
)
saved_frame_dirs.add(str(frame_dir))
saved_example_dirs.add(str(example_dir))
frame_rel = Path(frame_dir.relative_to(output_dir)) / (
f"{frame_idx:04d}_{sanitize_token(case_name)}_{sanitize_token(frame_name)}.jpg"
)
frame_path = output_dir / frame_rel
cv2.imwrite(
str(frame_path),
frame_overlay,
[int(cv2.IMWRITE_JPEG_QUALITY), int(args.jpeg_quality)],
)
frame_entry = {
"case_name": case_name,
"frame_name": frame_name,
"distance_bin": distance_bin,
"image_path": str(image_path),
"frame_visualization": str(frame_rel),
"num_examples": len(frame_examples),
"examples": [],
}
distance_key = distance_bin or "all"
index["distance_bins"].setdefault(distance_key, {"num_frames": 0, "num_examples": 0})
index["distance_bins"][distance_key]["num_frames"] += 1
index["distance_bins"][distance_key]["num_examples"] += len(frame_examples)
for ex_idx, example in enumerate(frame_examples, 1):
crop_image = draw_crop_panel(image.copy(), example, crop_scale=args.crop_scale)
bev_image = create_bev_panel(example, coord_system=evaluator.coord_system)
panel = combine_panels(frame_overlay.copy(), crop_image, bev_image, example)
rel = Path(example_dir.relative_to(output_dir)) / (
f"{frame_idx:04d}_{ex_idx:02d}_"
f"{sanitize_token(case_name)}_{sanitize_token(frame_name)}_"
f"{sanitize_token(example['class_name'])}_{sanitize_token(example['metric_name'])}.jpg"
)
panel_path = output_dir / rel
cv2.imwrite(
str(panel_path),
panel,
[int(cv2.IMWRITE_JPEG_QUALITY), int(args.jpeg_quality)],
)
example_record = dict(example)
example_record["visualization"] = str(rel)
frame_entry["examples"].append(example_record)
index["frames"].append(frame_entry)
index_path = output_dir / "index.json"
with open(index_path, "w") as file:
json.dump(index, file, indent=2)
print(f"Saved visualization index to: {index_path}")
print(f"Saved frame overlays to: {', '.join(sorted(saved_frame_dirs))}")
print(f"Saved example panels to: {', '.join(sorted(saved_example_dirs))}")
if __name__ == "__main__":
main()