440 lines
17 KiB
Python
440 lines
17 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Find Common Matches Between Two Models
|
||
|
|
|
||
|
|
This tool finds the common GT objects that were successfully matched by both models,
|
||
|
|
enabling fair comparison of 3D prediction quality on the same set of targets.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python eval_tools/find_common_matches.py \
|
||
|
|
--model1-matches eval_results/model1/detailed_3d_matches.json \
|
||
|
|
--model2-matches eval_results/model2/detailed_3d_matches.json \
|
||
|
|
--output common_matches.json
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
from collections import defaultdict
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
|
||
|
|
def find_common_matches(model1_matches, model2_matches):
|
||
|
|
"""
|
||
|
|
Find common GT objects matched by both models.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model1_matches: dict, detailed matches from model 1
|
||
|
|
model2_matches: dict, detailed matches from model 2
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
tuple: (common_matches, stats)
|
||
|
|
common_matches: dict with structure {case: {frame: {class: [match_info]}}}
|
||
|
|
stats: dict with match statistics
|
||
|
|
"""
|
||
|
|
common_matches = {}
|
||
|
|
stats = {
|
||
|
|
'model1_total': 0,
|
||
|
|
'model2_total': 0,
|
||
|
|
'common': 0,
|
||
|
|
'model1_unique': 0,
|
||
|
|
'model2_unique': 0,
|
||
|
|
'per_class': {}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Iterate through all cases in model 1
|
||
|
|
for case_name in model1_matches:
|
||
|
|
if case_name not in model2_matches:
|
||
|
|
# Case not in model 2, skip
|
||
|
|
continue
|
||
|
|
|
||
|
|
common_matches[case_name] = {}
|
||
|
|
|
||
|
|
# Iterate through frames
|
||
|
|
for frame_name in model1_matches[case_name]:
|
||
|
|
if frame_name not in model2_matches[case_name]:
|
||
|
|
# Frame not in model 2, skip
|
||
|
|
continue
|
||
|
|
|
||
|
|
common_matches[case_name][frame_name] = {}
|
||
|
|
|
||
|
|
# Iterate through classes
|
||
|
|
for class_name in model1_matches[case_name][frame_name]:
|
||
|
|
if class_name not in model2_matches[case_name][frame_name]:
|
||
|
|
# Class not in model 2, skip
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Get match lists for this class
|
||
|
|
m1_list = model1_matches[case_name][frame_name][class_name]
|
||
|
|
m2_list = model2_matches[case_name][frame_name][class_name]
|
||
|
|
|
||
|
|
# Build GT ID to index mappings
|
||
|
|
m1_gt_ids = {m['gt_id']: i for i, m in enumerate(m1_list)}
|
||
|
|
m2_gt_ids = {m['gt_id']: i for i, m in enumerate(m2_list)}
|
||
|
|
|
||
|
|
# Find common GT IDs
|
||
|
|
common_gt_ids = set(m1_gt_ids.keys()) & set(m2_gt_ids.keys())
|
||
|
|
|
||
|
|
# Update statistics
|
||
|
|
if class_name not in stats['per_class']:
|
||
|
|
stats['per_class'][class_name] = {
|
||
|
|
'model1_total': 0,
|
||
|
|
'model2_total': 0,
|
||
|
|
'common': 0,
|
||
|
|
'model1_unique': 0,
|
||
|
|
'model2_unique': 0
|
||
|
|
}
|
||
|
|
|
||
|
|
stats['model1_total'] += len(m1_list)
|
||
|
|
stats['model2_total'] += len(m2_list)
|
||
|
|
stats['common'] += len(common_gt_ids)
|
||
|
|
stats['model1_unique'] += len(m1_gt_ids) - len(common_gt_ids)
|
||
|
|
stats['model2_unique'] += len(m2_gt_ids) - len(common_gt_ids)
|
||
|
|
|
||
|
|
stats['per_class'][class_name]['model1_total'] += len(m1_list)
|
||
|
|
stats['per_class'][class_name]['model2_total'] += len(m2_list)
|
||
|
|
stats['per_class'][class_name]['common'] += len(common_gt_ids)
|
||
|
|
stats['per_class'][class_name]['model1_unique'] += len(m1_gt_ids) - len(common_gt_ids)
|
||
|
|
stats['per_class'][class_name]['model2_unique'] += len(m2_gt_ids) - len(common_gt_ids)
|
||
|
|
|
||
|
|
# Store common match information
|
||
|
|
common_list = []
|
||
|
|
for gt_id in common_gt_ids:
|
||
|
|
common_list.append({
|
||
|
|
'gt_id': gt_id,
|
||
|
|
'model1_idx': m1_gt_ids[gt_id],
|
||
|
|
'model2_idx': m2_gt_ids[gt_id]
|
||
|
|
})
|
||
|
|
|
||
|
|
if common_list:
|
||
|
|
common_matches[case_name][frame_name][class_name] = common_list
|
||
|
|
|
||
|
|
# Calculate percentages
|
||
|
|
if stats['model1_total'] > 0:
|
||
|
|
stats['common_percentage_of_model1'] = (stats['common'] / stats['model1_total']) * 100
|
||
|
|
else:
|
||
|
|
stats['common_percentage_of_model1'] = 0
|
||
|
|
|
||
|
|
if stats['model2_total'] > 0:
|
||
|
|
stats['common_percentage_of_model2'] = (stats['common'] / stats['model2_total']) * 100
|
||
|
|
else:
|
||
|
|
stats['common_percentage_of_model2'] = 0
|
||
|
|
|
||
|
|
for class_name in stats['per_class']:
|
||
|
|
class_stats = stats['per_class'][class_name]
|
||
|
|
if class_stats['model1_total'] > 0:
|
||
|
|
class_stats['common_percentage_of_model1'] = (class_stats['common'] / class_stats['model1_total']) * 100
|
||
|
|
else:
|
||
|
|
class_stats['common_percentage_of_model1'] = 0
|
||
|
|
|
||
|
|
if class_stats['model2_total'] > 0:
|
||
|
|
class_stats['common_percentage_of_model2'] = (class_stats['common'] / class_stats['model2_total']) * 100
|
||
|
|
else:
|
||
|
|
class_stats['common_percentage_of_model2'] = 0
|
||
|
|
|
||
|
|
return common_matches, stats
|
||
|
|
|
||
|
|
|
||
|
|
# Default distance ranges matching eval metrics_3d config
|
||
|
|
DEFAULT_LONG_RANGES = [
|
||
|
|
(0, 10), (10, 20), (20, 30), (30, 40), (40, 50),
|
||
|
|
(50, 60), (60, 70), (70, 80), (80, 90), (90, 100), (100, 999)
|
||
|
|
]
|
||
|
|
DEFAULT_LAT_RANGES = [
|
||
|
|
(-50, -40), (-40, -30), (-30, -20), (-20, -10), (-10, 0),
|
||
|
|
(0, 10), (10, 20), (20, 30), (30, 40), (40, 50)
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def _range_key_long(lo, hi):
|
||
|
|
return f'long_{lo}-{hi}m'
|
||
|
|
|
||
|
|
|
||
|
|
def _range_key_lat(lo, hi):
|
||
|
|
return f'lat_{lo}-{hi}m'
|
||
|
|
|
||
|
|
|
||
|
|
def _make_stats(data_dict):
|
||
|
|
"""Compute mean/std/median/min/max for each list in data_dict."""
|
||
|
|
result = {}
|
||
|
|
for key, values in data_dict.items():
|
||
|
|
if key in ('samples',):
|
||
|
|
result[key] = values
|
||
|
|
elif isinstance(values, list) and len(values) > 0:
|
||
|
|
arr = np.array(values)
|
||
|
|
result[key] = {
|
||
|
|
'mean': float(np.mean(arr)),
|
||
|
|
'std': float(np.std(arr)),
|
||
|
|
'median': float(np.median(arr)),
|
||
|
|
'min': float(np.min(arr)),
|
||
|
|
'max': float(np.max(arr)),
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
def _empty_bucket():
|
||
|
|
return {
|
||
|
|
'lateral': [], 'longitudinal': [], 'longitudinal_relative': [],
|
||
|
|
'heading': [], 'heading_relaxed': [], 'is_reversal': [], 'samples': 0
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _finalize_class_stats(data):
|
||
|
|
"""Convert a bucket dict to stats dict, adding optional fields."""
|
||
|
|
entry = {
|
||
|
|
'num_samples': data['samples'],
|
||
|
|
'lateral_error': _make_stats({'lateral': data['lateral']})['lateral'],
|
||
|
|
'longitudinal_error': _make_stats({'longitudinal': data['longitudinal']})['longitudinal'],
|
||
|
|
'heading_error': _make_stats({'heading': data['heading']})['heading'],
|
||
|
|
}
|
||
|
|
if data['longitudinal_relative']:
|
||
|
|
entry['longitudinal_relative_error'] = _make_stats(
|
||
|
|
{'v': data['longitudinal_relative']})['v']
|
||
|
|
if data['heading_relaxed']:
|
||
|
|
entry['heading_error_relaxed'] = _make_stats(
|
||
|
|
{'v': data['heading_relaxed']})['v']
|
||
|
|
if data['is_reversal']:
|
||
|
|
count = int(sum(data['is_reversal']))
|
||
|
|
entry['reversal_count'] = count
|
||
|
|
entry['reversal_percentage'] = float(count / data['samples'] * 100) if data['samples'] > 0 else 0.0
|
||
|
|
return entry
|
||
|
|
|
||
|
|
|
||
|
|
def recompute_3d_stats_from_common_matches(matches_data, common_matches, model_name,
|
||
|
|
long_ranges=None, lat_ranges=None):
|
||
|
|
"""
|
||
|
|
Recompute 3D statistics based on common matches only.
|
||
|
|
|
||
|
|
Returns a dict with structure:
|
||
|
|
{
|
||
|
|
class_name: {
|
||
|
|
'overall': { ... },
|
||
|
|
'long_0-10m': { ... },
|
||
|
|
...
|
||
|
|
'lat_-10-0m': { ... },
|
||
|
|
...
|
||
|
|
}
|
||
|
|
}
|
||
|
|
"""
|
||
|
|
if long_ranges is None:
|
||
|
|
long_ranges = DEFAULT_LONG_RANGES
|
||
|
|
if lat_ranges is None:
|
||
|
|
lat_ranges = DEFAULT_LAT_RANGES
|
||
|
|
|
||
|
|
# Bucket structure: class -> range_key -> _empty_bucket()
|
||
|
|
overall = {} # class -> _empty_bucket()
|
||
|
|
by_long = {} # class -> range_key -> _empty_bucket()
|
||
|
|
by_lat = {} # class -> range_key -> _empty_bucket()
|
||
|
|
|
||
|
|
for case_name, frames in common_matches.items():
|
||
|
|
for frame_name, classes in frames.items():
|
||
|
|
for class_name, common_list in classes.items():
|
||
|
|
if class_name not in overall:
|
||
|
|
overall[class_name] = _empty_bucket()
|
||
|
|
by_long[class_name] = {
|
||
|
|
_range_key_long(lo, hi): _empty_bucket()
|
||
|
|
for lo, hi in long_ranges
|
||
|
|
}
|
||
|
|
by_lat[class_name] = {
|
||
|
|
_range_key_lat(lo, hi): _empty_bucket()
|
||
|
|
for lo, hi in lat_ranges
|
||
|
|
}
|
||
|
|
|
||
|
|
for match_info in common_list:
|
||
|
|
idx = match_info[f'{model_name}_idx']
|
||
|
|
match = matches_data[case_name][frame_name][class_name][idx]
|
||
|
|
errs = match['errors']
|
||
|
|
dist = match.get('distance', {})
|
||
|
|
z_val = dist.get('longitudinal', None)
|
||
|
|
x_val = dist.get('lateral', None)
|
||
|
|
|
||
|
|
# Helper: fill one bucket
|
||
|
|
def _fill(bucket):
|
||
|
|
bucket['lateral'].append(errs['lateral'])
|
||
|
|
bucket['longitudinal'].append(errs['longitudinal'])
|
||
|
|
bucket['heading'].append(errs['heading'])
|
||
|
|
if 'longitudinal_relative' in errs:
|
||
|
|
bucket['longitudinal_relative'].append(errs['longitudinal_relative'])
|
||
|
|
if 'heading_relaxed' in errs:
|
||
|
|
bucket['heading_relaxed'].append(errs['heading_relaxed'])
|
||
|
|
if 'is_reversal' in errs:
|
||
|
|
bucket['is_reversal'].append(errs['is_reversal'])
|
||
|
|
bucket['samples'] += 1
|
||
|
|
|
||
|
|
_fill(overall[class_name])
|
||
|
|
|
||
|
|
# Longitudinal range bucket
|
||
|
|
if z_val is not None:
|
||
|
|
for lo, hi in long_ranges:
|
||
|
|
if lo <= z_val < hi:
|
||
|
|
_fill(by_long[class_name][_range_key_long(lo, hi)])
|
||
|
|
break
|
||
|
|
|
||
|
|
# Lateral range bucket
|
||
|
|
if x_val is not None:
|
||
|
|
for lo, hi in lat_ranges:
|
||
|
|
if lo <= x_val < hi:
|
||
|
|
_fill(by_lat[class_name][_range_key_lat(lo, hi)])
|
||
|
|
break
|
||
|
|
|
||
|
|
# Build result
|
||
|
|
result = {}
|
||
|
|
for class_name in overall:
|
||
|
|
result[class_name] = {}
|
||
|
|
|
||
|
|
# overall
|
||
|
|
if overall[class_name]['samples'] > 0:
|
||
|
|
result[class_name]['overall'] = _finalize_class_stats(overall[class_name])
|
||
|
|
else:
|
||
|
|
result[class_name]['overall'] = {'num_samples': 0}
|
||
|
|
|
||
|
|
# per longitudinal range
|
||
|
|
for rk, bucket in by_long[class_name].items():
|
||
|
|
if bucket['samples'] > 0:
|
||
|
|
result[class_name][rk] = _finalize_class_stats(bucket)
|
||
|
|
|
||
|
|
# per lateral range
|
||
|
|
for rk, bucket in by_lat[class_name].items():
|
||
|
|
if bucket['samples'] > 0:
|
||
|
|
result[class_name][rk] = _finalize_class_stats(bucket)
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
def print_statistics(stats, model1_name='model1', model2_name='model2'):
|
||
|
|
"""Print match statistics in a readable format."""
|
||
|
|
print("\n" + "="*80)
|
||
|
|
print("COMMON MATCH STATISTICS")
|
||
|
|
print("="*80)
|
||
|
|
|
||
|
|
print(f"\nOverall:")
|
||
|
|
print(f" {model1_name} Total Matches: {stats['model1_total']:,}")
|
||
|
|
print(f" {model2_name} Total Matches: {stats['model2_total']:,}")
|
||
|
|
print(f" Common Matches: {stats['common']:,} ({stats['common_percentage_of_model1']:.1f}% of {model1_name})")
|
||
|
|
print(f" {model1_name} Unique: {stats['model1_unique']:,} ({100 - stats['common_percentage_of_model1']:.1f}%)")
|
||
|
|
print(f" {model2_name} Unique: {stats['model2_unique']:,} ({100 - stats['common_percentage_of_model2']:.1f}%)")
|
||
|
|
|
||
|
|
print(f"\nPer-Class Statistics:")
|
||
|
|
# Truncate model names if too long for column headers
|
||
|
|
m1_short = model1_name[:10]
|
||
|
|
m2_short = model2_name[:10]
|
||
|
|
print(f"{'Class':<15} {m1_short:>10} {m2_short:>10} {'Common':>10} {'Common%':>10} {m1_short+' Uniq':>12} {m2_short+' Uniq':>12}")
|
||
|
|
print("-" * 80)
|
||
|
|
|
||
|
|
for class_name, class_stats in sorted(stats['per_class'].items()):
|
||
|
|
print(f"{class_name:<15} {class_stats['model1_total']:>10,} {class_stats['model2_total']:>10,} "
|
||
|
|
f"{class_stats['common']:>10,} {class_stats['common_percentage_of_model1']:>9.1f}% "
|
||
|
|
f"{class_stats['model1_unique']:>12,} {class_stats['model2_unique']:>12,}")
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Main function."""
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description='Find common matches between two model evaluation results',
|
||
|
|
formatter_class=argparse.RawDescriptionHelpFormatter
|
||
|
|
)
|
||
|
|
|
||
|
|
parser.add_argument('--model1-matches', type=str, required=True,
|
||
|
|
help='Path to model 1 detailed_3d_matches.json file')
|
||
|
|
parser.add_argument('--model2-matches', type=str, required=True,
|
||
|
|
help='Path to model 2 detailed_3d_matches.json file')
|
||
|
|
parser.add_argument('--output', type=str, default='common_matches.json',
|
||
|
|
help='Output path for common matches JSON file')
|
||
|
|
parser.add_argument('--model1-name', type=str, default='model1',
|
||
|
|
help='Name for model 1 (default: model1)')
|
||
|
|
parser.add_argument('--model2-name', type=str, default='model2',
|
||
|
|
help='Name for model 2 (default: model2)')
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Load detailed matches
|
||
|
|
print(f"Loading model 1 matches from: {args.model1_matches}")
|
||
|
|
with open(args.model1_matches, 'r') as f:
|
||
|
|
model1_matches = json.load(f)
|
||
|
|
|
||
|
|
print(f"Loading model 2 matches from: {args.model2_matches}")
|
||
|
|
with open(args.model2_matches, 'r') as f:
|
||
|
|
model2_matches = json.load(f)
|
||
|
|
|
||
|
|
# Find common matches
|
||
|
|
print("\nFinding common matches...")
|
||
|
|
common_matches, stats = find_common_matches(model1_matches, model2_matches)
|
||
|
|
|
||
|
|
# Print statistics
|
||
|
|
print_statistics(stats, args.model1_name, args.model2_name)
|
||
|
|
|
||
|
|
# Recompute 3D stats for common matches
|
||
|
|
print("\nRecomputing 3D statistics for common matches...")
|
||
|
|
model1_stats = recompute_3d_stats_from_common_matches(model1_matches, common_matches, 'model1')
|
||
|
|
model2_stats = recompute_3d_stats_from_common_matches(model2_matches, common_matches, 'model2')
|
||
|
|
|
||
|
|
# Prepare output
|
||
|
|
output_data = {
|
||
|
|
'match_statistics': stats,
|
||
|
|
'common_matches': common_matches,
|
||
|
|
'model1_3d_stats': model1_stats,
|
||
|
|
'model2_3d_stats': model2_stats,
|
||
|
|
'model_names': {
|
||
|
|
'model1': args.model1_name,
|
||
|
|
'model2': args.model2_name
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Save output
|
||
|
|
output_path = Path(args.output)
|
||
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
with open(output_path, 'w') as f:
|
||
|
|
json.dump(output_data, f, indent=2)
|
||
|
|
|
||
|
|
print(f"\n✓ Common matches saved to: {output_path}")
|
||
|
|
|
||
|
|
# Print 3D stats comparison
|
||
|
|
print("\n" + "="*80)
|
||
|
|
print("3D STATISTICS COMPARISON (COMMON MATCHES ONLY)")
|
||
|
|
print("="*80)
|
||
|
|
|
||
|
|
for class_name in sorted(model1_stats.keys()):
|
||
|
|
if class_name not in model2_stats:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# New format: class stats are nested under 'overall'
|
||
|
|
m1 = model1_stats[class_name].get('overall', model1_stats[class_name])
|
||
|
|
m2 = model2_stats[class_name].get('overall', model2_stats[class_name])
|
||
|
|
|
||
|
|
print(f"\n{class_name.upper()} (n={m1.get('num_samples', 0):,}):")
|
||
|
|
print(f"{'Metric':<20} {args.model1_name:>15} {args.model2_name:>15} {'Diff':>12} {'Change %':>10}")
|
||
|
|
print("-" * 80)
|
||
|
|
|
||
|
|
for error_type in ['lateral_error', 'longitudinal_error', 'heading_error']:
|
||
|
|
if error_type not in m1 or error_type not in m2:
|
||
|
|
continue
|
||
|
|
m1_mean = m1[error_type]['mean']
|
||
|
|
m2_mean = m2[error_type]['mean']
|
||
|
|
diff = m2_mean - m1_mean
|
||
|
|
change_pct = (diff / m1_mean * 100) if m1_mean > 0 else 0
|
||
|
|
|
||
|
|
error_name = error_type.replace('_', ' ').title()
|
||
|
|
print(f"{error_name:<20} {m1_mean:>15.4f} {m2_mean:>15.4f} {diff:>+12.4f} {change_pct:>+9.2f}%")
|
||
|
|
|
||
|
|
# Print relaxed heading error if available
|
||
|
|
if 'heading_error_relaxed' in m1 and 'heading_error_relaxed' in m2:
|
||
|
|
m1_mean = m1['heading_error_relaxed']['mean']
|
||
|
|
m2_mean = m2['heading_error_relaxed']['mean']
|
||
|
|
diff = m2_mean - m1_mean
|
||
|
|
change_pct = (diff / m1_mean * 100) if m1_mean > 0 else 0
|
||
|
|
print(f"{'Heading Error (Rlx)':<20} {m1_mean:>15.4f} {m2_mean:>15.4f} {diff:>+12.4f} {change_pct:>+9.2f}%")
|
||
|
|
|
||
|
|
# Print reversal statistics if available
|
||
|
|
if 'reversal_count' in m1 and 'reversal_count' in m2:
|
||
|
|
m1_count = m1['reversal_count']
|
||
|
|
m1_pct = m1.get('reversal_percentage', 0)
|
||
|
|
m2_count = m2['reversal_count']
|
||
|
|
m2_pct = m2.get('reversal_percentage', 0)
|
||
|
|
print(f"{'Reversals':<20} {m1_count:>11} ({m1_pct:>5.1f}%) {m2_count:>11} ({m2_pct:>5.1f}%)")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|