Files
yolov26_3d/eval_tools/model_comparison/find_common_matches.py
2026-06-24 09:35:46 +08:00

440 lines
17 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Find Common Matches Between Two Models
This tool finds the common GT objects that were successfully matched by both models,
enabling fair comparison of 3D prediction quality on the same set of targets.
Usage:
python eval_tools/find_common_matches.py \
--model1-matches eval_results/model1/detailed_3d_matches.json \
--model2-matches eval_results/model2/detailed_3d_matches.json \
--output common_matches.json
"""
import argparse
import json
import sys
from pathlib import Path
from collections import defaultdict
import numpy as np
def find_common_matches(model1_matches, model2_matches):
"""
Find common GT objects matched by both models.
Args:
model1_matches: dict, detailed matches from model 1
model2_matches: dict, detailed matches from model 2
Returns:
tuple: (common_matches, stats)
common_matches: dict with structure {case: {frame: {class: [match_info]}}}
stats: dict with match statistics
"""
common_matches = {}
stats = {
'model1_total': 0,
'model2_total': 0,
'common': 0,
'model1_unique': 0,
'model2_unique': 0,
'per_class': {}
}
# Iterate through all cases in model 1
for case_name in model1_matches:
if case_name not in model2_matches:
# Case not in model 2, skip
continue
common_matches[case_name] = {}
# Iterate through frames
for frame_name in model1_matches[case_name]:
if frame_name not in model2_matches[case_name]:
# Frame not in model 2, skip
continue
common_matches[case_name][frame_name] = {}
# Iterate through classes
for class_name in model1_matches[case_name][frame_name]:
if class_name not in model2_matches[case_name][frame_name]:
# Class not in model 2, skip
continue
# Get match lists for this class
m1_list = model1_matches[case_name][frame_name][class_name]
m2_list = model2_matches[case_name][frame_name][class_name]
# Build GT ID to index mappings
m1_gt_ids = {m['gt_id']: i for i, m in enumerate(m1_list)}
m2_gt_ids = {m['gt_id']: i for i, m in enumerate(m2_list)}
# Find common GT IDs
common_gt_ids = set(m1_gt_ids.keys()) & set(m2_gt_ids.keys())
# Update statistics
if class_name not in stats['per_class']:
stats['per_class'][class_name] = {
'model1_total': 0,
'model2_total': 0,
'common': 0,
'model1_unique': 0,
'model2_unique': 0
}
stats['model1_total'] += len(m1_list)
stats['model2_total'] += len(m2_list)
stats['common'] += len(common_gt_ids)
stats['model1_unique'] += len(m1_gt_ids) - len(common_gt_ids)
stats['model2_unique'] += len(m2_gt_ids) - len(common_gt_ids)
stats['per_class'][class_name]['model1_total'] += len(m1_list)
stats['per_class'][class_name]['model2_total'] += len(m2_list)
stats['per_class'][class_name]['common'] += len(common_gt_ids)
stats['per_class'][class_name]['model1_unique'] += len(m1_gt_ids) - len(common_gt_ids)
stats['per_class'][class_name]['model2_unique'] += len(m2_gt_ids) - len(common_gt_ids)
# Store common match information
common_list = []
for gt_id in common_gt_ids:
common_list.append({
'gt_id': gt_id,
'model1_idx': m1_gt_ids[gt_id],
'model2_idx': m2_gt_ids[gt_id]
})
if common_list:
common_matches[case_name][frame_name][class_name] = common_list
# Calculate percentages
if stats['model1_total'] > 0:
stats['common_percentage_of_model1'] = (stats['common'] / stats['model1_total']) * 100
else:
stats['common_percentage_of_model1'] = 0
if stats['model2_total'] > 0:
stats['common_percentage_of_model2'] = (stats['common'] / stats['model2_total']) * 100
else:
stats['common_percentage_of_model2'] = 0
for class_name in stats['per_class']:
class_stats = stats['per_class'][class_name]
if class_stats['model1_total'] > 0:
class_stats['common_percentage_of_model1'] = (class_stats['common'] / class_stats['model1_total']) * 100
else:
class_stats['common_percentage_of_model1'] = 0
if class_stats['model2_total'] > 0:
class_stats['common_percentage_of_model2'] = (class_stats['common'] / class_stats['model2_total']) * 100
else:
class_stats['common_percentage_of_model2'] = 0
return common_matches, stats
# Default distance ranges matching eval metrics_3d config
DEFAULT_LONG_RANGES = [
(0, 10), (10, 20), (20, 30), (30, 40), (40, 50),
(50, 60), (60, 70), (70, 80), (80, 90), (90, 100), (100, 999)
]
DEFAULT_LAT_RANGES = [
(-50, -40), (-40, -30), (-30, -20), (-20, -10), (-10, 0),
(0, 10), (10, 20), (20, 30), (30, 40), (40, 50)
]
def _range_key_long(lo, hi):
return f'long_{lo}-{hi}m'
def _range_key_lat(lo, hi):
return f'lat_{lo}-{hi}m'
def _make_stats(data_dict):
"""Compute mean/std/median/min/max for each list in data_dict."""
result = {}
for key, values in data_dict.items():
if key in ('samples',):
result[key] = values
elif isinstance(values, list) and len(values) > 0:
arr = np.array(values)
result[key] = {
'mean': float(np.mean(arr)),
'std': float(np.std(arr)),
'median': float(np.median(arr)),
'min': float(np.min(arr)),
'max': float(np.max(arr)),
}
return result
def _empty_bucket():
return {
'lateral': [], 'longitudinal': [], 'longitudinal_relative': [],
'heading': [], 'heading_relaxed': [], 'is_reversal': [], 'samples': 0
}
def _finalize_class_stats(data):
"""Convert a bucket dict to stats dict, adding optional fields."""
entry = {
'num_samples': data['samples'],
'lateral_error': _make_stats({'lateral': data['lateral']})['lateral'],
'longitudinal_error': _make_stats({'longitudinal': data['longitudinal']})['longitudinal'],
'heading_error': _make_stats({'heading': data['heading']})['heading'],
}
if data['longitudinal_relative']:
entry['longitudinal_relative_error'] = _make_stats(
{'v': data['longitudinal_relative']})['v']
if data['heading_relaxed']:
entry['heading_error_relaxed'] = _make_stats(
{'v': data['heading_relaxed']})['v']
if data['is_reversal']:
count = int(sum(data['is_reversal']))
entry['reversal_count'] = count
entry['reversal_percentage'] = float(count / data['samples'] * 100) if data['samples'] > 0 else 0.0
return entry
def recompute_3d_stats_from_common_matches(matches_data, common_matches, model_name,
long_ranges=None, lat_ranges=None):
"""
Recompute 3D statistics based on common matches only.
Returns a dict with structure:
{
class_name: {
'overall': { ... },
'long_0-10m': { ... },
...
'lat_-10-0m': { ... },
...
}
}
"""
if long_ranges is None:
long_ranges = DEFAULT_LONG_RANGES
if lat_ranges is None:
lat_ranges = DEFAULT_LAT_RANGES
# Bucket structure: class -> range_key -> _empty_bucket()
overall = {} # class -> _empty_bucket()
by_long = {} # class -> range_key -> _empty_bucket()
by_lat = {} # class -> range_key -> _empty_bucket()
for case_name, frames in common_matches.items():
for frame_name, classes in frames.items():
for class_name, common_list in classes.items():
if class_name not in overall:
overall[class_name] = _empty_bucket()
by_long[class_name] = {
_range_key_long(lo, hi): _empty_bucket()
for lo, hi in long_ranges
}
by_lat[class_name] = {
_range_key_lat(lo, hi): _empty_bucket()
for lo, hi in lat_ranges
}
for match_info in common_list:
idx = match_info[f'{model_name}_idx']
match = matches_data[case_name][frame_name][class_name][idx]
errs = match['errors']
dist = match.get('distance', {})
z_val = dist.get('longitudinal', None)
x_val = dist.get('lateral', None)
# Helper: fill one bucket
def _fill(bucket):
bucket['lateral'].append(errs['lateral'])
bucket['longitudinal'].append(errs['longitudinal'])
bucket['heading'].append(errs['heading'])
if 'longitudinal_relative' in errs:
bucket['longitudinal_relative'].append(errs['longitudinal_relative'])
if 'heading_relaxed' in errs:
bucket['heading_relaxed'].append(errs['heading_relaxed'])
if 'is_reversal' in errs:
bucket['is_reversal'].append(errs['is_reversal'])
bucket['samples'] += 1
_fill(overall[class_name])
# Longitudinal range bucket
if z_val is not None:
for lo, hi in long_ranges:
if lo <= z_val < hi:
_fill(by_long[class_name][_range_key_long(lo, hi)])
break
# Lateral range bucket
if x_val is not None:
for lo, hi in lat_ranges:
if lo <= x_val < hi:
_fill(by_lat[class_name][_range_key_lat(lo, hi)])
break
# Build result
result = {}
for class_name in overall:
result[class_name] = {}
# overall
if overall[class_name]['samples'] > 0:
result[class_name]['overall'] = _finalize_class_stats(overall[class_name])
else:
result[class_name]['overall'] = {'num_samples': 0}
# per longitudinal range
for rk, bucket in by_long[class_name].items():
if bucket['samples'] > 0:
result[class_name][rk] = _finalize_class_stats(bucket)
# per lateral range
for rk, bucket in by_lat[class_name].items():
if bucket['samples'] > 0:
result[class_name][rk] = _finalize_class_stats(bucket)
return result
def print_statistics(stats, model1_name='model1', model2_name='model2'):
"""Print match statistics in a readable format."""
print("\n" + "="*80)
print("COMMON MATCH STATISTICS")
print("="*80)
print(f"\nOverall:")
print(f" {model1_name} Total Matches: {stats['model1_total']:,}")
print(f" {model2_name} Total Matches: {stats['model2_total']:,}")
print(f" Common Matches: {stats['common']:,} ({stats['common_percentage_of_model1']:.1f}% of {model1_name})")
print(f" {model1_name} Unique: {stats['model1_unique']:,} ({100 - stats['common_percentage_of_model1']:.1f}%)")
print(f" {model2_name} Unique: {stats['model2_unique']:,} ({100 - stats['common_percentage_of_model2']:.1f}%)")
print(f"\nPer-Class Statistics:")
# Truncate model names if too long for column headers
m1_short = model1_name[:10]
m2_short = model2_name[:10]
print(f"{'Class':<15} {m1_short:>10} {m2_short:>10} {'Common':>10} {'Common%':>10} {m1_short+' Uniq':>12} {m2_short+' Uniq':>12}")
print("-" * 80)
for class_name, class_stats in sorted(stats['per_class'].items()):
print(f"{class_name:<15} {class_stats['model1_total']:>10,} {class_stats['model2_total']:>10,} "
f"{class_stats['common']:>10,} {class_stats['common_percentage_of_model1']:>9.1f}% "
f"{class_stats['model1_unique']:>12,} {class_stats['model2_unique']:>12,}")
def main():
"""Main function."""
parser = argparse.ArgumentParser(
description='Find common matches between two model evaluation results',
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--model1-matches', type=str, required=True,
help='Path to model 1 detailed_3d_matches.json file')
parser.add_argument('--model2-matches', type=str, required=True,
help='Path to model 2 detailed_3d_matches.json file')
parser.add_argument('--output', type=str, default='common_matches.json',
help='Output path for common matches JSON file')
parser.add_argument('--model1-name', type=str, default='model1',
help='Name for model 1 (default: model1)')
parser.add_argument('--model2-name', type=str, default='model2',
help='Name for model 2 (default: model2)')
args = parser.parse_args()
# Load detailed matches
print(f"Loading model 1 matches from: {args.model1_matches}")
with open(args.model1_matches, 'r') as f:
model1_matches = json.load(f)
print(f"Loading model 2 matches from: {args.model2_matches}")
with open(args.model2_matches, 'r') as f:
model2_matches = json.load(f)
# Find common matches
print("\nFinding common matches...")
common_matches, stats = find_common_matches(model1_matches, model2_matches)
# Print statistics
print_statistics(stats, args.model1_name, args.model2_name)
# Recompute 3D stats for common matches
print("\nRecomputing 3D statistics for common matches...")
model1_stats = recompute_3d_stats_from_common_matches(model1_matches, common_matches, 'model1')
model2_stats = recompute_3d_stats_from_common_matches(model2_matches, common_matches, 'model2')
# Prepare output
output_data = {
'match_statistics': stats,
'common_matches': common_matches,
'model1_3d_stats': model1_stats,
'model2_3d_stats': model2_stats,
'model_names': {
'model1': args.model1_name,
'model2': args.model2_name
}
}
# Save output
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
json.dump(output_data, f, indent=2)
print(f"\n✓ Common matches saved to: {output_path}")
# Print 3D stats comparison
print("\n" + "="*80)
print("3D STATISTICS COMPARISON (COMMON MATCHES ONLY)")
print("="*80)
for class_name in sorted(model1_stats.keys()):
if class_name not in model2_stats:
continue
# New format: class stats are nested under 'overall'
m1 = model1_stats[class_name].get('overall', model1_stats[class_name])
m2 = model2_stats[class_name].get('overall', model2_stats[class_name])
print(f"\n{class_name.upper()} (n={m1.get('num_samples', 0):,}):")
print(f"{'Metric':<20} {args.model1_name:>15} {args.model2_name:>15} {'Diff':>12} {'Change %':>10}")
print("-" * 80)
for error_type in ['lateral_error', 'longitudinal_error', 'heading_error']:
if error_type not in m1 or error_type not in m2:
continue
m1_mean = m1[error_type]['mean']
m2_mean = m2[error_type]['mean']
diff = m2_mean - m1_mean
change_pct = (diff / m1_mean * 100) if m1_mean > 0 else 0
error_name = error_type.replace('_', ' ').title()
print(f"{error_name:<20} {m1_mean:>15.4f} {m2_mean:>15.4f} {diff:>+12.4f} {change_pct:>+9.2f}%")
# Print relaxed heading error if available
if 'heading_error_relaxed' in m1 and 'heading_error_relaxed' in m2:
m1_mean = m1['heading_error_relaxed']['mean']
m2_mean = m2['heading_error_relaxed']['mean']
diff = m2_mean - m1_mean
change_pct = (diff / m1_mean * 100) if m1_mean > 0 else 0
print(f"{'Heading Error (Rlx)':<20} {m1_mean:>15.4f} {m2_mean:>15.4f} {diff:>+12.4f} {change_pct:>+9.2f}%")
# Print reversal statistics if available
if 'reversal_count' in m1 and 'reversal_count' in m2:
m1_count = m1['reversal_count']
m1_pct = m1.get('reversal_percentage', 0)
m2_count = m2['reversal_count']
m2_pct = m2.get('reversal_percentage', 0)
print(f"{'Reversals':<20} {m1_count:>11} ({m1_pct:>5.1f}%) {m2_count:>11} ({m2_pct:>5.1f}%)")
if __name__ == '__main__':
main()