Files
yolov26_3d/eval_tools/heading_analysis/extract_bad_heading_cases.py
2026-06-24 09:35:46 +08:00

314 lines
13 KiB
Python
Executable File

"""
Extract bad heading error cases from detailed_3d_matches.json
This tool extracts and analyzes cases with large heading errors for visualization.
Usage:
python eval_tools/extract_bad_heading_cases.py \
--input eval_results_common_match_comparison/yolov5s-300w/20260203_210259/detailed_3d_matches.json \
--threshold 1.5 \
--top-k 100 \
--output bad_heading_cases.json
"""
import argparse
import json
import sys
from pathlib import Path
from collections import defaultdict
import numpy as np
# Allow importing class_config from the eval_tools root
sys.path.insert(0, str(Path(__file__).parent.parent))
from class_config import CLASSES_3D as _CLASSES_3D, CLASS_NAMES as _CLASS_NAMES
_CLASSES_3D_NAMES = set(_CLASS_NAMES[i] for i in _CLASSES_3D)
def parse_args():
parser = argparse.ArgumentParser(description='Extract bad heading error cases')
parser.add_argument('--input', type=str, required=True,
help='Path to detailed_3d_matches.json')
parser.add_argument('--threshold', type=float, default=1.5,
help='Heading error threshold in radians (default: 1.5 ≈ 85°)')
parser.add_argument('--top-k', type=int, default=None,
help='Only extract top K worst cases (default: all)')
parser.add_argument('--classes', nargs='+', default=None,
help='Filter by classes (e.g., vehicle pedestrian bicycle rider)')
parser.add_argument('--min-distance', type=float, default=None,
help='Minimum distance in meters')
parser.add_argument('--max-distance', type=float, default=None,
help='Maximum distance in meters')
parser.add_argument('--min-confidence', type=float, default=None,
help='Minimum detection confidence')
parser.add_argument('--reversal-only', action='store_true',
help='Only extract reversal errors (error > π - 0.1)')
parser.add_argument('--output', type=str, default='bad_heading_cases.json',
help='Output JSON file path')
parser.add_argument('--stats', action='store_true',
help='Print statistics')
return parser.parse_args()
def calculate_distance_3d(x, y, z):
"""Calculate 3D Euclidean distance from ego vehicle."""
return np.sqrt(x**2 + y**2 + z**2)
def is_reversal_error(error, threshold=0.1):
"""Check if error is a reversal error (≈ π or 180°)."""
return error > (np.pi - threshold)
def extract_bad_cases(data, args):
"""Extract bad heading error cases based on criteria.
Args:
data: Loaded JSON data from detailed_3d_matches.json
args: Command line arguments
Returns:
List of bad cases with metadata
"""
bad_cases = []
stats = defaultdict(lambda: {'count': 0, 'reversal_count': 0, 'total_error': 0})
total_processed = 0
# Iterate through all matches
for case_id, case_data in data.items():
for frame_id, frame_data in case_data.items():
if not isinstance(frame_data, dict):
continue
for class_name, class_data in frame_data.items():
if class_name not in _CLASSES_3D_NAMES:
continue
# Filter by class if specified
if args.classes and class_name not in args.classes:
continue
# class_data is already the list of matches
matches = class_data if isinstance(class_data, list) else []
for match in matches:
total_processed += 1
# Extract key information
# heading_error is in errors['heading'], not directly in match
errors = match.get('errors', {})
heading_error = errors.get('heading', 0) if isinstance(errors, dict) else 0
lateral_error = errors.get('lateral', 0) if isinstance(errors, dict) else 0
longitudinal_error = errors.get('longitudinal', 0) if isinstance(errors, dict) else 0
gt_rotation = match.get('gt_rotation', 0)
det_rotation = match.get('det_rotation', 0)
# Get 3D center coordinates
gt_center = match.get('gt_center_3d', [0, 0, 0])
det_center = match.get('det_center_3d', [0, 0, 0])
# Calculate distance
distance = calculate_distance_3d(*gt_center)
# Get other metadata
confidence = match.get('confidence', 0)
iou = match.get('iou', 0)
# Apply filters
if heading_error < args.threshold:
continue
if args.min_distance and distance < args.min_distance:
continue
if args.max_distance and distance > args.max_distance:
continue
if args.min_confidence and confidence < args.min_confidence:
continue
if args.reversal_only and not is_reversal_error(heading_error):
continue
# Update statistics
stats[class_name]['count'] += 1
stats[class_name]['total_error'] += heading_error
if is_reversal_error(heading_error):
stats[class_name]['reversal_count'] += 1
# Create case entry
case_entry = {
'case_id': case_id,
'frame_id': frame_id,
'class': class_name,
'heading_error': float(heading_error),
'heading_error_deg': float(np.degrees(heading_error)),
'gt_rotation': float(gt_rotation),
'gt_rotation_deg': float(np.degrees(gt_rotation)),
'det_rotation': float(det_rotation),
'det_rotation_deg': float(np.degrees(det_rotation)),
'is_reversal': is_reversal_error(heading_error),
'distance': float(distance),
'confidence': float(confidence),
'iou': float(iou),
'gt_center': [float(x) for x in gt_center],
'det_center': [float(x) for x in det_center],
'lateral_error': float(lateral_error),
'longitudinal_error': float(longitudinal_error),
'gt_bbox_2d': match.get('gt_bbox', [0, 0, 0, 0]),
'det_bbox_2d': match.get('det_bbox', [0, 0, 0, 0]),
}
bad_cases.append(case_entry)
# Sort by heading error (descending)
bad_cases.sort(key=lambda x: x['heading_error'], reverse=True)
# Limit to top-k if specified
if args.top_k:
bad_cases = bad_cases[:args.top_k]
return bad_cases, stats, total_processed
def print_statistics(bad_cases, stats, total_processed, args):
"""Print detailed statistics about extracted cases."""
print("\n" + "="*80)
print("BAD HEADING ERROR CASES EXTRACTION SUMMARY")
print("="*80)
print(f"\nInput file: {args.input}")
print(f"Threshold: {args.threshold:.2f} rad ({np.degrees(args.threshold):.1f}°)")
print(f"Total processed: {total_processed:,}")
print(f"Bad cases found: {len(bad_cases):,} ({100*len(bad_cases)/total_processed:.2f}%)")
if args.top_k:
print(f"Output limited to: Top {args.top_k}")
print("\n" + "-"*80)
print("STATISTICS BY CLASS")
print("-"*80)
print(f"{'Class':<15} {'Count':<10} {'Reversal':<12} {'Rev %':<10} {'Avg Error':<12}")
print("-"*80)
for class_name in sorted(stats.keys()):
stat = stats[class_name]
count = stat['count']
rev_count = stat['reversal_count']
avg_error = stat['total_error'] / count if count > 0 else 0
rev_pct = 100 * rev_count / count if count > 0 else 0
print(f"{class_name:<15} {count:<10,} {rev_count:<12,} {rev_pct:<10.1f} {avg_error:<12.3f}")
print("-"*80)
# Error distribution
if bad_cases:
errors = [c['heading_error'] for c in bad_cases]
print(f"\nERROR DISTRIBUTION:")
print(f" Min: {min(errors):.3f} rad ({np.degrees(min(errors)):.1f}°)")
print(f" Max: {max(errors):.3f} rad ({np.degrees(max(errors)):.1f}°)")
print(f" Mean: {np.mean(errors):.3f} rad ({np.degrees(np.mean(errors)):.1f}°)")
print(f" Median: {np.median(errors):.3f} rad ({np.degrees(np.median(errors)):.1f}°)")
print(f" Std: {np.std(errors):.3f} rad ({np.degrees(np.std(errors)):.1f}°)")
# Reversal statistics
reversal_count = sum(1 for c in bad_cases if c['is_reversal'])
print(f"\n Reversal errors (>3.04 rad): {reversal_count} ({100*reversal_count/len(bad_cases):.1f}%)")
# Distance distribution
distances = [c['distance'] for c in bad_cases]
print(f"\nDISTANCE DISTRIBUTION:")
print(f" Min: {min(distances):.1f} m")
print(f" Max: {max(distances):.1f} m")
print(f" Mean: {np.mean(distances):.1f} m")
print(f" Median: {np.median(distances):.1f} m")
# Confidence distribution
confidences = [c['confidence'] for c in bad_cases]
print(f"\nCONFIDENCE DISTRIBUTION:")
print(f" Min: {min(confidences):.3f}")
print(f" Max: {max(confidences):.3f}")
print(f" Mean: {np.mean(confidences):.3f}")
print(f" Median: {np.median(confidences):.3f}")
print("\n" + "="*80)
def main():
args = parse_args()
# Load input JSON
input_path = Path(args.input)
if not input_path.exists():
print(f"Error: Input file not found: {input_path}")
sys.exit(1)
print(f"Loading data from {input_path}...")
with open(input_path, 'r') as f:
data = json.load(f)
print(f"Loaded {len(data)} cases")
# Extract bad cases
print(f"\nExtracting bad cases with heading_error > {args.threshold:.2f} rad...")
bad_cases, stats, total_processed = extract_bad_cases(data, args)
# Print statistics
if args.stats or len(bad_cases) > 0:
print_statistics(bad_cases, stats, total_processed, args)
# Save output
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_data = {
'metadata': {
'source': str(input_path),
'threshold': args.threshold,
'threshold_degrees': float(np.degrees(args.threshold)),
'total_processed': total_processed,
'total_extracted': len(bad_cases),
'filters': {
'classes': args.classes,
'min_distance': args.min_distance,
'max_distance': args.max_distance,
'min_confidence': args.min_confidence,
'reversal_only': args.reversal_only,
'top_k': args.top_k
}
},
'statistics': {
class_name: {
'count': stat['count'],
'reversal_count': stat['reversal_count'],
'reversal_percentage': 100 * stat['reversal_count'] / stat['count'] if stat['count'] > 0 else 0,
'avg_error': stat['total_error'] / stat['count'] if stat['count'] > 0 else 0
}
for class_name, stat in stats.items()
},
'cases': bad_cases
}
with open(output_path, 'w') as f:
json.dump(output_data, f, indent=2)
print(f"\nExtracted {len(bad_cases)} bad cases")
print(f"Output saved to: {output_path}")
# Show top 5 worst cases
if bad_cases:
print(f"\nTop 5 Worst Cases:")
print("-"*80)
print(f"{'No':<5} {'Class':<12} {'Error (rad)':<12} {'Error (°)':<12} {'Distance':<12} {'Reversal':<10}")
print("-"*80)
for i, case in enumerate(bad_cases[:5], 1):
print(f"{i:<5} {case['class']:<12} {case['heading_error']:<12.3f} "
f"{case['heading_error_deg']:<12.1f} {case['distance']:<12.1f} "
f"{'' if case['is_reversal'] else '':<10}")
print("-"*80)
if __name__ == '__main__':
main()