#!/usr/bin/env python3 """ Model Evaluation Comparison Tool with Visualization Extended version with plotting capabilities. Usage: python eval_tools/compare_models_visualize.py \ --model1 eval_results/model1/evaluation_report.json \ --model2 eval_results/model2/evaluation_report.json \ --output-dir comparison_results \ --model1-name "mono3d" \ --model2-name "yolov5s-300w" """ import argparse import json import os import sys from pathlib import Path # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) from eval_tools.compare_models import ModelComparator try: import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.pyplot as plt import numpy as np MATPLOTLIB_AVAILABLE = True except ImportError: MATPLOTLIB_AVAILABLE = False print("Warning: matplotlib not available, visualization will be skipped") class VisualizationComparator(ModelComparator): """Extended comparator with visualization capabilities.""" def plot_2d_metrics_comparison(self, output_dir): """Plot 2D metrics comparison.""" if not MATPLOTLIB_AVAILABLE: print("Skipping 2D metrics plot (matplotlib not available)") return print("\nGenerating 2D metrics comparison plots...") comparison = self.comparison_results.get('2d_metrics', {}) if not comparison: return # Overall metrics bar chart fig, axes = plt.subplots(1, 3, figsize=(15, 5)) fig.suptitle('2D Detection Metrics - Overall Comparison', fontsize=16, fontweight='bold') overall = comparison['overall'] metrics = ['precision', 'recall', 'map'] metric_names = ['Precision', 'Recall', 'mAP'] for idx, (metric, metric_name) in enumerate(zip(metrics, metric_names)): if metric not in overall: continue values = overall[metric] model_names = [self.model1_name, self.model2_name] model_values = [values[self.model1_name], values[self.model2_name]] bars = axes[idx].bar(model_names, model_values, color=['#3498db', '#e74c3c']) axes[idx].set_ylabel(metric_name) axes[idx].set_title(f'{metric_name}') axes[idx].set_ylim(0, 1.0) axes[idx].grid(axis='y', alpha=0.3) # Add value labels on bars for bar in bars: height = bar.get_height() axes[idx].text(bar.get_x() + bar.get_width()/2., height, f'{height:.3f}', ha='center', va='bottom', fontsize=10) plt.tight_layout() output_file = os.path.join(output_dir, 'comparison_2d_overall.png') plt.savefig(output_file, dpi=150, bbox_inches='tight') plt.close() print(f" ✓ Saved: {output_file}") # Per-class AP comparison per_class = comparison.get('per_class', {}) if per_class: class_names = sorted(per_class.keys()) m1_aps = [per_class[c]['ap'][self.model1_name] for c in class_names] m2_aps = [per_class[c]['ap'][self.model2_name] for c in class_names] fig, ax = plt.subplots(figsize=(12, 6)) x = np.arange(len(class_names)) width = 0.35 bars1 = ax.bar(x - width/2, m1_aps, width, label=self.model1_name, color='#3498db') bars2 = ax.bar(x + width/2, m2_aps, width, label=self.model2_name, color='#e74c3c') ax.set_xlabel('Class', fontsize=12) ax.set_ylabel('Average Precision (AP)', fontsize=12) ax.set_title('Per-Class AP Comparison', fontsize=14, fontweight='bold') ax.set_xticks(x) ax.set_xticklabels(class_names, rotation=45, ha='right') ax.legend() ax.grid(axis='y', alpha=0.3) plt.tight_layout() output_file = os.path.join(output_dir, 'comparison_2d_per_class.png') plt.savefig(output_file, dpi=150, bbox_inches='tight') plt.close() print(f" ✓ Saved: {output_file}") def plot_3d_metrics_comparison(self, output_dir): """Plot 3D metrics comparison.""" if not MATPLOTLIB_AVAILABLE: print("Skipping 3D metrics plot (matplotlib not available)") return print("\nGenerating 3D metrics comparison plots...") comparison = self.comparison_results.get('3d_metrics', {}) if not comparison: return # Sort distance ranges by starting distance value def get_range_start(range_key): if range_key == 'overall': return -1 # Put 'overall' at the beginning try: # Extract starting distance from format like "0-20m" or "100-999m" return int(range_key.split('-')[0]) except (ValueError, IndexError): return float('inf') # For each class with distance ranges for class_name, ranges in comparison.items(): if not ranges: continue # Check if we have distance ranges, sorted by distance range_keys = sorted([k for k in ranges.keys() if k != 'overall'], key=get_range_start) if not range_keys: range_keys = ['overall'] # Create subplots for lateral, longitudinal, and heading errors fig, axes = plt.subplots(1, 3, figsize=(18, 5)) fig.suptitle(f'3D Detection Metrics - {class_name.upper()}', fontsize=16, fontweight='bold') error_types = ['lateral_error', 'longitudinal_error', 'heading_error'] error_names = ['Lateral Error (m)', 'Longitudinal Error (m)', 'Heading Error (rad)'] for idx, (error_type, error_name) in enumerate(zip(error_types, error_names)): m1_values = [] m2_values = [] m1_stds = [] m2_stds = [] labels = [] for range_key in range_keys: if range_key not in ranges: continue metrics = ranges[range_key] if error_type not in metrics: continue data = metrics[error_type] m1_values.append(data[self.model1_name]['mean']) m2_values.append(data[self.model2_name]['mean']) m1_stds.append(data[self.model1_name]['std']) m2_stds.append(data[self.model2_name]['std']) labels.append(range_key) if not m1_values: continue x = np.arange(len(labels)) width = 0.35 bars1 = axes[idx].bar(x - width/2, m1_values, width, yerr=m1_stds, label=self.model1_name, color='#3498db', alpha=0.8, capsize=5) bars2 = axes[idx].bar(x + width/2, m2_values, width, yerr=m2_stds, label=self.model2_name, color='#e74c3c', alpha=0.8, capsize=5) axes[idx].set_xlabel('Distance Range', fontsize=10) axes[idx].set_ylabel(error_name, fontsize=10) axes[idx].set_title(error_name.split('(')[0], fontsize=12) axes[idx].set_xticks(x) axes[idx].set_xticklabels(labels, rotation=45, ha='right') axes[idx].legend() axes[idx].grid(axis='y', alpha=0.3) plt.tight_layout() output_file = os.path.join(output_dir, f'comparison_3d_{class_name}.png') plt.savefig(output_file, dpi=150, bbox_inches='tight') plt.close() print(f" ✓ Saved: {output_file}") def plot_improvement_heatmap(self, output_dir): """Plot improvement heatmap for 3D metrics.""" if not MATPLOTLIB_AVAILABLE: print("Skipping improvement heatmap (matplotlib not available)") return print("\nGenerating improvement heatmap...") comparison = self.comparison_results.get('3d_metrics', {}) if not comparison: return # Collect improvement data data_matrix = [] row_labels = [] col_labels = ['Lateral', 'Longitudinal', 'Heading'] # Sort distance ranges by starting distance value def get_range_start(range_key): if range_key == 'overall': return -1 # Put 'overall' at the beginning of each class try: # Extract starting distance from format like "0-20m" or "100-999m" return int(range_key.split('-')[0]) except (ValueError, IndexError): return float('inf') for class_name, ranges in sorted(comparison.items()): # Sort ranges: overall first, then by distance sorted_range_keys = sorted(ranges.keys(), key=get_range_start) for range_key in sorted_range_keys: metrics = ranges[range_key] row_data = [] for error_type in ['lateral_error', 'longitudinal_error', 'heading_error']: if error_type in metrics: # Negative change % means improvement (lower error) change = -metrics[error_type]['relative_change_%'] row_data.append(change) else: row_data.append(0) if any(x != 0 for x in row_data): data_matrix.append(row_data) label = f"{class_name}\n{range_key}" row_labels.append(label) if not data_matrix: return data_matrix = np.array(data_matrix) fig, ax = plt.subplots(figsize=(10, max(6, len(row_labels) * 0.5))) # Create heatmap im = ax.imshow(data_matrix, cmap='RdYlGn', aspect='auto', vmin=-50, vmax=50) # Set ticks ax.set_xticks(np.arange(len(col_labels))) ax.set_yticks(np.arange(len(row_labels))) ax.set_xticklabels(col_labels) ax.set_yticklabels(row_labels, fontsize=8) # Add colorbar cbar = plt.colorbar(im, ax=ax) cbar.set_label(f'Improvement % ({self.model2_name} vs {self.model1_name})', rotation=270, labelpad=20) # Add text annotations for i in range(len(row_labels)): for j in range(len(col_labels)): text = ax.text(j, i, f'{data_matrix[i, j]:.1f}%', ha="center", va="center", color="black", fontsize=8) ax.set_title(f'3D Metrics Improvement Heatmap\n(Positive = {self.model2_name} Better)', fontsize=14, fontweight='bold') plt.tight_layout() output_file = os.path.join(output_dir, 'comparison_3d_improvement_heatmap.png') plt.savefig(output_file, dpi=150, bbox_inches='tight') plt.close() print(f" ✓ Saved: {output_file}") def generate_visualizations(self, output_dir): """Generate all visualizations.""" if not MATPLOTLIB_AVAILABLE: print("\n⚠ Matplotlib not available, skipping visualizations") print("Install with: pip install matplotlib") return print("\n" + "="*80) print("GENERATING VISUALIZATIONS") print("="*80) self.plot_2d_metrics_comparison(output_dir) self.plot_3d_metrics_comparison(output_dir) self.plot_improvement_heatmap(output_dir) print("\n✓ All visualizations generated") def main(): """Main function.""" parser = argparse.ArgumentParser( description='Compare evaluation results from two models with visualization', formatter_class=argparse.RawDescriptionHelpFormatter ) parser.add_argument('--model1', type=str, required=True, help='Path to model 1 evaluation report JSON') parser.add_argument('--model2', type=str, required=True, help='Path to model 2 evaluation report JSON') parser.add_argument('--output-dir', type=str, default='comparison_results', help='Output directory for comparison results') parser.add_argument('--model1-name', type=str, default='Model-1', help='Display name for model 1') parser.add_argument('--model2-name', type=str, default='Model-2', help='Display name for model 2') parser.add_argument('--no-plots', action='store_true', help='Skip visualization generation') args = parser.parse_args() # Load reports print("="*80) print("MODEL COMPARISON TOOL (with Visualization)") print("="*80) print(f"\nLoading model 1: {args.model1}") with open(args.model1, 'r') as f: model1_report = json.load(f) print(f"Loading model 2: {args.model2}") with open(args.model2, 'r') as f: model2_report = json.load(f) # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Compare models comparator = VisualizationComparator( model1_report, model2_report, model1_name=args.model1_name, model2_name=args.model2_name ) results = comparator.compare_all() # Generate reports text_output = os.path.join(args.output_dir, 'comparison_report.txt') json_output = os.path.join(args.output_dir, 'comparison_report.json') comparator.generate_text_report(text_output) comparator.generate_json_report(json_output) # Generate visualizations if not args.no_plots: comparator.generate_visualizations(args.output_dir) print("\n" + "="*80) print("COMPARISON COMPLETE") print("="*80) print(f"\nResults saved to: {args.output_dir}/") print(f" - Text report: comparison_report.txt") print(f" - JSON report: comparison_report.json") if not args.no_plots and MATPLOTLIB_AVAILABLE: print(f" - Visualization plots: comparison_*.png") print("") if __name__ == '__main__': main()