375 lines
14 KiB
Python
Executable File
375 lines
14 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Model Evaluation Comparison Tool with Visualization
|
|
|
|
Extended version with plotting capabilities.
|
|
|
|
Usage:
|
|
python eval_tools/compare_models_visualize.py \
|
|
--model1 eval_results/model1/evaluation_report.json \
|
|
--model2 eval_results/model2/evaluation_report.json \
|
|
--output-dir comparison_results \
|
|
--model1-name "mono3d" \
|
|
--model2-name "yolov5s-300w"
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add parent directory to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from eval_tools.compare_models import ModelComparator
|
|
|
|
try:
|
|
import matplotlib
|
|
matplotlib.use('Agg') # Use non-interactive backend
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
MATPLOTLIB_AVAILABLE = True
|
|
except ImportError:
|
|
MATPLOTLIB_AVAILABLE = False
|
|
print("Warning: matplotlib not available, visualization will be skipped")
|
|
|
|
|
|
class VisualizationComparator(ModelComparator):
|
|
"""Extended comparator with visualization capabilities."""
|
|
|
|
def plot_2d_metrics_comparison(self, output_dir):
|
|
"""Plot 2D metrics comparison."""
|
|
if not MATPLOTLIB_AVAILABLE:
|
|
print("Skipping 2D metrics plot (matplotlib not available)")
|
|
return
|
|
|
|
print("\nGenerating 2D metrics comparison plots...")
|
|
|
|
comparison = self.comparison_results.get('2d_metrics', {})
|
|
if not comparison:
|
|
return
|
|
|
|
# Overall metrics bar chart
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
|
fig.suptitle('2D Detection Metrics - Overall Comparison', fontsize=16, fontweight='bold')
|
|
|
|
overall = comparison['overall']
|
|
metrics = ['precision', 'recall', 'map']
|
|
metric_names = ['Precision', 'Recall', 'mAP']
|
|
|
|
for idx, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
|
|
if metric not in overall:
|
|
continue
|
|
|
|
values = overall[metric]
|
|
model_names = [self.model1_name, self.model2_name]
|
|
model_values = [values[self.model1_name], values[self.model2_name]]
|
|
|
|
bars = axes[idx].bar(model_names, model_values, color=['#3498db', '#e74c3c'])
|
|
axes[idx].set_ylabel(metric_name)
|
|
axes[idx].set_title(f'{metric_name}')
|
|
axes[idx].set_ylim(0, 1.0)
|
|
axes[idx].grid(axis='y', alpha=0.3)
|
|
|
|
# Add value labels on bars
|
|
for bar in bars:
|
|
height = bar.get_height()
|
|
axes[idx].text(bar.get_x() + bar.get_width()/2., height,
|
|
f'{height:.3f}',
|
|
ha='center', va='bottom', fontsize=10)
|
|
|
|
plt.tight_layout()
|
|
output_file = os.path.join(output_dir, 'comparison_2d_overall.png')
|
|
plt.savefig(output_file, dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
print(f" ✓ Saved: {output_file}")
|
|
|
|
# Per-class AP comparison
|
|
per_class = comparison.get('per_class', {})
|
|
if per_class:
|
|
class_names = sorted(per_class.keys())
|
|
m1_aps = [per_class[c]['ap'][self.model1_name] for c in class_names]
|
|
m2_aps = [per_class[c]['ap'][self.model2_name] for c in class_names]
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
x = np.arange(len(class_names))
|
|
width = 0.35
|
|
|
|
bars1 = ax.bar(x - width/2, m1_aps, width, label=self.model1_name, color='#3498db')
|
|
bars2 = ax.bar(x + width/2, m2_aps, width, label=self.model2_name, color='#e74c3c')
|
|
|
|
ax.set_xlabel('Class', fontsize=12)
|
|
ax.set_ylabel('Average Precision (AP)', fontsize=12)
|
|
ax.set_title('Per-Class AP Comparison', fontsize=14, fontweight='bold')
|
|
ax.set_xticks(x)
|
|
ax.set_xticklabels(class_names, rotation=45, ha='right')
|
|
ax.legend()
|
|
ax.grid(axis='y', alpha=0.3)
|
|
|
|
plt.tight_layout()
|
|
output_file = os.path.join(output_dir, 'comparison_2d_per_class.png')
|
|
plt.savefig(output_file, dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
print(f" ✓ Saved: {output_file}")
|
|
|
|
def plot_3d_metrics_comparison(self, output_dir):
|
|
"""Plot 3D metrics comparison."""
|
|
if not MATPLOTLIB_AVAILABLE:
|
|
print("Skipping 3D metrics plot (matplotlib not available)")
|
|
return
|
|
|
|
print("\nGenerating 3D metrics comparison plots...")
|
|
|
|
comparison = self.comparison_results.get('3d_metrics', {})
|
|
if not comparison:
|
|
return
|
|
|
|
# Sort distance ranges by starting distance value
|
|
def get_range_start(range_key):
|
|
if range_key == 'overall':
|
|
return -1 # Put 'overall' at the beginning
|
|
try:
|
|
# Extract starting distance from format like "0-20m" or "100-999m"
|
|
return int(range_key.split('-')[0])
|
|
except (ValueError, IndexError):
|
|
return float('inf')
|
|
|
|
# For each class with distance ranges
|
|
for class_name, ranges in comparison.items():
|
|
if not ranges:
|
|
continue
|
|
|
|
# Check if we have distance ranges, sorted by distance
|
|
range_keys = sorted([k for k in ranges.keys() if k != 'overall'], key=get_range_start)
|
|
if not range_keys:
|
|
range_keys = ['overall']
|
|
|
|
# Create subplots for lateral, longitudinal, and heading errors
|
|
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
|
fig.suptitle(f'3D Detection Metrics - {class_name.upper()}', fontsize=16, fontweight='bold')
|
|
|
|
error_types = ['lateral_error', 'longitudinal_error', 'heading_error']
|
|
error_names = ['Lateral Error (m)', 'Longitudinal Error (m)', 'Heading Error (rad)']
|
|
|
|
for idx, (error_type, error_name) in enumerate(zip(error_types, error_names)):
|
|
m1_values = []
|
|
m2_values = []
|
|
m1_stds = []
|
|
m2_stds = []
|
|
labels = []
|
|
|
|
for range_key in range_keys:
|
|
if range_key not in ranges:
|
|
continue
|
|
|
|
metrics = ranges[range_key]
|
|
if error_type not in metrics:
|
|
continue
|
|
|
|
data = metrics[error_type]
|
|
m1_values.append(data[self.model1_name]['mean'])
|
|
m2_values.append(data[self.model2_name]['mean'])
|
|
m1_stds.append(data[self.model1_name]['std'])
|
|
m2_stds.append(data[self.model2_name]['std'])
|
|
labels.append(range_key)
|
|
|
|
if not m1_values:
|
|
continue
|
|
|
|
x = np.arange(len(labels))
|
|
width = 0.35
|
|
|
|
bars1 = axes[idx].bar(x - width/2, m1_values, width,
|
|
yerr=m1_stds, label=self.model1_name,
|
|
color='#3498db', alpha=0.8, capsize=5)
|
|
bars2 = axes[idx].bar(x + width/2, m2_values, width,
|
|
yerr=m2_stds, label=self.model2_name,
|
|
color='#e74c3c', alpha=0.8, capsize=5)
|
|
|
|
axes[idx].set_xlabel('Distance Range', fontsize=10)
|
|
axes[idx].set_ylabel(error_name, fontsize=10)
|
|
axes[idx].set_title(error_name.split('(')[0], fontsize=12)
|
|
axes[idx].set_xticks(x)
|
|
axes[idx].set_xticklabels(labels, rotation=45, ha='right')
|
|
axes[idx].legend()
|
|
axes[idx].grid(axis='y', alpha=0.3)
|
|
|
|
plt.tight_layout()
|
|
output_file = os.path.join(output_dir, f'comparison_3d_{class_name}.png')
|
|
plt.savefig(output_file, dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
print(f" ✓ Saved: {output_file}")
|
|
|
|
def plot_improvement_heatmap(self, output_dir):
|
|
"""Plot improvement heatmap for 3D metrics."""
|
|
if not MATPLOTLIB_AVAILABLE:
|
|
print("Skipping improvement heatmap (matplotlib not available)")
|
|
return
|
|
|
|
print("\nGenerating improvement heatmap...")
|
|
|
|
comparison = self.comparison_results.get('3d_metrics', {})
|
|
if not comparison:
|
|
return
|
|
|
|
# Collect improvement data
|
|
data_matrix = []
|
|
row_labels = []
|
|
col_labels = ['Lateral', 'Longitudinal', 'Heading']
|
|
|
|
# Sort distance ranges by starting distance value
|
|
def get_range_start(range_key):
|
|
if range_key == 'overall':
|
|
return -1 # Put 'overall' at the beginning of each class
|
|
try:
|
|
# Extract starting distance from format like "0-20m" or "100-999m"
|
|
return int(range_key.split('-')[0])
|
|
except (ValueError, IndexError):
|
|
return float('inf')
|
|
|
|
for class_name, ranges in sorted(comparison.items()):
|
|
# Sort ranges: overall first, then by distance
|
|
sorted_range_keys = sorted(ranges.keys(), key=get_range_start)
|
|
for range_key in sorted_range_keys:
|
|
metrics = ranges[range_key]
|
|
|
|
row_data = []
|
|
for error_type in ['lateral_error', 'longitudinal_error', 'heading_error']:
|
|
if error_type in metrics:
|
|
# Negative change % means improvement (lower error)
|
|
change = -metrics[error_type]['relative_change_%']
|
|
row_data.append(change)
|
|
else:
|
|
row_data.append(0)
|
|
|
|
if any(x != 0 for x in row_data):
|
|
data_matrix.append(row_data)
|
|
label = f"{class_name}\n{range_key}"
|
|
row_labels.append(label)
|
|
|
|
if not data_matrix:
|
|
return
|
|
|
|
data_matrix = np.array(data_matrix)
|
|
|
|
fig, ax = plt.subplots(figsize=(10, max(6, len(row_labels) * 0.5)))
|
|
|
|
# Create heatmap
|
|
im = ax.imshow(data_matrix, cmap='RdYlGn', aspect='auto', vmin=-50, vmax=50)
|
|
|
|
# Set ticks
|
|
ax.set_xticks(np.arange(len(col_labels)))
|
|
ax.set_yticks(np.arange(len(row_labels)))
|
|
ax.set_xticklabels(col_labels)
|
|
ax.set_yticklabels(row_labels, fontsize=8)
|
|
|
|
# Add colorbar
|
|
cbar = plt.colorbar(im, ax=ax)
|
|
cbar.set_label(f'Improvement % ({self.model2_name} vs {self.model1_name})', rotation=270, labelpad=20)
|
|
|
|
# Add text annotations
|
|
for i in range(len(row_labels)):
|
|
for j in range(len(col_labels)):
|
|
text = ax.text(j, i, f'{data_matrix[i, j]:.1f}%',
|
|
ha="center", va="center", color="black", fontsize=8)
|
|
|
|
ax.set_title(f'3D Metrics Improvement Heatmap\n(Positive = {self.model2_name} Better)',
|
|
fontsize=14, fontweight='bold')
|
|
|
|
plt.tight_layout()
|
|
output_file = os.path.join(output_dir, 'comparison_3d_improvement_heatmap.png')
|
|
plt.savefig(output_file, dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
print(f" ✓ Saved: {output_file}")
|
|
|
|
def generate_visualizations(self, output_dir):
|
|
"""Generate all visualizations."""
|
|
if not MATPLOTLIB_AVAILABLE:
|
|
print("\n⚠ Matplotlib not available, skipping visualizations")
|
|
print("Install with: pip install matplotlib")
|
|
return
|
|
|
|
print("\n" + "="*80)
|
|
print("GENERATING VISUALIZATIONS")
|
|
print("="*80)
|
|
|
|
self.plot_2d_metrics_comparison(output_dir)
|
|
self.plot_3d_metrics_comparison(output_dir)
|
|
self.plot_improvement_heatmap(output_dir)
|
|
|
|
print("\n✓ All visualizations generated")
|
|
|
|
|
|
def main():
|
|
"""Main function."""
|
|
parser = argparse.ArgumentParser(
|
|
description='Compare evaluation results from two models with visualization',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
)
|
|
|
|
parser.add_argument('--model1', type=str, required=True,
|
|
help='Path to model 1 evaluation report JSON')
|
|
parser.add_argument('--model2', type=str, required=True,
|
|
help='Path to model 2 evaluation report JSON')
|
|
parser.add_argument('--output-dir', type=str, default='comparison_results',
|
|
help='Output directory for comparison results')
|
|
parser.add_argument('--model1-name', type=str, default='Model-1',
|
|
help='Display name for model 1')
|
|
parser.add_argument('--model2-name', type=str, default='Model-2',
|
|
help='Display name for model 2')
|
|
parser.add_argument('--no-plots', action='store_true',
|
|
help='Skip visualization generation')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load reports
|
|
print("="*80)
|
|
print("MODEL COMPARISON TOOL (with Visualization)")
|
|
print("="*80)
|
|
print(f"\nLoading model 1: {args.model1}")
|
|
with open(args.model1, 'r') as f:
|
|
model1_report = json.load(f)
|
|
|
|
print(f"Loading model 2: {args.model2}")
|
|
with open(args.model2, 'r') as f:
|
|
model2_report = json.load(f)
|
|
|
|
# Create output directory
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# Compare models
|
|
comparator = VisualizationComparator(
|
|
model1_report,
|
|
model2_report,
|
|
model1_name=args.model1_name,
|
|
model2_name=args.model2_name
|
|
)
|
|
|
|
results = comparator.compare_all()
|
|
|
|
# Generate reports
|
|
text_output = os.path.join(args.output_dir, 'comparison_report.txt')
|
|
json_output = os.path.join(args.output_dir, 'comparison_report.json')
|
|
|
|
comparator.generate_text_report(text_output)
|
|
comparator.generate_json_report(json_output)
|
|
|
|
# Generate visualizations
|
|
if not args.no_plots:
|
|
comparator.generate_visualizations(args.output_dir)
|
|
|
|
print("\n" + "="*80)
|
|
print("COMPARISON COMPLETE")
|
|
print("="*80)
|
|
print(f"\nResults saved to: {args.output_dir}/")
|
|
print(f" - Text report: comparison_report.txt")
|
|
print(f" - JSON report: comparison_report.json")
|
|
if not args.no_plots and MATPLOTLIB_AVAILABLE:
|
|
print(f" - Visualization plots: comparison_*.png")
|
|
print("")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|