314 lines
13 KiB
Python
Executable File
314 lines
13 KiB
Python
Executable File
"""
|
|
Extract bad heading error cases from detailed_3d_matches.json
|
|
|
|
This tool extracts and analyzes cases with large heading errors for visualization.
|
|
|
|
Usage:
|
|
python eval_tools/extract_bad_heading_cases.py \
|
|
--input eval_results_common_match_comparison/yolov5s-300w/20260203_210259/detailed_3d_matches.json \
|
|
--threshold 1.5 \
|
|
--top-k 100 \
|
|
--output bad_heading_cases.json
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
import numpy as np
|
|
|
|
# Allow importing class_config from the eval_tools root
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
from class_config import CLASSES_3D as _CLASSES_3D, CLASS_NAMES as _CLASS_NAMES
|
|
_CLASSES_3D_NAMES = set(_CLASS_NAMES[i] for i in _CLASSES_3D)
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Extract bad heading error cases')
|
|
parser.add_argument('--input', type=str, required=True,
|
|
help='Path to detailed_3d_matches.json')
|
|
parser.add_argument('--threshold', type=float, default=1.5,
|
|
help='Heading error threshold in radians (default: 1.5 ≈ 85°)')
|
|
parser.add_argument('--top-k', type=int, default=None,
|
|
help='Only extract top K worst cases (default: all)')
|
|
parser.add_argument('--classes', nargs='+', default=None,
|
|
help='Filter by classes (e.g., vehicle pedestrian bicycle rider)')
|
|
parser.add_argument('--min-distance', type=float, default=None,
|
|
help='Minimum distance in meters')
|
|
parser.add_argument('--max-distance', type=float, default=None,
|
|
help='Maximum distance in meters')
|
|
parser.add_argument('--min-confidence', type=float, default=None,
|
|
help='Minimum detection confidence')
|
|
parser.add_argument('--reversal-only', action='store_true',
|
|
help='Only extract reversal errors (error > π - 0.1)')
|
|
parser.add_argument('--output', type=str, default='bad_heading_cases.json',
|
|
help='Output JSON file path')
|
|
parser.add_argument('--stats', action='store_true',
|
|
help='Print statistics')
|
|
return parser.parse_args()
|
|
|
|
|
|
def calculate_distance_3d(x, y, z):
|
|
"""Calculate 3D Euclidean distance from ego vehicle."""
|
|
return np.sqrt(x**2 + y**2 + z**2)
|
|
|
|
|
|
def is_reversal_error(error, threshold=0.1):
|
|
"""Check if error is a reversal error (≈ π or 180°)."""
|
|
return error > (np.pi - threshold)
|
|
|
|
|
|
def extract_bad_cases(data, args):
|
|
"""Extract bad heading error cases based on criteria.
|
|
|
|
Args:
|
|
data: Loaded JSON data from detailed_3d_matches.json
|
|
args: Command line arguments
|
|
|
|
Returns:
|
|
List of bad cases with metadata
|
|
"""
|
|
bad_cases = []
|
|
stats = defaultdict(lambda: {'count': 0, 'reversal_count': 0, 'total_error': 0})
|
|
|
|
total_processed = 0
|
|
|
|
# Iterate through all matches
|
|
for case_id, case_data in data.items():
|
|
for frame_id, frame_data in case_data.items():
|
|
if not isinstance(frame_data, dict):
|
|
continue
|
|
|
|
for class_name, class_data in frame_data.items():
|
|
if class_name not in _CLASSES_3D_NAMES:
|
|
continue
|
|
|
|
# Filter by class if specified
|
|
if args.classes and class_name not in args.classes:
|
|
continue
|
|
|
|
# class_data is already the list of matches
|
|
matches = class_data if isinstance(class_data, list) else []
|
|
|
|
for match in matches:
|
|
total_processed += 1
|
|
|
|
# Extract key information
|
|
# heading_error is in errors['heading'], not directly in match
|
|
errors = match.get('errors', {})
|
|
heading_error = errors.get('heading', 0) if isinstance(errors, dict) else 0
|
|
lateral_error = errors.get('lateral', 0) if isinstance(errors, dict) else 0
|
|
longitudinal_error = errors.get('longitudinal', 0) if isinstance(errors, dict) else 0
|
|
|
|
gt_rotation = match.get('gt_rotation', 0)
|
|
det_rotation = match.get('det_rotation', 0)
|
|
|
|
# Get 3D center coordinates
|
|
gt_center = match.get('gt_center_3d', [0, 0, 0])
|
|
det_center = match.get('det_center_3d', [0, 0, 0])
|
|
|
|
# Calculate distance
|
|
distance = calculate_distance_3d(*gt_center)
|
|
|
|
# Get other metadata
|
|
confidence = match.get('confidence', 0)
|
|
iou = match.get('iou', 0)
|
|
|
|
# Apply filters
|
|
if heading_error < args.threshold:
|
|
continue
|
|
|
|
if args.min_distance and distance < args.min_distance:
|
|
continue
|
|
|
|
if args.max_distance and distance > args.max_distance:
|
|
continue
|
|
|
|
if args.min_confidence and confidence < args.min_confidence:
|
|
continue
|
|
|
|
if args.reversal_only and not is_reversal_error(heading_error):
|
|
continue
|
|
|
|
# Update statistics
|
|
stats[class_name]['count'] += 1
|
|
stats[class_name]['total_error'] += heading_error
|
|
if is_reversal_error(heading_error):
|
|
stats[class_name]['reversal_count'] += 1
|
|
|
|
# Create case entry
|
|
case_entry = {
|
|
'case_id': case_id,
|
|
'frame_id': frame_id,
|
|
'class': class_name,
|
|
'heading_error': float(heading_error),
|
|
'heading_error_deg': float(np.degrees(heading_error)),
|
|
'gt_rotation': float(gt_rotation),
|
|
'gt_rotation_deg': float(np.degrees(gt_rotation)),
|
|
'det_rotation': float(det_rotation),
|
|
'det_rotation_deg': float(np.degrees(det_rotation)),
|
|
'is_reversal': is_reversal_error(heading_error),
|
|
'distance': float(distance),
|
|
'confidence': float(confidence),
|
|
'iou': float(iou),
|
|
'gt_center': [float(x) for x in gt_center],
|
|
'det_center': [float(x) for x in det_center],
|
|
'lateral_error': float(lateral_error),
|
|
'longitudinal_error': float(longitudinal_error),
|
|
'gt_bbox_2d': match.get('gt_bbox', [0, 0, 0, 0]),
|
|
'det_bbox_2d': match.get('det_bbox', [0, 0, 0, 0]),
|
|
}
|
|
|
|
bad_cases.append(case_entry)
|
|
|
|
# Sort by heading error (descending)
|
|
bad_cases.sort(key=lambda x: x['heading_error'], reverse=True)
|
|
|
|
# Limit to top-k if specified
|
|
if args.top_k:
|
|
bad_cases = bad_cases[:args.top_k]
|
|
|
|
return bad_cases, stats, total_processed
|
|
|
|
|
|
def print_statistics(bad_cases, stats, total_processed, args):
|
|
"""Print detailed statistics about extracted cases."""
|
|
print("\n" + "="*80)
|
|
print("BAD HEADING ERROR CASES EXTRACTION SUMMARY")
|
|
print("="*80)
|
|
|
|
print(f"\nInput file: {args.input}")
|
|
print(f"Threshold: {args.threshold:.2f} rad ({np.degrees(args.threshold):.1f}°)")
|
|
print(f"Total processed: {total_processed:,}")
|
|
print(f"Bad cases found: {len(bad_cases):,} ({100*len(bad_cases)/total_processed:.2f}%)")
|
|
|
|
if args.top_k:
|
|
print(f"Output limited to: Top {args.top_k}")
|
|
|
|
print("\n" + "-"*80)
|
|
print("STATISTICS BY CLASS")
|
|
print("-"*80)
|
|
print(f"{'Class':<15} {'Count':<10} {'Reversal':<12} {'Rev %':<10} {'Avg Error':<12}")
|
|
print("-"*80)
|
|
|
|
for class_name in sorted(stats.keys()):
|
|
stat = stats[class_name]
|
|
count = stat['count']
|
|
rev_count = stat['reversal_count']
|
|
avg_error = stat['total_error'] / count if count > 0 else 0
|
|
rev_pct = 100 * rev_count / count if count > 0 else 0
|
|
|
|
print(f"{class_name:<15} {count:<10,} {rev_count:<12,} {rev_pct:<10.1f} {avg_error:<12.3f}")
|
|
|
|
print("-"*80)
|
|
|
|
# Error distribution
|
|
if bad_cases:
|
|
errors = [c['heading_error'] for c in bad_cases]
|
|
print(f"\nERROR DISTRIBUTION:")
|
|
print(f" Min: {min(errors):.3f} rad ({np.degrees(min(errors)):.1f}°)")
|
|
print(f" Max: {max(errors):.3f} rad ({np.degrees(max(errors)):.1f}°)")
|
|
print(f" Mean: {np.mean(errors):.3f} rad ({np.degrees(np.mean(errors)):.1f}°)")
|
|
print(f" Median: {np.median(errors):.3f} rad ({np.degrees(np.median(errors)):.1f}°)")
|
|
print(f" Std: {np.std(errors):.3f} rad ({np.degrees(np.std(errors)):.1f}°)")
|
|
|
|
# Reversal statistics
|
|
reversal_count = sum(1 for c in bad_cases if c['is_reversal'])
|
|
print(f"\n Reversal errors (>3.04 rad): {reversal_count} ({100*reversal_count/len(bad_cases):.1f}%)")
|
|
|
|
# Distance distribution
|
|
distances = [c['distance'] for c in bad_cases]
|
|
print(f"\nDISTANCE DISTRIBUTION:")
|
|
print(f" Min: {min(distances):.1f} m")
|
|
print(f" Max: {max(distances):.1f} m")
|
|
print(f" Mean: {np.mean(distances):.1f} m")
|
|
print(f" Median: {np.median(distances):.1f} m")
|
|
|
|
# Confidence distribution
|
|
confidences = [c['confidence'] for c in bad_cases]
|
|
print(f"\nCONFIDENCE DISTRIBUTION:")
|
|
print(f" Min: {min(confidences):.3f}")
|
|
print(f" Max: {max(confidences):.3f}")
|
|
print(f" Mean: {np.mean(confidences):.3f}")
|
|
print(f" Median: {np.median(confidences):.3f}")
|
|
|
|
print("\n" + "="*80)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Load input JSON
|
|
input_path = Path(args.input)
|
|
if not input_path.exists():
|
|
print(f"Error: Input file not found: {input_path}")
|
|
sys.exit(1)
|
|
|
|
print(f"Loading data from {input_path}...")
|
|
with open(input_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
print(f"Loaded {len(data)} cases")
|
|
|
|
# Extract bad cases
|
|
print(f"\nExtracting bad cases with heading_error > {args.threshold:.2f} rad...")
|
|
bad_cases, stats, total_processed = extract_bad_cases(data, args)
|
|
|
|
# Print statistics
|
|
if args.stats or len(bad_cases) > 0:
|
|
print_statistics(bad_cases, stats, total_processed, args)
|
|
|
|
# Save output
|
|
output_path = Path(args.output)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
output_data = {
|
|
'metadata': {
|
|
'source': str(input_path),
|
|
'threshold': args.threshold,
|
|
'threshold_degrees': float(np.degrees(args.threshold)),
|
|
'total_processed': total_processed,
|
|
'total_extracted': len(bad_cases),
|
|
'filters': {
|
|
'classes': args.classes,
|
|
'min_distance': args.min_distance,
|
|
'max_distance': args.max_distance,
|
|
'min_confidence': args.min_confidence,
|
|
'reversal_only': args.reversal_only,
|
|
'top_k': args.top_k
|
|
}
|
|
},
|
|
'statistics': {
|
|
class_name: {
|
|
'count': stat['count'],
|
|
'reversal_count': stat['reversal_count'],
|
|
'reversal_percentage': 100 * stat['reversal_count'] / stat['count'] if stat['count'] > 0 else 0,
|
|
'avg_error': stat['total_error'] / stat['count'] if stat['count'] > 0 else 0
|
|
}
|
|
for class_name, stat in stats.items()
|
|
},
|
|
'cases': bad_cases
|
|
}
|
|
|
|
with open(output_path, 'w') as f:
|
|
json.dump(output_data, f, indent=2)
|
|
|
|
print(f"\nExtracted {len(bad_cases)} bad cases")
|
|
print(f"Output saved to: {output_path}")
|
|
|
|
# Show top 5 worst cases
|
|
if bad_cases:
|
|
print(f"\nTop 5 Worst Cases:")
|
|
print("-"*80)
|
|
print(f"{'No':<5} {'Class':<12} {'Error (rad)':<12} {'Error (°)':<12} {'Distance':<12} {'Reversal':<10}")
|
|
print("-"*80)
|
|
for i, case in enumerate(bad_cases[:5], 1):
|
|
print(f"{i:<5} {case['class']:<12} {case['heading_error']:<12.3f} "
|
|
f"{case['heading_error_deg']:<12.1f} {case['distance']:<12.1f} "
|
|
f"{'✓' if case['is_reversal'] else '':<10}")
|
|
print("-"*80)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|