156 lines
5.2 KiB
Python
156 lines
5.2 KiB
Python
|
|
"""
|
||
|
|
2D matching module for ground truth and detection boxes.
|
||
|
|
"""
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
|
||
|
|
class Matcher2D:
|
||
|
|
"""Match 2D bounding boxes between ground truth and detections."""
|
||
|
|
|
||
|
|
def __init__(self, iou_threshold=0.5):
|
||
|
|
"""
|
||
|
|
Initialize matcher.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
iou_threshold: float, IoU threshold for matching
|
||
|
|
"""
|
||
|
|
self.iou_threshold = iou_threshold
|
||
|
|
|
||
|
|
def compute_iou(self, box1, box2):
|
||
|
|
"""
|
||
|
|
Compute IoU between two boxes.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
box1: [x1, y1, x2, y2]
|
||
|
|
box2: [x1, y1, x2, y2]
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
float, IoU value
|
||
|
|
"""
|
||
|
|
x1 = max(box1[0], box2[0])
|
||
|
|
y1 = max(box1[1], box2[1])
|
||
|
|
x2 = min(box1[2], box2[2])
|
||
|
|
y2 = min(box1[3], box2[3])
|
||
|
|
|
||
|
|
if x2 < x1 or y2 < y1:
|
||
|
|
return 0.0
|
||
|
|
|
||
|
|
intersection = (x2 - x1) * (y2 - y1)
|
||
|
|
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||
|
|
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||
|
|
union = area1 + area2 - intersection
|
||
|
|
|
||
|
|
return intersection / union if union > 0 else 0.0
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _normalize_roi_id(roi_id):
|
||
|
|
"""Normalize ROI identifiers like 'roi0'/'0' to plain numeric strings."""
|
||
|
|
if roi_id is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
roi_id_str = str(roi_id).strip().lower()
|
||
|
|
if roi_id_str.startswith('roi'):
|
||
|
|
roi_id_str = roi_id_str[3:]
|
||
|
|
return roi_id_str or None
|
||
|
|
|
||
|
|
def compute_pair_iou(self, gt, det):
|
||
|
|
"""Compute IoU for one GT/detection pair, honoring ROI-specific GT boxes when present."""
|
||
|
|
gt_boxes_by_roi = gt.get('bbox_2d_by_roi')
|
||
|
|
det_box = det.get('bbox_2d')
|
||
|
|
if det_box is None:
|
||
|
|
return 0.0
|
||
|
|
|
||
|
|
if not gt_boxes_by_roi:
|
||
|
|
gt_box = gt.get('bbox_2d')
|
||
|
|
return self.compute_iou(gt_box, det_box) if gt_box is not None else 0.0
|
||
|
|
|
||
|
|
det_roi_id = self._normalize_roi_id(det.get('roi_id'))
|
||
|
|
if det_roi_id is not None:
|
||
|
|
gt_box = gt_boxes_by_roi.get(det_roi_id)
|
||
|
|
return self.compute_iou(gt_box, det_box) if gt_box is not None else 0.0
|
||
|
|
|
||
|
|
return max((self.compute_iou(gt_box, det_box) for gt_box in gt_boxes_by_roi.values()), default=0.0)
|
||
|
|
|
||
|
|
def compute_iou_matrix(self, gts, dets):
|
||
|
|
"""
|
||
|
|
Compute IoU matrix between all GT and detection pairs.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
gts: list of ground truth dicts
|
||
|
|
dets: list of detection dicts
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
numpy array of shape (len(gts), len(dets))
|
||
|
|
"""
|
||
|
|
if len(gts) == 0 or len(dets) == 0:
|
||
|
|
return np.zeros((len(gts), len(dets)))
|
||
|
|
|
||
|
|
iou_matrix = np.zeros((len(gts), len(dets)))
|
||
|
|
|
||
|
|
for i, gt in enumerate(gts):
|
||
|
|
for j, det in enumerate(dets):
|
||
|
|
iou_matrix[i, j] = self.compute_pair_iou(gt, det)
|
||
|
|
|
||
|
|
return iou_matrix
|
||
|
|
|
||
|
|
def match(self, gts, dets, class_id):
|
||
|
|
"""
|
||
|
|
Match detections to ground truths using greedy algorithm.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
gts: list of ground truth dicts for a specific class
|
||
|
|
dets: list of detection dicts for a specific class
|
||
|
|
class_id: int, class ID to match
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict with keys:
|
||
|
|
- matches: list of (gt_idx, det_idx, iou) tuples
|
||
|
|
- unmatched_gts: list of unmatched GT indices
|
||
|
|
- unmatched_dets: list of unmatched detection indices
|
||
|
|
"""
|
||
|
|
# Filter by class
|
||
|
|
gts_filtered = [gt for gt in gts if gt['label'] == class_id]
|
||
|
|
dets_filtered = [det for det in dets if det['label'] == class_id]
|
||
|
|
|
||
|
|
# Sort detections by confidence (highest first)
|
||
|
|
det_indices = np.argsort([-det['confidence'] for det in dets_filtered])
|
||
|
|
dets_sorted = [dets_filtered[i] for i in det_indices]
|
||
|
|
|
||
|
|
# Compute IoU matrix
|
||
|
|
iou_matrix = self.compute_iou_matrix(gts_filtered, dets_sorted)
|
||
|
|
|
||
|
|
matches = []
|
||
|
|
matched_gt_indices = set()
|
||
|
|
matched_det_indices = set()
|
||
|
|
|
||
|
|
# Greedy matching: for each detection (sorted by confidence), find best GT
|
||
|
|
for det_idx in range(len(dets_sorted)):
|
||
|
|
best_iou = 0.0
|
||
|
|
best_gt_idx = -1
|
||
|
|
|
||
|
|
for gt_idx in range(len(gts_filtered)):
|
||
|
|
if gt_idx in matched_gt_indices:
|
||
|
|
continue
|
||
|
|
|
||
|
|
iou = iou_matrix[gt_idx, det_idx]
|
||
|
|
if iou >= self.iou_threshold and iou > best_iou:
|
||
|
|
best_iou = iou
|
||
|
|
best_gt_idx = gt_idx
|
||
|
|
|
||
|
|
if best_gt_idx >= 0:
|
||
|
|
matches.append((best_gt_idx, det_idx, best_iou))
|
||
|
|
matched_gt_indices.add(best_gt_idx)
|
||
|
|
matched_det_indices.add(det_idx)
|
||
|
|
|
||
|
|
# Find unmatched GTs and detections
|
||
|
|
unmatched_gts = [i for i in range(len(gts_filtered)) if i not in matched_gt_indices]
|
||
|
|
unmatched_dets = [i for i in range(len(dets_sorted)) if i not in matched_det_indices]
|
||
|
|
|
||
|
|
return {
|
||
|
|
'matches': matches,
|
||
|
|
'unmatched_gts': unmatched_gts,
|
||
|
|
'unmatched_dets': unmatched_dets,
|
||
|
|
'gts_filtered': gts_filtered,
|
||
|
|
'dets_sorted': dets_sorted
|
||
|
|
}
|