单目3D初始代码
This commit is contained in:
155
eval_tools/evaluator/matcher.py
Executable file
155
eval_tools/evaluator/matcher.py
Executable file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
2D matching module for ground truth and detection boxes.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Matcher2D:
|
||||
"""Match 2D bounding boxes between ground truth and detections."""
|
||||
|
||||
def __init__(self, iou_threshold=0.5):
|
||||
"""
|
||||
Initialize matcher.
|
||||
|
||||
Args:
|
||||
iou_threshold: float, IoU threshold for matching
|
||||
"""
|
||||
self.iou_threshold = iou_threshold
|
||||
|
||||
def compute_iou(self, box1, box2):
|
||||
"""
|
||||
Compute IoU between two boxes.
|
||||
|
||||
Args:
|
||||
box1: [x1, y1, x2, y2]
|
||||
box2: [x1, y1, x2, y2]
|
||||
|
||||
Returns:
|
||||
float, IoU value
|
||||
"""
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
|
||||
if x2 < x1 or y2 < y1:
|
||||
return 0.0
|
||||
|
||||
intersection = (x2 - x1) * (y2 - y1)
|
||||
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
@staticmethod
|
||||
def _normalize_roi_id(roi_id):
|
||||
"""Normalize ROI identifiers like 'roi0'/'0' to plain numeric strings."""
|
||||
if roi_id is None:
|
||||
return None
|
||||
|
||||
roi_id_str = str(roi_id).strip().lower()
|
||||
if roi_id_str.startswith('roi'):
|
||||
roi_id_str = roi_id_str[3:]
|
||||
return roi_id_str or None
|
||||
|
||||
def compute_pair_iou(self, gt, det):
|
||||
"""Compute IoU for one GT/detection pair, honoring ROI-specific GT boxes when present."""
|
||||
gt_boxes_by_roi = gt.get('bbox_2d_by_roi')
|
||||
det_box = det.get('bbox_2d')
|
||||
if det_box is None:
|
||||
return 0.0
|
||||
|
||||
if not gt_boxes_by_roi:
|
||||
gt_box = gt.get('bbox_2d')
|
||||
return self.compute_iou(gt_box, det_box) if gt_box is not None else 0.0
|
||||
|
||||
det_roi_id = self._normalize_roi_id(det.get('roi_id'))
|
||||
if det_roi_id is not None:
|
||||
gt_box = gt_boxes_by_roi.get(det_roi_id)
|
||||
return self.compute_iou(gt_box, det_box) if gt_box is not None else 0.0
|
||||
|
||||
return max((self.compute_iou(gt_box, det_box) for gt_box in gt_boxes_by_roi.values()), default=0.0)
|
||||
|
||||
def compute_iou_matrix(self, gts, dets):
|
||||
"""
|
||||
Compute IoU matrix between all GT and detection pairs.
|
||||
|
||||
Args:
|
||||
gts: list of ground truth dicts
|
||||
dets: list of detection dicts
|
||||
|
||||
Returns:
|
||||
numpy array of shape (len(gts), len(dets))
|
||||
"""
|
||||
if len(gts) == 0 or len(dets) == 0:
|
||||
return np.zeros((len(gts), len(dets)))
|
||||
|
||||
iou_matrix = np.zeros((len(gts), len(dets)))
|
||||
|
||||
for i, gt in enumerate(gts):
|
||||
for j, det in enumerate(dets):
|
||||
iou_matrix[i, j] = self.compute_pair_iou(gt, det)
|
||||
|
||||
return iou_matrix
|
||||
|
||||
def match(self, gts, dets, class_id):
|
||||
"""
|
||||
Match detections to ground truths using greedy algorithm.
|
||||
|
||||
Args:
|
||||
gts: list of ground truth dicts for a specific class
|
||||
dets: list of detection dicts for a specific class
|
||||
class_id: int, class ID to match
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- matches: list of (gt_idx, det_idx, iou) tuples
|
||||
- unmatched_gts: list of unmatched GT indices
|
||||
- unmatched_dets: list of unmatched detection indices
|
||||
"""
|
||||
# Filter by class
|
||||
gts_filtered = [gt for gt in gts if gt['label'] == class_id]
|
||||
dets_filtered = [det for det in dets if det['label'] == class_id]
|
||||
|
||||
# Sort detections by confidence (highest first)
|
||||
det_indices = np.argsort([-det['confidence'] for det in dets_filtered])
|
||||
dets_sorted = [dets_filtered[i] for i in det_indices]
|
||||
|
||||
# Compute IoU matrix
|
||||
iou_matrix = self.compute_iou_matrix(gts_filtered, dets_sorted)
|
||||
|
||||
matches = []
|
||||
matched_gt_indices = set()
|
||||
matched_det_indices = set()
|
||||
|
||||
# Greedy matching: for each detection (sorted by confidence), find best GT
|
||||
for det_idx in range(len(dets_sorted)):
|
||||
best_iou = 0.0
|
||||
best_gt_idx = -1
|
||||
|
||||
for gt_idx in range(len(gts_filtered)):
|
||||
if gt_idx in matched_gt_indices:
|
||||
continue
|
||||
|
||||
iou = iou_matrix[gt_idx, det_idx]
|
||||
if iou >= self.iou_threshold and iou > best_iou:
|
||||
best_iou = iou
|
||||
best_gt_idx = gt_idx
|
||||
|
||||
if best_gt_idx >= 0:
|
||||
matches.append((best_gt_idx, det_idx, best_iou))
|
||||
matched_gt_indices.add(best_gt_idx)
|
||||
matched_det_indices.add(det_idx)
|
||||
|
||||
# Find unmatched GTs and detections
|
||||
unmatched_gts = [i for i in range(len(gts_filtered)) if i not in matched_gt_indices]
|
||||
unmatched_dets = [i for i in range(len(dets_sorted)) if i not in matched_det_indices]
|
||||
|
||||
return {
|
||||
'matches': matches,
|
||||
'unmatched_gts': unmatched_gts,
|
||||
'unmatched_dets': unmatched_dets,
|
||||
'gts_filtered': gts_filtered,
|
||||
'dets_sorted': dets_sorted
|
||||
}
|
||||
Reference in New Issue
Block a user