1001 lines
35 KiB
Python
Executable File
1001 lines
35 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Analyze 2D false positives and false negatives for YOLOv5-3D evaluation data.
|
|
|
|
This tool reuses the existing evaluation data loading pipeline so that ROI GT
|
|
filtering, detection parsing, and path probing stay aligned with
|
|
``eval_tools/core/eval.py``.
|
|
|
|
Error categories
|
|
----------------
|
|
False positives:
|
|
- duplicate: overlaps a same-class GT above the match IoU threshold, but that
|
|
GT was already claimed by a higher-confidence detection.
|
|
- class_confusion: overlaps a GT of another class above the match IoU
|
|
threshold.
|
|
- localization: overlaps a same-class GT, but IoU is below the match
|
|
threshold and above a configurable "near miss" threshold.
|
|
- background: does not overlap any GT strongly enough to explain the error.
|
|
|
|
False negatives:
|
|
- class_confusion: a detection of another class overlaps the GT above the
|
|
match IoU threshold.
|
|
- low_score: a same-class detection overlaps the GT above the match IoU
|
|
threshold, but its confidence is below the operating threshold.
|
|
- localization: an above-threshold same-class detection is close to the GT
|
|
but does not reach match IoU.
|
|
- low_score_localization: a below-threshold same-class detection is close to
|
|
the GT but still poorly localized.
|
|
- missing: no plausible same-class detection is present.
|
|
"""
|
|
|
|
import argparse
|
|
import heapq
|
|
import json
|
|
import sys
|
|
from collections import Counter, defaultdict
|
|
from datetime import datetime
|
|
from functools import partial
|
|
from itertools import count
|
|
from multiprocessing import Pool, cpu_count
|
|
from pathlib import Path
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
from eval_tools.evaluator.evaluator import Evaluator
|
|
from eval_tools.evaluator.matcher import Matcher2D
|
|
from eval_tools.evaluator.parser import GroundTruthParser
|
|
|
|
|
|
DEFAULT_CLASS_IDS = list(sorted(GroundTruthParser.CLASS_NAMES.keys()))
|
|
CLASS_NAME_TO_ID = {
|
|
name.lower(): class_id for class_id, name in GroundTruthParser.CLASS_NAMES.items()
|
|
}
|
|
|
|
|
|
def load_config(config_path):
|
|
"""Load configuration from YAML file."""
|
|
try:
|
|
import yaml
|
|
except ModuleNotFoundError as exc:
|
|
raise ModuleNotFoundError(
|
|
"PyYAML is required when using --config. Please install it in the active environment."
|
|
) from exc
|
|
|
|
with open(config_path, "r") as file:
|
|
config = yaml.safe_load(file)
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
if "output" in config and "save_path" in config["output"]:
|
|
config["output"]["save_path"] = config["output"]["save_path"].replace(
|
|
"{timestamp}", timestamp
|
|
)
|
|
if "dataset" in config:
|
|
if "det_path" in config["dataset"]:
|
|
config["dataset"]["det_path"] = config["dataset"]["det_path"].replace(
|
|
"{timestamp}", timestamp
|
|
)
|
|
if "gt_path" in config["dataset"]:
|
|
config["dataset"]["gt_path"] = config["dataset"]["gt_path"].replace(
|
|
"{timestamp}", timestamp
|
|
)
|
|
return config
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Analyze 2D FP/FN patterns using the existing evaluation data pipeline."
|
|
)
|
|
parser.add_argument("--config", type=str, help="Path to YAML evaluation config")
|
|
parser.add_argument("--det-path", type=str, help="Detection results root directory")
|
|
parser.add_argument("--gt-path", type=str, help="Ground-truth labels root directory")
|
|
parser.add_argument("--path-depth", type=int, choices=[1, 2], help="Directory depth")
|
|
parser.add_argument(
|
|
"--det-format",
|
|
type=str,
|
|
choices=["auto", "json", "txt"],
|
|
help="Detection file format",
|
|
)
|
|
parser.add_argument(
|
|
"--gt-format",
|
|
type=str,
|
|
choices=["auto", "json", "txt"],
|
|
help="Ground-truth file format",
|
|
)
|
|
parser.add_argument("--img-width", type=int, help="Image width")
|
|
parser.add_argument("--img-height", type=int, help="Image height")
|
|
parser.add_argument(
|
|
"--coord-system",
|
|
type=str,
|
|
choices=["camera", "ego"],
|
|
help="Coordinate system used by the parser/evaluator",
|
|
)
|
|
parser.add_argument(
|
|
"--iou-threshold",
|
|
type=float,
|
|
help="IoU threshold used for TP matching and duplicate/confusion checks",
|
|
)
|
|
parser.add_argument(
|
|
"--conf-threshold",
|
|
type=float,
|
|
help="Confidence threshold for the analyzed operating point",
|
|
)
|
|
parser.add_argument(
|
|
"--near-iou-threshold",
|
|
type=float,
|
|
default=0.1,
|
|
help="Near-miss IoU threshold for localization-style FP/FN categorization",
|
|
)
|
|
parser.add_argument(
|
|
"--classes",
|
|
nargs="+",
|
|
default=None,
|
|
help="Optional class filter, e.g. vehicle pedestrian rider or numeric IDs",
|
|
)
|
|
parser.add_argument(
|
|
"--num-workers",
|
|
type=int,
|
|
default=None,
|
|
help="Worker count for scanning and frame analysis (default: auto-detect)",
|
|
)
|
|
parser.add_argument(
|
|
"--max-frames",
|
|
type=int,
|
|
default=None,
|
|
help="Only analyze the first N frames after loading",
|
|
)
|
|
parser.add_argument(
|
|
"--max-fp-details",
|
|
type=int,
|
|
default=1000,
|
|
help="Maximum number of FP examples to keep per error type in the JSON report",
|
|
)
|
|
parser.add_argument(
|
|
"--max-fn-details",
|
|
type=int,
|
|
default=1000,
|
|
help="Maximum number of FN examples to keep per error type in the JSON report",
|
|
)
|
|
parser.add_argument(
|
|
"--top-k-frames",
|
|
type=int,
|
|
default=50,
|
|
help="Number of worst frames to include in the summary",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
default=None,
|
|
help="Output directory. Defaults to eval_tools/analysis/results/<timestamp>",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def build_config(args):
|
|
"""Build analysis config from YAML and CLI overrides."""
|
|
if args.config:
|
|
config = load_config(args.config)
|
|
else:
|
|
if not args.det_path or not args.gt_path:
|
|
raise ValueError(
|
|
"--det-path and --gt-path are required when --config is not provided."
|
|
)
|
|
config = {
|
|
"dataset": {
|
|
"det_path": args.det_path,
|
|
"gt_path": args.gt_path,
|
|
"path_depth": args.path_depth or 1,
|
|
"det_format": args.det_format or "auto",
|
|
"gt_format": args.gt_format or "auto",
|
|
},
|
|
"image": {
|
|
"width": args.img_width or 1920,
|
|
"height": args.img_height or 1080,
|
|
},
|
|
"matching": {"iou_threshold": args.iou_threshold or 0.5},
|
|
"metrics_2d": {"conf_threshold": args.conf_threshold or 0.5},
|
|
"metrics_3d": {"coordinate_system": args.coord_system or "camera"},
|
|
}
|
|
|
|
dataset_cfg = config.setdefault("dataset", {})
|
|
image_cfg = config.setdefault("image", {})
|
|
matching_cfg = config.setdefault("matching", {})
|
|
metrics_2d_cfg = config.setdefault("metrics_2d", {})
|
|
metrics_3d_cfg = config.setdefault("metrics_3d", {})
|
|
|
|
if args.det_path:
|
|
dataset_cfg["det_path"] = args.det_path
|
|
if args.gt_path:
|
|
dataset_cfg["gt_path"] = args.gt_path
|
|
if args.path_depth is not None:
|
|
dataset_cfg["path_depth"] = args.path_depth
|
|
if args.det_format:
|
|
dataset_cfg["det_format"] = args.det_format
|
|
if args.gt_format:
|
|
dataset_cfg["gt_format"] = args.gt_format
|
|
if args.img_width is not None:
|
|
image_cfg["width"] = args.img_width
|
|
if args.img_height is not None:
|
|
image_cfg["height"] = args.img_height
|
|
if args.iou_threshold is not None:
|
|
matching_cfg["iou_threshold"] = args.iou_threshold
|
|
if args.conf_threshold is not None:
|
|
metrics_2d_cfg["conf_threshold"] = args.conf_threshold
|
|
if args.coord_system:
|
|
metrics_3d_cfg["coordinate_system"] = args.coord_system
|
|
|
|
dataset_cfg.setdefault("path_depth", 1)
|
|
dataset_cfg.setdefault("det_format", "auto")
|
|
dataset_cfg.setdefault("gt_format", "auto")
|
|
image_cfg.setdefault("width", 1920)
|
|
image_cfg.setdefault("height", 1080)
|
|
matching_cfg.setdefault("iou_threshold", 0.5)
|
|
metrics_2d_cfg.setdefault("conf_threshold", 0.5)
|
|
metrics_3d_cfg.setdefault("coordinate_system", "camera")
|
|
|
|
if "det_path" not in dataset_cfg or "gt_path" not in dataset_cfg:
|
|
raise ValueError("Both det_path and gt_path must be available in the final config.")
|
|
|
|
return config
|
|
|
|
|
|
def parse_class_ids(raw_classes):
|
|
"""Parse class names/IDs from CLI into numeric class IDs."""
|
|
if not raw_classes:
|
|
return DEFAULT_CLASS_IDS
|
|
|
|
class_ids = []
|
|
for token in raw_classes:
|
|
token_str = str(token).strip().lower()
|
|
if token_str.isdigit() or (token_str.startswith("-") and token_str[1:].isdigit()):
|
|
class_id = int(token_str)
|
|
else:
|
|
if token_str not in CLASS_NAME_TO_ID:
|
|
raise ValueError(f"Unknown class: {token}")
|
|
class_id = CLASS_NAME_TO_ID[token_str]
|
|
|
|
if class_id not in GroundTruthParser.CLASS_NAMES:
|
|
raise ValueError(f"Unsupported class ID: {class_id}")
|
|
class_ids.append(class_id)
|
|
|
|
return sorted(set(class_ids))
|
|
|
|
|
|
def class_name(class_id):
|
|
return GroundTruthParser.CLASS_NAMES.get(class_id, f"class_{class_id}")
|
|
|
|
|
|
def object_id(obj, prefix, fallback_idx):
|
|
obj_id = obj.get("id")
|
|
return str(obj_id) if obj_id is not None else f"{prefix}_{fallback_idx}"
|
|
|
|
|
|
def bbox_area(bbox):
|
|
if not bbox or len(bbox) < 4:
|
|
return 0.0
|
|
return max(0.0, bbox[2] - bbox[0]) * max(0.0, bbox[3] - bbox[1])
|
|
|
|
|
|
def round_float(value, digits=6):
|
|
return round(float(value), digits)
|
|
|
|
|
|
def limit_examples_per_type(examples, limit_per_type):
|
|
"""Keep at most ``limit_per_type`` examples for each error_type."""
|
|
if limit_per_type is None or limit_per_type <= 0:
|
|
return examples
|
|
|
|
kept = []
|
|
counts = Counter()
|
|
for item in examples:
|
|
error_type = item.get("error_type", "unknown")
|
|
if counts[error_type] >= limit_per_type:
|
|
continue
|
|
kept.append(item)
|
|
counts[error_type] += 1
|
|
return kept
|
|
|
|
|
|
def build_case_key(pair):
|
|
level1_name = pair.get("level1_name")
|
|
if level1_name:
|
|
return f"{level1_name}/{pair['case']}"
|
|
return pair["case"]
|
|
|
|
|
|
def extract_3d_meta(obj, coord_system):
|
|
d3_info = obj.get("3d_info")
|
|
if not d3_info:
|
|
return {"distance_m": None, "lateral_m": None}
|
|
|
|
center = d3_info.get("center", [])
|
|
if len(center) < 3:
|
|
return {"distance_m": None, "lateral_m": None}
|
|
|
|
if coord_system == "camera":
|
|
return {
|
|
"distance_m": round_float(center[2]),
|
|
"lateral_m": round_float(center[0]),
|
|
}
|
|
return {
|
|
"distance_m": round_float(center[0]),
|
|
"lateral_m": round_float(center[1]),
|
|
}
|
|
|
|
|
|
def best_gt_overlap(matcher, det, candidates, exclude_class_id=None):
|
|
best_iou = 0.0
|
|
best_idx = -1
|
|
best_gt = None
|
|
for idx, gt in enumerate(candidates):
|
|
if exclude_class_id is not None and gt["label"] == exclude_class_id:
|
|
continue
|
|
iou = matcher.compute_pair_iou(gt, det)
|
|
if iou > best_iou:
|
|
best_iou = iou
|
|
best_idx = idx
|
|
best_gt = gt
|
|
return best_iou, best_idx, best_gt
|
|
|
|
|
|
def best_det_overlap(
|
|
matcher,
|
|
gt,
|
|
detections,
|
|
label_id=None,
|
|
conf_min=None,
|
|
conf_max=None,
|
|
exclude_label_id=None,
|
|
):
|
|
best_iou = 0.0
|
|
best_idx = -1
|
|
best_det = None
|
|
for idx, det in enumerate(detections):
|
|
if label_id is not None and det["label"] != label_id:
|
|
continue
|
|
if exclude_label_id is not None and det["label"] == exclude_label_id:
|
|
continue
|
|
confidence = float(det.get("confidence", 0.0))
|
|
if conf_min is not None and confidence < conf_min:
|
|
continue
|
|
if conf_max is not None and confidence >= conf_max:
|
|
continue
|
|
iou = matcher.compute_pair_iou(gt, det)
|
|
if iou > best_iou:
|
|
best_iou = iou
|
|
best_idx = idx
|
|
best_det = det
|
|
return best_iou, best_idx, best_det
|
|
|
|
|
|
def classify_fp_detail(
|
|
matcher,
|
|
det,
|
|
det_idx,
|
|
class_id,
|
|
gts_same_class,
|
|
matched_gt_indices,
|
|
all_gts,
|
|
case_key,
|
|
frame_name,
|
|
coord_system,
|
|
iou_threshold,
|
|
near_iou_threshold,
|
|
):
|
|
best_same_iou, best_same_idx, best_same_gt = best_gt_overlap(
|
|
matcher, det, gts_same_class
|
|
)
|
|
best_other_iou, best_other_idx, best_other_gt = best_gt_overlap(
|
|
matcher, det, all_gts, exclude_class_id=class_id
|
|
)
|
|
|
|
if best_same_iou >= iou_threshold and best_same_idx in matched_gt_indices:
|
|
error_type = "duplicate"
|
|
elif best_other_iou >= iou_threshold:
|
|
error_type = "class_confusion"
|
|
elif best_same_iou >= near_iou_threshold:
|
|
error_type = "localization"
|
|
else:
|
|
error_type = "background"
|
|
|
|
detail = {
|
|
"case_name": case_key,
|
|
"frame_name": frame_name,
|
|
"class_id": class_id,
|
|
"class_name": class_name(class_id),
|
|
"error_type": error_type,
|
|
"det_id": object_id(det, "det", det_idx),
|
|
"confidence": round_float(det.get("confidence", 0.0)),
|
|
"det_bbox": [round_float(v) for v in det.get("bbox_2d", [])],
|
|
"det_bbox_area": round_float(bbox_area(det.get("bbox_2d", []))),
|
|
"best_same_class_iou": round_float(best_same_iou),
|
|
"best_same_gt_id": (
|
|
object_id(best_same_gt, "gt", best_same_idx) if best_same_gt else None
|
|
),
|
|
"best_other_class_iou": round_float(best_other_iou),
|
|
"best_other_gt_id": (
|
|
object_id(best_other_gt, "gt", best_other_idx) if best_other_gt else None
|
|
),
|
|
"best_other_class": class_name(best_other_gt["label"]) if best_other_gt else None,
|
|
}
|
|
detail.update(extract_3d_meta(det, coord_system))
|
|
return detail
|
|
|
|
|
|
def classify_fn_detail(
|
|
matcher,
|
|
gt,
|
|
gt_idx,
|
|
class_id,
|
|
active_dets,
|
|
all_dets,
|
|
case_key,
|
|
frame_name,
|
|
coord_system,
|
|
iou_threshold,
|
|
conf_threshold,
|
|
near_iou_threshold,
|
|
):
|
|
active_same_iou, active_same_idx, active_same_det = best_det_overlap(
|
|
matcher, gt, active_dets, label_id=class_id
|
|
)
|
|
low_same_iou, low_same_idx, low_same_det = best_det_overlap(
|
|
matcher, gt, all_dets, label_id=class_id, conf_max=conf_threshold
|
|
)
|
|
other_active_iou, other_active_idx, other_active_det = best_det_overlap(
|
|
matcher, gt, active_dets, exclude_label_id=class_id
|
|
)
|
|
|
|
if other_active_iou >= iou_threshold:
|
|
error_type = "class_confusion"
|
|
best_det = other_active_det
|
|
best_det_idx = other_active_idx
|
|
best_det_iou = other_active_iou
|
|
elif low_same_iou >= iou_threshold:
|
|
error_type = "low_score"
|
|
best_det = low_same_det
|
|
best_det_idx = low_same_idx
|
|
best_det_iou = low_same_iou
|
|
elif active_same_iou >= near_iou_threshold:
|
|
error_type = "localization"
|
|
best_det = active_same_det
|
|
best_det_idx = active_same_idx
|
|
best_det_iou = active_same_iou
|
|
elif low_same_iou >= near_iou_threshold:
|
|
error_type = "low_score_localization"
|
|
best_det = low_same_det
|
|
best_det_idx = low_same_idx
|
|
best_det_iou = low_same_iou
|
|
else:
|
|
error_type = "missing"
|
|
best_det = None
|
|
best_det_idx = -1
|
|
best_det_iou = 0.0
|
|
|
|
detail = {
|
|
"case_name": case_key,
|
|
"frame_name": frame_name,
|
|
"class_id": class_id,
|
|
"class_name": class_name(class_id),
|
|
"error_type": error_type,
|
|
"gt_id": object_id(gt, "gt", gt_idx),
|
|
"gt_bbox": [round_float(v) for v in gt.get("bbox_2d", [])],
|
|
"gt_bbox_area": round_float(bbox_area(gt.get("bbox_2d", []))),
|
|
"best_same_active_iou": round_float(active_same_iou),
|
|
"best_same_low_score_iou": round_float(low_same_iou),
|
|
"best_other_class_iou": round_float(other_active_iou),
|
|
"best_det_iou": round_float(best_det_iou),
|
|
}
|
|
if best_det is not None:
|
|
detail["best_det_id"] = object_id(best_det, "det", best_det_idx)
|
|
detail["best_det_class"] = class_name(best_det["label"])
|
|
detail["best_det_confidence"] = round_float(best_det.get("confidence", 0.0))
|
|
detail["best_det_bbox"] = [round_float(v) for v in best_det.get("bbox_2d", [])]
|
|
else:
|
|
detail["best_det_id"] = None
|
|
detail["best_det_class"] = None
|
|
detail["best_det_confidence"] = None
|
|
detail["best_det_bbox"] = None
|
|
|
|
detail.update(extract_3d_meta(gt, coord_system))
|
|
return detail
|
|
|
|
|
|
def rank_fp_detail(detail):
|
|
overlap = max(detail["best_same_class_iou"], detail["best_other_class_iou"])
|
|
return (
|
|
detail["confidence"],
|
|
overlap,
|
|
detail["det_bbox_area"],
|
|
)
|
|
|
|
|
|
def rank_fn_detail(detail):
|
|
best_conf = detail["best_det_confidence"] or 0.0
|
|
return (
|
|
best_conf,
|
|
detail["best_det_iou"],
|
|
detail["gt_bbox_area"],
|
|
)
|
|
|
|
|
|
def push_limited_detail(heap_by_type, seq_counter, detail, limit_per_type, rank_fn):
|
|
if limit_per_type is None or limit_per_type <= 0:
|
|
return
|
|
|
|
error_type = detail.get("error_type", "unknown")
|
|
entry = (rank_fn(detail), next(seq_counter), detail)
|
|
heap = heap_by_type[error_type]
|
|
|
|
if len(heap) < limit_per_type:
|
|
heapq.heappush(heap, entry)
|
|
elif entry[0] > heap[0][0]:
|
|
heapq.heapreplace(heap, entry)
|
|
|
|
|
|
def heaps_to_sorted_examples(heap_by_type):
|
|
examples = []
|
|
for error_type in sorted(heap_by_type.keys()):
|
|
entries = sorted(heap_by_type[error_type], key=lambda item: item[0], reverse=True)
|
|
examples.extend(detail for _rank, _seq, detail in entries)
|
|
return examples
|
|
|
|
|
|
def analyze_frame_worker(
|
|
pair,
|
|
class_ids,
|
|
coord_system,
|
|
conf_threshold,
|
|
iou_threshold,
|
|
near_iou_threshold,
|
|
):
|
|
matcher = Matcher2D(iou_threshold=iou_threshold)
|
|
case_key = build_case_key(pair)
|
|
frame_name = pair["frame"]
|
|
|
|
gts = Evaluator._parse_ground_truths_for_pair(pair, coord_system)
|
|
all_dets = Evaluator._parse_detections_for_pair(pair, coord_system)
|
|
active_dets = [
|
|
det for det in all_dets if float(det.get("confidence", 0.0)) >= conf_threshold
|
|
]
|
|
|
|
gts_by_class = defaultdict(list)
|
|
active_by_class = defaultdict(list)
|
|
for gt in gts:
|
|
gts_by_class[gt["label"]].append(gt)
|
|
for det in active_dets:
|
|
active_by_class[det["label"]].append(det)
|
|
|
|
frame_fp = 0
|
|
frame_fn = 0
|
|
frame_tp = 0
|
|
frame_fp_by_type = Counter()
|
|
frame_fn_by_type = Counter()
|
|
frame_class_breakdown = {}
|
|
frame_fp_details = []
|
|
frame_fn_details = []
|
|
total_gt = 0
|
|
total_det = 0
|
|
|
|
for class_id in class_ids:
|
|
class_str = class_name(class_id)
|
|
gts_same = gts_by_class.get(class_id, [])
|
|
active_same = active_by_class.get(class_id, [])
|
|
|
|
match_result = matcher.match(gts_same, active_same, class_id)
|
|
matched_gt_indices = {gt_idx for gt_idx, _, _ in match_result["matches"]}
|
|
unmatched_det_indices = match_result["unmatched_dets"]
|
|
unmatched_gt_indices = match_result["unmatched_gts"]
|
|
|
|
tp_count = len(match_result["matches"])
|
|
fp_count = len(unmatched_det_indices)
|
|
fn_count = len(unmatched_gt_indices)
|
|
|
|
frame_tp += tp_count
|
|
frame_fp += fp_count
|
|
frame_fn += fn_count
|
|
total_gt += len(gts_same)
|
|
total_det += len(active_same)
|
|
|
|
class_fp_by_type = Counter()
|
|
class_fn_by_type = Counter()
|
|
|
|
for unmatched_det_idx in unmatched_det_indices:
|
|
det = match_result["dets_sorted"][unmatched_det_idx]
|
|
detail = classify_fp_detail(
|
|
matcher=matcher,
|
|
det=det,
|
|
det_idx=unmatched_det_idx,
|
|
class_id=class_id,
|
|
gts_same_class=match_result["gts_filtered"],
|
|
matched_gt_indices=matched_gt_indices,
|
|
all_gts=gts,
|
|
case_key=case_key,
|
|
frame_name=frame_name,
|
|
coord_system=coord_system,
|
|
iou_threshold=iou_threshold,
|
|
near_iou_threshold=near_iou_threshold,
|
|
)
|
|
frame_fp_details.append(detail)
|
|
frame_fp_by_type[detail["error_type"]] += 1
|
|
class_fp_by_type[detail["error_type"]] += 1
|
|
|
|
for unmatched_gt_idx in unmatched_gt_indices:
|
|
gt = match_result["gts_filtered"][unmatched_gt_idx]
|
|
detail = classify_fn_detail(
|
|
matcher=matcher,
|
|
gt=gt,
|
|
gt_idx=unmatched_gt_idx,
|
|
class_id=class_id,
|
|
active_dets=active_dets,
|
|
all_dets=all_dets,
|
|
case_key=case_key,
|
|
frame_name=frame_name,
|
|
coord_system=coord_system,
|
|
iou_threshold=iou_threshold,
|
|
conf_threshold=conf_threshold,
|
|
near_iou_threshold=near_iou_threshold,
|
|
)
|
|
frame_fn_details.append(detail)
|
|
frame_fn_by_type[detail["error_type"]] += 1
|
|
class_fn_by_type[detail["error_type"]] += 1
|
|
|
|
frame_class_breakdown[class_str] = {
|
|
"gt_count": len(gts_same),
|
|
"det_count_above_threshold": len(active_same),
|
|
"tp_count": tp_count,
|
|
"fp_count": fp_count,
|
|
"fn_count": fn_count,
|
|
"fp_by_type": class_fp_by_type,
|
|
"fn_by_type": class_fn_by_type,
|
|
}
|
|
|
|
return {
|
|
"case_name": case_key,
|
|
"frame_name": frame_name,
|
|
"totals": {
|
|
"fp_total": frame_fp,
|
|
"fn_total": frame_fn,
|
|
"tp_total": frame_tp,
|
|
"gt_total": total_gt,
|
|
"det_total_above_threshold": total_det,
|
|
},
|
|
"frame_stats": {
|
|
"case_name": case_key,
|
|
"frame_name": frame_name,
|
|
"fp_count": frame_fp,
|
|
"fn_count": frame_fn,
|
|
"tp_count": frame_tp,
|
|
"total_errors": frame_fp + frame_fn,
|
|
"fp_by_type": dict(sorted(frame_fp_by_type.items())),
|
|
"fn_by_type": dict(sorted(frame_fn_by_type.items())),
|
|
"per_class": {
|
|
class_str: {
|
|
"gt_count": stats["gt_count"],
|
|
"det_count_above_threshold": stats["det_count_above_threshold"],
|
|
"tp_count": stats["tp_count"],
|
|
"fp_count": stats["fp_count"],
|
|
"fn_count": stats["fn_count"],
|
|
}
|
|
for class_str, stats in frame_class_breakdown.items()
|
|
},
|
|
},
|
|
"per_class": frame_class_breakdown,
|
|
"fp_details": frame_fp_details,
|
|
"fn_details": frame_fn_details,
|
|
}
|
|
|
|
|
|
class Analyze2DFPFN:
|
|
"""Analyze 2D FP/FN patterns at a fixed operating point."""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
class_ids,
|
|
near_iou_threshold=0.1,
|
|
num_workers=1,
|
|
max_frames=None,
|
|
max_fp_details=1000,
|
|
max_fn_details=1000,
|
|
top_k_frames=50,
|
|
):
|
|
self.config = config
|
|
self.class_ids = class_ids
|
|
self.near_iou_threshold = near_iou_threshold
|
|
self.max_frames = max_frames
|
|
self.max_fp_details = max_fp_details
|
|
self.max_fn_details = max_fn_details
|
|
self.top_k_frames = top_k_frames
|
|
self.num_workers = (
|
|
max(1, cpu_count() - 1) if num_workers is None else max(1, int(num_workers))
|
|
)
|
|
|
|
self.iou_threshold = float(config.get("matching", {}).get("iou_threshold", 0.5))
|
|
self.conf_threshold = float(
|
|
config.get("metrics_2d", {}).get("conf_threshold", 0.5)
|
|
)
|
|
self.coord_system = config.get("metrics_3d", {}).get(
|
|
"coordinate_system", "camera"
|
|
)
|
|
|
|
self.matcher = Matcher2D(iou_threshold=self.iou_threshold)
|
|
self.evaluator = Evaluator(
|
|
config=config,
|
|
iou_threshold=self.iou_threshold,
|
|
num_workers=self.num_workers,
|
|
save_detailed_matches=False,
|
|
)
|
|
|
|
def _init_class_stats(self):
|
|
stats = {}
|
|
for class_id in self.class_ids:
|
|
stats[class_name(class_id)] = {
|
|
"class_id": class_id,
|
|
"gt_count": 0,
|
|
"det_count_above_threshold": 0,
|
|
"tp_count": 0,
|
|
"fp_count": 0,
|
|
"fn_count": 0,
|
|
"fp_by_type": Counter(),
|
|
"fn_by_type": Counter(),
|
|
}
|
|
return stats
|
|
|
|
def _merge_frame_result(self, summary, frame_stats, frame_result, fp_heaps, fn_heaps, seq_counter):
|
|
totals = frame_result["totals"]
|
|
summary["fp_total"] += totals["fp_total"]
|
|
summary["fn_total"] += totals["fn_total"]
|
|
summary["tp_total"] += totals["tp_total"]
|
|
summary["gt_total"] += totals["gt_total"]
|
|
summary["det_total_above_threshold"] += totals["det_total_above_threshold"]
|
|
|
|
frame_stats.append(frame_result["frame_stats"])
|
|
|
|
for error_type, count_value in frame_result["frame_stats"]["fp_by_type"].items():
|
|
summary["fp_by_type"][error_type] += count_value
|
|
for error_type, count_value in frame_result["frame_stats"]["fn_by_type"].items():
|
|
summary["fn_by_type"][error_type] += count_value
|
|
|
|
for class_str, class_result in frame_result["per_class"].items():
|
|
class_stats = summary["per_class"][class_str]
|
|
class_stats["gt_count"] += class_result["gt_count"]
|
|
class_stats["det_count_above_threshold"] += class_result["det_count_above_threshold"]
|
|
class_stats["tp_count"] += class_result["tp_count"]
|
|
class_stats["fp_count"] += class_result["fp_count"]
|
|
class_stats["fn_count"] += class_result["fn_count"]
|
|
for error_type, count_value in class_result["fp_by_type"].items():
|
|
class_stats["fp_by_type"][error_type] += count_value
|
|
for error_type, count_value in class_result["fn_by_type"].items():
|
|
class_stats["fn_by_type"][error_type] += count_value
|
|
|
|
for detail in frame_result["fp_details"]:
|
|
push_limited_detail(
|
|
fp_heaps, seq_counter, detail, self.max_fp_details, rank_fp_detail
|
|
)
|
|
for detail in frame_result["fn_details"]:
|
|
push_limited_detail(
|
|
fn_heaps, seq_counter, detail, self.max_fn_details, rank_fn_detail
|
|
)
|
|
|
|
def analyze(self):
|
|
dataset_cfg = self.config["dataset"]
|
|
image_cfg = self.config["image"]
|
|
|
|
self.evaluator.load_data_from_paths(
|
|
det_root=dataset_cfg["det_path"],
|
|
gt_root=dataset_cfg["gt_path"],
|
|
img_width=image_cfg.get("width", 1920),
|
|
img_height=image_cfg.get("height", 1080),
|
|
path_depth=dataset_cfg.get("path_depth", 1),
|
|
det_format=dataset_cfg.get("det_format", "auto"),
|
|
gt_format=dataset_cfg.get("gt_format", "auto"),
|
|
)
|
|
|
|
image_pairs = self.evaluator.image_pairs
|
|
if self.max_frames is not None:
|
|
image_pairs = image_pairs[: self.max_frames]
|
|
|
|
summary = {
|
|
"num_frames": len(image_pairs),
|
|
"num_cases": len({build_case_key(pair) for pair in image_pairs}),
|
|
"classes": [class_name(class_id) for class_id in self.class_ids],
|
|
"fp_total": 0,
|
|
"fn_total": 0,
|
|
"tp_total": 0,
|
|
"gt_total": 0,
|
|
"det_total_above_threshold": 0,
|
|
"fp_by_type": Counter(),
|
|
"fn_by_type": Counter(),
|
|
"per_class": self._init_class_stats(),
|
|
}
|
|
|
|
frame_stats = []
|
|
fp_heaps = defaultdict(list)
|
|
fn_heaps = defaultdict(list)
|
|
seq_counter = count()
|
|
|
|
worker = partial(
|
|
analyze_frame_worker,
|
|
class_ids=self.class_ids,
|
|
coord_system=self.coord_system,
|
|
conf_threshold=self.conf_threshold,
|
|
iou_threshold=self.iou_threshold,
|
|
near_iou_threshold=self.near_iou_threshold,
|
|
)
|
|
|
|
if self.num_workers > 1 and len(image_pairs) > 1:
|
|
chunksize = max(1, len(image_pairs) // max(self.num_workers * 8, 1))
|
|
with Pool(processes=self.num_workers) as pool:
|
|
iterator = pool.imap(worker, image_pairs, chunksize=chunksize)
|
|
for frame_result in tqdm(iterator, total=len(image_pairs), desc="Analyzing 2D FP/FN"):
|
|
self._merge_frame_result(
|
|
summary, frame_stats, frame_result, fp_heaps, fn_heaps, seq_counter
|
|
)
|
|
else:
|
|
for pair in tqdm(image_pairs, desc="Analyzing 2D FP/FN"):
|
|
frame_result = worker(pair)
|
|
self._merge_frame_result(
|
|
summary, frame_stats, frame_result, fp_heaps, fn_heaps, seq_counter
|
|
)
|
|
|
|
for class_stats in summary["per_class"].values():
|
|
class_stats["fp_by_type"] = dict(sorted(class_stats["fp_by_type"].items()))
|
|
class_stats["fn_by_type"] = dict(sorted(class_stats["fn_by_type"].items()))
|
|
|
|
summary["fp_by_type"] = dict(sorted(summary["fp_by_type"].items()))
|
|
summary["fn_by_type"] = dict(sorted(summary["fn_by_type"].items()))
|
|
|
|
frame_stats.sort(
|
|
key=lambda item: (
|
|
item["total_errors"],
|
|
item["fn_count"],
|
|
item["fp_count"],
|
|
),
|
|
reverse=True,
|
|
)
|
|
fp_examples = heaps_to_sorted_examples(fp_heaps)
|
|
fn_examples = heaps_to_sorted_examples(fn_heaps)
|
|
|
|
return {
|
|
"metadata": {
|
|
"created_at": datetime.now().isoformat(timespec="seconds"),
|
|
"det_path": dataset_cfg["det_path"],
|
|
"gt_path": dataset_cfg["gt_path"],
|
|
"path_depth": dataset_cfg.get("path_depth", 1),
|
|
"det_format": dataset_cfg.get("det_format", "auto"),
|
|
"gt_format": dataset_cfg.get("gt_format", "auto"),
|
|
"image_width": image_cfg.get("width", 1920),
|
|
"image_height": image_cfg.get("height", 1080),
|
|
"coord_system": self.coord_system,
|
|
"iou_threshold": self.iou_threshold,
|
|
"conf_threshold": self.conf_threshold,
|
|
"near_iou_threshold": self.near_iou_threshold,
|
|
"max_frames": self.max_frames,
|
|
"num_workers": self.num_workers,
|
|
"max_fp_details_per_type": self.max_fp_details,
|
|
"max_fn_details_per_type": self.max_fn_details,
|
|
},
|
|
"summary": summary,
|
|
"top_frames": frame_stats[: self.top_k_frames],
|
|
"all_frame_stats": frame_stats,
|
|
"false_positive_examples": fp_examples,
|
|
"false_negative_examples": fn_examples,
|
|
}
|
|
|
|
|
|
def write_markdown_report(report, output_path):
|
|
"""Write a compact human-readable Markdown summary."""
|
|
metadata = report["metadata"]
|
|
summary = report["summary"]
|
|
|
|
with open(output_path, "w") as file:
|
|
file.write("# 2D FP/FN Analysis Report\n\n")
|
|
|
|
file.write("## Configuration\n\n")
|
|
file.write("| Item | Value |\n")
|
|
file.write("| --- | --- |\n")
|
|
file.write(f"| Detection path | `{metadata['det_path']}` |\n")
|
|
file.write(f"| Ground-truth path | `{metadata['gt_path']}` |\n")
|
|
file.write(f"| Coordinate system | `{metadata['coord_system']}` |\n")
|
|
file.write(f"| IoU threshold | `{metadata['iou_threshold']:.3f}` |\n")
|
|
file.write(f"| Confidence threshold | `{metadata['conf_threshold']:.3f}` |\n")
|
|
file.write(f"| Near-miss IoU | `{metadata['near_iou_threshold']:.3f}` |\n")
|
|
file.write(f"| Frames analyzed | `{summary['num_frames']}` |\n")
|
|
file.write(f"| Cases analyzed | `{summary['num_cases']}` |\n")
|
|
file.write(f"| Classes | `{', '.join(summary['classes'])}` |\n")
|
|
file.write(f"| Num workers | `{metadata['num_workers']}` |\n\n")
|
|
|
|
file.write("## Overall Counts\n\n")
|
|
file.write("| Metric | Value |\n")
|
|
file.write("| --- | ---: |\n")
|
|
file.write(f"| GT total | {summary['gt_total']} |\n")
|
|
file.write(f"| Det total @ threshold | {summary['det_total_above_threshold']} |\n")
|
|
file.write(f"| TP total | {summary['tp_total']} |\n")
|
|
file.write(f"| FP total | {summary['fp_total']} |\n")
|
|
file.write(f"| FN total | {summary['fn_total']} |\n\n")
|
|
|
|
file.write("## FP By Type\n\n")
|
|
file.write("| Error Type | Count |\n")
|
|
file.write("| --- | ---: |\n")
|
|
for error_type, count in summary["fp_by_type"].items():
|
|
file.write(f"| `{error_type}` | {count} |\n")
|
|
file.write("\n")
|
|
|
|
file.write("## FN By Type\n\n")
|
|
file.write("| Error Type | Count |\n")
|
|
file.write("| --- | ---: |\n")
|
|
for error_type, count in summary["fn_by_type"].items():
|
|
file.write(f"| `{error_type}` | {count} |\n")
|
|
file.write("\n")
|
|
|
|
file.write("## Per-Class Summary\n\n")
|
|
file.write("| Class | GT | Det | TP | FP | FN |\n")
|
|
file.write("| --- | ---: | ---: | ---: | ---: | ---: |\n")
|
|
for class_str, class_stats in sorted(report["summary"]["per_class"].items()):
|
|
file.write(
|
|
f"| `{class_str}` | {class_stats['gt_count']} | "
|
|
f"{class_stats['det_count_above_threshold']} | {class_stats['tp_count']} | "
|
|
f"{class_stats['fp_count']} | {class_stats['fn_count']} |\n"
|
|
)
|
|
file.write("\n")
|
|
|
|
file.write("## Top Frames\n\n")
|
|
file.write("| Case / Frame | Errors | FN | FP | TP |\n")
|
|
file.write("| --- | ---: | ---: | ---: | ---: |\n")
|
|
for item in report["top_frames"]:
|
|
file.write(
|
|
f"| `{item['case_name']}/{item['frame_name']}` | "
|
|
f"{item['total_errors']} | {item['fn_count']} | "
|
|
f"{item['fp_count']} | {item['tp_count']} |\n"
|
|
)
|
|
|
|
|
|
def ensure_output_dir(path_str):
|
|
"""Resolve and create output directory."""
|
|
if path_str:
|
|
output_dir = Path(path_str)
|
|
else:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
output_dir = REPO_ROOT / "eval_tools" / "analysis" / "results" / timestamp
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
return output_dir
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
config = build_config(args)
|
|
class_ids = parse_class_ids(args.classes)
|
|
|
|
analyzer = Analyze2DFPFN(
|
|
config=config,
|
|
class_ids=class_ids,
|
|
near_iou_threshold=args.near_iou_threshold,
|
|
num_workers=args.num_workers,
|
|
max_frames=args.max_frames,
|
|
max_fp_details=args.max_fp_details,
|
|
max_fn_details=args.max_fn_details,
|
|
top_k_frames=args.top_k_frames,
|
|
)
|
|
report = analyzer.analyze()
|
|
|
|
output_dir = ensure_output_dir(args.output_dir)
|
|
json_path = output_dir / "analysis_report.json"
|
|
md_path = output_dir / "analysis_report.md"
|
|
|
|
with open(json_path, "w") as file:
|
|
json.dump(report, file, indent=2)
|
|
write_markdown_report(report, md_path)
|
|
|
|
print(f"\nJSON report saved to: {json_path}")
|
|
print(f"Markdown report saved to: {md_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|