1046 lines
52 KiB
Python
Executable File
1046 lines
52 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Model Evaluation Comparison Tool
|
|
|
|
This script compares evaluation results from two different models and generates
|
|
comprehensive comparison reports including:
|
|
- Overall 2D and 3D metrics comparison
|
|
- Per-class performance comparison
|
|
- Distance-range based 3D metrics comparison
|
|
- Per-case performance comparison
|
|
- Statistical significance tests
|
|
- Visualization plots
|
|
|
|
Usage:
|
|
python eval_tools/compare_models.py \
|
|
--model1 eval_results/model1/evaluation_report.json \
|
|
--model2 eval_results/model2/evaluation_report.json \
|
|
--output-dir comparison_results \
|
|
--model1-name "Model-A" \
|
|
--model2-name "Model-B"
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
import numpy as np
|
|
from collections import defaultdict
|
|
|
|
# Allow importing class_config from the eval_tools root
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
from class_config import REPORT_3D_CLASS_KEYS
|
|
|
|
# Allow importing from the same directory
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
try:
|
|
from find_common_matches import recompute_3d_stats_from_common_matches as _recompute_range_stats
|
|
except ImportError:
|
|
_recompute_range_stats = None
|
|
|
|
|
|
class ModelComparator:
|
|
"""Compare evaluation results from two models."""
|
|
|
|
def __init__(self, model1_report, model2_report, model1_name="Model-1", model2_name="Model-2",
|
|
common_matches_data=None,
|
|
model1_detailed_path=None, model2_detailed_path=None):
|
|
"""
|
|
Initialize comparator.
|
|
|
|
Args:
|
|
model1_report: dict, evaluation report for model 1
|
|
model2_report: dict, evaluation report for model 2
|
|
model1_name: str, display name for model 1
|
|
model2_name: str, display name for model 2
|
|
common_matches_data: dict, optional data from find_common_matches.py;
|
|
if provided, 3D comparison will use common matches only
|
|
model1_detailed_path: str, path to model 1 detailed_3d_matches.json (optional
|
|
fallback when common_matches.json lacks per-range data)
|
|
model2_detailed_path: str, path to model 2 detailed_3d_matches.json (optional)
|
|
"""
|
|
self.model1_report = model1_report
|
|
self.model2_report = model2_report
|
|
self.model1_name = model1_name
|
|
self.model2_name = model2_name
|
|
self.common_matches_data = common_matches_data
|
|
self.model1_detailed_path = model1_detailed_path
|
|
self.model2_detailed_path = model2_detailed_path
|
|
|
|
self.comparison_results = {}
|
|
|
|
def compare_2d_metrics(self):
|
|
"""Compare 2D detection metrics."""
|
|
print("\n" + "="*80)
|
|
print("Comparing 2D Detection Metrics")
|
|
print("="*80)
|
|
|
|
comparison = {
|
|
'overall': {},
|
|
'per_class': {}
|
|
}
|
|
|
|
# Overall comparison
|
|
m1_overall = self.model1_report['2d_evaluation']['overall']
|
|
m2_overall = self.model2_report['2d_evaluation']['overall']
|
|
|
|
for metric in ['precision', 'recall', 'f1_score', 'map']:
|
|
comparison['overall'][metric] = {
|
|
self.model1_name: m1_overall[metric],
|
|
self.model2_name: m2_overall[metric],
|
|
'diff': m2_overall[metric] - m1_overall[metric],
|
|
'relative_change_%': ((m2_overall[metric] - m1_overall[metric]) / m1_overall[metric] * 100) if m1_overall[metric] > 0 else 0
|
|
}
|
|
|
|
# Per-class comparison
|
|
m1_classes = self.model1_report['2d_evaluation']['per_class']
|
|
m2_classes = self.model2_report['2d_evaluation']['per_class']
|
|
|
|
for class_name in m1_classes.keys():
|
|
if class_name not in m2_classes:
|
|
continue
|
|
|
|
comparison['per_class'][class_name] = {}
|
|
for metric in ['precision', 'recall', 'f1_score', 'ap']:
|
|
m1_val = m1_classes[class_name][metric]
|
|
m2_val = m2_classes[class_name][metric]
|
|
|
|
comparison['per_class'][class_name][metric] = {
|
|
self.model1_name: m1_val,
|
|
self.model2_name: m2_val,
|
|
'diff': m2_val - m1_val,
|
|
'relative_change_%': ((m2_val - m1_val) / m1_val * 100) if m1_val > 0 else 0
|
|
}
|
|
|
|
self.comparison_results['2d_metrics'] = comparison
|
|
return comparison
|
|
|
|
def _compare_range_bucket(self, m1_range, m2_range):
|
|
"""Compare a single distance-range bucket between two models.
|
|
|
|
m1_range / m2_range are dicts like those inside 3d_evaluation[class][range_key].
|
|
Returns a comparison dict (same shape as used by generate_text_report).
|
|
"""
|
|
if m1_range.get('num_samples', 0) == 0 or m2_range.get('num_samples', 0) == 0:
|
|
return None
|
|
|
|
cmp = {}
|
|
for error_type in ['lateral_error', 'longitudinal_error',
|
|
'longitudinal_relative_error', 'heading_error',
|
|
'heading_error_relaxed']:
|
|
if error_type not in m1_range or error_type not in m2_range:
|
|
continue
|
|
m1_mean = m1_range[error_type]['mean']
|
|
m2_mean = m2_range[error_type]['mean']
|
|
cmp[error_type] = {
|
|
self.model1_name: {
|
|
'mean': m1_mean,
|
|
'std': m1_range[error_type]['std'],
|
|
'samples': m1_range['num_samples']
|
|
},
|
|
self.model2_name: {
|
|
'mean': m2_mean,
|
|
'std': m2_range[error_type]['std'],
|
|
'samples': m2_range['num_samples']
|
|
},
|
|
'diff': m2_mean - m1_mean,
|
|
'relative_change_%': ((m2_mean - m1_mean) / m1_mean * 100) if m1_mean > 0 else 0,
|
|
'improvement': m2_mean < m1_mean
|
|
}
|
|
return cmp if cmp else None
|
|
|
|
def _compare_3d_metrics_common_matches(self):
|
|
"""Compare 3D metrics using only common matches."""
|
|
comparison = {}
|
|
|
|
# Get recomputed stats from common_matches_data
|
|
m1_stats = self.common_matches_data.get('model1_3d_stats', {})
|
|
m2_stats = self.common_matches_data.get('model2_3d_stats', {})
|
|
match_stats = self.common_matches_data.get('match_statistics', {})
|
|
|
|
# Detect whether the stats are in the new per-range format
|
|
# New format: m1_stats[class]['overall'] exists
|
|
# Old format: m1_stats[class]['lateral_error'] directly (no 'overall' key)
|
|
def _has_range_format(stats):
|
|
for cls_data in stats.values():
|
|
return 'overall' in cls_data
|
|
return False
|
|
|
|
use_range_format = _has_range_format(m1_stats) and _has_range_format(m2_stats)
|
|
|
|
# ── Upgrade old-format common_matches.json to per-range format on-the-fly ─────
|
|
if not use_range_format and _recompute_range_stats is not None:
|
|
print(" common_matches.json has legacy flat format — computing per-range stats "
|
|
"from detailed_3d_matches.json...")
|
|
m1_det = self.model1_detailed_path
|
|
m2_det = self.model2_detailed_path
|
|
if m1_det and m2_det and Path(m1_det).exists() and Path(m2_det).exists():
|
|
with open(m1_det, 'r') as f:
|
|
m1_detailed = json.load(f)
|
|
with open(m2_det, 'r') as f:
|
|
m2_detailed = json.load(f)
|
|
raw_common = self.common_matches_data.get('common_matches', {})
|
|
m1_stats = _recompute_range_stats(m1_detailed, raw_common, 'model1')
|
|
m2_stats = _recompute_range_stats(m2_detailed, raw_common, 'model2')
|
|
use_range_format = True
|
|
print(" ✓ Per-range stats computed from detailed match files.")
|
|
else:
|
|
print(" WARNING: detailed_3d_matches.json paths not available or not found; "
|
|
"distance-range sections will be empty.")
|
|
print(f" model1 path: {m1_det}")
|
|
print(f" model2 path: {m2_det}")
|
|
|
|
# Print match statistics summary
|
|
print(f"\nMatch Statistics:")
|
|
print(f" {self.model1_name} Total Matches: {match_stats.get('model1_total', 0):,}")
|
|
print(f" {self.model2_name} Total Matches: {match_stats.get('model2_total', 0):,}")
|
|
print(f" Common Matches: {match_stats.get('common', 0):,} "
|
|
f"({match_stats.get('common_percentage_of_model1', 0):.1f}% of {self.model1_name})")
|
|
print(f" {self.model1_name} Unique: {match_stats.get('model1_unique', 0):,}")
|
|
print(f" {self.model2_name} Unique: {match_stats.get('model2_unique', 0):,}")
|
|
|
|
for class_name in m1_stats.keys():
|
|
if class_name not in m2_stats:
|
|
continue
|
|
|
|
m1_class = m1_stats[class_name]
|
|
m2_class = m2_stats[class_name]
|
|
|
|
if use_range_format:
|
|
# New format: each class is a dict of range_key -> stats_dict
|
|
comparison[class_name] = {
|
|
'common_samples': m1_class.get('overall', {}).get('num_samples', 0),
|
|
'match_info': match_stats.get('per_class', {}).get(class_name, {})
|
|
}
|
|
for range_key in m1_class.keys():
|
|
if range_key not in m2_class:
|
|
continue
|
|
if range_key == 'match_info':
|
|
continue
|
|
bucket_cmp = self._compare_range_bucket(m1_class[range_key],
|
|
m2_class[range_key])
|
|
if bucket_cmp is not None:
|
|
comparison[class_name][range_key] = bucket_cmp
|
|
# Add reversal info to overall if available
|
|
ov1 = m1_class.get('overall', {})
|
|
ov2 = m2_class.get('overall', {})
|
|
if 'reversal_count' in ov1 and 'reversal_count' in ov2:
|
|
comparison[class_name].setdefault('overall', {})
|
|
comparison[class_name]['overall']['reversal_info'] = {
|
|
self.model1_name: {
|
|
'count': ov1['reversal_count'],
|
|
'percentage': ov1.get('reversal_percentage', 0)
|
|
},
|
|
self.model2_name: {
|
|
'count': ov2['reversal_count'],
|
|
'percentage': ov2.get('reversal_percentage', 0)
|
|
}
|
|
}
|
|
else:
|
|
# Legacy format: class stats are flat (no 'overall' key)
|
|
m1_overall = m1_class
|
|
m2_overall = m2_class
|
|
comparison[class_name] = {
|
|
'overall': {},
|
|
'common_samples': m1_overall.get('num_samples', 0),
|
|
'match_info': match_stats.get('per_class', {}).get(class_name, {})
|
|
}
|
|
for error_type in ['lateral_error', 'longitudinal_error', 'heading_error']:
|
|
if error_type not in m1_overall or error_type not in m2_overall:
|
|
continue
|
|
m1_mean = m1_overall[error_type]['mean']
|
|
m2_mean = m2_overall[error_type]['mean']
|
|
comparison[class_name]['overall'][error_type] = {
|
|
self.model1_name: {
|
|
'mean': m1_mean,
|
|
'std': m1_overall[error_type]['std'],
|
|
'samples': m1_overall.get('num_samples', 0)
|
|
},
|
|
self.model2_name: {
|
|
'mean': m2_mean,
|
|
'std': m2_overall[error_type]['std'],
|
|
'samples': m2_overall.get('num_samples', 0)
|
|
},
|
|
'diff': m2_mean - m1_mean,
|
|
'relative_change_%': ((m2_mean - m1_mean) / m1_mean * 100) if m1_mean > 0 else 0,
|
|
'improvement': m2_mean < m1_mean
|
|
}
|
|
for opt_type in ['longitudinal_relative_error', 'heading_error_relaxed']:
|
|
if opt_type in m1_overall and opt_type in m2_overall:
|
|
m1v = m1_overall[opt_type]['mean']
|
|
m2v = m2_overall[opt_type]['mean']
|
|
comparison[class_name]['overall'][opt_type] = {
|
|
self.model1_name: {
|
|
'mean': m1v, 'std': m1_overall[opt_type]['std'],
|
|
'samples': m1_overall.get('num_samples', 0)
|
|
},
|
|
self.model2_name: {
|
|
'mean': m2v, 'std': m2_overall[opt_type]['std'],
|
|
'samples': m2_overall.get('num_samples', 0)
|
|
},
|
|
'diff': m2v - m1v,
|
|
'relative_change_%': ((m2v - m1v) / m1v * 100) if m1v > 0 else 0,
|
|
'improvement': m2v < m1v
|
|
}
|
|
if 'reversal_count' in m1_overall and 'reversal_count' in m2_overall:
|
|
comparison[class_name]['overall']['reversal_info'] = {
|
|
self.model1_name: {
|
|
'count': m1_overall['reversal_count'],
|
|
'percentage': m1_overall.get('reversal_percentage', 0)
|
|
},
|
|
self.model2_name: {
|
|
'count': m2_overall['reversal_count'],
|
|
'percentage': m2_overall.get('reversal_percentage', 0)
|
|
}
|
|
}
|
|
|
|
self.comparison_results['3d_metrics'] = comparison
|
|
self.comparison_results['match_statistics'] = match_stats
|
|
return comparison
|
|
|
|
def compare_3d_metrics(self):
|
|
"""Compare 3D detection metrics."""
|
|
print("\n" + "="*80)
|
|
if self.common_matches_data:
|
|
print("Comparing 3D Detection Metrics (COMMON MATCHES ONLY)")
|
|
else:
|
|
print("Comparing 3D Detection Metrics")
|
|
print("="*80)
|
|
|
|
# If using common matches, use precomputed stats
|
|
if self.common_matches_data:
|
|
return self._compare_3d_metrics_common_matches()
|
|
|
|
comparison = {}
|
|
|
|
m1_3d = self.model1_report.get('3d_evaluation', {})
|
|
m2_3d = self.model2_report.get('3d_evaluation', {})
|
|
|
|
for class_name in m1_3d.keys():
|
|
if class_name not in m2_3d:
|
|
continue
|
|
|
|
comparison[class_name] = {}
|
|
|
|
# Check if distance-range based
|
|
if 'overall' in m1_3d[class_name]:
|
|
# Compare overall and per-range
|
|
for range_key in m1_3d[class_name].keys():
|
|
if range_key not in m2_3d[class_name]:
|
|
continue
|
|
|
|
m1_range = m1_3d[class_name][range_key]
|
|
m2_range = m2_3d[class_name][range_key]
|
|
|
|
if m1_range['num_samples'] == 0 and m2_range['num_samples'] == 0:
|
|
continue
|
|
|
|
comparison[class_name][range_key] = {}
|
|
|
|
for error_type in ['lateral_error', 'longitudinal_error', 'heading_error']:
|
|
if m1_range['num_samples'] > 0 and m2_range['num_samples'] > 0:
|
|
m1_mean = m1_range[error_type]['mean']
|
|
m2_mean = m2_range[error_type]['mean']
|
|
|
|
comparison[class_name][range_key][error_type] = {
|
|
self.model1_name: {
|
|
'mean': m1_mean,
|
|
'std': m1_range[error_type]['std'],
|
|
'samples': m1_range['num_samples']
|
|
},
|
|
self.model2_name: {
|
|
'mean': m2_mean,
|
|
'std': m2_range[error_type]['std'],
|
|
'samples': m2_range['num_samples']
|
|
},
|
|
'diff': m2_mean - m1_mean,
|
|
'relative_change_%': ((m2_mean - m1_mean) / m1_mean * 100) if m1_mean > 0 else 0,
|
|
'improvement': m1_mean > m2_mean # Lower error is better
|
|
}
|
|
if (m1_range['num_samples'] > 0 and m2_range['num_samples'] > 0
|
|
and 'longitudinal_relative_error' in m1_range
|
|
and 'longitudinal_relative_error' in m2_range):
|
|
m1_mean_lr = m1_range['longitudinal_relative_error']['mean']
|
|
m2_mean_lr = m2_range['longitudinal_relative_error']['mean']
|
|
comparison[class_name][range_key]['longitudinal_relative_error'] = {
|
|
self.model1_name: {
|
|
'mean': m1_mean_lr,
|
|
'std': m1_range['longitudinal_relative_error']['std'],
|
|
'samples': m1_range['num_samples']
|
|
},
|
|
self.model2_name: {
|
|
'mean': m2_mean_lr,
|
|
'std': m2_range['longitudinal_relative_error']['std'],
|
|
'samples': m2_range['num_samples']
|
|
},
|
|
'diff': m2_mean_lr - m1_mean_lr,
|
|
'relative_change_%': ((m2_mean_lr - m1_mean_lr) / m1_mean_lr * 100) if m1_mean_lr > 0 else 0,
|
|
'improvement': m2_mean_lr < m1_mean_lr
|
|
}
|
|
if (m1_range['num_samples'] > 0 and m2_range['num_samples'] > 0
|
|
and 'heading_error_relaxed' in m1_range
|
|
and 'heading_error_relaxed' in m2_range):
|
|
m1_mean_hr = m1_range['heading_error_relaxed']['mean']
|
|
m2_mean_hr = m2_range['heading_error_relaxed']['mean']
|
|
comparison[class_name][range_key]['heading_error_relaxed'] = {
|
|
self.model1_name: {
|
|
'mean': m1_mean_hr,
|
|
'std': m1_range['heading_error_relaxed']['std'],
|
|
'samples': m1_range['num_samples']
|
|
},
|
|
self.model2_name: {
|
|
'mean': m2_mean_hr,
|
|
'std': m2_range['heading_error_relaxed']['std'],
|
|
'samples': m2_range['num_samples']
|
|
},
|
|
'diff': m2_mean_hr - m1_mean_hr,
|
|
'relative_change_%': ((m2_mean_hr - m1_mean_hr) / m1_mean_hr * 100) if m1_mean_hr > 0 else 0,
|
|
'improvement': m2_mean_hr < m1_mean_hr
|
|
}
|
|
else:
|
|
# Legacy format
|
|
m1_class = m1_3d[class_name]
|
|
m2_class = m2_3d[class_name]
|
|
|
|
if m1_class['num_samples'] == 0 and m2_class['num_samples'] == 0:
|
|
continue
|
|
|
|
comparison[class_name]['overall'] = {}
|
|
|
|
for error_type in ['lateral_error', 'longitudinal_error', 'heading_error']:
|
|
if m1_class['num_samples'] > 0 and m2_class['num_samples'] > 0:
|
|
m1_mean = m1_class[error_type]['mean']
|
|
m2_mean = m2_class[error_type]['mean']
|
|
|
|
comparison[class_name]['overall'][error_type] = {
|
|
self.model1_name: {
|
|
'mean': m1_mean,
|
|
'std': m1_class[error_type]['std'],
|
|
'samples': m1_class['num_samples']
|
|
},
|
|
self.model2_name: {
|
|
'mean': m2_mean,
|
|
'std': m2_class[error_type]['std'],
|
|
'samples': m2_class['num_samples']
|
|
},
|
|
'diff': m2_mean - m1_mean,
|
|
'relative_change_%': ((m2_mean - m1_mean) / m1_mean * 100) if m1_mean > 0 else 0,
|
|
'improvement': m1_mean > m2_mean
|
|
}
|
|
if (m1_class['num_samples'] > 0 and m2_class['num_samples'] > 0
|
|
and 'longitudinal_relative_error' in m1_class
|
|
and 'longitudinal_relative_error' in m2_class):
|
|
m1_mean_lr = m1_class['longitudinal_relative_error']['mean']
|
|
m2_mean_lr = m2_class['longitudinal_relative_error']['mean']
|
|
comparison[class_name]['overall']['longitudinal_relative_error'] = {
|
|
self.model1_name: {
|
|
'mean': m1_mean_lr,
|
|
'std': m1_class['longitudinal_relative_error']['std'],
|
|
'samples': m1_class['num_samples']
|
|
},
|
|
self.model2_name: {
|
|
'mean': m2_mean_lr,
|
|
'std': m2_class['longitudinal_relative_error']['std'],
|
|
'samples': m2_class['num_samples']
|
|
},
|
|
'diff': m2_mean_lr - m1_mean_lr,
|
|
'relative_change_%': ((m2_mean_lr - m1_mean_lr) / m1_mean_lr * 100) if m1_mean_lr > 0 else 0,
|
|
'improvement': m2_mean_lr < m1_mean_lr
|
|
}
|
|
|
|
self.comparison_results['3d_metrics'] = comparison
|
|
return comparison
|
|
|
|
def compare_per_case(self):
|
|
"""Compare per-case performance."""
|
|
print("\n" + "="*80)
|
|
print("Comparing Per-Case Performance")
|
|
print("="*80)
|
|
|
|
comparison = {
|
|
'2d': {},
|
|
'3d': {}
|
|
}
|
|
|
|
# Get common cases
|
|
m1_cases_2d = set(self.model1_report.get('per_case_2d', {}).keys())
|
|
m2_cases_2d = set(self.model2_report.get('per_case_2d', {}).keys())
|
|
common_cases = m1_cases_2d.intersection(m2_cases_2d)
|
|
|
|
print(f"Found {len(common_cases)} common cases")
|
|
|
|
# 2D per-case comparison
|
|
for case_name in sorted(common_cases):
|
|
m1_case = self.model1_report['per_case_2d'][case_name]
|
|
m2_case = self.model2_report['per_case_2d'][case_name]
|
|
|
|
comparison['2d'][case_name] = {}
|
|
|
|
# Overall metrics
|
|
comparison['2d'][case_name]['overall'] = {}
|
|
for metric in ['precision', 'recall', 'f1_score', 'map']:
|
|
m1_val = m1_case['overall'][metric]
|
|
m2_val = m2_case['overall'][metric]
|
|
|
|
comparison['2d'][case_name]['overall'][metric] = {
|
|
self.model1_name: m1_val,
|
|
self.model2_name: m2_val,
|
|
'diff': m2_val - m1_val,
|
|
'relative_change_%': ((m2_val - m1_val) / m1_val * 100) if m1_val > 0 else 0
|
|
}
|
|
|
|
# 3D per-case comparison
|
|
m1_cases_3d = set(self.model1_report.get('per_case_3d', {}).keys())
|
|
m2_cases_3d = set(self.model2_report.get('per_case_3d', {}).keys())
|
|
common_cases_3d = m1_cases_3d.intersection(m2_cases_3d)
|
|
|
|
for case_name in sorted(common_cases_3d):
|
|
m1_case = self.model1_report['per_case_3d'][case_name]
|
|
m2_case = self.model2_report['per_case_3d'][case_name]
|
|
|
|
comparison['3d'][case_name] = {}
|
|
|
|
# Compare 3D classes
|
|
for class_name in REPORT_3D_CLASS_KEYS:
|
|
if class_name not in m1_case or class_name not in m2_case:
|
|
continue
|
|
|
|
comparison['3d'][case_name][class_name] = {}
|
|
|
|
# Get overall metrics
|
|
if 'overall' in m1_case[class_name]:
|
|
m1_overall = m1_case[class_name]['overall']
|
|
m2_overall = m2_case[class_name]['overall']
|
|
else:
|
|
m1_overall = m1_case[class_name]
|
|
m2_overall = m2_case[class_name]
|
|
|
|
if m1_overall['num_samples'] == 0 or m2_overall['num_samples'] == 0:
|
|
continue
|
|
|
|
for error_type in ['lateral_error', 'longitudinal_error']:
|
|
m1_mean = m1_overall[error_type]['mean']
|
|
m2_mean = m2_overall[error_type]['mean']
|
|
|
|
comparison['3d'][case_name][class_name][error_type] = {
|
|
self.model1_name: m1_mean,
|
|
self.model2_name: m2_mean,
|
|
'diff': m2_mean - m1_mean,
|
|
'improvement': m1_mean > m2_mean
|
|
}
|
|
|
|
self.comparison_results['per_case'] = comparison
|
|
return comparison
|
|
|
|
def generate_summary_stats(self):
|
|
"""Generate summary statistics."""
|
|
print("\n" + "="*80)
|
|
print("Generating Summary Statistics")
|
|
print("="*80)
|
|
|
|
summary = {
|
|
'2d': {
|
|
'ap': {
|
|
'wins': 0, # Number of classes where model2 is better
|
|
'losses': 0,
|
|
'ties': 0
|
|
},
|
|
'f1_score': {
|
|
'wins': 0,
|
|
'losses': 0,
|
|
'ties': 0
|
|
}
|
|
},
|
|
'3d': {
|
|
'lateral': {
|
|
'wins': 0,
|
|
'losses': 0,
|
|
'ties': 0
|
|
},
|
|
'longitudinal': {
|
|
'wins': 0,
|
|
'losses': 0,
|
|
'ties': 0
|
|
},
|
|
'heading': {
|
|
'wins': 0,
|
|
'losses': 0,
|
|
'ties': 0
|
|
}
|
|
}
|
|
}
|
|
|
|
# Count 2D wins/losses based on AP
|
|
if '2d_metrics' in self.comparison_results:
|
|
for class_name, metrics in self.comparison_results['2d_metrics']['per_class'].items():
|
|
if 'ap' in metrics:
|
|
diff = metrics['ap']['diff']
|
|
if abs(diff) < 0.01: # Consider < 1% as tie
|
|
summary['2d']['ap']['ties'] += 1
|
|
elif diff > 0:
|
|
summary['2d']['ap']['wins'] += 1
|
|
else:
|
|
summary['2d']['ap']['losses'] += 1
|
|
|
|
# Count based on F1 Score
|
|
if 'f1_score' in metrics:
|
|
diff = metrics['f1_score']['diff']
|
|
if abs(diff) < 0.01: # Consider < 1% as tie
|
|
summary['2d']['f1_score']['ties'] += 1
|
|
elif diff > 0:
|
|
summary['2d']['f1_score']['wins'] += 1
|
|
else:
|
|
summary['2d']['f1_score']['losses'] += 1
|
|
|
|
# Count 3D wins/losses based on all error types
|
|
if '3d_metrics' in self.comparison_results:
|
|
for class_name, ranges in self.comparison_results['3d_metrics'].items():
|
|
for range_key, metrics in ranges.items():
|
|
# Skip non-metric fields (like 'common_samples', 'match_info')
|
|
if not isinstance(metrics, dict):
|
|
continue
|
|
if 'lateral_error' not in metrics:
|
|
continue
|
|
|
|
# Count lateral error
|
|
if metrics['lateral_error']['improvement']:
|
|
summary['3d']['lateral']['wins'] += 1
|
|
else:
|
|
summary['3d']['lateral']['losses'] += 1
|
|
|
|
# Count longitudinal error
|
|
if 'longitudinal_error' in metrics:
|
|
if metrics['longitudinal_error']['improvement']:
|
|
summary['3d']['longitudinal']['wins'] += 1
|
|
else:
|
|
summary['3d']['longitudinal']['losses'] += 1
|
|
|
|
# Count heading error
|
|
if 'heading_error' in metrics:
|
|
if metrics['heading_error']['improvement']:
|
|
summary['3d']['heading']['wins'] += 1
|
|
else:
|
|
summary['3d']['heading']['losses'] += 1
|
|
|
|
self.comparison_results['summary'] = summary
|
|
return summary
|
|
|
|
def generate_text_report(self, output_file):
|
|
"""Generate human-readable text report."""
|
|
print(f"\nGenerating text report: {output_file}")
|
|
|
|
with open(output_file, 'w') as f:
|
|
f.write("="*80 + "\n")
|
|
f.write("MODEL COMPARISON REPORT\n")
|
|
f.write("="*80 + "\n\n")
|
|
|
|
f.write(f"Model 1: {self.model1_name}\n")
|
|
f.write(f"Model 2: {self.model2_name}\n\n")
|
|
|
|
# 2D Overall Comparison
|
|
if '2d_metrics' in self.comparison_results:
|
|
f.write("\n" + "="*80 + "\n")
|
|
f.write("2D DETECTION METRICS - OVERALL COMPARISON\n")
|
|
f.write("="*80 + "\n\n")
|
|
|
|
overall = self.comparison_results['2d_metrics']['overall']
|
|
f.write(f"{'Metric':<15} {self.model1_name:<12} {self.model2_name:<12} {'Diff':<12} {'Change %':<12}\n")
|
|
f.write("-"*80 + "\n")
|
|
|
|
for metric, values in overall.items():
|
|
f.write(f"{metric.upper():<15} "
|
|
f"{values[self.model1_name]:<12.4f} "
|
|
f"{values[self.model2_name]:<12.4f} "
|
|
f"{values['diff']:>+11.4f} "
|
|
f"{values['relative_change_%']:>+11.2f}%\n")
|
|
|
|
# Per-class 2D comparison - detailed metrics
|
|
f.write("\n" + "="*80 + "\n")
|
|
f.write("2D DETECTION METRICS - PER-CLASS COMPARISON\n")
|
|
f.write("="*80 + "\n\n")
|
|
|
|
# Precision, Recall, F1 Score comparison
|
|
f.write("Precision / Recall / F1 Score:\n")
|
|
f.write(f"{'Class':<15} {'Metric':<12} {self.model1_name:<12} {self.model2_name:<12} {'Diff':<12} {'Change %':<12}\n")
|
|
f.write("-"*100 + "\n")
|
|
|
|
per_class = self.comparison_results['2d_metrics']['per_class']
|
|
for class_name in sorted(per_class.keys()):
|
|
for metric_name in ['precision', 'recall', 'f1_score']:
|
|
metric_data = per_class[class_name][metric_name]
|
|
f.write(f"{class_name:<15} "
|
|
f"{metric_name:<12} "
|
|
f"{metric_data[self.model1_name]:<12.4f} "
|
|
f"{metric_data[self.model2_name]:<12.4f} "
|
|
f"{metric_data['diff']:>+11.4f} "
|
|
f"{metric_data['relative_change_%']:>+11.2f}%\n")
|
|
|
|
# AP comparison
|
|
f.write("\nAverage Precision (AP):\n")
|
|
f.write(f"{'Class':<15} {self.model1_name:<12} {self.model2_name:<12} {'Diff':<12} {'Change %':<12}\n")
|
|
f.write("-"*80 + "\n")
|
|
|
|
for class_name in sorted(per_class.keys()):
|
|
ap_data = per_class[class_name]['ap']
|
|
f.write(f"{class_name:<15} "
|
|
f"{ap_data[self.model1_name]:<12.4f} "
|
|
f"{ap_data[self.model2_name]:<12.4f} "
|
|
f"{ap_data['diff']:>+11.4f} "
|
|
f"{ap_data['relative_change_%']:>+11.2f}%\n")
|
|
|
|
# 3D Comparison
|
|
if '3d_metrics' in self.comparison_results:
|
|
f.write("\n" + "="*80 + "\n")
|
|
f.write("3D DETECTION METRICS COMPARISON\n")
|
|
f.write("="*80 + "\n\n")
|
|
|
|
# First, write table format summary for overall metrics
|
|
f.write("OVERALL 3D METRICS SUMMARY (by class)\n")
|
|
f.write("-"*80 + "\n\n")
|
|
|
|
for class_name, ranges in sorted(self.comparison_results['3d_metrics'].items()):
|
|
if 'overall' not in ranges:
|
|
continue
|
|
|
|
overall = ranges['overall']
|
|
if 'lateral_error' not in overall:
|
|
continue
|
|
|
|
f.write(f"{class_name.upper()}:\n")
|
|
f.write(f"{'Metric':<20} {self.model1_name:<15} {self.model2_name:<15} {'Diff':<15} {'Change %':<12} {'Result':<10}\n")
|
|
f.write("-"*100 + "\n")
|
|
|
|
for error_type, display_name in [('lateral_error', 'Lateral (m)'),
|
|
('longitudinal_error', 'Longitudinal (m)'),
|
|
('longitudinal_relative_error', 'Long Relative'),
|
|
('heading_error', 'Heading (rad)')]:
|
|
if error_type in overall:
|
|
data = overall[error_type]
|
|
m1_str = f"{data[self.model1_name]['mean']:.4f}"
|
|
m2_str = f"{data[self.model2_name]['mean']:.4f}"
|
|
diff_str = f"{data['diff']:+.4f}"
|
|
change_str = f"{data['relative_change_%']:+.2f}%"
|
|
result_str = "✓ BETTER" if data['improvement'] else "✗ WORSE"
|
|
|
|
f.write(f"{display_name:<20} {m1_str:<15} {m2_str:<15} {diff_str:<15} {change_str:<12} {result_str:<10}\n")
|
|
|
|
# Add relaxed heading error if available
|
|
if 'heading_error_relaxed' in overall:
|
|
data = overall['heading_error_relaxed']
|
|
m1_str = f"{data[self.model1_name]['mean']:.4f}"
|
|
m2_str = f"{data[self.model2_name]['mean']:.4f}"
|
|
diff_str = f"{data['diff']:+.4f}"
|
|
change_str = f"{data['relative_change_%']:+.2f}%"
|
|
result_str = "✓ BETTER" if data['improvement'] else "✗ WORSE"
|
|
|
|
f.write(f"{'Heading Relaxed (rad)':<20} {m1_str:<15} {m2_str:<15} {diff_str:<15} {change_str:<12} {result_str:<10}\n")
|
|
|
|
# Add reversal statistics if available
|
|
if 'reversal_info' in overall:
|
|
rev_info = overall['reversal_info']
|
|
m1_rev = f"{rev_info[self.model1_name]['count']} ({rev_info[self.model1_name]['percentage']:.1f}%)"
|
|
m2_rev = f"{rev_info[self.model2_name]['count']} ({rev_info[self.model2_name]['percentage']:.1f}%)"
|
|
f.write(f"{'Reversal Cases':<20} {m1_rev:<15} {m2_rev:<15}\n")
|
|
|
|
f.write(f"Samples: {self.model1_name}={overall['lateral_error'][self.model1_name]['samples']}, "
|
|
f"{self.model2_name}={overall['lateral_error'][self.model2_name]['samples']}\n")
|
|
f.write("\n")
|
|
|
|
# Helper: extract numeric start value from range keys like
|
|
# "long_0-10m", "long_100-999m", "lat_-50--40m", "lat_-10-0m", "overall"
|
|
def _range_sort_key(range_key):
|
|
if range_key == 'overall':
|
|
return float('inf')
|
|
try:
|
|
# Strip prefix (long_ / lat_) and trailing 'm'
|
|
stripped = range_key
|
|
for prefix in ('long_', 'lat_'):
|
|
if stripped.startswith(prefix):
|
|
stripped = stripped[len(prefix):]
|
|
break
|
|
stripped = stripped.rstrip('m')
|
|
# Find the separator dash: the '-' immediately preceded by a digit.
|
|
# This correctly handles negatives like "-50--40" (separator after '0')
|
|
# as well as "-10-0", "0-10", "100-999", etc.
|
|
m = re.search(r'(?<=\d)-', stripped)
|
|
if m:
|
|
return float(stripped[:m.start()])
|
|
return float('inf')
|
|
except (ValueError, IndexError):
|
|
return float('inf')
|
|
|
|
# Helper: write one range block
|
|
def _write_range_block(f, range_key, metrics):
|
|
f.write(f"\n [{range_key}]:\n")
|
|
f.write(f" {'Metric':<22} {self.model1_name:<14} {self.model2_name:<14} {'Diff':<14} {'Change %':<11} Result\n")
|
|
f.write(" " + "-"*82 + "\n")
|
|
for error_type, display_name in [
|
|
('lateral_error', 'Lateral (m)'),
|
|
('longitudinal_error', 'Longitudinal (m)'),
|
|
('longitudinal_relative_error','Long Relative'),
|
|
('heading_error', 'Heading (rad)'),
|
|
('heading_error_relaxed', 'Head Relaxed (rad)'),
|
|
]:
|
|
if error_type not in metrics:
|
|
continue
|
|
data = metrics[error_type]
|
|
m1_str = f"{data[self.model1_name]['mean']:.4f}"
|
|
m2_str = f"{data[self.model2_name]['mean']:.4f}"
|
|
diff_str = f"{data['diff']:+.4f}"
|
|
change_str= f"{data['relative_change_%']:+.2f}%"
|
|
result = "✓" if data['improvement'] else "✗"
|
|
f.write(f" {display_name:<22} {m1_str:<14} {m2_str:<14} {diff_str:<14} {change_str:<11} {result}\n")
|
|
m1_n = metrics.get('lateral_error', metrics.get('longitudinal_error', {})).get(self.model1_name, {}).get('samples', '-')
|
|
m2_n = metrics.get('lateral_error', metrics.get('longitudinal_error', {})).get(self.model2_name, {}).get('samples', '-')
|
|
f.write(f" Samples: {self.model1_name}={m1_n} {self.model2_name}={m2_n}\n")
|
|
|
|
# ── Longitudinal distance ranges ──────────────────────────────
|
|
f.write("\n" + "-"*80 + "\n")
|
|
f.write("3D METRICS BY LONGITUDINAL DISTANCE RANGE (long_*)\n")
|
|
f.write("-"*80 + "\n")
|
|
|
|
for class_name, ranges in sorted(self.comparison_results['3d_metrics'].items()):
|
|
# Collect longitudinal-range entries
|
|
long_items = {k: v for k, v in ranges.items()
|
|
if isinstance(v, dict) and k.startswith('long_') and 'lateral_error' in v}
|
|
if not long_items:
|
|
continue
|
|
|
|
f.write(f"\n{class_name.upper()}:\n")
|
|
f.write("-"*80 + "\n")
|
|
|
|
# Compact summary table: one row per distance range, key metrics only
|
|
col1 = 12
|
|
f.write(f"\n Quick summary (mean errors):\n")
|
|
f.write(f" {'Range':<12} {'Samples':>8} "
|
|
f"{'Lat(m)':>8} {'':>9} "
|
|
f"{'Long(m)':>8} {'':>9} "
|
|
f"{'LongRel':>8} {'':>9} "
|
|
f"{'Head(rad)':>9} {''}\n")
|
|
hdr2 = (f" {'':12} {'':>8} "
|
|
f"{self.model1_name:>8} {self.model2_name:>9} "
|
|
f"{self.model1_name:>8} {self.model2_name:>9} "
|
|
f"{self.model1_name:>8} {self.model2_name:>9} "
|
|
f"{self.model1_name:>9} {self.model2_name}\n")
|
|
f.write(hdr2)
|
|
f.write(" " + "-"*110 + "\n")
|
|
|
|
for rk, metrics in sorted(long_items.items(), key=lambda x: _range_sort_key(x[0])):
|
|
label = rk.replace('long_', '')
|
|
n1 = metrics.get('lateral_error', {}).get(self.model1_name, {}).get('samples', 0)
|
|
lat1 = metrics['lateral_error'][self.model1_name]['mean'] if 'lateral_error' in metrics else float('nan')
|
|
lat2 = metrics['lateral_error'][self.model2_name]['mean'] if 'lateral_error' in metrics else float('nan')
|
|
lon1 = metrics['longitudinal_error'][self.model1_name]['mean'] if 'longitudinal_error' in metrics else float('nan')
|
|
lon2 = metrics['longitudinal_error'][self.model2_name]['mean'] if 'longitudinal_error' in metrics else float('nan')
|
|
lr1 = metrics['longitudinal_relative_error'][self.model1_name]['mean'] if 'longitudinal_relative_error' in metrics else float('nan')
|
|
lr2 = metrics['longitudinal_relative_error'][self.model2_name]['mean'] if 'longitudinal_relative_error' in metrics else float('nan')
|
|
h1 = metrics['heading_error'][self.model1_name]['mean'] if 'heading_error' in metrics else float('nan')
|
|
h2 = metrics['heading_error'][self.model2_name]['mean'] if 'heading_error' in metrics else float('nan')
|
|
lat_mark = "✓" if lat2 < lat1 else "✗"
|
|
lon_mark = "✓" if lon2 < lon1 else "✗"
|
|
lr_mark = "✓" if lr2 < lr1 else "✗"
|
|
head_mark = "✓" if h2 < h1 else "✗"
|
|
f.write(f" {label:<12} {n1:>8} "
|
|
f"{lat1:>8.4f} {lat2:>8.4f}{lat_mark} "
|
|
f"{lon1:>8.4f} {lon2:>8.4f}{lon_mark} "
|
|
f"{lr1:>8.4f} {lr2:>8.4f}{lr_mark} "
|
|
f"{h1:>9.4f} {h2:>8.4f}{head_mark}\n")
|
|
|
|
# Detailed per-range blocks
|
|
f.write(f"\n Detailed breakdown:\n")
|
|
for rk, metrics in sorted(long_items.items(), key=lambda x: _range_sort_key(x[0])):
|
|
_write_range_block(f, rk, metrics)
|
|
|
|
# ── Lateral distance ranges ───────────────────────────────────
|
|
f.write("\n\n" + "-"*80 + "\n")
|
|
f.write("3D METRICS BY LATERAL DISTANCE RANGE (lat_*)\n")
|
|
f.write("-"*80 + "\n")
|
|
|
|
for class_name, ranges in sorted(self.comparison_results['3d_metrics'].items()):
|
|
lat_items = {k: v for k, v in ranges.items()
|
|
if isinstance(v, dict) and k.startswith('lat_') and 'lateral_error' in v}
|
|
if not lat_items:
|
|
continue
|
|
|
|
f.write(f"\n{class_name.upper()}:\n")
|
|
f.write("-"*80 + "\n")
|
|
|
|
f.write(f"\n Quick summary (mean errors):\n")
|
|
f.write(f" {'Range':<14} {'Samples':>8} "
|
|
f"{'Lat(m)':>8} {'':>9} "
|
|
f"{'Long(m)':>8} {'':>9} "
|
|
f"{'LongRel':>8} {'':>9} "
|
|
f"{'Head(rad)':>9} {''}\n")
|
|
hdr2 = (f" {'':14} {'':>8} "
|
|
f"{self.model1_name:>8} {self.model2_name:>9} "
|
|
f"{self.model1_name:>8} {self.model2_name:>9} "
|
|
f"{self.model1_name:>8} {self.model2_name:>9} "
|
|
f"{self.model1_name:>9} {self.model2_name}\n")
|
|
f.write(hdr2)
|
|
f.write(" " + "-"*110 + "\n")
|
|
|
|
for rk, metrics in sorted(lat_items.items(), key=lambda x: _range_sort_key(x[0])):
|
|
label = rk.replace('lat_', '')
|
|
n1 = metrics.get('lateral_error', {}).get(self.model1_name, {}).get('samples', 0)
|
|
lat1 = metrics['lateral_error'][self.model1_name]['mean'] if 'lateral_error' in metrics else float('nan')
|
|
lat2 = metrics['lateral_error'][self.model2_name]['mean'] if 'lateral_error' in metrics else float('nan')
|
|
lon1 = metrics['longitudinal_error'][self.model1_name]['mean'] if 'longitudinal_error' in metrics else float('nan')
|
|
lon2 = metrics['longitudinal_error'][self.model2_name]['mean'] if 'longitudinal_error' in metrics else float('nan')
|
|
lr1 = metrics['longitudinal_relative_error'][self.model1_name]['mean'] if 'longitudinal_relative_error' in metrics else float('nan')
|
|
lr2 = metrics['longitudinal_relative_error'][self.model2_name]['mean'] if 'longitudinal_relative_error' in metrics else float('nan')
|
|
h1 = metrics['heading_error'][self.model1_name]['mean'] if 'heading_error' in metrics else float('nan')
|
|
h2 = metrics['heading_error'][self.model2_name]['mean'] if 'heading_error' in metrics else float('nan')
|
|
lat_mark = "✓" if lat2 < lat1 else "✗"
|
|
lon_mark = "✓" if lon2 < lon1 else "✗"
|
|
lr_mark = "✓" if lr2 < lr1 else "✗"
|
|
head_mark = "✓" if h2 < h1 else "✗"
|
|
f.write(f" {label:<14} {n1:>8} "
|
|
f"{lat1:>8.4f} {lat2:>8.4f}{lat_mark} "
|
|
f"{lon1:>8.4f} {lon2:>8.4f}{lon_mark} "
|
|
f"{lr1:>8.4f} {lr2:>8.4f}{lr_mark} "
|
|
f"{h1:>9.4f} {h2:>8.4f}{head_mark}\n")
|
|
|
|
f.write(f"\n Detailed breakdown:\n")
|
|
for rk, metrics in sorted(lat_items.items(), key=lambda x: _range_sort_key(x[0])):
|
|
_write_range_block(f, rk, metrics)
|
|
|
|
# Summary
|
|
if 'summary' in self.comparison_results:
|
|
f.write("\n" + "="*80 + "\n")
|
|
f.write("SUMMARY\n")
|
|
f.write("="*80 + "\n\n")
|
|
|
|
summary = self.comparison_results['summary']
|
|
|
|
f.write(f"2D Detection (by AP):\n")
|
|
f.write(f" {self.model2_name} wins: {summary['2d']['ap']['wins']}\n")
|
|
f.write(f" {self.model1_name} wins: {summary['2d']['ap']['losses']}\n")
|
|
f.write(f" Ties: {summary['2d']['ap']['ties']}\n\n")
|
|
|
|
f.write(f"2D Detection (by F1 Score):\n")
|
|
f.write(f" {self.model2_name} wins: {summary['2d']['f1_score']['wins']}\n")
|
|
f.write(f" {self.model1_name} wins: {summary['2d']['f1_score']['losses']}\n")
|
|
f.write(f" Ties: {summary['2d']['f1_score']['ties']}\n\n")
|
|
|
|
f.write(f"3D Detection:\n")
|
|
f.write(f" By Lateral Error:\n")
|
|
f.write(f" {self.model2_name} wins: {summary['3d']['lateral']['wins']}\n")
|
|
f.write(f" {self.model1_name} wins: {summary['3d']['lateral']['losses']}\n")
|
|
f.write(f" Ties: {summary['3d']['lateral']['ties']}\n")
|
|
f.write(f" By Longitudinal Error:\n")
|
|
f.write(f" {self.model2_name} wins: {summary['3d']['longitudinal']['wins']}\n")
|
|
f.write(f" {self.model1_name} wins: {summary['3d']['longitudinal']['losses']}\n")
|
|
f.write(f" Ties: {summary['3d']['longitudinal']['ties']}\n")
|
|
f.write(f" By Heading Error:\n")
|
|
f.write(f" {self.model2_name} wins: {summary['3d']['heading']['wins']}\n")
|
|
f.write(f" {self.model1_name} wins: {summary['3d']['heading']['losses']}\n")
|
|
f.write(f" Ties: {summary['3d']['heading']['ties']}\n")
|
|
|
|
print(f"✓ Text report saved to: {output_file}")
|
|
|
|
def generate_json_report(self, output_file):
|
|
"""Generate JSON report."""
|
|
print(f"\nGenerating JSON report: {output_file}")
|
|
|
|
with open(output_file, 'w') as f:
|
|
json.dump(self.comparison_results, f, indent=2)
|
|
|
|
print(f"✓ JSON report saved to: {output_file}")
|
|
|
|
def compare_all(self):
|
|
"""Run all comparisons."""
|
|
self.compare_2d_metrics()
|
|
self.compare_3d_metrics()
|
|
self.compare_per_case()
|
|
self.generate_summary_stats()
|
|
|
|
return self.comparison_results
|
|
|
|
|
|
def main():
|
|
"""Main function."""
|
|
parser = argparse.ArgumentParser(
|
|
description='Compare evaluation results from two models',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
)
|
|
|
|
parser.add_argument('--model1', type=str, required=True,
|
|
help='Path to model 1 evaluation report JSON')
|
|
parser.add_argument('--model2', type=str, required=True,
|
|
help='Path to model 2 evaluation report JSON')
|
|
parser.add_argument('--output-dir', type=str, default='comparison_results',
|
|
help='Output directory for comparison results')
|
|
parser.add_argument('--model1-name', type=str, default='Model-1',
|
|
help='Display name for model 1')
|
|
parser.add_argument('--model2-name', type=str, default='Model-2',
|
|
help='Display name for model 2')
|
|
parser.add_argument('--common-matches', type=str, default=None,
|
|
help='Path to common_matches.json from find_common_matches.py. '
|
|
'If provided, 3D comparison will use common matches only.')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load reports
|
|
print("="*80)
|
|
print("MODEL COMPARISON TOOL")
|
|
print("="*80)
|
|
print(f"\nLoading model 1: {args.model1}")
|
|
with open(args.model1, 'r') as f:
|
|
model1_report = json.load(f)
|
|
|
|
print(f"Loading model 2: {args.model2}")
|
|
with open(args.model2, 'r') as f:
|
|
model2_report = json.load(f)
|
|
|
|
# Load common matches if provided
|
|
common_matches_data = None
|
|
if args.common_matches:
|
|
print(f"Loading common matches: {args.common_matches}")
|
|
with open(args.common_matches, 'r') as f:
|
|
common_matches_data = json.load(f)
|
|
print("✓ Will use common matches for 3D comparison")
|
|
|
|
# Create output directory
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# Derive detailed_3d_matches.json paths (same dir as evaluation_report.json)
|
|
model1_detailed = str(Path(args.model1).parent / 'detailed_3d_matches.json')
|
|
model2_detailed = str(Path(args.model2).parent / 'detailed_3d_matches.json')
|
|
|
|
# Compare models
|
|
comparator = ModelComparator(
|
|
model1_report,
|
|
model2_report,
|
|
model1_name=args.model1_name,
|
|
model2_name=args.model2_name,
|
|
common_matches_data=common_matches_data,
|
|
model1_detailed_path=model1_detailed,
|
|
model2_detailed_path=model2_detailed,
|
|
)
|
|
|
|
results = comparator.compare_all()
|
|
|
|
# Generate reports
|
|
text_output = os.path.join(args.output_dir, 'comparison_report.txt')
|
|
json_output = os.path.join(args.output_dir, 'comparison_report.json')
|
|
|
|
comparator.generate_text_report(text_output)
|
|
comparator.generate_json_report(json_output)
|
|
|
|
print("\n" + "="*80)
|
|
print("COMPARISON COMPLETE")
|
|
print("="*80)
|
|
print(f"\nResults saved to: {args.output_dir}/")
|
|
print(f" - Text report: comparison_report.txt")
|
|
print(f" - JSON report: comparison_report.json")
|
|
if common_matches_data:
|
|
print(f"\nNote: 3D metrics comparison is based on common matches only.")
|
|
print(f" Matched by both models: {common_matches_data['match_statistics']['common']:,} samples")
|
|
print("")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|