Files
yolov26_3d/eval_tools/model_comparison/compare_models_visualize.py
2026-06-24 09:35:46 +08:00

375 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Model Evaluation Comparison Tool with Visualization
Extended version with plotting capabilities.
Usage:
python eval_tools/compare_models_visualize.py \
--model1 eval_results/model1/evaluation_report.json \
--model2 eval_results/model2/evaluation_report.json \
--output-dir comparison_results \
--model1-name "mono3d" \
--model2-name "yolov5s-300w"
"""
import argparse
import json
import os
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from eval_tools.compare_models import ModelComparator
try:
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
import numpy as np
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
print("Warning: matplotlib not available, visualization will be skipped")
class VisualizationComparator(ModelComparator):
"""Extended comparator with visualization capabilities."""
def plot_2d_metrics_comparison(self, output_dir):
"""Plot 2D metrics comparison."""
if not MATPLOTLIB_AVAILABLE:
print("Skipping 2D metrics plot (matplotlib not available)")
return
print("\nGenerating 2D metrics comparison plots...")
comparison = self.comparison_results.get('2d_metrics', {})
if not comparison:
return
# Overall metrics bar chart
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('2D Detection Metrics - Overall Comparison', fontsize=16, fontweight='bold')
overall = comparison['overall']
metrics = ['precision', 'recall', 'map']
metric_names = ['Precision', 'Recall', 'mAP']
for idx, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
if metric not in overall:
continue
values = overall[metric]
model_names = [self.model1_name, self.model2_name]
model_values = [values[self.model1_name], values[self.model2_name]]
bars = axes[idx].bar(model_names, model_values, color=['#3498db', '#e74c3c'])
axes[idx].set_ylabel(metric_name)
axes[idx].set_title(f'{metric_name}')
axes[idx].set_ylim(0, 1.0)
axes[idx].grid(axis='y', alpha=0.3)
# Add value labels on bars
for bar in bars:
height = bar.get_height()
axes[idx].text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}',
ha='center', va='bottom', fontsize=10)
plt.tight_layout()
output_file = os.path.join(output_dir, 'comparison_2d_overall.png')
plt.savefig(output_file, dpi=150, bbox_inches='tight')
plt.close()
print(f" ✓ Saved: {output_file}")
# Per-class AP comparison
per_class = comparison.get('per_class', {})
if per_class:
class_names = sorted(per_class.keys())
m1_aps = [per_class[c]['ap'][self.model1_name] for c in class_names]
m2_aps = [per_class[c]['ap'][self.model2_name] for c in class_names]
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(class_names))
width = 0.35
bars1 = ax.bar(x - width/2, m1_aps, width, label=self.model1_name, color='#3498db')
bars2 = ax.bar(x + width/2, m2_aps, width, label=self.model2_name, color='#e74c3c')
ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Average Precision (AP)', fontsize=12)
ax.set_title('Per-Class AP Comparison', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
output_file = os.path.join(output_dir, 'comparison_2d_per_class.png')
plt.savefig(output_file, dpi=150, bbox_inches='tight')
plt.close()
print(f" ✓ Saved: {output_file}")
def plot_3d_metrics_comparison(self, output_dir):
"""Plot 3D metrics comparison."""
if not MATPLOTLIB_AVAILABLE:
print("Skipping 3D metrics plot (matplotlib not available)")
return
print("\nGenerating 3D metrics comparison plots...")
comparison = self.comparison_results.get('3d_metrics', {})
if not comparison:
return
# Sort distance ranges by starting distance value
def get_range_start(range_key):
if range_key == 'overall':
return -1 # Put 'overall' at the beginning
try:
# Extract starting distance from format like "0-20m" or "100-999m"
return int(range_key.split('-')[0])
except (ValueError, IndexError):
return float('inf')
# For each class with distance ranges
for class_name, ranges in comparison.items():
if not ranges:
continue
# Check if we have distance ranges, sorted by distance
range_keys = sorted([k for k in ranges.keys() if k != 'overall'], key=get_range_start)
if not range_keys:
range_keys = ['overall']
# Create subplots for lateral, longitudinal, and heading errors
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle(f'3D Detection Metrics - {class_name.upper()}', fontsize=16, fontweight='bold')
error_types = ['lateral_error', 'longitudinal_error', 'heading_error']
error_names = ['Lateral Error (m)', 'Longitudinal Error (m)', 'Heading Error (rad)']
for idx, (error_type, error_name) in enumerate(zip(error_types, error_names)):
m1_values = []
m2_values = []
m1_stds = []
m2_stds = []
labels = []
for range_key in range_keys:
if range_key not in ranges:
continue
metrics = ranges[range_key]
if error_type not in metrics:
continue
data = metrics[error_type]
m1_values.append(data[self.model1_name]['mean'])
m2_values.append(data[self.model2_name]['mean'])
m1_stds.append(data[self.model1_name]['std'])
m2_stds.append(data[self.model2_name]['std'])
labels.append(range_key)
if not m1_values:
continue
x = np.arange(len(labels))
width = 0.35
bars1 = axes[idx].bar(x - width/2, m1_values, width,
yerr=m1_stds, label=self.model1_name,
color='#3498db', alpha=0.8, capsize=5)
bars2 = axes[idx].bar(x + width/2, m2_values, width,
yerr=m2_stds, label=self.model2_name,
color='#e74c3c', alpha=0.8, capsize=5)
axes[idx].set_xlabel('Distance Range', fontsize=10)
axes[idx].set_ylabel(error_name, fontsize=10)
axes[idx].set_title(error_name.split('(')[0], fontsize=12)
axes[idx].set_xticks(x)
axes[idx].set_xticklabels(labels, rotation=45, ha='right')
axes[idx].legend()
axes[idx].grid(axis='y', alpha=0.3)
plt.tight_layout()
output_file = os.path.join(output_dir, f'comparison_3d_{class_name}.png')
plt.savefig(output_file, dpi=150, bbox_inches='tight')
plt.close()
print(f" ✓ Saved: {output_file}")
def plot_improvement_heatmap(self, output_dir):
"""Plot improvement heatmap for 3D metrics."""
if not MATPLOTLIB_AVAILABLE:
print("Skipping improvement heatmap (matplotlib not available)")
return
print("\nGenerating improvement heatmap...")
comparison = self.comparison_results.get('3d_metrics', {})
if not comparison:
return
# Collect improvement data
data_matrix = []
row_labels = []
col_labels = ['Lateral', 'Longitudinal', 'Heading']
# Sort distance ranges by starting distance value
def get_range_start(range_key):
if range_key == 'overall':
return -1 # Put 'overall' at the beginning of each class
try:
# Extract starting distance from format like "0-20m" or "100-999m"
return int(range_key.split('-')[0])
except (ValueError, IndexError):
return float('inf')
for class_name, ranges in sorted(comparison.items()):
# Sort ranges: overall first, then by distance
sorted_range_keys = sorted(ranges.keys(), key=get_range_start)
for range_key in sorted_range_keys:
metrics = ranges[range_key]
row_data = []
for error_type in ['lateral_error', 'longitudinal_error', 'heading_error']:
if error_type in metrics:
# Negative change % means improvement (lower error)
change = -metrics[error_type]['relative_change_%']
row_data.append(change)
else:
row_data.append(0)
if any(x != 0 for x in row_data):
data_matrix.append(row_data)
label = f"{class_name}\n{range_key}"
row_labels.append(label)
if not data_matrix:
return
data_matrix = np.array(data_matrix)
fig, ax = plt.subplots(figsize=(10, max(6, len(row_labels) * 0.5)))
# Create heatmap
im = ax.imshow(data_matrix, cmap='RdYlGn', aspect='auto', vmin=-50, vmax=50)
# Set ticks
ax.set_xticks(np.arange(len(col_labels)))
ax.set_yticks(np.arange(len(row_labels)))
ax.set_xticklabels(col_labels)
ax.set_yticklabels(row_labels, fontsize=8)
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label(f'Improvement % ({self.model2_name} vs {self.model1_name})', rotation=270, labelpad=20)
# Add text annotations
for i in range(len(row_labels)):
for j in range(len(col_labels)):
text = ax.text(j, i, f'{data_matrix[i, j]:.1f}%',
ha="center", va="center", color="black", fontsize=8)
ax.set_title(f'3D Metrics Improvement Heatmap\n(Positive = {self.model2_name} Better)',
fontsize=14, fontweight='bold')
plt.tight_layout()
output_file = os.path.join(output_dir, 'comparison_3d_improvement_heatmap.png')
plt.savefig(output_file, dpi=150, bbox_inches='tight')
plt.close()
print(f" ✓ Saved: {output_file}")
def generate_visualizations(self, output_dir):
"""Generate all visualizations."""
if not MATPLOTLIB_AVAILABLE:
print("\n⚠ Matplotlib not available, skipping visualizations")
print("Install with: pip install matplotlib")
return
print("\n" + "="*80)
print("GENERATING VISUALIZATIONS")
print("="*80)
self.plot_2d_metrics_comparison(output_dir)
self.plot_3d_metrics_comparison(output_dir)
self.plot_improvement_heatmap(output_dir)
print("\n✓ All visualizations generated")
def main():
"""Main function."""
parser = argparse.ArgumentParser(
description='Compare evaluation results from two models with visualization',
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--model1', type=str, required=True,
help='Path to model 1 evaluation report JSON')
parser.add_argument('--model2', type=str, required=True,
help='Path to model 2 evaluation report JSON')
parser.add_argument('--output-dir', type=str, default='comparison_results',
help='Output directory for comparison results')
parser.add_argument('--model1-name', type=str, default='Model-1',
help='Display name for model 1')
parser.add_argument('--model2-name', type=str, default='Model-2',
help='Display name for model 2')
parser.add_argument('--no-plots', action='store_true',
help='Skip visualization generation')
args = parser.parse_args()
# Load reports
print("="*80)
print("MODEL COMPARISON TOOL (with Visualization)")
print("="*80)
print(f"\nLoading model 1: {args.model1}")
with open(args.model1, 'r') as f:
model1_report = json.load(f)
print(f"Loading model 2: {args.model2}")
with open(args.model2, 'r') as f:
model2_report = json.load(f)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Compare models
comparator = VisualizationComparator(
model1_report,
model2_report,
model1_name=args.model1_name,
model2_name=args.model2_name
)
results = comparator.compare_all()
# Generate reports
text_output = os.path.join(args.output_dir, 'comparison_report.txt')
json_output = os.path.join(args.output_dir, 'comparison_report.json')
comparator.generate_text_report(text_output)
comparator.generate_json_report(json_output)
# Generate visualizations
if not args.no_plots:
comparator.generate_visualizations(args.output_dir)
print("\n" + "="*80)
print("COMPARISON COMPLETE")
print("="*80)
print(f"\nResults saved to: {args.output_dir}/")
print(f" - Text report: comparison_report.txt")
print(f" - JSON report: comparison_report.json")
if not args.no_plots and MATPLOTLIB_AVAILABLE:
print(f" - Visualization plots: comparison_*.png")
print("")
if __name__ == '__main__':
main()