#!/usr/bin/env python3 """ Find Common Matches Between Two Models This tool finds the common GT objects that were successfully matched by both models, enabling fair comparison of 3D prediction quality on the same set of targets. Usage: python eval_tools/find_common_matches.py \ --model1-matches eval_results/model1/detailed_3d_matches.json \ --model2-matches eval_results/model2/detailed_3d_matches.json \ --output common_matches.json """ import argparse import json import sys from pathlib import Path from collections import defaultdict import numpy as np def find_common_matches(model1_matches, model2_matches): """ Find common GT objects matched by both models. Args: model1_matches: dict, detailed matches from model 1 model2_matches: dict, detailed matches from model 2 Returns: tuple: (common_matches, stats) common_matches: dict with structure {case: {frame: {class: [match_info]}}} stats: dict with match statistics """ common_matches = {} stats = { 'model1_total': 0, 'model2_total': 0, 'common': 0, 'model1_unique': 0, 'model2_unique': 0, 'per_class': {} } # Iterate through all cases in model 1 for case_name in model1_matches: if case_name not in model2_matches: # Case not in model 2, skip continue common_matches[case_name] = {} # Iterate through frames for frame_name in model1_matches[case_name]: if frame_name not in model2_matches[case_name]: # Frame not in model 2, skip continue common_matches[case_name][frame_name] = {} # Iterate through classes for class_name in model1_matches[case_name][frame_name]: if class_name not in model2_matches[case_name][frame_name]: # Class not in model 2, skip continue # Get match lists for this class m1_list = model1_matches[case_name][frame_name][class_name] m2_list = model2_matches[case_name][frame_name][class_name] # Build GT ID to index mappings m1_gt_ids = {m['gt_id']: i for i, m in enumerate(m1_list)} m2_gt_ids = {m['gt_id']: i for i, m in enumerate(m2_list)} # Find common GT IDs common_gt_ids = set(m1_gt_ids.keys()) & set(m2_gt_ids.keys()) # Update statistics if class_name not in stats['per_class']: stats['per_class'][class_name] = { 'model1_total': 0, 'model2_total': 0, 'common': 0, 'model1_unique': 0, 'model2_unique': 0 } stats['model1_total'] += len(m1_list) stats['model2_total'] += len(m2_list) stats['common'] += len(common_gt_ids) stats['model1_unique'] += len(m1_gt_ids) - len(common_gt_ids) stats['model2_unique'] += len(m2_gt_ids) - len(common_gt_ids) stats['per_class'][class_name]['model1_total'] += len(m1_list) stats['per_class'][class_name]['model2_total'] += len(m2_list) stats['per_class'][class_name]['common'] += len(common_gt_ids) stats['per_class'][class_name]['model1_unique'] += len(m1_gt_ids) - len(common_gt_ids) stats['per_class'][class_name]['model2_unique'] += len(m2_gt_ids) - len(common_gt_ids) # Store common match information common_list = [] for gt_id in common_gt_ids: common_list.append({ 'gt_id': gt_id, 'model1_idx': m1_gt_ids[gt_id], 'model2_idx': m2_gt_ids[gt_id] }) if common_list: common_matches[case_name][frame_name][class_name] = common_list # Calculate percentages if stats['model1_total'] > 0: stats['common_percentage_of_model1'] = (stats['common'] / stats['model1_total']) * 100 else: stats['common_percentage_of_model1'] = 0 if stats['model2_total'] > 0: stats['common_percentage_of_model2'] = (stats['common'] / stats['model2_total']) * 100 else: stats['common_percentage_of_model2'] = 0 for class_name in stats['per_class']: class_stats = stats['per_class'][class_name] if class_stats['model1_total'] > 0: class_stats['common_percentage_of_model1'] = (class_stats['common'] / class_stats['model1_total']) * 100 else: class_stats['common_percentage_of_model1'] = 0 if class_stats['model2_total'] > 0: class_stats['common_percentage_of_model2'] = (class_stats['common'] / class_stats['model2_total']) * 100 else: class_stats['common_percentage_of_model2'] = 0 return common_matches, stats # Default distance ranges matching eval metrics_3d config DEFAULT_LONG_RANGES = [ (0, 10), (10, 20), (20, 30), (30, 40), (40, 50), (50, 60), (60, 70), (70, 80), (80, 90), (90, 100), (100, 999) ] DEFAULT_LAT_RANGES = [ (-50, -40), (-40, -30), (-30, -20), (-20, -10), (-10, 0), (0, 10), (10, 20), (20, 30), (30, 40), (40, 50) ] def _range_key_long(lo, hi): return f'long_{lo}-{hi}m' def _range_key_lat(lo, hi): return f'lat_{lo}-{hi}m' def _make_stats(data_dict): """Compute mean/std/median/min/max for each list in data_dict.""" result = {} for key, values in data_dict.items(): if key in ('samples',): result[key] = values elif isinstance(values, list) and len(values) > 0: arr = np.array(values) result[key] = { 'mean': float(np.mean(arr)), 'std': float(np.std(arr)), 'median': float(np.median(arr)), 'min': float(np.min(arr)), 'max': float(np.max(arr)), } return result def _empty_bucket(): return { 'lateral': [], 'longitudinal': [], 'longitudinal_relative': [], 'heading': [], 'heading_relaxed': [], 'is_reversal': [], 'samples': 0 } def _finalize_class_stats(data): """Convert a bucket dict to stats dict, adding optional fields.""" entry = { 'num_samples': data['samples'], 'lateral_error': _make_stats({'lateral': data['lateral']})['lateral'], 'longitudinal_error': _make_stats({'longitudinal': data['longitudinal']})['longitudinal'], 'heading_error': _make_stats({'heading': data['heading']})['heading'], } if data['longitudinal_relative']: entry['longitudinal_relative_error'] = _make_stats( {'v': data['longitudinal_relative']})['v'] if data['heading_relaxed']: entry['heading_error_relaxed'] = _make_stats( {'v': data['heading_relaxed']})['v'] if data['is_reversal']: count = int(sum(data['is_reversal'])) entry['reversal_count'] = count entry['reversal_percentage'] = float(count / data['samples'] * 100) if data['samples'] > 0 else 0.0 return entry def recompute_3d_stats_from_common_matches(matches_data, common_matches, model_name, long_ranges=None, lat_ranges=None): """ Recompute 3D statistics based on common matches only. Returns a dict with structure: { class_name: { 'overall': { ... }, 'long_0-10m': { ... }, ... 'lat_-10-0m': { ... }, ... } } """ if long_ranges is None: long_ranges = DEFAULT_LONG_RANGES if lat_ranges is None: lat_ranges = DEFAULT_LAT_RANGES # Bucket structure: class -> range_key -> _empty_bucket() overall = {} # class -> _empty_bucket() by_long = {} # class -> range_key -> _empty_bucket() by_lat = {} # class -> range_key -> _empty_bucket() for case_name, frames in common_matches.items(): for frame_name, classes in frames.items(): for class_name, common_list in classes.items(): if class_name not in overall: overall[class_name] = _empty_bucket() by_long[class_name] = { _range_key_long(lo, hi): _empty_bucket() for lo, hi in long_ranges } by_lat[class_name] = { _range_key_lat(lo, hi): _empty_bucket() for lo, hi in lat_ranges } for match_info in common_list: idx = match_info[f'{model_name}_idx'] match = matches_data[case_name][frame_name][class_name][idx] errs = match['errors'] dist = match.get('distance', {}) z_val = dist.get('longitudinal', None) x_val = dist.get('lateral', None) # Helper: fill one bucket def _fill(bucket): bucket['lateral'].append(errs['lateral']) bucket['longitudinal'].append(errs['longitudinal']) bucket['heading'].append(errs['heading']) if 'longitudinal_relative' in errs: bucket['longitudinal_relative'].append(errs['longitudinal_relative']) if 'heading_relaxed' in errs: bucket['heading_relaxed'].append(errs['heading_relaxed']) if 'is_reversal' in errs: bucket['is_reversal'].append(errs['is_reversal']) bucket['samples'] += 1 _fill(overall[class_name]) # Longitudinal range bucket if z_val is not None: for lo, hi in long_ranges: if lo <= z_val < hi: _fill(by_long[class_name][_range_key_long(lo, hi)]) break # Lateral range bucket if x_val is not None: for lo, hi in lat_ranges: if lo <= x_val < hi: _fill(by_lat[class_name][_range_key_lat(lo, hi)]) break # Build result result = {} for class_name in overall: result[class_name] = {} # overall if overall[class_name]['samples'] > 0: result[class_name]['overall'] = _finalize_class_stats(overall[class_name]) else: result[class_name]['overall'] = {'num_samples': 0} # per longitudinal range for rk, bucket in by_long[class_name].items(): if bucket['samples'] > 0: result[class_name][rk] = _finalize_class_stats(bucket) # per lateral range for rk, bucket in by_lat[class_name].items(): if bucket['samples'] > 0: result[class_name][rk] = _finalize_class_stats(bucket) return result def print_statistics(stats, model1_name='model1', model2_name='model2'): """Print match statistics in a readable format.""" print("\n" + "="*80) print("COMMON MATCH STATISTICS") print("="*80) print(f"\nOverall:") print(f" {model1_name} Total Matches: {stats['model1_total']:,}") print(f" {model2_name} Total Matches: {stats['model2_total']:,}") print(f" Common Matches: {stats['common']:,} ({stats['common_percentage_of_model1']:.1f}% of {model1_name})") print(f" {model1_name} Unique: {stats['model1_unique']:,} ({100 - stats['common_percentage_of_model1']:.1f}%)") print(f" {model2_name} Unique: {stats['model2_unique']:,} ({100 - stats['common_percentage_of_model2']:.1f}%)") print(f"\nPer-Class Statistics:") # Truncate model names if too long for column headers m1_short = model1_name[:10] m2_short = model2_name[:10] print(f"{'Class':<15} {m1_short:>10} {m2_short:>10} {'Common':>10} {'Common%':>10} {m1_short+' Uniq':>12} {m2_short+' Uniq':>12}") print("-" * 80) for class_name, class_stats in sorted(stats['per_class'].items()): print(f"{class_name:<15} {class_stats['model1_total']:>10,} {class_stats['model2_total']:>10,} " f"{class_stats['common']:>10,} {class_stats['common_percentage_of_model1']:>9.1f}% " f"{class_stats['model1_unique']:>12,} {class_stats['model2_unique']:>12,}") def main(): """Main function.""" parser = argparse.ArgumentParser( description='Find common matches between two model evaluation results', formatter_class=argparse.RawDescriptionHelpFormatter ) parser.add_argument('--model1-matches', type=str, required=True, help='Path to model 1 detailed_3d_matches.json file') parser.add_argument('--model2-matches', type=str, required=True, help='Path to model 2 detailed_3d_matches.json file') parser.add_argument('--output', type=str, default='common_matches.json', help='Output path for common matches JSON file') parser.add_argument('--model1-name', type=str, default='model1', help='Name for model 1 (default: model1)') parser.add_argument('--model2-name', type=str, default='model2', help='Name for model 2 (default: model2)') args = parser.parse_args() # Load detailed matches print(f"Loading model 1 matches from: {args.model1_matches}") with open(args.model1_matches, 'r') as f: model1_matches = json.load(f) print(f"Loading model 2 matches from: {args.model2_matches}") with open(args.model2_matches, 'r') as f: model2_matches = json.load(f) # Find common matches print("\nFinding common matches...") common_matches, stats = find_common_matches(model1_matches, model2_matches) # Print statistics print_statistics(stats, args.model1_name, args.model2_name) # Recompute 3D stats for common matches print("\nRecomputing 3D statistics for common matches...") model1_stats = recompute_3d_stats_from_common_matches(model1_matches, common_matches, 'model1') model2_stats = recompute_3d_stats_from_common_matches(model2_matches, common_matches, 'model2') # Prepare output output_data = { 'match_statistics': stats, 'common_matches': common_matches, 'model1_3d_stats': model1_stats, 'model2_3d_stats': model2_stats, 'model_names': { 'model1': args.model1_name, 'model2': args.model2_name } } # Save output output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w') as f: json.dump(output_data, f, indent=2) print(f"\n✓ Common matches saved to: {output_path}") # Print 3D stats comparison print("\n" + "="*80) print("3D STATISTICS COMPARISON (COMMON MATCHES ONLY)") print("="*80) for class_name in sorted(model1_stats.keys()): if class_name not in model2_stats: continue # New format: class stats are nested under 'overall' m1 = model1_stats[class_name].get('overall', model1_stats[class_name]) m2 = model2_stats[class_name].get('overall', model2_stats[class_name]) print(f"\n{class_name.upper()} (n={m1.get('num_samples', 0):,}):") print(f"{'Metric':<20} {args.model1_name:>15} {args.model2_name:>15} {'Diff':>12} {'Change %':>10}") print("-" * 80) for error_type in ['lateral_error', 'longitudinal_error', 'heading_error']: if error_type not in m1 or error_type not in m2: continue m1_mean = m1[error_type]['mean'] m2_mean = m2[error_type]['mean'] diff = m2_mean - m1_mean change_pct = (diff / m1_mean * 100) if m1_mean > 0 else 0 error_name = error_type.replace('_', ' ').title() print(f"{error_name:<20} {m1_mean:>15.4f} {m2_mean:>15.4f} {diff:>+12.4f} {change_pct:>+9.2f}%") # Print relaxed heading error if available if 'heading_error_relaxed' in m1 and 'heading_error_relaxed' in m2: m1_mean = m1['heading_error_relaxed']['mean'] m2_mean = m2['heading_error_relaxed']['mean'] diff = m2_mean - m1_mean change_pct = (diff / m1_mean * 100) if m1_mean > 0 else 0 print(f"{'Heading Error (Rlx)':<20} {m1_mean:>15.4f} {m2_mean:>15.4f} {diff:>+12.4f} {change_pct:>+9.2f}%") # Print reversal statistics if available if 'reversal_count' in m1 and 'reversal_count' in m2: m1_count = m1['reversal_count'] m1_pct = m1.get('reversal_percentage', 0) m2_count = m2['reversal_count'] m2_pct = m2.get('reversal_percentage', 0) print(f"{'Reversals':<20} {m1_count:>11} ({m1_pct:>5.1f}%) {m2_count:>11} ({m2_pct:>5.1f}%)") if __name__ == '__main__': main()