单目3D初始代码

This commit is contained in:
zhao.zhu
2026-06-24 09:35:46 +08:00
commit 04a5895b6b
1153 changed files with 340700 additions and 0 deletions

155
eval_tools/evaluator/matcher.py Executable file
View File

@@ -0,0 +1,155 @@
"""
2D matching module for ground truth and detection boxes.
"""
import numpy as np
class Matcher2D:
"""Match 2D bounding boxes between ground truth and detections."""
def __init__(self, iou_threshold=0.5):
"""
Initialize matcher.
Args:
iou_threshold: float, IoU threshold for matching
"""
self.iou_threshold = iou_threshold
def compute_iou(self, box1, box2):
"""
Compute IoU between two boxes.
Args:
box1: [x1, y1, x2, y2]
box2: [x1, y1, x2, y2]
Returns:
float, IoU value
"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
if x2 < x1 or y2 < y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
@staticmethod
def _normalize_roi_id(roi_id):
"""Normalize ROI identifiers like 'roi0'/'0' to plain numeric strings."""
if roi_id is None:
return None
roi_id_str = str(roi_id).strip().lower()
if roi_id_str.startswith('roi'):
roi_id_str = roi_id_str[3:]
return roi_id_str or None
def compute_pair_iou(self, gt, det):
"""Compute IoU for one GT/detection pair, honoring ROI-specific GT boxes when present."""
gt_boxes_by_roi = gt.get('bbox_2d_by_roi')
det_box = det.get('bbox_2d')
if det_box is None:
return 0.0
if not gt_boxes_by_roi:
gt_box = gt.get('bbox_2d')
return self.compute_iou(gt_box, det_box) if gt_box is not None else 0.0
det_roi_id = self._normalize_roi_id(det.get('roi_id'))
if det_roi_id is not None:
gt_box = gt_boxes_by_roi.get(det_roi_id)
return self.compute_iou(gt_box, det_box) if gt_box is not None else 0.0
return max((self.compute_iou(gt_box, det_box) for gt_box in gt_boxes_by_roi.values()), default=0.0)
def compute_iou_matrix(self, gts, dets):
"""
Compute IoU matrix between all GT and detection pairs.
Args:
gts: list of ground truth dicts
dets: list of detection dicts
Returns:
numpy array of shape (len(gts), len(dets))
"""
if len(gts) == 0 or len(dets) == 0:
return np.zeros((len(gts), len(dets)))
iou_matrix = np.zeros((len(gts), len(dets)))
for i, gt in enumerate(gts):
for j, det in enumerate(dets):
iou_matrix[i, j] = self.compute_pair_iou(gt, det)
return iou_matrix
def match(self, gts, dets, class_id):
"""
Match detections to ground truths using greedy algorithm.
Args:
gts: list of ground truth dicts for a specific class
dets: list of detection dicts for a specific class
class_id: int, class ID to match
Returns:
dict with keys:
- matches: list of (gt_idx, det_idx, iou) tuples
- unmatched_gts: list of unmatched GT indices
- unmatched_dets: list of unmatched detection indices
"""
# Filter by class
gts_filtered = [gt for gt in gts if gt['label'] == class_id]
dets_filtered = [det for det in dets if det['label'] == class_id]
# Sort detections by confidence (highest first)
det_indices = np.argsort([-det['confidence'] for det in dets_filtered])
dets_sorted = [dets_filtered[i] for i in det_indices]
# Compute IoU matrix
iou_matrix = self.compute_iou_matrix(gts_filtered, dets_sorted)
matches = []
matched_gt_indices = set()
matched_det_indices = set()
# Greedy matching: for each detection (sorted by confidence), find best GT
for det_idx in range(len(dets_sorted)):
best_iou = 0.0
best_gt_idx = -1
for gt_idx in range(len(gts_filtered)):
if gt_idx in matched_gt_indices:
continue
iou = iou_matrix[gt_idx, det_idx]
if iou >= self.iou_threshold and iou > best_iou:
best_iou = iou
best_gt_idx = gt_idx
if best_gt_idx >= 0:
matches.append((best_gt_idx, det_idx, best_iou))
matched_gt_indices.add(best_gt_idx)
matched_det_indices.add(det_idx)
# Find unmatched GTs and detections
unmatched_gts = [i for i in range(len(gts_filtered)) if i not in matched_gt_indices]
unmatched_dets = [i for i in range(len(dets_sorted)) if i not in matched_det_indices]
return {
'matches': matches,
'unmatched_gts': unmatched_gts,
'unmatched_dets': unmatched_dets,
'gts_filtered': gts_filtered,
'dets_sorted': dets_sorted
}